@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -10,181 +10,208 @@
|
|
|
10
10
|
#include <string>
|
|
11
11
|
#include <vector>
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
13
|
+
extern struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers);
|
|
14
|
+
|
|
15
|
+
static void dump(const llama_token_data_array * cur_p) {
|
|
16
|
+
for (size_t i = 0; i < cur_p->size; i++) {
|
|
17
|
+
printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
|
16
18
|
}
|
|
17
19
|
}
|
|
18
20
|
|
|
19
|
-
#define DUMP(
|
|
21
|
+
#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
23
|
+
struct sampler_tester {
|
|
24
|
+
sampler_tester(size_t n_vocab) {
|
|
25
|
+
cur.reserve(n_vocab);
|
|
26
|
+
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
27
|
+
const float logit = logf(token_id);
|
|
28
|
+
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
|
|
28
32
|
}
|
|
29
33
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
34
|
+
sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
|
|
35
|
+
cur.reserve(probs.size());
|
|
36
|
+
for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
|
|
37
|
+
const float logit = logf(probs[token_id]);
|
|
38
|
+
cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
|
|
39
|
+
}
|
|
35
40
|
|
|
36
|
-
|
|
37
|
-
for (size_t i = 0; i < candidates_p.size; i++) {
|
|
38
|
-
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
|
41
|
+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
|
|
39
42
|
}
|
|
40
|
-
}
|
|
41
43
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
candidates.reserve(n_vocab);
|
|
46
|
-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
47
|
-
const float logit = logf(probs[token_id]);
|
|
48
|
-
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
44
|
+
void apply(llama_sampler * sampler) {
|
|
45
|
+
llama_sampler_apply(sampler, &cur_p);
|
|
46
|
+
llama_sampler_free(sampler);
|
|
49
47
|
}
|
|
50
48
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
|
58
|
-
for (size_t i = 0; i < candidates_p.size; i++) {
|
|
59
|
-
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
49
|
+
void check() {
|
|
50
|
+
GGML_ASSERT(cur_p.size == probs_expected.size());
|
|
51
|
+
for (size_t i = 0; i < cur_p.size; i++) {
|
|
52
|
+
GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
|
|
53
|
+
}
|
|
60
54
|
}
|
|
55
|
+
|
|
56
|
+
llama_token_data_array cur_p;
|
|
57
|
+
|
|
58
|
+
private:
|
|
59
|
+
const std::vector<float> probs_expected;
|
|
60
|
+
|
|
61
|
+
std::vector<llama_token_data> cur;
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
static void test_temp(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp) {
|
|
65
|
+
sampler_tester tester(probs, probs_expected);
|
|
66
|
+
|
|
67
|
+
DUMP(&tester.cur_p);
|
|
68
|
+
tester.apply(llama_sampler_init_temp(temp));
|
|
69
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
70
|
+
DUMP(&tester.cur_p);
|
|
71
|
+
|
|
72
|
+
tester.check();
|
|
61
73
|
}
|
|
62
74
|
|
|
63
|
-
static void
|
|
64
|
-
|
|
65
|
-
std::vector<llama_token_data> candidates;
|
|
66
|
-
candidates.reserve(n_vocab);
|
|
67
|
-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
68
|
-
const float logit = logf(probs[token_id]);
|
|
69
|
-
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
70
|
-
}
|
|
75
|
+
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
|
|
76
|
+
sampler_tester tester(probs, probs_expected);
|
|
71
77
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
DUMP(&
|
|
78
|
+
DUMP(&tester.cur_p);
|
|
79
|
+
tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
|
|
80
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
81
|
+
DUMP(&tester.cur_p);
|
|
76
82
|
|
|
77
|
-
|
|
78
|
-
for (size_t i = 0; i < candidates_p.size; i++) {
|
|
79
|
-
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
80
|
-
}
|
|
83
|
+
tester.check();
|
|
81
84
|
}
|
|
82
85
|
|
|
83
|
-
static void
|
|
84
|
-
|
|
85
|
-
std::vector<llama_token_data> candidates;
|
|
86
|
-
candidates.reserve(n_vocab);
|
|
87
|
-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
88
|
-
const float logit = logf(probs[token_id]);
|
|
89
|
-
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
90
|
-
}
|
|
86
|
+
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
|
|
87
|
+
sampler_tester tester(probs, probs_expected);
|
|
91
88
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
DUMP(&
|
|
96
|
-
llama_sample_softmax(nullptr, &candidates_p);
|
|
89
|
+
DUMP(&tester.cur_p);
|
|
90
|
+
tester.apply(llama_sampler_init_top_k(k));
|
|
91
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
92
|
+
DUMP(&tester.cur_p);
|
|
97
93
|
|
|
98
|
-
|
|
99
|
-
for (size_t i = 0; i < candidates_p.size; i++) {
|
|
100
|
-
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
101
|
-
}
|
|
94
|
+
tester.check();
|
|
102
95
|
}
|
|
103
96
|
|
|
104
|
-
static void
|
|
105
|
-
|
|
106
|
-
std::vector<llama_token_data> candidates;
|
|
107
|
-
candidates.reserve(n_vocab);
|
|
108
|
-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
109
|
-
const float logit = logf(probs[token_id]);
|
|
110
|
-
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
111
|
-
}
|
|
97
|
+
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
|
|
98
|
+
sampler_tester tester(probs, probs_expected);
|
|
112
99
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
DUMP(&
|
|
100
|
+
DUMP(&tester.cur_p);
|
|
101
|
+
tester.apply(llama_sampler_init_top_p(p, 1));
|
|
102
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
103
|
+
DUMP(&tester.cur_p);
|
|
117
104
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
105
|
+
tester.check();
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
|
|
109
|
+
sampler_tester tester(probs, probs_expected);
|
|
110
|
+
|
|
111
|
+
DUMP(&tester.cur_p);
|
|
112
|
+
tester.apply(llama_sampler_init_min_p(p, 1));
|
|
113
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
114
|
+
DUMP(&tester.cur_p);
|
|
115
|
+
|
|
116
|
+
tester.check();
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p, float t) {
|
|
120
|
+
sampler_tester tester(probs, probs_expected);
|
|
121
|
+
|
|
122
|
+
DUMP(&tester.cur_p);
|
|
123
|
+
tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
|
|
124
|
+
DUMP(&tester.cur_p);
|
|
125
|
+
|
|
126
|
+
tester.check();
|
|
122
127
|
}
|
|
123
128
|
|
|
124
|
-
static void
|
|
129
|
+
static void test_typical(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
|
|
130
|
+
sampler_tester tester(probs, probs_expected);
|
|
131
|
+
|
|
132
|
+
DUMP(&tester.cur_p);
|
|
133
|
+
tester.apply(llama_sampler_init_typical(p, 1));
|
|
134
|
+
DUMP(&tester.cur_p);
|
|
135
|
+
|
|
136
|
+
tester.check();
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
static void test_penalties(
|
|
125
140
|
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
|
126
|
-
const std::vector<float> &
|
|
141
|
+
const std::vector<float> & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
|
|
127
142
|
) {
|
|
128
|
-
GGML_ASSERT(probs.size() ==
|
|
143
|
+
GGML_ASSERT(probs.size() == probs_expected.size());
|
|
144
|
+
|
|
145
|
+
sampler_tester tester(probs, probs_expected);
|
|
129
146
|
|
|
130
147
|
const size_t n_vocab = probs.size();
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
for (
|
|
134
|
-
|
|
135
|
-
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
148
|
+
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
|
149
|
+
|
|
150
|
+
for (size_t i = 0; i < last_tokens.size(); i++) {
|
|
151
|
+
llama_sampler_accept(sampler, last_tokens[i]);
|
|
136
152
|
}
|
|
137
153
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
llama_sample_softmax(nullptr, &candidates_p);
|
|
143
|
-
DUMP(&candidates_p);
|
|
154
|
+
DUMP(&tester.cur_p);
|
|
155
|
+
tester.apply(sampler);
|
|
156
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
157
|
+
DUMP(&tester.cur_p);
|
|
144
158
|
|
|
145
|
-
|
|
146
|
-
for (size_t i = 0; i < candidates_p.size; i++) {
|
|
147
|
-
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
148
|
-
}
|
|
159
|
+
tester.check();
|
|
149
160
|
}
|
|
150
161
|
|
|
151
|
-
static void
|
|
152
|
-
const
|
|
162
|
+
static void test_dry(
|
|
163
|
+
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
|
164
|
+
const std::vector<float> & expected_probs, float dry_multiplier, float dry_base,
|
|
165
|
+
int dry_allowed_length, int dry_penalty_last_n,
|
|
166
|
+
const std::vector<std::vector<llama_token>> & seq_breakers
|
|
153
167
|
) {
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
168
|
+
GGML_ASSERT(probs.size() == expected_probs.size());
|
|
169
|
+
|
|
170
|
+
sampler_tester tester(probs, expected_probs);
|
|
171
|
+
|
|
172
|
+
auto * sampler = llama_sampler_init_dry_testing(1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
|
|
173
|
+
|
|
174
|
+
for (size_t i = 0; i < last_tokens.size(); i++) {
|
|
175
|
+
llama_sampler_accept(sampler, last_tokens[i]);
|
|
159
176
|
}
|
|
160
177
|
|
|
161
|
-
|
|
178
|
+
DUMP(&tester.cur_p);
|
|
179
|
+
tester.apply(sampler);
|
|
180
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
181
|
+
DUMP(&tester.cur_p);
|
|
182
|
+
tester.check();
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
|
186
|
+
) {
|
|
187
|
+
sampler_tester tester(n_vocab);
|
|
162
188
|
|
|
163
189
|
llama_token min_token_id = 0;
|
|
164
190
|
const llama_token max_token_id = n_vocab-1;
|
|
165
191
|
|
|
166
192
|
for (auto s : samplers_sequence) {
|
|
167
193
|
switch (s){
|
|
168
|
-
case 'k':
|
|
169
|
-
case '
|
|
170
|
-
case '
|
|
171
|
-
case '
|
|
172
|
-
case '
|
|
173
|
-
|
|
174
|
-
default : GGML_ABORT("Unknown sampler"); break;
|
|
194
|
+
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
|
|
195
|
+
case 'y': GGML_ABORT("typical test not implemented");
|
|
196
|
+
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
|
|
197
|
+
case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
|
|
198
|
+
case 't': GGML_ABORT("temperature test not implemented");
|
|
199
|
+
default : GGML_ABORT("Unknown sampler");
|
|
175
200
|
}
|
|
176
201
|
|
|
177
|
-
|
|
202
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
203
|
+
|
|
204
|
+
auto & cur_p = tester.cur_p;
|
|
178
205
|
|
|
179
|
-
const int size =
|
|
206
|
+
const int size = cur_p.size;
|
|
180
207
|
|
|
181
208
|
if (s == 'k') {
|
|
182
209
|
const int expected_size = std::min(size, top_k);
|
|
183
210
|
min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k));
|
|
184
211
|
|
|
185
212
|
GGML_ASSERT(size == expected_size);
|
|
186
|
-
GGML_ASSERT(
|
|
187
|
-
GGML_ASSERT(
|
|
213
|
+
GGML_ASSERT(cur_p.data[0].id == max_token_id);
|
|
214
|
+
GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
|
|
188
215
|
} else if (s == 'p') {
|
|
189
216
|
const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2;
|
|
190
217
|
const int softmax_numerator_target = ceilf(top_p * softmax_divisor);
|
|
@@ -206,8 +233,8 @@ static void test_sampler_queue(
|
|
|
206
233
|
}
|
|
207
234
|
|
|
208
235
|
GGML_ASSERT(size == expected_size);
|
|
209
|
-
GGML_ASSERT(
|
|
210
|
-
GGML_ASSERT(
|
|
236
|
+
GGML_ASSERT(cur_p.data[0].id == max_token_id);
|
|
237
|
+
GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
|
|
211
238
|
} else if (s == 'm') {
|
|
212
239
|
int expected_size = ceilf((1.0f-min_p) * n_vocab);
|
|
213
240
|
expected_size = std::max(expected_size, 1);
|
|
@@ -219,29 +246,73 @@ static void test_sampler_queue(
|
|
|
219
246
|
min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
|
|
220
247
|
|
|
221
248
|
GGML_ASSERT(size == expected_size);
|
|
222
|
-
GGML_ASSERT(
|
|
223
|
-
GGML_ASSERT(
|
|
249
|
+
GGML_ASSERT(cur_p.data[0].id == max_token_id);
|
|
250
|
+
GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
|
|
224
251
|
} else {
|
|
225
252
|
GGML_ABORT("fatal error");
|
|
226
253
|
}
|
|
227
254
|
}
|
|
228
255
|
|
|
229
|
-
printf("Sampler queue %3s OK with n_vocab=%
|
|
256
|
+
printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
|
|
230
257
|
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
|
|
231
258
|
}
|
|
232
259
|
|
|
260
|
+
static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
|
|
261
|
+
std::vector<llama_token_data> cur(data.size());
|
|
262
|
+
std::copy(data.begin(), data.end(), cur.begin());
|
|
263
|
+
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
264
|
+
llama_sampler_apply(cnstr, &cur_p);
|
|
265
|
+
llama_sampler_reset(cnstr);
|
|
266
|
+
const int64_t t_start = ggml_time_us();
|
|
267
|
+
for (int i = 0; i < n_iter; i++) {
|
|
268
|
+
std::copy(data.begin(), data.end(), cur.begin());
|
|
269
|
+
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
270
|
+
llama_sampler_apply(cnstr, &cur_p);
|
|
271
|
+
llama_sampler_reset(cnstr);
|
|
272
|
+
}
|
|
273
|
+
const int64_t t_end = ggml_time_us();
|
|
274
|
+
llama_sampler_free(cnstr);
|
|
275
|
+
printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
|
|
279
|
+
|
|
280
|
+
static void test_perf() {
|
|
281
|
+
const int n_vocab = 1 << 17;
|
|
282
|
+
|
|
283
|
+
std::vector<llama_token_data> data;
|
|
284
|
+
|
|
285
|
+
data.reserve(n_vocab);
|
|
286
|
+
for (int i = 0; i < n_vocab; i++) {
|
|
287
|
+
const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f);
|
|
288
|
+
data.emplace_back(llama_token_data{i, logit, 0.0f});
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
BENCH(llama_sampler_init_top_k (40), data, 32);
|
|
292
|
+
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
|
293
|
+
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
|
294
|
+
BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
|
|
295
|
+
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
|
296
|
+
}
|
|
297
|
+
|
|
233
298
|
int main(void) {
|
|
234
299
|
ggml_time_init();
|
|
235
300
|
|
|
236
|
-
|
|
237
|
-
|
|
301
|
+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
|
302
|
+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
|
303
|
+
|
|
304
|
+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
|
305
|
+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
|
306
|
+
|
|
307
|
+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
|
308
|
+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
|
238
309
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
|
239
310
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
|
240
311
|
|
|
241
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {
|
|
242
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.
|
|
243
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.
|
|
244
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
|
|
312
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
|
|
313
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
|
|
314
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
|
|
315
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
|
245
316
|
|
|
246
317
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
|
|
247
318
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
|
|
@@ -252,20 +323,31 @@ int main(void) {
|
|
|
252
323
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
|
253
324
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
|
254
325
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
326
|
+
printf("XTC should:\n");
|
|
327
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
|
|
328
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.19f);
|
|
329
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.29f);
|
|
330
|
+
|
|
331
|
+
printf("XTC should not:\n");
|
|
332
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
|
|
258
333
|
|
|
259
334
|
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
|
260
335
|
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
|
261
336
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
337
|
+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
|
338
|
+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
|
339
|
+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
|
340
|
+
|
|
341
|
+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
|
342
|
+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
|
343
|
+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
|
265
344
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
345
|
+
|
|
346
|
+
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
|
|
347
|
+
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
|
|
348
|
+
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
|
|
349
|
+
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
|
|
350
|
+
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
|
269
351
|
|
|
270
352
|
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
|
271
353
|
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
|
@@ -297,5 +379,7 @@ int main(void) {
|
|
|
297
379
|
|
|
298
380
|
printf("OK\n");
|
|
299
381
|
|
|
382
|
+
test_perf();
|
|
383
|
+
|
|
300
384
|
return 0;
|
|
301
385
|
}
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
#include <map>
|
|
8
8
|
#include <vector>
|
|
9
9
|
#include <fstream>
|
|
10
|
+
#include <thread>
|
|
10
11
|
|
|
11
12
|
//static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
|
12
13
|
// static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
|
@@ -194,45 +195,64 @@ int main(int argc, char **argv) {
|
|
|
194
195
|
|
|
195
196
|
const bool add_special = false;
|
|
196
197
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
198
|
+
// multi-threaded tokenization
|
|
199
|
+
const int nthread = std::thread::hardware_concurrency();
|
|
200
|
+
std::vector<std::thread> threads(nthread);
|
|
201
|
+
|
|
202
|
+
for (int i = 0; i < nthread; i++) {
|
|
203
|
+
threads[i] = std::thread([&, i]() {
|
|
204
|
+
for (const auto & test_kv : k_tests) {
|
|
205
|
+
const std::vector<llama_token> res = common_tokenize(ctx, test_kv.first, add_special, false);
|
|
206
|
+
|
|
207
|
+
// here only print the result of the first thread
|
|
208
|
+
// because the other threads are running the same tests
|
|
209
|
+
if (i != 0) {
|
|
210
|
+
continue;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
printf("\n");
|
|
214
|
+
printf("src: '%s'\n", test_kv.first.c_str());
|
|
215
|
+
printf("res: '%s'\n", common_detokenize(ctx, res).c_str());
|
|
216
|
+
printf("tok: ");
|
|
217
|
+
for (const auto & tok : res) {
|
|
218
|
+
printf("%d ", tok);
|
|
219
|
+
}
|
|
220
|
+
printf("\n");
|
|
221
|
+
|
|
222
|
+
bool correct = res.size() == test_kv.second.size();
|
|
223
|
+
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
|
224
|
+
if (test_kv.second[i] != res[i]) {
|
|
225
|
+
correct = false;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
if (!correct) {
|
|
230
|
+
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
|
231
|
+
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
|
232
|
+
common_detokenize(ctx, res).c_str(),
|
|
233
|
+
common_detokenize(ctx, test_kv.second).c_str());
|
|
234
|
+
fprintf(stderr, "%s : expected tokens: ", __func__);
|
|
235
|
+
for (const auto & t : test_kv.second) {
|
|
236
|
+
fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
|
|
237
|
+
}
|
|
238
|
+
fprintf(stderr, "\n");
|
|
239
|
+
fprintf(stderr, "%s : got tokens: ", __func__);
|
|
240
|
+
for (const auto & t : res) {
|
|
241
|
+
fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
|
|
242
|
+
}
|
|
243
|
+
fprintf(stderr, "\n");
|
|
244
|
+
|
|
245
|
+
success = false;
|
|
246
|
+
}
|
|
213
247
|
}
|
|
214
|
-
}
|
|
215
|
-
|
|
216
|
-
if (!correct) {
|
|
217
|
-
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
|
218
|
-
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
|
219
|
-
llama_detokenize(ctx, res).c_str(),
|
|
220
|
-
llama_detokenize(ctx, test_kv.second).c_str());
|
|
221
|
-
fprintf(stderr, "%s : expected tokens: ", __func__);
|
|
222
|
-
for (const auto & t : test_kv.second) {
|
|
223
|
-
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
|
|
224
|
-
}
|
|
225
|
-
fprintf(stderr, "\n");
|
|
226
|
-
fprintf(stderr, "%s : got tokens: ", __func__);
|
|
227
|
-
for (const auto & t : res) {
|
|
228
|
-
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
|
|
229
|
-
}
|
|
230
|
-
fprintf(stderr, "\n");
|
|
248
|
+
});
|
|
249
|
+
}
|
|
231
250
|
|
|
232
|
-
|
|
233
|
-
|
|
251
|
+
for (int i = 0; i < nthread; i++) {
|
|
252
|
+
threads[i].join();
|
|
234
253
|
}
|
|
235
254
|
|
|
255
|
+
// single threaded tokenization
|
|
236
256
|
if (!fname_text.empty()) {
|
|
237
257
|
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
|
|
238
258
|
|
|
@@ -253,7 +273,7 @@ int main(int argc, char **argv) {
|
|
|
253
273
|
{
|
|
254
274
|
const auto t_start = ggml_time_us();
|
|
255
275
|
|
|
256
|
-
res =
|
|
276
|
+
res = common_tokenize(ctx, text, add_special, false);
|
|
257
277
|
|
|
258
278
|
const auto t_end = ggml_time_us();
|
|
259
279
|
|
|
@@ -78,10 +78,10 @@ int main(int argc, char **argv) {
|
|
|
78
78
|
const int n_vocab = llama_n_vocab(model);
|
|
79
79
|
|
|
80
80
|
for (int i = 0; i < n_vocab; ++i) {
|
|
81
|
-
std::string str =
|
|
81
|
+
std::string str = common_detokenize(ctx, std::vector<int>(1, i));
|
|
82
82
|
try {
|
|
83
83
|
auto cps = unicode_cpts_from_utf8(str);
|
|
84
|
-
std::vector<llama_token> tokens =
|
|
84
|
+
std::vector<llama_token> tokens = common_tokenize(ctx, str, false, true);
|
|
85
85
|
if (ignore_merges && tokens.size() > 1) {
|
|
86
86
|
fprintf(stderr,
|
|
87
87
|
"%s : error: token %d detokenizes to '%s'(%zu) but "
|
|
@@ -94,7 +94,7 @@ int main(int argc, char **argv) {
|
|
|
94
94
|
fprintf(stderr, "]\n");
|
|
95
95
|
return 2;
|
|
96
96
|
}
|
|
97
|
-
std::string check =
|
|
97
|
+
std::string check = common_detokenize(ctx, tokens);
|
|
98
98
|
if (check != str) {
|
|
99
99
|
fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
|
|
100
100
|
__func__, i, str.c_str(), str.length(), check.c_str(), check.length());
|
|
@@ -123,8 +123,8 @@ int main(int argc, char **argv) {
|
|
|
123
123
|
}
|
|
124
124
|
|
|
125
125
|
std::string str = unicode_cpt_to_utf8(cp);
|
|
126
|
-
std::vector<llama_token> tokens =
|
|
127
|
-
std::string check =
|
|
126
|
+
std::vector<llama_token> tokens = common_tokenize(ctx, str, false);
|
|
127
|
+
std::string check = common_detokenize(ctx, tokens);
|
|
128
128
|
if (cp != 9601 && str != check) {
|
|
129
129
|
fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
|
|
130
130
|
cp, check.c_str(), check.length(), str.c_str(), str.length());
|
|
@@ -66,9 +66,9 @@ int main(int argc, char ** argv) {
|
|
|
66
66
|
const int n_vocab = llama_n_vocab(model);
|
|
67
67
|
|
|
68
68
|
for (int i = 0; i < n_vocab; ++i) {
|
|
69
|
-
std::string str =
|
|
70
|
-
std::vector<llama_token> tokens =
|
|
71
|
-
std::string check =
|
|
69
|
+
std::string str = common_detokenize(ctx, std::vector<int>(1, i), true);
|
|
70
|
+
std::vector<llama_token> tokens = common_tokenize(ctx, str, false, true);
|
|
71
|
+
std::string check = common_detokenize(ctx, tokens);
|
|
72
72
|
if (check != str) {
|
|
73
73
|
fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
|
|
74
74
|
__func__, i, str.c_str(), str.length(), check.c_str(), check.length());
|
|
@@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
|
|
|
93
93
|
}
|
|
94
94
|
|
|
95
95
|
std::string str = unicode_cpt_to_utf8(cp);
|
|
96
|
-
std::vector<llama_token> tokens =
|
|
97
|
-
std::string check =
|
|
96
|
+
std::vector<llama_token> tokens = common_tokenize(ctx, str, false, true);
|
|
97
|
+
std::string check = common_detokenize(ctx, tokens);
|
|
98
98
|
if (cp != 9601 && str != check) {
|
|
99
99
|
fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
|
|
100
100
|
cp, check.c_str(), check.length(), str.c_str(), str.length());
|