@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -1,460 +1,466 @@
|
|
|
1
|
-
#define LLAMA_API_INTERNAL
|
|
2
1
|
#include "sampling.h"
|
|
3
|
-
#include <random>
|
|
4
2
|
|
|
5
|
-
|
|
6
|
-
struct llama_sampling_context * result = new llama_sampling_context();
|
|
3
|
+
#include "common.h"
|
|
7
4
|
|
|
8
|
-
|
|
9
|
-
|
|
5
|
+
#include <cmath>
|
|
6
|
+
#include <unordered_map>
|
|
10
7
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
8
|
+
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
|
9
|
+
// TODO: deduplicate with llama-impl.h
|
|
10
|
+
template<typename T>
|
|
11
|
+
struct ring_buffer {
|
|
12
|
+
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
if (
|
|
17
|
-
|
|
18
|
-
delete result;
|
|
19
|
-
return nullptr;
|
|
14
|
+
T & front() {
|
|
15
|
+
if (sz == 0) {
|
|
16
|
+
throw std::runtime_error("ring buffer is empty");
|
|
20
17
|
}
|
|
18
|
+
return data[first];
|
|
19
|
+
}
|
|
21
20
|
|
|
22
|
-
|
|
23
|
-
if (
|
|
24
|
-
|
|
25
|
-
delete result;
|
|
26
|
-
return nullptr;
|
|
21
|
+
const T & front() const {
|
|
22
|
+
if (sz == 0) {
|
|
23
|
+
throw std::runtime_error("ring buffer is empty");
|
|
27
24
|
}
|
|
25
|
+
return data[first];
|
|
26
|
+
}
|
|
28
27
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
grammar_rules.data(),
|
|
33
|
-
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
|
34
|
-
if (grammar == nullptr) {
|
|
35
|
-
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
28
|
+
T & back() {
|
|
29
|
+
if (sz == 0) {
|
|
30
|
+
throw std::runtime_error("ring buffer is empty");
|
|
36
31
|
}
|
|
37
|
-
|
|
32
|
+
return data[pos];
|
|
38
33
|
}
|
|
39
34
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
return result;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
void llama_sampling_free(struct llama_sampling_context * ctx) {
|
|
50
|
-
if (ctx->grammar != NULL) {
|
|
51
|
-
llama_grammar_free(ctx->grammar);
|
|
35
|
+
const T & back() const {
|
|
36
|
+
if (sz == 0) {
|
|
37
|
+
throw std::runtime_error("ring buffer is empty");
|
|
38
|
+
}
|
|
39
|
+
return data[pos];
|
|
52
40
|
}
|
|
53
41
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
42
|
+
void push_back(const T & value) {
|
|
43
|
+
if (sz == capacity) {
|
|
44
|
+
// advance the start when buffer is full
|
|
45
|
+
first = (first + 1) % capacity;
|
|
46
|
+
} else {
|
|
47
|
+
sz++;
|
|
48
|
+
}
|
|
49
|
+
data[pos] = value;
|
|
50
|
+
pos = (pos + 1) % capacity;
|
|
61
51
|
}
|
|
62
52
|
|
|
63
|
-
|
|
64
|
-
|
|
53
|
+
T pop_front() {
|
|
54
|
+
if (sz == 0) {
|
|
55
|
+
throw std::runtime_error("ring buffer is empty");
|
|
56
|
+
}
|
|
57
|
+
T value = data[first];
|
|
58
|
+
first = (first + 1) % capacity;
|
|
59
|
+
sz--;
|
|
60
|
+
return value;
|
|
61
|
+
}
|
|
65
62
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
if (grammar == nullptr) {
|
|
70
|
-
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
63
|
+
const T & rat(size_t i) const {
|
|
64
|
+
if (i >= sz) {
|
|
65
|
+
throw std::runtime_error("ring buffer: index out of bounds");
|
|
71
66
|
}
|
|
72
|
-
|
|
67
|
+
return data[(first + sz - i - 1) % capacity];
|
|
73
68
|
}
|
|
74
69
|
|
|
75
|
-
std::
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
70
|
+
std::vector<T> to_vector() const {
|
|
71
|
+
std::vector<T> result;
|
|
72
|
+
result.reserve(sz);
|
|
73
|
+
for (size_t i = 0; i < sz; i++) {
|
|
74
|
+
result.push_back(data[(first + i) % capacity]);
|
|
75
|
+
}
|
|
76
|
+
return result;
|
|
77
|
+
}
|
|
79
78
|
|
|
80
|
-
void
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
void clear() {
|
|
80
|
+
// here only reset the status of the buffer
|
|
81
|
+
sz = 0;
|
|
82
|
+
first = 0;
|
|
83
|
+
pos = 0;
|
|
83
84
|
}
|
|
84
|
-
ctx->rng.seed(seed);
|
|
85
|
-
}
|
|
86
85
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
llama_grammar_free(dst->grammar);
|
|
90
|
-
dst->grammar = nullptr;
|
|
86
|
+
bool empty() const {
|
|
87
|
+
return sz == 0;
|
|
91
88
|
}
|
|
92
89
|
|
|
93
|
-
|
|
94
|
-
|
|
90
|
+
size_t size() const {
|
|
91
|
+
return sz;
|
|
95
92
|
}
|
|
96
93
|
|
|
97
|
-
|
|
98
|
-
|
|
94
|
+
size_t capacity = 0;
|
|
95
|
+
size_t sz = 0;
|
|
96
|
+
size_t first = 0;
|
|
97
|
+
size_t pos = 0;
|
|
98
|
+
std::vector<T> data;
|
|
99
|
+
};
|
|
99
100
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
}
|
|
101
|
+
struct common_sampler {
|
|
102
|
+
common_sampler_params params;
|
|
103
103
|
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
struct llama_sampler * grmr;
|
|
105
|
+
struct llama_sampler * chain;
|
|
106
106
|
|
|
107
|
-
|
|
107
|
+
ring_buffer<llama_token> prev;
|
|
108
108
|
|
|
109
|
-
std::
|
|
109
|
+
std::vector<llama_token_data> cur;
|
|
110
110
|
|
|
111
|
-
|
|
112
|
-
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
|
|
113
|
-
}
|
|
111
|
+
llama_token_data_array cur_p;
|
|
114
112
|
|
|
115
|
-
|
|
116
|
-
|
|
113
|
+
void set_logits(struct llama_context * ctx, int idx) {
|
|
114
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
115
|
+
|
|
116
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
117
|
+
|
|
118
|
+
cur.resize(n_vocab);
|
|
117
119
|
|
|
118
|
-
|
|
120
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
121
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
cur_p = { cur.data(), cur.size(), -1, false };
|
|
125
|
+
}
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
std::string common_sampler_params::print() const {
|
|
119
129
|
char result[1024];
|
|
120
130
|
|
|
121
131
|
snprintf(result, sizeof(result),
|
|
122
132
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
123
|
-
"\
|
|
133
|
+
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
|
134
|
+
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
|
124
135
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
136
|
+
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
|
137
|
+
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
|
138
|
+
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
|
139
|
+
mirostat, mirostat_eta, mirostat_tau);
|
|
128
140
|
|
|
129
141
|
return std::string(result);
|
|
130
142
|
}
|
|
131
143
|
|
|
132
|
-
|
|
133
|
-
|
|
144
|
+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
|
|
145
|
+
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
|
146
|
+
|
|
147
|
+
lparams.no_perf = params.no_perf;
|
|
148
|
+
|
|
149
|
+
auto * result = new common_sampler {
|
|
150
|
+
/* .params = */ params,
|
|
151
|
+
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
|
152
|
+
/* .chain = */ llama_sampler_chain_init(lparams),
|
|
153
|
+
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
154
|
+
/* .cur = */ {},
|
|
155
|
+
/* .cur_p = */ {},
|
|
156
|
+
};
|
|
157
|
+
|
|
158
|
+
llama_sampler_chain_add(result->chain,
|
|
159
|
+
llama_sampler_init_logit_bias(
|
|
160
|
+
llama_n_vocab(model),
|
|
161
|
+
params.logit_bias.size(),
|
|
162
|
+
params.logit_bias.data()));
|
|
163
|
+
|
|
164
|
+
llama_sampler_chain_add(result->chain,
|
|
165
|
+
llama_sampler_init_penalties(
|
|
166
|
+
llama_n_vocab (model),
|
|
167
|
+
llama_token_eos(model),
|
|
168
|
+
llama_token_nl (model),
|
|
169
|
+
params.penalty_last_n,
|
|
170
|
+
params.penalty_repeat,
|
|
171
|
+
params.penalty_freq,
|
|
172
|
+
params.penalty_present,
|
|
173
|
+
params.penalize_nl,
|
|
174
|
+
params.ignore_eos));
|
|
175
|
+
|
|
134
176
|
if (params.mirostat == 0) {
|
|
135
|
-
for (auto
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
177
|
+
for (const auto & cnstr : params.samplers) {
|
|
178
|
+
switch (cnstr) {
|
|
179
|
+
case COMMON_SAMPLER_TYPE_DRY:
|
|
180
|
+
{
|
|
181
|
+
std::vector<const char*> c_breakers;
|
|
182
|
+
c_breakers.reserve(params.dry_sequence_breakers.size());
|
|
183
|
+
for (const auto& str : params.dry_sequence_breakers) {
|
|
184
|
+
c_breakers.push_back(str.c_str());
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
188
|
+
}
|
|
189
|
+
break;
|
|
190
|
+
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
191
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
|
192
|
+
break;
|
|
193
|
+
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
194
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
195
|
+
break;
|
|
196
|
+
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
197
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
198
|
+
break;
|
|
199
|
+
case COMMON_SAMPLER_TYPE_XTC:
|
|
200
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
201
|
+
break;
|
|
202
|
+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
203
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
204
|
+
break;
|
|
205
|
+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
206
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
207
|
+
break;
|
|
208
|
+
case COMMON_SAMPLER_TYPE_INFILL:
|
|
209
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
|
210
|
+
break;
|
|
211
|
+
default:
|
|
212
|
+
GGML_ASSERT(false && "unknown sampler type");
|
|
139
213
|
}
|
|
140
214
|
}
|
|
215
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
216
|
+
} else if (params.mirostat == 1) {
|
|
217
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
218
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
219
|
+
} else if (params.mirostat == 2) {
|
|
220
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
221
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
|
141
222
|
} else {
|
|
142
|
-
|
|
223
|
+
GGML_ASSERT(false && "unknown mirostat version");
|
|
143
224
|
}
|
|
144
225
|
|
|
145
226
|
return result;
|
|
146
227
|
}
|
|
147
228
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
case llama_sampler_type::TEMPERATURE: return "temperature";
|
|
156
|
-
default : return "";
|
|
229
|
+
void common_sampler_free(struct common_sampler * gsmpl) {
|
|
230
|
+
if (gsmpl) {
|
|
231
|
+
llama_sampler_free(gsmpl->grmr);
|
|
232
|
+
|
|
233
|
+
llama_sampler_free(gsmpl->chain);
|
|
234
|
+
|
|
235
|
+
delete gsmpl;
|
|
157
236
|
}
|
|
158
237
|
}
|
|
159
238
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
{"typical_p", llama_sampler_type::TYPICAL_P},
|
|
165
|
-
{"min_p", llama_sampler_type::MIN_P},
|
|
166
|
-
{"tfs_z", llama_sampler_type::TFS_Z},
|
|
167
|
-
{"temperature", llama_sampler_type::TEMPERATURE}
|
|
168
|
-
};
|
|
239
|
+
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
|
240
|
+
if (accept_grammar) {
|
|
241
|
+
llama_sampler_accept(gsmpl->grmr, token);
|
|
242
|
+
}
|
|
169
243
|
|
|
170
|
-
|
|
171
|
-
// make it ready for both system names and input names
|
|
172
|
-
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
|
173
|
-
{"top-k", llama_sampler_type::TOP_K},
|
|
174
|
-
{"top-p", llama_sampler_type::TOP_P},
|
|
175
|
-
{"nucleus", llama_sampler_type::TOP_P},
|
|
176
|
-
{"typical-p", llama_sampler_type::TYPICAL_P},
|
|
177
|
-
{"typical", llama_sampler_type::TYPICAL_P},
|
|
178
|
-
{"min-p", llama_sampler_type::MIN_P},
|
|
179
|
-
{"tfs-z", llama_sampler_type::TFS_Z},
|
|
180
|
-
{"tfs", llama_sampler_type::TFS_Z},
|
|
181
|
-
{"temp", llama_sampler_type::TEMPERATURE}
|
|
182
|
-
};
|
|
244
|
+
llama_sampler_accept(gsmpl->chain, token);
|
|
183
245
|
|
|
184
|
-
|
|
185
|
-
sampler_types.reserve(names.size());
|
|
186
|
-
for (const auto & name : names)
|
|
187
|
-
{
|
|
188
|
-
auto sampler_item = sampler_canonical_name_map.find(name);
|
|
189
|
-
if (sampler_item != sampler_canonical_name_map.end())
|
|
190
|
-
{
|
|
191
|
-
sampler_types.push_back(sampler_item->second);
|
|
192
|
-
}
|
|
193
|
-
else
|
|
194
|
-
{
|
|
195
|
-
if (allow_alt_names)
|
|
196
|
-
{
|
|
197
|
-
sampler_item = sampler_alt_name_map.find(name);
|
|
198
|
-
if (sampler_item != sampler_alt_name_map.end())
|
|
199
|
-
{
|
|
200
|
-
sampler_types.push_back(sampler_item->second);
|
|
201
|
-
}
|
|
202
|
-
}
|
|
203
|
-
}
|
|
204
|
-
}
|
|
205
|
-
return sampler_types;
|
|
246
|
+
gsmpl->prev.push_back(token);
|
|
206
247
|
}
|
|
207
248
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
{'k', llama_sampler_type::TOP_K},
|
|
211
|
-
{'p', llama_sampler_type::TOP_P},
|
|
212
|
-
{'y', llama_sampler_type::TYPICAL_P},
|
|
213
|
-
{'m', llama_sampler_type::MIN_P},
|
|
214
|
-
{'f', llama_sampler_type::TFS_Z},
|
|
215
|
-
{'t', llama_sampler_type::TEMPERATURE}
|
|
216
|
-
};
|
|
249
|
+
void common_sampler_reset(struct common_sampler * gsmpl) {
|
|
250
|
+
llama_sampler_reset(gsmpl->grmr);
|
|
217
251
|
|
|
218
|
-
|
|
219
|
-
sampler_types.reserve(names_string.size());
|
|
220
|
-
for (const auto & c : names_string) {
|
|
221
|
-
const auto sampler_item = sampler_name_map.find(c);
|
|
222
|
-
if (sampler_item != sampler_name_map.end()) {
|
|
223
|
-
sampler_types.push_back(sampler_item->second);
|
|
224
|
-
}
|
|
225
|
-
}
|
|
226
|
-
return sampler_types;
|
|
252
|
+
llama_sampler_reset(gsmpl->chain);
|
|
227
253
|
}
|
|
228
254
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
const int32_t top_k = params.top_k;
|
|
239
|
-
const float top_p = params.top_p;
|
|
240
|
-
const float min_p = params.min_p;
|
|
241
|
-
const float tfs_z = params.tfs_z;
|
|
242
|
-
const float typical_p = params.typical_p;
|
|
243
|
-
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
|
|
244
|
-
|
|
245
|
-
for (auto sampler_type : samplers_sequence) {
|
|
246
|
-
switch (sampler_type) {
|
|
247
|
-
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
|
248
|
-
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
|
249
|
-
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
|
250
|
-
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
|
|
251
|
-
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
|
|
252
|
-
case llama_sampler_type::TEMPERATURE:
|
|
253
|
-
if (dynatemp_range > 0) {
|
|
254
|
-
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
|
|
255
|
-
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
|
|
256
|
-
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
|
|
257
|
-
} else {
|
|
258
|
-
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
259
|
-
}
|
|
260
|
-
break;
|
|
261
|
-
default : break;
|
|
262
|
-
}
|
|
263
|
-
}
|
|
255
|
+
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|
256
|
+
return new common_sampler {
|
|
257
|
+
/* .params = */ gsmpl->params,
|
|
258
|
+
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
|
259
|
+
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
260
|
+
/* .prev = */ gsmpl->prev,
|
|
261
|
+
/* .cur = */ gsmpl->cur,
|
|
262
|
+
/* .cur_p = */ gsmpl->cur_p,
|
|
263
|
+
};
|
|
264
264
|
}
|
|
265
265
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
struct llama_context * ctx_main,
|
|
269
|
-
struct llama_context * ctx_cfg,
|
|
270
|
-
const int idx,
|
|
271
|
-
bool is_resampling) {
|
|
272
|
-
const llama_sampling_params & params = ctx_sampling->params;
|
|
273
|
-
|
|
274
|
-
const float temp = params.temp;
|
|
275
|
-
const int mirostat = params.mirostat;
|
|
276
|
-
const float mirostat_tau = params.mirostat_tau;
|
|
277
|
-
const float mirostat_eta = params.mirostat_eta;
|
|
278
|
-
|
|
279
|
-
std::vector<float> original_logits;
|
|
280
|
-
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
|
281
|
-
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
|
282
|
-
GGML_ASSERT(!original_logits.empty());
|
|
283
|
-
}
|
|
284
|
-
llama_token id = 0;
|
|
285
|
-
|
|
286
|
-
if (temp < 0.0) {
|
|
287
|
-
// greedy sampling, with probs
|
|
288
|
-
llama_sample_softmax(ctx_main, &cur_p);
|
|
289
|
-
id = cur_p.data[0].id;
|
|
290
|
-
} else if (temp == 0.0) {
|
|
291
|
-
// greedy sampling, no probs
|
|
292
|
-
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
|
293
|
-
} else {
|
|
294
|
-
if (mirostat == 1) {
|
|
295
|
-
const int mirostat_m = 100;
|
|
296
|
-
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
297
|
-
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
|
|
298
|
-
} else if (mirostat == 2) {
|
|
299
|
-
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
300
|
-
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
|
301
|
-
} else {
|
|
302
|
-
// temperature sampling
|
|
303
|
-
size_t min_keep = std::max(1, params.min_keep);
|
|
304
|
-
|
|
305
|
-
sampler_queue(ctx_main, params, cur_p, min_keep);
|
|
266
|
+
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
|
267
|
+
// TODO: measure grammar performance
|
|
306
268
|
|
|
307
|
-
|
|
269
|
+
if (gsmpl) {
|
|
270
|
+
llama_perf_sampler_print(gsmpl->chain);
|
|
271
|
+
}
|
|
272
|
+
if (ctx) {
|
|
273
|
+
llama_perf_context_print(ctx);
|
|
274
|
+
}
|
|
275
|
+
}
|
|
308
276
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
// LOG("top %d candidates:\n", n_top);
|
|
277
|
+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
|
278
|
+
gsmpl->set_logits(ctx, idx);
|
|
312
279
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
|
|
317
|
-
// }
|
|
318
|
-
//}
|
|
280
|
+
auto & grmr = gsmpl->grmr;
|
|
281
|
+
auto & chain = gsmpl->chain;
|
|
282
|
+
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
|
319
283
|
|
|
320
|
-
|
|
321
|
-
|
|
284
|
+
if (grammar_first) {
|
|
285
|
+
llama_sampler_apply(grmr, &cur_p);
|
|
322
286
|
}
|
|
323
287
|
|
|
324
|
-
|
|
325
|
-
// Get a pointer to the logits
|
|
326
|
-
float * logits = llama_get_logits_ith(ctx_main, idx);
|
|
288
|
+
llama_sampler_apply(chain, &cur_p);
|
|
327
289
|
|
|
328
|
-
|
|
329
|
-
llama_token_data single_token_data = {id, logits[id], 0.0f};
|
|
330
|
-
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
|
290
|
+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
|
331
291
|
|
|
332
|
-
|
|
333
|
-
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
|
|
292
|
+
const llama_token id = cur_p.data[cur_p.selected].id;
|
|
334
293
|
|
|
335
|
-
|
|
336
|
-
|
|
294
|
+
if (grammar_first) {
|
|
295
|
+
return id;
|
|
296
|
+
}
|
|
337
297
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
298
|
+
// check if it the sampled token fits the grammar
|
|
299
|
+
{
|
|
300
|
+
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
|
301
|
+
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
|
341
302
|
|
|
342
|
-
|
|
343
|
-
std::copy(original_logits.begin(), original_logits.end(), logits);
|
|
303
|
+
llama_sampler_apply(grmr, &single_token_data_array);
|
|
344
304
|
|
|
345
|
-
|
|
305
|
+
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
306
|
+
if (is_valid) {
|
|
307
|
+
return id;
|
|
346
308
|
}
|
|
347
309
|
}
|
|
348
310
|
|
|
349
|
-
|
|
311
|
+
// resampling:
|
|
312
|
+
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
|
313
|
+
gsmpl->set_logits(ctx, idx);
|
|
350
314
|
|
|
351
|
-
|
|
352
|
-
|
|
315
|
+
llama_sampler_apply(grmr, &cur_p);
|
|
316
|
+
llama_sampler_apply(chain, &cur_p);
|
|
353
317
|
|
|
354
|
-
|
|
355
|
-
struct llama_sampling_context * ctx_sampling,
|
|
356
|
-
struct llama_context * ctx_main,
|
|
357
|
-
struct llama_context * ctx_cfg,
|
|
358
|
-
const int idx,
|
|
359
|
-
bool apply_grammar,
|
|
360
|
-
std::vector<float> * original_logits) {
|
|
361
|
-
const llama_sampling_params & params = ctx_sampling->params;
|
|
318
|
+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
|
362
319
|
|
|
363
|
-
|
|
320
|
+
return cur_p.data[cur_p.selected].id;
|
|
321
|
+
}
|
|
364
322
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
const float penalty_present = params.penalty_present;
|
|
323
|
+
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|
324
|
+
return llama_sampler_get_seed(gsmpl->chain);
|
|
325
|
+
}
|
|
369
326
|
|
|
370
|
-
|
|
327
|
+
// helpers
|
|
371
328
|
|
|
372
|
-
|
|
373
|
-
|
|
329
|
+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
|
|
330
|
+
return &gsmpl->cur_p;
|
|
331
|
+
}
|
|
374
332
|
|
|
375
|
-
|
|
376
|
-
|
|
333
|
+
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
|
334
|
+
return gsmpl->prev.rat(0);
|
|
335
|
+
}
|
|
377
336
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
337
|
+
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
|
338
|
+
std::string result = "logits ";
|
|
339
|
+
|
|
340
|
+
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
|
341
|
+
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
342
|
+
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
|
382
343
|
}
|
|
383
344
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
345
|
+
return result;
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
|
|
349
|
+
n = std::min(n, (int) gsmpl->prev.size());
|
|
350
|
+
|
|
351
|
+
if (n <= 0) {
|
|
352
|
+
return "";
|
|
387
353
|
}
|
|
388
354
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
355
|
+
std::string result;
|
|
356
|
+
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
|
|
357
|
+
|
|
358
|
+
for (int i = n - 1; i >= 0; i--) {
|
|
359
|
+
const llama_token id = gsmpl->prev.rat(i);
|
|
360
|
+
|
|
361
|
+
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
|
362
|
+
|
|
363
|
+
result += common_token_to_piece(ctx_main, id);
|
|
392
364
|
}
|
|
393
365
|
|
|
394
|
-
|
|
366
|
+
return result;
|
|
367
|
+
}
|
|
395
368
|
|
|
396
|
-
|
|
397
|
-
|
|
369
|
+
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|
370
|
+
switch (cnstr) {
|
|
371
|
+
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
|
372
|
+
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
|
373
|
+
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
|
374
|
+
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
|
375
|
+
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
|
376
|
+
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
|
377
|
+
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
|
378
|
+
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
|
379
|
+
default : return '?';
|
|
398
380
|
}
|
|
381
|
+
}
|
|
399
382
|
|
|
400
|
-
|
|
383
|
+
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|
384
|
+
switch (cnstr) {
|
|
385
|
+
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
|
386
|
+
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
|
387
|
+
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
|
388
|
+
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
|
389
|
+
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
|
390
|
+
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
|
391
|
+
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
|
392
|
+
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
|
393
|
+
default : return "";
|
|
394
|
+
}
|
|
395
|
+
}
|
|
401
396
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
397
|
+
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
|
398
|
+
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
|
399
|
+
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
|
400
|
+
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
401
|
+
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
402
|
+
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
403
|
+
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
404
|
+
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
405
|
+
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
|
406
|
+
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
|
407
|
+
};
|
|
407
408
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
409
|
+
// since samplers names are written multiple ways
|
|
410
|
+
// make it ready for both system names and input names
|
|
411
|
+
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
|
412
|
+
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
413
|
+
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
414
|
+
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
|
415
|
+
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
416
|
+
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
417
|
+
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
418
|
+
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
419
|
+
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
420
|
+
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
421
|
+
};
|
|
411
422
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
423
|
+
std::vector<common_sampler_type> samplers;
|
|
424
|
+
samplers.reserve(names.size());
|
|
425
|
+
|
|
426
|
+
for (const auto & name : names) {
|
|
427
|
+
auto sampler = sampler_canonical_name_map.find(name);
|
|
428
|
+
if (sampler != sampler_canonical_name_map.end()) {
|
|
429
|
+
samplers.push_back(sampler->second);
|
|
430
|
+
} else {
|
|
431
|
+
if (allow_alt_names) {
|
|
432
|
+
sampler = sampler_alt_name_map.find(name);
|
|
433
|
+
if (sampler != sampler_alt_name_map.end()) {
|
|
434
|
+
samplers.push_back(sampler->second);
|
|
417
435
|
}
|
|
418
436
|
}
|
|
419
437
|
}
|
|
420
438
|
}
|
|
421
439
|
|
|
422
|
-
|
|
423
|
-
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
|
424
|
-
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
|
|
425
|
-
}
|
|
426
|
-
|
|
427
|
-
return cur_p;
|
|
440
|
+
return samplers;
|
|
428
441
|
}
|
|
429
442
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
}
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
struct llama_context * ctx_main,
|
|
442
|
-
struct llama_context * ctx_cfg,
|
|
443
|
-
const int idx,
|
|
444
|
-
bool apply_grammar,
|
|
445
|
-
std::vector<float> * original_logits) {
|
|
446
|
-
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
|
447
|
-
}
|
|
443
|
+
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
|
444
|
+
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
|
445
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
|
446
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
|
447
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
448
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
|
449
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
|
450
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
451
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
|
452
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
|
453
|
+
};
|
|
448
454
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
struct llama_context * ctx_main,
|
|
452
|
-
llama_token id,
|
|
453
|
-
bool apply_grammar) {
|
|
454
|
-
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
|
455
|
-
ctx_sampling->prev.push_back(id);
|
|
455
|
+
std::vector<common_sampler_type> samplers;
|
|
456
|
+
samplers.reserve(chars.size());
|
|
456
457
|
|
|
457
|
-
|
|
458
|
-
|
|
458
|
+
for (const auto & c : chars) {
|
|
459
|
+
const auto sampler = sampler_name_map.find(c);
|
|
460
|
+
if (sampler != sampler_name_map.end()) {
|
|
461
|
+
samplers.push_back(sampler->second);
|
|
462
|
+
}
|
|
459
463
|
}
|
|
464
|
+
|
|
465
|
+
return samplers;
|
|
460
466
|
}
|