@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
// A basic application simulating a server with multiple clients.
|
|
2
2
|
// The clients submit requests to the server and they are processed in parallel.
|
|
3
3
|
|
|
4
|
+
#include "arg.h"
|
|
4
5
|
#include "common.h"
|
|
6
|
+
#include "sampling.h"
|
|
7
|
+
#include "log.h"
|
|
5
8
|
#include "llama.h"
|
|
6
9
|
|
|
7
10
|
#include <cmath>
|
|
@@ -50,8 +53,8 @@ static std::vector<std::string> k_prompts = {
|
|
|
50
53
|
|
|
51
54
|
struct client {
|
|
52
55
|
~client() {
|
|
53
|
-
if (
|
|
54
|
-
|
|
56
|
+
if (smpl) {
|
|
57
|
+
common_sampler_free(smpl);
|
|
55
58
|
}
|
|
56
59
|
}
|
|
57
60
|
|
|
@@ -72,7 +75,7 @@ struct client {
|
|
|
72
75
|
std::string prompt;
|
|
73
76
|
std::string response;
|
|
74
77
|
|
|
75
|
-
struct
|
|
78
|
+
struct common_sampler * smpl = nullptr;
|
|
76
79
|
};
|
|
77
80
|
|
|
78
81
|
static void print_date_time() {
|
|
@@ -81,7 +84,9 @@ static void print_date_time() {
|
|
|
81
84
|
char buffer[80];
|
|
82
85
|
strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", local_time);
|
|
83
86
|
|
|
84
|
-
|
|
87
|
+
LOG_INF("\n");
|
|
88
|
+
LOG_INF("\033[35mrun parameters as of %s\033[0m\n", buffer);
|
|
89
|
+
LOG_INF("\n");
|
|
85
90
|
}
|
|
86
91
|
|
|
87
92
|
// Define a split string function to ...
|
|
@@ -98,13 +103,14 @@ static std::vector<std::string> split_string(const std::string& input, char deli
|
|
|
98
103
|
int main(int argc, char ** argv) {
|
|
99
104
|
srand(1234);
|
|
100
105
|
|
|
101
|
-
|
|
106
|
+
common_params params;
|
|
102
107
|
|
|
103
|
-
if (!
|
|
104
|
-
gpt_params_print_usage(argc, argv, params);
|
|
108
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
|
105
109
|
return 1;
|
|
106
110
|
}
|
|
107
111
|
|
|
112
|
+
common_init();
|
|
113
|
+
|
|
108
114
|
// number of simultaneous "clients" to simulate
|
|
109
115
|
const int32_t n_clients = params.n_parallel;
|
|
110
116
|
|
|
@@ -119,41 +125,34 @@ int main(int argc, char ** argv) {
|
|
|
119
125
|
|
|
120
126
|
const bool dump_kv_cache = params.dump_kv_cache;
|
|
121
127
|
|
|
122
|
-
#ifndef LOG_DISABLE_LOGS
|
|
123
|
-
log_set_target(log_filename_generator("parallel", "log"));
|
|
124
|
-
LOG_TEE("Log start\n");
|
|
125
|
-
log_dump_cmdline(argc, argv);
|
|
126
|
-
#endif // LOG_DISABLE_LOGS
|
|
127
|
-
|
|
128
128
|
// init llama.cpp
|
|
129
129
|
llama_backend_init();
|
|
130
130
|
llama_numa_init(params.numa);
|
|
131
131
|
|
|
132
|
-
llama_model * model = NULL;
|
|
133
|
-
llama_context * ctx = NULL;
|
|
134
|
-
|
|
135
132
|
// load the target model
|
|
136
|
-
|
|
133
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
134
|
+
|
|
135
|
+
llama_model * model = llama_init.model;
|
|
136
|
+
llama_context * ctx = llama_init.context;
|
|
137
137
|
|
|
138
138
|
// load the prompts from an external file if there are any
|
|
139
139
|
if (params.prompt.empty()) {
|
|
140
|
-
|
|
140
|
+
LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
|
|
141
141
|
} else {
|
|
142
142
|
// Output each line of the input params.prompts vector and copy to k_prompts
|
|
143
143
|
int index = 0;
|
|
144
|
-
|
|
144
|
+
LOG_INF("\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str());
|
|
145
145
|
|
|
146
146
|
std::vector<std::string> prompts = split_string(params.prompt, '\n');
|
|
147
147
|
for (const auto& prompt : prompts) {
|
|
148
148
|
k_prompts.resize(index + 1);
|
|
149
149
|
k_prompts[index] = prompt;
|
|
150
150
|
index++;
|
|
151
|
-
|
|
151
|
+
LOG_INF("%3d prompt: %s\n", index, prompt.c_str());
|
|
152
152
|
}
|
|
153
153
|
}
|
|
154
154
|
|
|
155
|
-
|
|
156
|
-
fflush(stderr);
|
|
155
|
+
LOG_INF("\n\n");
|
|
157
156
|
|
|
158
157
|
const int n_ctx = llama_n_ctx(ctx);
|
|
159
158
|
|
|
@@ -161,11 +160,11 @@ int main(int argc, char ** argv) {
|
|
|
161
160
|
for (size_t i = 0; i < clients.size(); ++i) {
|
|
162
161
|
auto & client = clients[i];
|
|
163
162
|
client.id = i;
|
|
164
|
-
client.
|
|
163
|
+
client.smpl = common_sampler_init(model, params.sparams);
|
|
165
164
|
}
|
|
166
165
|
|
|
167
166
|
std::vector<llama_token> tokens_system;
|
|
168
|
-
tokens_system =
|
|
167
|
+
tokens_system = common_tokenize(ctx, k_system, true);
|
|
169
168
|
const int32_t n_tokens_system = tokens_system.size();
|
|
170
169
|
|
|
171
170
|
llama_seq_id g_seq_id = 0;
|
|
@@ -182,19 +181,19 @@ int main(int argc, char ** argv) {
|
|
|
182
181
|
|
|
183
182
|
const auto t_main_start = ggml_time_us();
|
|
184
183
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
184
|
+
LOG_INF("%s: Simulating parallel requests from clients:\n", __func__);
|
|
185
|
+
LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
|
|
186
|
+
LOG_INF("\n");
|
|
188
187
|
|
|
189
188
|
{
|
|
190
|
-
|
|
189
|
+
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
|
|
191
190
|
|
|
192
191
|
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
|
193
|
-
|
|
192
|
+
common_batch_add(batch, tokens_system[i], i, { 0 }, false);
|
|
194
193
|
}
|
|
195
194
|
|
|
196
195
|
if (llama_decode(ctx, batch) != 0) {
|
|
197
|
-
|
|
196
|
+
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
|
198
197
|
return 1;
|
|
199
198
|
}
|
|
200
199
|
|
|
@@ -203,18 +202,18 @@ int main(int argc, char ** argv) {
|
|
|
203
202
|
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
204
203
|
}
|
|
205
204
|
|
|
206
|
-
|
|
205
|
+
LOG_INF("\n");
|
|
207
206
|
}
|
|
208
207
|
|
|
209
|
-
|
|
208
|
+
LOG_INF("Processing requests ...\n\n");
|
|
210
209
|
|
|
211
210
|
while (true) {
|
|
212
211
|
if (dump_kv_cache) {
|
|
213
212
|
llama_kv_cache_view_update(ctx, &kvc_view);
|
|
214
|
-
|
|
213
|
+
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
|
215
214
|
}
|
|
216
215
|
|
|
217
|
-
|
|
216
|
+
common_batch_clear(batch);
|
|
218
217
|
|
|
219
218
|
// decode any currently ongoing sequences
|
|
220
219
|
for (auto & client : clients) {
|
|
@@ -224,7 +223,7 @@ int main(int argc, char ** argv) {
|
|
|
224
223
|
|
|
225
224
|
client.i_batch = batch.n_tokens;
|
|
226
225
|
|
|
227
|
-
|
|
226
|
+
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
|
228
227
|
|
|
229
228
|
client.n_decoded += 1;
|
|
230
229
|
}
|
|
@@ -237,7 +236,7 @@ int main(int argc, char ** argv) {
|
|
|
237
236
|
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
238
237
|
}
|
|
239
238
|
|
|
240
|
-
|
|
239
|
+
LOG_INF("%s: clearing the KV cache\n", __func__);
|
|
241
240
|
}
|
|
242
241
|
|
|
243
242
|
// insert new sequences for decoding
|
|
@@ -253,14 +252,14 @@ int main(int argc, char ** argv) {
|
|
|
253
252
|
client.prompt = client.input + "\nAssistant:";
|
|
254
253
|
client.response = "";
|
|
255
254
|
|
|
256
|
-
|
|
255
|
+
common_sampler_reset(client.smpl);
|
|
257
256
|
|
|
258
257
|
// do not prepend BOS because we have a system prompt!
|
|
259
258
|
std::vector<llama_token> tokens_prompt;
|
|
260
|
-
tokens_prompt =
|
|
259
|
+
tokens_prompt = common_tokenize(ctx, client.prompt, false);
|
|
261
260
|
|
|
262
261
|
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
|
263
|
-
|
|
262
|
+
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
|
264
263
|
}
|
|
265
264
|
|
|
266
265
|
// extract the logits only for the last token
|
|
@@ -272,7 +271,7 @@ int main(int argc, char ** argv) {
|
|
|
272
271
|
client.n_decoded = 0;
|
|
273
272
|
client.i_batch = batch.n_tokens - 1;
|
|
274
273
|
|
|
275
|
-
|
|
274
|
+
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
|
|
276
275
|
|
|
277
276
|
g_seq_id += 1;
|
|
278
277
|
|
|
@@ -309,18 +308,17 @@ int main(int argc, char ** argv) {
|
|
|
309
308
|
batch.n_seq_id + i,
|
|
310
309
|
batch.seq_id + i,
|
|
311
310
|
batch.logits + i,
|
|
312
|
-
0, 0, 0, // unused
|
|
313
311
|
};
|
|
314
312
|
|
|
315
313
|
const int ret = llama_decode(ctx, batch_view);
|
|
316
314
|
if (ret != 0) {
|
|
317
315
|
if (n_batch == 1 || ret < 0) {
|
|
318
316
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
|
319
|
-
|
|
317
|
+
LOG_ERR("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
|
320
318
|
return 1;
|
|
321
319
|
}
|
|
322
320
|
|
|
323
|
-
|
|
321
|
+
LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
|
324
322
|
|
|
325
323
|
n_cache_miss += 1;
|
|
326
324
|
|
|
@@ -331,7 +329,7 @@ int main(int argc, char ** argv) {
|
|
|
331
329
|
continue;
|
|
332
330
|
}
|
|
333
331
|
|
|
334
|
-
|
|
332
|
+
LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
|
|
335
333
|
|
|
336
334
|
for (auto & client : clients) {
|
|
337
335
|
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
|
|
@@ -341,9 +339,9 @@ int main(int argc, char ** argv) {
|
|
|
341
339
|
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
|
342
340
|
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
|
343
341
|
|
|
344
|
-
const llama_token id =
|
|
342
|
+
const llama_token id = common_sampler_sample(client.smpl, ctx, client.i_batch - i);
|
|
345
343
|
|
|
346
|
-
|
|
344
|
+
common_sampler_accept(client.smpl, id, true);
|
|
347
345
|
|
|
348
346
|
if (client.n_decoded == 1) {
|
|
349
347
|
// start measuring generation time after the first token to make sure all concurrent clients
|
|
@@ -351,7 +349,7 @@ int main(int argc, char ** argv) {
|
|
|
351
349
|
client.t_start_gen = ggml_time_us();
|
|
352
350
|
}
|
|
353
351
|
|
|
354
|
-
const std::string token_str =
|
|
352
|
+
const std::string token_str = common_token_to_piece(ctx, id);
|
|
355
353
|
|
|
356
354
|
client.response += token_str;
|
|
357
355
|
client.sampled = id;
|
|
@@ -371,12 +369,12 @@ int main(int argc, char ** argv) {
|
|
|
371
369
|
}
|
|
372
370
|
|
|
373
371
|
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
|
374
|
-
llama_kv_cache_seq_rm(ctx,
|
|
372
|
+
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
|
375
373
|
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
|
376
374
|
|
|
377
375
|
const auto t_main_end = ggml_time_us();
|
|
378
376
|
|
|
379
|
-
|
|
377
|
+
LOG_INF("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\n\033[35mResponse: %s\033[0m\n\n",
|
|
380
378
|
client.id, client.seq_id, n_seq, client.n_prompt, client.n_decoded,
|
|
381
379
|
(t_main_end - client.t_start_prompt) / 1e6,
|
|
382
380
|
(double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6,
|
|
@@ -399,21 +397,22 @@ int main(int argc, char ** argv) {
|
|
|
399
397
|
|
|
400
398
|
print_date_time();
|
|
401
399
|
|
|
402
|
-
|
|
400
|
+
LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
|
|
403
401
|
if (params.prompt_file.empty()) {
|
|
404
402
|
params.prompt_file = "used built-in defaults";
|
|
405
403
|
}
|
|
406
|
-
|
|
407
|
-
|
|
404
|
+
LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str());
|
|
405
|
+
LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str());
|
|
408
406
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
407
|
+
LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
|
|
408
|
+
LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
|
|
409
|
+
LOG_INF("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
|
|
410
|
+
LOG_INF("Cache misses: %6d\n", n_cache_miss);
|
|
413
411
|
|
|
414
|
-
|
|
412
|
+
LOG_INF("\n");
|
|
415
413
|
|
|
416
|
-
|
|
414
|
+
// TODO: print sampling/grammar timings for all clients
|
|
415
|
+
llama_perf_context_print(ctx);
|
|
417
416
|
|
|
418
417
|
llama_batch_free(batch);
|
|
419
418
|
|
|
@@ -422,7 +421,7 @@ int main(int argc, char ** argv) {
|
|
|
422
421
|
|
|
423
422
|
llama_backend_free();
|
|
424
423
|
|
|
425
|
-
|
|
424
|
+
LOG("\n\n");
|
|
426
425
|
|
|
427
426
|
return 0;
|
|
428
427
|
}
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
#include "arg.h"
|
|
1
2
|
#include "common.h"
|
|
3
|
+
#include "log.h"
|
|
2
4
|
#include "llama.h"
|
|
3
5
|
|
|
4
6
|
#include <cmath>
|
|
@@ -6,27 +8,24 @@
|
|
|
6
8
|
#include <string>
|
|
7
9
|
#include <vector>
|
|
8
10
|
|
|
9
|
-
static void print_usage(int
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
LOG_TEE("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]);
|
|
14
|
-
LOG_TEE("\n");
|
|
11
|
+
static void print_usage(int, char ** argv) {
|
|
12
|
+
LOG("\nexample usage:\n");
|
|
13
|
+
LOG("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]);
|
|
14
|
+
LOG("\n");
|
|
15
15
|
}
|
|
16
16
|
|
|
17
17
|
int main(int argc, char ** argv) {
|
|
18
|
-
|
|
18
|
+
common_params params;
|
|
19
19
|
|
|
20
20
|
params.n_junk = 250;
|
|
21
21
|
params.n_keep = 32;
|
|
22
22
|
params.i_pos = -1;
|
|
23
23
|
|
|
24
|
-
if (!
|
|
25
|
-
print_usage(argc, argv, params);
|
|
24
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) {
|
|
26
25
|
return 1;
|
|
27
26
|
}
|
|
28
27
|
|
|
29
|
-
|
|
28
|
+
common_init();
|
|
30
29
|
|
|
31
30
|
int n_junk = params.n_junk;
|
|
32
31
|
int n_keep = params.n_keep;
|
|
@@ -62,36 +61,41 @@ int main(int argc, char ** argv) {
|
|
|
62
61
|
|
|
63
62
|
// initialize the model
|
|
64
63
|
|
|
65
|
-
llama_model_params model_params =
|
|
64
|
+
llama_model_params model_params = common_model_params_to_llama(params);
|
|
66
65
|
|
|
67
66
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
|
68
67
|
|
|
69
68
|
if (model == NULL) {
|
|
70
|
-
|
|
69
|
+
LOG_ERR("%s: unable to load model\n" , __func__);
|
|
71
70
|
return 1;
|
|
72
71
|
}
|
|
73
72
|
|
|
74
73
|
// initialize the context
|
|
75
74
|
|
|
76
|
-
llama_context_params ctx_params =
|
|
75
|
+
llama_context_params ctx_params = common_context_params_to_llama(params);
|
|
77
76
|
|
|
78
77
|
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
|
|
79
78
|
|
|
80
79
|
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
|
|
81
80
|
|
|
82
81
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
|
83
|
-
|
|
84
82
|
if (ctx == NULL) {
|
|
85
|
-
|
|
83
|
+
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
|
|
86
84
|
return 1;
|
|
87
85
|
}
|
|
88
86
|
|
|
87
|
+
auto sparams = llama_sampler_chain_default_params();
|
|
88
|
+
|
|
89
|
+
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
|
90
|
+
|
|
91
|
+
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
|
92
|
+
|
|
89
93
|
// tokenize the prompt
|
|
90
94
|
std::vector<llama_token> tokens_list;
|
|
91
|
-
tokens_list =
|
|
95
|
+
tokens_list = common_tokenize(ctx, params.prompt, true);
|
|
92
96
|
|
|
93
97
|
// tokenize the prefix and use it as a sink
|
|
94
|
-
const int n_tokens_prefix =
|
|
98
|
+
const int n_tokens_prefix = common_tokenize(ctx, prompt_prefix, true).size();
|
|
95
99
|
|
|
96
100
|
const int n_tokens_all = tokens_list.size();
|
|
97
101
|
|
|
@@ -106,14 +110,14 @@ int main(int argc, char ** argv) {
|
|
|
106
110
|
const int n_batch = ctx_params.n_batch;
|
|
107
111
|
const int n_batch_grp = ctx_params.n_batch/n_grp;
|
|
108
112
|
|
|
109
|
-
|
|
113
|
+
LOG_INF("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
|
|
110
114
|
|
|
111
115
|
// print the prompt token-by-token
|
|
112
116
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
//
|
|
117
|
+
LOG_INF("\n");
|
|
118
|
+
LOG_INF("prefix tokens: %d\n", n_tokens_prefix);
|
|
119
|
+
LOG_INF("prompt tokens: %d\n", n_tokens_all);
|
|
120
|
+
//LOG_INF("prompt: %s\n", params.prompt.c_str());
|
|
117
121
|
|
|
118
122
|
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
|
|
119
123
|
|
|
@@ -133,10 +137,10 @@ int main(int argc, char ** argv) {
|
|
|
133
137
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
|
134
138
|
}
|
|
135
139
|
|
|
136
|
-
|
|
140
|
+
common_batch_clear(batch);
|
|
137
141
|
|
|
138
142
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
|
139
|
-
|
|
143
|
+
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
|
140
144
|
}
|
|
141
145
|
|
|
142
146
|
if (i + n_batch >= n_tokens_all) {
|
|
@@ -144,11 +148,11 @@ int main(int argc, char ** argv) {
|
|
|
144
148
|
}
|
|
145
149
|
|
|
146
150
|
if (llama_decode(ctx, batch) != 0) {
|
|
147
|
-
|
|
151
|
+
LOG_INF("%s: llama_decode() failed\n", __func__);
|
|
148
152
|
return 1;
|
|
149
153
|
}
|
|
150
154
|
|
|
151
|
-
|
|
155
|
+
LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
|
|
152
156
|
|
|
153
157
|
if (i + n_batch >= n_tokens_all) {
|
|
154
158
|
break;
|
|
@@ -158,7 +162,7 @@ int main(int argc, char ** argv) {
|
|
|
158
162
|
for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
|
|
159
163
|
const int n_discard = n_batch;
|
|
160
164
|
|
|
161
|
-
|
|
165
|
+
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
|
|
162
166
|
|
|
163
167
|
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
|
164
168
|
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
|
@@ -167,10 +171,10 @@ int main(int argc, char ** argv) {
|
|
|
167
171
|
|
|
168
172
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
|
169
173
|
|
|
170
|
-
|
|
174
|
+
common_batch_clear(batch);
|
|
171
175
|
|
|
172
176
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
|
173
|
-
|
|
177
|
+
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
|
174
178
|
}
|
|
175
179
|
|
|
176
180
|
if (i + n_batch >= n_tokens_all) {
|
|
@@ -178,18 +182,18 @@ int main(int argc, char ** argv) {
|
|
|
178
182
|
}
|
|
179
183
|
|
|
180
184
|
if (llama_decode(ctx, batch) != 0) {
|
|
181
|
-
|
|
185
|
+
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
|
182
186
|
return 1;
|
|
183
187
|
}
|
|
184
188
|
|
|
185
|
-
|
|
189
|
+
LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
|
|
186
190
|
}
|
|
187
191
|
|
|
188
192
|
{
|
|
189
193
|
const int n_discard = n_past - n_ctx + n_predict;
|
|
190
194
|
|
|
191
195
|
if (n_discard > 0) {
|
|
192
|
-
|
|
196
|
+
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
|
193
197
|
|
|
194
198
|
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
|
195
199
|
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
|
@@ -200,76 +204,64 @@ int main(int argc, char ** argv) {
|
|
|
200
204
|
}
|
|
201
205
|
}
|
|
202
206
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
207
|
+
LOG_INF("\n");
|
|
208
|
+
LOG_INF("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
|
|
209
|
+
LOG_INF("\n");
|
|
206
210
|
|
|
207
211
|
// main loop
|
|
208
212
|
|
|
209
213
|
int n_cur = n_tokens_all;
|
|
210
214
|
int n_decode = 0;
|
|
211
215
|
|
|
212
|
-
|
|
213
|
-
fflush(stdout);
|
|
216
|
+
LOG_INF("%s", prompt_suffix.c_str());
|
|
214
217
|
|
|
215
218
|
const auto t_main_start = ggml_time_us();
|
|
216
219
|
|
|
217
220
|
while (n_cur <= n_len) {
|
|
218
221
|
// sample the next token
|
|
219
222
|
{
|
|
220
|
-
|
|
221
|
-
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
|
222
|
-
|
|
223
|
-
std::vector<llama_token_data> candidates;
|
|
224
|
-
candidates.reserve(n_vocab);
|
|
225
|
-
|
|
226
|
-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
227
|
-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
|
228
|
-
}
|
|
229
|
-
|
|
230
|
-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
231
|
-
|
|
232
|
-
// sample the most likely token
|
|
233
|
-
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
|
223
|
+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
|
234
224
|
|
|
235
225
|
// is it an end of generation?
|
|
236
226
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
|
237
|
-
|
|
227
|
+
LOG("\n");
|
|
238
228
|
|
|
239
229
|
break;
|
|
240
230
|
}
|
|
241
231
|
|
|
242
|
-
|
|
243
|
-
fflush(stdout);
|
|
232
|
+
LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
|
|
244
233
|
|
|
245
234
|
n_decode += 1;
|
|
246
235
|
|
|
247
236
|
// prepare the next batch
|
|
248
|
-
|
|
237
|
+
common_batch_clear(batch);
|
|
249
238
|
|
|
250
239
|
// push this new token for next evaluation
|
|
251
|
-
|
|
240
|
+
common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
|
|
252
241
|
}
|
|
253
242
|
|
|
254
243
|
n_cur += 1;
|
|
255
244
|
|
|
256
245
|
// evaluate the current batch with the transformer model
|
|
257
246
|
if (llama_decode(ctx, batch)) {
|
|
258
|
-
|
|
247
|
+
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
|
259
248
|
return 1;
|
|
260
249
|
}
|
|
261
250
|
}
|
|
262
251
|
|
|
263
|
-
|
|
252
|
+
LOG("\n");
|
|
264
253
|
|
|
265
254
|
const auto t_main_end = ggml_time_us();
|
|
266
255
|
|
|
267
|
-
|
|
256
|
+
LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
|
268
257
|
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
|
269
258
|
|
|
270
|
-
|
|
259
|
+
LOG("\n");
|
|
260
|
+
llama_perf_context_print(ctx);
|
|
261
|
+
|
|
262
|
+
LOG("\n");
|
|
271
263
|
|
|
272
|
-
|
|
264
|
+
llama_sampler_free(smpl);
|
|
273
265
|
|
|
274
266
|
llama_batch_free(batch);
|
|
275
267
|
|