@fugood/llama.node 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -10
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +6 -4
- package/src/LlamaCompletionWorker.cpp +6 -6
- package/src/LlamaContext.cpp +7 -9
- package/src/common.hpp +2 -1
- package/src/llama.cpp/.github/workflows/build.yml +98 -24
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +43 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +20 -8
- package/src/llama.cpp/common/CMakeLists.txt +12 -10
- package/src/llama.cpp/common/arg.cpp +2006 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +496 -1632
- package/src/llama.cpp/common/common.h +161 -63
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +3 -0
- package/src/llama.cpp/common/sampling.cpp +348 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/common/train.cpp +2 -0
- package/src/llama.cpp/docs/build.md +36 -1
- package/src/llama.cpp/examples/CMakeLists.txt +0 -1
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +39 -55
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
- package/src/llama.cpp/examples/infill/infill.cpp +117 -132
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +685 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
- package/src/llama.cpp/examples/llava/llava.cpp +110 -24
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
- package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
- package/src/llama.cpp/examples/main/main.cpp +210 -262
- package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
- package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
- package/src/llama.cpp/examples/server/server.cpp +1027 -1073
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +107 -105
- package/src/llama.cpp/examples/simple/simple.cpp +35 -41
- package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
- package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
- package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
- package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
- package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
- package/src/llama.cpp/ggml/include/ggml.h +293 -186
- package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
- package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
- package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
- package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
- package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
- package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
- package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
- package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
- package/src/llama.cpp/include/llama.h +241 -264
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
- package/src/llama.cpp/src/llama-sampling.h +20 -47
- package/src/llama.cpp/src/llama-vocab.cpp +343 -120
- package/src/llama.cpp/src/llama-vocab.h +33 -17
- package/src/llama.cpp/src/llama.cpp +4247 -1525
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +3 -0
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
- package/src/llama.cpp/tests/test-barrier.cpp +93 -0
- package/src/llama.cpp/tests/test-grad0.cpp +187 -70
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
- package/src/llama.cpp/tests/test-rope.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +157 -98
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
|
@@ -1,460 +1,458 @@
|
|
|
1
|
-
#define LLAMA_API_INTERNAL
|
|
2
1
|
#include "sampling.h"
|
|
3
|
-
#include <random>
|
|
4
2
|
|
|
5
|
-
|
|
6
|
-
struct llama_sampling_context * result = new llama_sampling_context();
|
|
3
|
+
#include "common.h"
|
|
7
4
|
|
|
8
|
-
|
|
9
|
-
|
|
5
|
+
#include <cmath>
|
|
6
|
+
#include <unordered_map>
|
|
10
7
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
8
|
+
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
|
9
|
+
// TODO: deduplicate with llama-impl.h
|
|
10
|
+
template<typename T>
|
|
11
|
+
struct ring_buffer {
|
|
12
|
+
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
if (
|
|
17
|
-
|
|
18
|
-
delete result;
|
|
19
|
-
return nullptr;
|
|
14
|
+
T & front() {
|
|
15
|
+
if (sz == 0) {
|
|
16
|
+
throw std::runtime_error("ring buffer is empty");
|
|
20
17
|
}
|
|
18
|
+
return data[first];
|
|
19
|
+
}
|
|
21
20
|
|
|
22
|
-
|
|
23
|
-
if (
|
|
24
|
-
|
|
25
|
-
delete result;
|
|
26
|
-
return nullptr;
|
|
21
|
+
const T & front() const {
|
|
22
|
+
if (sz == 0) {
|
|
23
|
+
throw std::runtime_error("ring buffer is empty");
|
|
27
24
|
}
|
|
25
|
+
return data[first];
|
|
26
|
+
}
|
|
28
27
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
grammar_rules.data(),
|
|
33
|
-
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
|
34
|
-
if (grammar == nullptr) {
|
|
35
|
-
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
28
|
+
T & back() {
|
|
29
|
+
if (sz == 0) {
|
|
30
|
+
throw std::runtime_error("ring buffer is empty");
|
|
36
31
|
}
|
|
37
|
-
|
|
32
|
+
return data[pos];
|
|
38
33
|
}
|
|
39
34
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
return result;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
void llama_sampling_free(struct llama_sampling_context * ctx) {
|
|
50
|
-
if (ctx->grammar != NULL) {
|
|
51
|
-
llama_grammar_free(ctx->grammar);
|
|
35
|
+
const T & back() const {
|
|
36
|
+
if (sz == 0) {
|
|
37
|
+
throw std::runtime_error("ring buffer is empty");
|
|
38
|
+
}
|
|
39
|
+
return data[pos];
|
|
52
40
|
}
|
|
53
41
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
42
|
+
void push_back(const T & value) {
|
|
43
|
+
if (sz == capacity) {
|
|
44
|
+
// advance the start when buffer is full
|
|
45
|
+
first = (first + 1) % capacity;
|
|
46
|
+
} else {
|
|
47
|
+
sz++;
|
|
48
|
+
}
|
|
49
|
+
data[pos] = value;
|
|
50
|
+
pos = (pos + 1) % capacity;
|
|
61
51
|
}
|
|
62
52
|
|
|
63
|
-
|
|
64
|
-
|
|
53
|
+
T pop_front() {
|
|
54
|
+
if (sz == 0) {
|
|
55
|
+
throw std::runtime_error("ring buffer is empty");
|
|
56
|
+
}
|
|
57
|
+
T value = data[first];
|
|
58
|
+
first = (first + 1) % capacity;
|
|
59
|
+
sz--;
|
|
60
|
+
return value;
|
|
61
|
+
}
|
|
65
62
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
if (grammar == nullptr) {
|
|
70
|
-
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
63
|
+
const T & rat(size_t i) const {
|
|
64
|
+
if (i >= sz) {
|
|
65
|
+
throw std::runtime_error("ring buffer: index out of bounds");
|
|
71
66
|
}
|
|
72
|
-
|
|
67
|
+
return data[(first + sz - i - 1) % capacity];
|
|
73
68
|
}
|
|
74
69
|
|
|
75
|
-
std::
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
70
|
+
std::vector<T> to_vector() const {
|
|
71
|
+
std::vector<T> result;
|
|
72
|
+
result.reserve(sz);
|
|
73
|
+
for (size_t i = 0; i < sz; i++) {
|
|
74
|
+
result.push_back(data[(first + i) % capacity]);
|
|
75
|
+
}
|
|
76
|
+
return result;
|
|
77
|
+
}
|
|
79
78
|
|
|
80
|
-
void
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
void clear() {
|
|
80
|
+
// here only reset the status of the buffer
|
|
81
|
+
sz = 0;
|
|
82
|
+
first = 0;
|
|
83
|
+
pos = 0;
|
|
83
84
|
}
|
|
84
|
-
ctx->rng.seed(seed);
|
|
85
|
-
}
|
|
86
85
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
llama_grammar_free(dst->grammar);
|
|
90
|
-
dst->grammar = nullptr;
|
|
86
|
+
bool empty() const {
|
|
87
|
+
return sz == 0;
|
|
91
88
|
}
|
|
92
89
|
|
|
93
|
-
|
|
94
|
-
|
|
90
|
+
size_t size() const {
|
|
91
|
+
return sz;
|
|
95
92
|
}
|
|
96
93
|
|
|
97
|
-
|
|
98
|
-
|
|
94
|
+
size_t capacity = 0;
|
|
95
|
+
size_t sz = 0;
|
|
96
|
+
size_t first = 0;
|
|
97
|
+
size_t pos = 0;
|
|
98
|
+
std::vector<T> data;
|
|
99
|
+
};
|
|
99
100
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
}
|
|
101
|
+
struct gpt_sampler {
|
|
102
|
+
gpt_sampler_params params;
|
|
103
103
|
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
struct llama_sampler * grmr;
|
|
105
|
+
struct llama_sampler * chain;
|
|
106
106
|
|
|
107
|
-
|
|
107
|
+
ring_buffer<llama_token> prev;
|
|
108
108
|
|
|
109
|
-
std::
|
|
109
|
+
std::vector<llama_token_data> cur;
|
|
110
110
|
|
|
111
|
-
|
|
112
|
-
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
|
|
113
|
-
}
|
|
111
|
+
llama_token_data_array cur_p;
|
|
114
112
|
|
|
115
|
-
|
|
116
|
-
|
|
113
|
+
void set_logits(struct llama_context * ctx, int idx) {
|
|
114
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
117
115
|
|
|
118
|
-
|
|
116
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
117
|
+
|
|
118
|
+
cur.resize(n_vocab);
|
|
119
|
+
|
|
120
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
121
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
cur_p = { cur.data(), cur.size(), -1, false };
|
|
125
|
+
}
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
std::string gpt_sampler_params::print() const {
|
|
119
129
|
char result[1024];
|
|
120
130
|
|
|
121
131
|
snprintf(result, sizeof(result),
|
|
122
132
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
123
133
|
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
|
124
134
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
135
|
+
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
|
136
|
+
top_k, tfs_z, top_p, min_p, typ_p, temp,
|
|
137
|
+
mirostat, mirostat_eta, mirostat_tau);
|
|
128
138
|
|
|
129
139
|
return std::string(result);
|
|
130
140
|
}
|
|
131
141
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
142
|
+
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
|
143
|
+
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
|
144
|
+
|
|
145
|
+
lparams.no_perf = params.no_perf;
|
|
146
|
+
|
|
147
|
+
auto * result = new gpt_sampler {
|
|
148
|
+
/* .params = */ params,
|
|
149
|
+
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
|
150
|
+
/* .chain = */ llama_sampler_chain_init(lparams),
|
|
151
|
+
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
152
|
+
/* .cur = */ {},
|
|
153
|
+
/* .cur_p = */ {},
|
|
154
|
+
};
|
|
155
|
+
|
|
156
|
+
llama_sampler_chain_add(result->chain,
|
|
157
|
+
llama_sampler_init_logit_bias(
|
|
158
|
+
llama_n_vocab(model),
|
|
159
|
+
params.logit_bias.size(),
|
|
160
|
+
params.logit_bias.data()));
|
|
161
|
+
|
|
162
|
+
llama_sampler_chain_add(result->chain,
|
|
163
|
+
llama_sampler_init_penalties(
|
|
164
|
+
llama_n_vocab (model),
|
|
165
|
+
llama_token_eos(model),
|
|
166
|
+
llama_token_nl (model),
|
|
167
|
+
params.penalty_last_n,
|
|
168
|
+
params.penalty_repeat,
|
|
169
|
+
params.penalty_freq,
|
|
170
|
+
params.penalty_present,
|
|
171
|
+
params.penalize_nl,
|
|
172
|
+
params.ignore_eos));
|
|
173
|
+
|
|
174
|
+
if (params.temp > 0.0f) {
|
|
175
|
+
if (params.mirostat == 0) {
|
|
176
|
+
for (const auto & cnstr : params.samplers) {
|
|
177
|
+
switch (cnstr) {
|
|
178
|
+
case GPT_SAMPLER_TYPE_TOP_K:
|
|
179
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
|
180
|
+
break;
|
|
181
|
+
case GPT_SAMPLER_TYPE_TOP_P:
|
|
182
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
183
|
+
break;
|
|
184
|
+
case GPT_SAMPLER_TYPE_MIN_P:
|
|
185
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
186
|
+
break;
|
|
187
|
+
case GPT_SAMPLER_TYPE_TFS_Z:
|
|
188
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
|
189
|
+
break;
|
|
190
|
+
case GPT_SAMPLER_TYPE_TYPICAL_P:
|
|
191
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
192
|
+
break;
|
|
193
|
+
case GPT_SAMPLER_TYPE_TEMPERATURE:
|
|
194
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
195
|
+
break;
|
|
196
|
+
default:
|
|
197
|
+
GGML_ASSERT(false && "unknown sampler type");
|
|
198
|
+
}
|
|
139
199
|
}
|
|
200
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
201
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
202
|
+
} else if (params.mirostat == 1) {
|
|
203
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
204
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
205
|
+
} else if (params.mirostat == 2) {
|
|
206
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
207
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
|
208
|
+
} else {
|
|
209
|
+
GGML_ASSERT(false && "unknown mirostat version");
|
|
140
210
|
}
|
|
141
211
|
} else {
|
|
142
|
-
|
|
212
|
+
if (params.n_probs > 0) {
|
|
213
|
+
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
|
214
|
+
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
|
215
|
+
//
|
|
216
|
+
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
|
217
|
+
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
|
218
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
|
219
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
220
|
+
}
|
|
221
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
|
143
222
|
}
|
|
144
223
|
|
|
145
224
|
return result;
|
|
146
225
|
}
|
|
147
226
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
case llama_sampler_type::TEMPERATURE: return "temperature";
|
|
156
|
-
default : return "";
|
|
227
|
+
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
|
228
|
+
if (gsmpl) {
|
|
229
|
+
llama_sampler_free(gsmpl->grmr);
|
|
230
|
+
|
|
231
|
+
llama_sampler_free(gsmpl->chain);
|
|
232
|
+
|
|
233
|
+
delete gsmpl;
|
|
157
234
|
}
|
|
158
235
|
}
|
|
159
236
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
{"typical_p", llama_sampler_type::TYPICAL_P},
|
|
165
|
-
{"min_p", llama_sampler_type::MIN_P},
|
|
166
|
-
{"tfs_z", llama_sampler_type::TFS_Z},
|
|
167
|
-
{"temperature", llama_sampler_type::TEMPERATURE}
|
|
168
|
-
};
|
|
237
|
+
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
|
238
|
+
if (accept_grammar) {
|
|
239
|
+
llama_sampler_accept(gsmpl->grmr, token);
|
|
240
|
+
}
|
|
169
241
|
|
|
170
|
-
|
|
171
|
-
// make it ready for both system names and input names
|
|
172
|
-
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
|
173
|
-
{"top-k", llama_sampler_type::TOP_K},
|
|
174
|
-
{"top-p", llama_sampler_type::TOP_P},
|
|
175
|
-
{"nucleus", llama_sampler_type::TOP_P},
|
|
176
|
-
{"typical-p", llama_sampler_type::TYPICAL_P},
|
|
177
|
-
{"typical", llama_sampler_type::TYPICAL_P},
|
|
178
|
-
{"min-p", llama_sampler_type::MIN_P},
|
|
179
|
-
{"tfs-z", llama_sampler_type::TFS_Z},
|
|
180
|
-
{"tfs", llama_sampler_type::TFS_Z},
|
|
181
|
-
{"temp", llama_sampler_type::TEMPERATURE}
|
|
182
|
-
};
|
|
242
|
+
llama_sampler_accept(gsmpl->chain, token);
|
|
183
243
|
|
|
184
|
-
|
|
185
|
-
sampler_types.reserve(names.size());
|
|
186
|
-
for (const auto & name : names)
|
|
187
|
-
{
|
|
188
|
-
auto sampler_item = sampler_canonical_name_map.find(name);
|
|
189
|
-
if (sampler_item != sampler_canonical_name_map.end())
|
|
190
|
-
{
|
|
191
|
-
sampler_types.push_back(sampler_item->second);
|
|
192
|
-
}
|
|
193
|
-
else
|
|
194
|
-
{
|
|
195
|
-
if (allow_alt_names)
|
|
196
|
-
{
|
|
197
|
-
sampler_item = sampler_alt_name_map.find(name);
|
|
198
|
-
if (sampler_item != sampler_alt_name_map.end())
|
|
199
|
-
{
|
|
200
|
-
sampler_types.push_back(sampler_item->second);
|
|
201
|
-
}
|
|
202
|
-
}
|
|
203
|
-
}
|
|
204
|
-
}
|
|
205
|
-
return sampler_types;
|
|
244
|
+
gsmpl->prev.push_back(token);
|
|
206
245
|
}
|
|
207
246
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
{'k', llama_sampler_type::TOP_K},
|
|
211
|
-
{'p', llama_sampler_type::TOP_P},
|
|
212
|
-
{'y', llama_sampler_type::TYPICAL_P},
|
|
213
|
-
{'m', llama_sampler_type::MIN_P},
|
|
214
|
-
{'f', llama_sampler_type::TFS_Z},
|
|
215
|
-
{'t', llama_sampler_type::TEMPERATURE}
|
|
216
|
-
};
|
|
247
|
+
void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
|
|
248
|
+
llama_sampler_reset(gsmpl->grmr);
|
|
217
249
|
|
|
218
|
-
|
|
219
|
-
sampler_types.reserve(names_string.size());
|
|
220
|
-
for (const auto & c : names_string) {
|
|
221
|
-
const auto sampler_item = sampler_name_map.find(c);
|
|
222
|
-
if (sampler_item != sampler_name_map.end()) {
|
|
223
|
-
sampler_types.push_back(sampler_item->second);
|
|
224
|
-
}
|
|
225
|
-
}
|
|
226
|
-
return sampler_types;
|
|
250
|
+
llama_sampler_reset(gsmpl->chain);
|
|
227
251
|
}
|
|
228
252
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
const int32_t top_k = params.top_k;
|
|
239
|
-
const float top_p = params.top_p;
|
|
240
|
-
const float min_p = params.min_p;
|
|
241
|
-
const float tfs_z = params.tfs_z;
|
|
242
|
-
const float typical_p = params.typical_p;
|
|
243
|
-
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
|
|
244
|
-
|
|
245
|
-
for (auto sampler_type : samplers_sequence) {
|
|
246
|
-
switch (sampler_type) {
|
|
247
|
-
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
|
248
|
-
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
|
249
|
-
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
|
250
|
-
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
|
|
251
|
-
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
|
|
252
|
-
case llama_sampler_type::TEMPERATURE:
|
|
253
|
-
if (dynatemp_range > 0) {
|
|
254
|
-
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
|
|
255
|
-
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
|
|
256
|
-
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
|
|
257
|
-
} else {
|
|
258
|
-
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
259
|
-
}
|
|
260
|
-
break;
|
|
261
|
-
default : break;
|
|
262
|
-
}
|
|
263
|
-
}
|
|
253
|
+
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
|
254
|
+
return new gpt_sampler {
|
|
255
|
+
/* .params = */ gsmpl->params,
|
|
256
|
+
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
|
257
|
+
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
258
|
+
/* .prev = */ gsmpl->prev,
|
|
259
|
+
/* .cur = */ gsmpl->cur,
|
|
260
|
+
/* .cur_p = */ gsmpl->cur_p,
|
|
261
|
+
};
|
|
264
262
|
}
|
|
265
263
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
struct llama_context * ctx_main,
|
|
269
|
-
struct llama_context * ctx_cfg,
|
|
270
|
-
const int idx,
|
|
271
|
-
bool is_resampling) {
|
|
272
|
-
const llama_sampling_params & params = ctx_sampling->params;
|
|
273
|
-
|
|
274
|
-
const float temp = params.temp;
|
|
275
|
-
const int mirostat = params.mirostat;
|
|
276
|
-
const float mirostat_tau = params.mirostat_tau;
|
|
277
|
-
const float mirostat_eta = params.mirostat_eta;
|
|
278
|
-
|
|
279
|
-
std::vector<float> original_logits;
|
|
280
|
-
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
|
281
|
-
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
|
282
|
-
GGML_ASSERT(!original_logits.empty());
|
|
283
|
-
}
|
|
284
|
-
llama_token id = 0;
|
|
285
|
-
|
|
286
|
-
if (temp < 0.0) {
|
|
287
|
-
// greedy sampling, with probs
|
|
288
|
-
llama_sample_softmax(ctx_main, &cur_p);
|
|
289
|
-
id = cur_p.data[0].id;
|
|
290
|
-
} else if (temp == 0.0) {
|
|
291
|
-
// greedy sampling, no probs
|
|
292
|
-
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
|
293
|
-
} else {
|
|
294
|
-
if (mirostat == 1) {
|
|
295
|
-
const int mirostat_m = 100;
|
|
296
|
-
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
297
|
-
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
|
|
298
|
-
} else if (mirostat == 2) {
|
|
299
|
-
llama_sample_temp(ctx_main, &cur_p, temp);
|
|
300
|
-
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
|
301
|
-
} else {
|
|
302
|
-
// temperature sampling
|
|
303
|
-
size_t min_keep = std::max(1, params.min_keep);
|
|
304
|
-
|
|
305
|
-
sampler_queue(ctx_main, params, cur_p, min_keep);
|
|
264
|
+
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
|
|
265
|
+
// TODO: measure grammar performance
|
|
306
266
|
|
|
307
|
-
|
|
267
|
+
if (gsmpl) {
|
|
268
|
+
llama_perf_sampler_print(gsmpl->chain);
|
|
269
|
+
}
|
|
270
|
+
if (ctx) {
|
|
271
|
+
llama_perf_context_print(ctx);
|
|
272
|
+
}
|
|
273
|
+
}
|
|
308
274
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
// LOG("top %d candidates:\n", n_top);
|
|
275
|
+
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
|
276
|
+
gsmpl->set_logits(ctx, idx);
|
|
312
277
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
|
|
317
|
-
// }
|
|
318
|
-
//}
|
|
278
|
+
auto & grmr = gsmpl->grmr;
|
|
279
|
+
auto & chain = gsmpl->chain;
|
|
280
|
+
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
|
319
281
|
|
|
320
|
-
|
|
321
|
-
|
|
282
|
+
if (grammar_first) {
|
|
283
|
+
llama_sampler_apply(grmr, &cur_p);
|
|
322
284
|
}
|
|
323
285
|
|
|
324
|
-
|
|
325
|
-
// Get a pointer to the logits
|
|
326
|
-
float * logits = llama_get_logits_ith(ctx_main, idx);
|
|
286
|
+
llama_sampler_apply(chain, &cur_p);
|
|
327
287
|
|
|
328
|
-
|
|
329
|
-
llama_token_data single_token_data = {id, logits[id], 0.0f};
|
|
330
|
-
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
|
288
|
+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
|
331
289
|
|
|
332
|
-
|
|
333
|
-
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
|
|
290
|
+
const llama_token id = cur_p.data[cur_p.selected].id;
|
|
334
291
|
|
|
335
|
-
|
|
336
|
-
|
|
292
|
+
if (grammar_first) {
|
|
293
|
+
return id;
|
|
294
|
+
}
|
|
337
295
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
296
|
+
// check if it the sampled token fits the grammar
|
|
297
|
+
{
|
|
298
|
+
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
|
299
|
+
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
|
341
300
|
|
|
342
|
-
|
|
343
|
-
std::copy(original_logits.begin(), original_logits.end(), logits);
|
|
301
|
+
llama_sampler_apply(grmr, &single_token_data_array);
|
|
344
302
|
|
|
345
|
-
|
|
303
|
+
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
304
|
+
if (is_valid) {
|
|
305
|
+
return id;
|
|
346
306
|
}
|
|
347
307
|
}
|
|
348
308
|
|
|
349
|
-
|
|
309
|
+
// resampling:
|
|
310
|
+
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
|
311
|
+
gsmpl->set_logits(ctx, idx);
|
|
350
312
|
|
|
351
|
-
|
|
352
|
-
|
|
313
|
+
llama_sampler_apply(grmr, &cur_p);
|
|
314
|
+
llama_sampler_apply(chain, &cur_p);
|
|
353
315
|
|
|
354
|
-
|
|
355
|
-
struct llama_sampling_context * ctx_sampling,
|
|
356
|
-
struct llama_context * ctx_main,
|
|
357
|
-
struct llama_context * ctx_cfg,
|
|
358
|
-
const int idx,
|
|
359
|
-
bool apply_grammar,
|
|
360
|
-
std::vector<float> * original_logits) {
|
|
361
|
-
const llama_sampling_params & params = ctx_sampling->params;
|
|
316
|
+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
|
362
317
|
|
|
363
|
-
|
|
318
|
+
return cur_p.data[cur_p.selected].id;
|
|
319
|
+
}
|
|
364
320
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
321
|
+
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
|
|
322
|
+
return llama_sampler_get_seed(gsmpl->chain);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
// helpers
|
|
369
326
|
|
|
370
|
-
|
|
327
|
+
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
|
|
328
|
+
return &gsmpl->cur_p;
|
|
329
|
+
}
|
|
371
330
|
|
|
372
|
-
|
|
373
|
-
|
|
331
|
+
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
|
|
332
|
+
return gsmpl->prev.rat(0);
|
|
333
|
+
}
|
|
374
334
|
|
|
375
|
-
|
|
376
|
-
|
|
335
|
+
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
|
336
|
+
std::string result = "logits ";
|
|
377
337
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
*original_logits = {logits, logits + n_vocab};
|
|
338
|
+
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
|
339
|
+
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
340
|
+
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
|
382
341
|
}
|
|
383
342
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
343
|
+
return result;
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
|
|
347
|
+
n = std::min(n, (int) gsmpl->prev.size());
|
|
348
|
+
|
|
349
|
+
if (n <= 0) {
|
|
350
|
+
return "";
|
|
387
351
|
}
|
|
388
352
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
353
|
+
std::string result;
|
|
354
|
+
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
|
|
355
|
+
|
|
356
|
+
for (int i = n - 1; i >= 0; i--) {
|
|
357
|
+
const llama_token id = gsmpl->prev.rat(i);
|
|
358
|
+
|
|
359
|
+
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
|
360
|
+
|
|
361
|
+
result += llama_token_to_piece(ctx_main, id);
|
|
392
362
|
}
|
|
393
363
|
|
|
394
|
-
|
|
364
|
+
return result;
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
|
|
368
|
+
switch (cnstr) {
|
|
369
|
+
case GPT_SAMPLER_TYPE_TOP_K: return 'k';
|
|
370
|
+
case GPT_SAMPLER_TYPE_TFS_Z: return 'f';
|
|
371
|
+
case GPT_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
|
372
|
+
case GPT_SAMPLER_TYPE_TOP_P: return 'p';
|
|
373
|
+
case GPT_SAMPLER_TYPE_MIN_P: return 'm';
|
|
374
|
+
case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
|
|
375
|
+
default : return '?';
|
|
376
|
+
}
|
|
377
|
+
}
|
|
395
378
|
|
|
396
|
-
|
|
397
|
-
|
|
379
|
+
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
|
|
380
|
+
switch (cnstr) {
|
|
381
|
+
case GPT_SAMPLER_TYPE_TOP_K: return "top_k";
|
|
382
|
+
case GPT_SAMPLER_TYPE_TFS_Z: return "tfs_z";
|
|
383
|
+
case GPT_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
|
384
|
+
case GPT_SAMPLER_TYPE_TOP_P: return "top_p";
|
|
385
|
+
case GPT_SAMPLER_TYPE_MIN_P: return "min_p";
|
|
386
|
+
case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
|
387
|
+
default : return "";
|
|
398
388
|
}
|
|
389
|
+
}
|
|
399
390
|
|
|
400
|
-
|
|
391
|
+
std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
|
392
|
+
std::unordered_map<std::string, gpt_sampler_type> sampler_canonical_name_map {
|
|
393
|
+
{ "top_k", GPT_SAMPLER_TYPE_TOP_K },
|
|
394
|
+
{ "top_p", GPT_SAMPLER_TYPE_TOP_P },
|
|
395
|
+
{ "typ_p", GPT_SAMPLER_TYPE_TYPICAL_P },
|
|
396
|
+
{ "min_p", GPT_SAMPLER_TYPE_MIN_P },
|
|
397
|
+
{ "tfs_z", GPT_SAMPLER_TYPE_TFS_Z },
|
|
398
|
+
{ "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
|
|
399
|
+
};
|
|
401
400
|
|
|
402
|
-
//
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
401
|
+
// since samplers names are written multiple ways
|
|
402
|
+
// make it ready for both system names and input names
|
|
403
|
+
std::unordered_map<std::string, gpt_sampler_type> sampler_alt_name_map {
|
|
404
|
+
{ "top-k", GPT_SAMPLER_TYPE_TOP_K },
|
|
405
|
+
{ "top-p", GPT_SAMPLER_TYPE_TOP_P },
|
|
406
|
+
{ "nucleus", GPT_SAMPLER_TYPE_TOP_P },
|
|
407
|
+
{ "typical-p", GPT_SAMPLER_TYPE_TYPICAL_P },
|
|
408
|
+
{ "typical", GPT_SAMPLER_TYPE_TYPICAL_P },
|
|
409
|
+
{ "typ-p", GPT_SAMPLER_TYPE_TYPICAL_P },
|
|
410
|
+
{ "typ", GPT_SAMPLER_TYPE_TYPICAL_P },
|
|
411
|
+
{ "min-p", GPT_SAMPLER_TYPE_MIN_P },
|
|
412
|
+
{ "tfs-z", GPT_SAMPLER_TYPE_TFS_Z },
|
|
413
|
+
{ "tfs", GPT_SAMPLER_TYPE_TFS_Z },
|
|
414
|
+
{ "temp", GPT_SAMPLER_TYPE_TEMPERATURE },
|
|
415
|
+
};
|
|
407
416
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
|
417
|
+
std::vector<gpt_sampler_type> samplers;
|
|
418
|
+
samplers.reserve(names.size());
|
|
411
419
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
420
|
+
for (const auto & name : names) {
|
|
421
|
+
auto sampler = sampler_canonical_name_map.find(name);
|
|
422
|
+
if (sampler != sampler_canonical_name_map.end()) {
|
|
423
|
+
samplers.push_back(sampler->second);
|
|
424
|
+
} else {
|
|
425
|
+
if (allow_alt_names) {
|
|
426
|
+
sampler = sampler_alt_name_map.find(name);
|
|
427
|
+
if (sampler != sampler_alt_name_map.end()) {
|
|
428
|
+
samplers.push_back(sampler->second);
|
|
417
429
|
}
|
|
418
430
|
}
|
|
419
431
|
}
|
|
420
432
|
}
|
|
421
433
|
|
|
422
|
-
|
|
423
|
-
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
|
424
|
-
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
|
|
425
|
-
}
|
|
426
|
-
|
|
427
|
-
return cur_p;
|
|
434
|
+
return samplers;
|
|
428
435
|
}
|
|
429
436
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
}
|
|
438
|
-
|
|
439
|
-
llama_token_data_array llama_sampling_prepare(
|
|
440
|
-
struct llama_sampling_context * ctx_sampling,
|
|
441
|
-
struct llama_context * ctx_main,
|
|
442
|
-
struct llama_context * ctx_cfg,
|
|
443
|
-
const int idx,
|
|
444
|
-
bool apply_grammar,
|
|
445
|
-
std::vector<float> * original_logits) {
|
|
446
|
-
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
|
447
|
-
}
|
|
437
|
+
std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars) {
|
|
438
|
+
std::unordered_map<char, gpt_sampler_type> sampler_name_map = {
|
|
439
|
+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K), GPT_SAMPLER_TYPE_TOP_K },
|
|
440
|
+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z), GPT_SAMPLER_TYPE_TFS_Z },
|
|
441
|
+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P },
|
|
442
|
+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P },
|
|
443
|
+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P },
|
|
444
|
+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
|
|
445
|
+
};
|
|
448
446
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
struct llama_context * ctx_main,
|
|
452
|
-
llama_token id,
|
|
453
|
-
bool apply_grammar) {
|
|
454
|
-
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
|
455
|
-
ctx_sampling->prev.push_back(id);
|
|
447
|
+
std::vector<gpt_sampler_type> samplers;
|
|
448
|
+
samplers.reserve(chars.size());
|
|
456
449
|
|
|
457
|
-
|
|
458
|
-
|
|
450
|
+
for (const auto & c : chars) {
|
|
451
|
+
const auto sampler = sampler_name_map.find(c);
|
|
452
|
+
if (sampler != sampler_name_map.end()) {
|
|
453
|
+
samplers.push_back(sampler->second);
|
|
454
|
+
}
|
|
459
455
|
}
|
|
456
|
+
|
|
457
|
+
return samplers;
|
|
460
458
|
}
|