@fugood/llama.node 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -10
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +6 -4
- package/src/LlamaCompletionWorker.cpp +6 -6
- package/src/LlamaContext.cpp +7 -9
- package/src/common.hpp +2 -1
- package/src/llama.cpp/.github/workflows/build.yml +98 -24
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +43 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +20 -8
- package/src/llama.cpp/common/CMakeLists.txt +12 -10
- package/src/llama.cpp/common/arg.cpp +2006 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +496 -1632
- package/src/llama.cpp/common/common.h +161 -63
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +3 -0
- package/src/llama.cpp/common/sampling.cpp +348 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/common/train.cpp +2 -0
- package/src/llama.cpp/docs/build.md +36 -1
- package/src/llama.cpp/examples/CMakeLists.txt +0 -1
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +39 -55
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
- package/src/llama.cpp/examples/infill/infill.cpp +117 -132
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +685 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
- package/src/llama.cpp/examples/llava/llava.cpp +110 -24
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
- package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
- package/src/llama.cpp/examples/main/main.cpp +210 -262
- package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
- package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
- package/src/llama.cpp/examples/server/server.cpp +1027 -1073
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +107 -105
- package/src/llama.cpp/examples/simple/simple.cpp +35 -41
- package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
- package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
- package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
- package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
- package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
- package/src/llama.cpp/ggml/include/ggml.h +293 -186
- package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
- package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
- package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
- package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
- package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
- package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
- package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
- package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
- package/src/llama.cpp/include/llama.h +241 -264
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
- package/src/llama.cpp/src/llama-sampling.h +20 -47
- package/src/llama.cpp/src/llama-vocab.cpp +343 -120
- package/src/llama.cpp/src/llama-vocab.h +33 -17
- package/src/llama.cpp/src/llama.cpp +4247 -1525
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +3 -0
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
- package/src/llama.cpp/tests/test-barrier.cpp +93 -0
- package/src/llama.cpp/tests/test-grad0.cpp +187 -70
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
- package/src/llama.cpp/tests/test-rope.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +157 -98
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
|
@@ -1,12 +1,53 @@
|
|
|
1
1
|
#include "llama-sampling.h"
|
|
2
2
|
|
|
3
|
+
#include "llama-vocab.h"
|
|
4
|
+
#include "llama-grammar.h"
|
|
5
|
+
|
|
3
6
|
#include <algorithm>
|
|
7
|
+
#include <cassert>
|
|
8
|
+
#include <cfloat>
|
|
9
|
+
#include <chrono>
|
|
10
|
+
#include <cmath>
|
|
11
|
+
#include <cstdlib>
|
|
4
12
|
#include <cstring>
|
|
5
13
|
#include <ctime>
|
|
6
|
-
#include <cfloat>
|
|
7
14
|
#include <numeric>
|
|
15
|
+
#include <random>
|
|
8
16
|
#include <unordered_map>
|
|
9
17
|
|
|
18
|
+
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
|
19
|
+
// iterator for the probabilities
|
|
20
|
+
#ifdef __GNUC__
|
|
21
|
+
#pragma GCC diagnostic push
|
|
22
|
+
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
|
23
|
+
#endif
|
|
24
|
+
|
|
25
|
+
struct probs_iterator {
|
|
26
|
+
typedef std::input_iterator_tag iterator_category;
|
|
27
|
+
typedef float value_type;
|
|
28
|
+
typedef float * pointer;
|
|
29
|
+
typedef float & reference;
|
|
30
|
+
typedef ptrdiff_t difference_type;
|
|
31
|
+
|
|
32
|
+
const llama_token_data * data;
|
|
33
|
+
|
|
34
|
+
bool operator==(const probs_iterator & other) const { return data == other.data; }
|
|
35
|
+
bool operator!=(const probs_iterator & other) const { return data != other.data; }
|
|
36
|
+
const float & operator*() const { return data->p; }
|
|
37
|
+
probs_iterator & operator++() { ++data; return *this; }
|
|
38
|
+
probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
#ifdef __GNUC__
|
|
42
|
+
#pragma GCC diagnostic pop
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
|
|
46
|
+
|
|
47
|
+
return dist(rng);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/*
|
|
10
51
|
static void llama_log_softmax(float * array, size_t size) {
|
|
11
52
|
float max_l = *std::max_element(array, array + size);
|
|
12
53
|
float sum = 0.f;
|
|
@@ -20,79 +61,65 @@ static void llama_log_softmax(float * array, size_t size) {
|
|
|
20
61
|
array[i] = logf(array[i] / sum);
|
|
21
62
|
}
|
|
22
63
|
}
|
|
64
|
+
*/
|
|
23
65
|
|
|
24
|
-
void
|
|
25
|
-
|
|
26
|
-
seed = time(NULL);
|
|
27
|
-
}
|
|
28
|
-
|
|
29
|
-
smpl->rng.seed(seed);
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
|
33
|
-
GGML_ASSERT(candidates->size > 0);
|
|
34
|
-
|
|
35
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
|
66
|
+
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
|
67
|
+
GGML_ASSERT(cur_p->size > 0);
|
|
36
68
|
|
|
37
69
|
// Sort the logits in descending order
|
|
38
|
-
if (!
|
|
39
|
-
std::sort(
|
|
70
|
+
if (!cur_p->sorted) {
|
|
71
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
|
40
72
|
return a.logit > b.logit;
|
|
41
73
|
});
|
|
42
|
-
|
|
74
|
+
cur_p->sorted = true;
|
|
43
75
|
}
|
|
44
76
|
|
|
45
|
-
float max_l =
|
|
77
|
+
float max_l = cur_p->data[0].logit;
|
|
46
78
|
float cum_sum = 0.0f;
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
79
|
+
|
|
80
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
81
|
+
float p = expf(cur_p->data[i].logit - max_l);
|
|
82
|
+
cur_p->data[i].p = p;
|
|
50
83
|
cum_sum += p;
|
|
51
84
|
}
|
|
52
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
|
53
|
-
candidates->data[i].p /= cum_sum;
|
|
54
|
-
}
|
|
55
85
|
|
|
56
|
-
|
|
57
|
-
|
|
86
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
87
|
+
cur_p->data[i].p /= cum_sum;
|
|
58
88
|
}
|
|
59
89
|
}
|
|
60
90
|
|
|
61
|
-
void
|
|
91
|
+
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
|
62
92
|
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
|
63
|
-
// if (k >= (int32_t)
|
|
93
|
+
// if (k >= (int32_t)cur_p->size) {
|
|
64
94
|
// return;
|
|
65
95
|
// }
|
|
66
96
|
|
|
67
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
|
68
|
-
|
|
69
97
|
if (k <= 0) {
|
|
70
|
-
k =
|
|
98
|
+
k = cur_p->size;
|
|
71
99
|
}
|
|
72
100
|
|
|
73
|
-
k = std::
|
|
74
|
-
k = std::min(k, (int) candidates->size);
|
|
101
|
+
k = std::min(k, (int) cur_p->size);
|
|
75
102
|
|
|
76
103
|
// Sort scores in descending order
|
|
77
|
-
if (!
|
|
104
|
+
if (!cur_p->sorted) {
|
|
78
105
|
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
|
79
106
|
return a.logit > b.logit;
|
|
80
107
|
};
|
|
81
108
|
if (k <= 128) {
|
|
82
|
-
std::partial_sort(
|
|
109
|
+
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
|
|
83
110
|
} else {
|
|
84
111
|
constexpr int nbuckets = 128;
|
|
85
112
|
constexpr float bucket_low = -10.0f;
|
|
86
113
|
constexpr float bucket_high = 10.0f;
|
|
87
114
|
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
|
88
|
-
constexpr float
|
|
115
|
+
constexpr float bucket_inter = -bucket_low * bucket_scale;
|
|
89
116
|
|
|
90
|
-
std::vector<int> bucket_idx(
|
|
117
|
+
std::vector<int> bucket_idx(cur_p->size);
|
|
91
118
|
std::vector<int> histo(nbuckets, 0);
|
|
92
119
|
|
|
93
|
-
for (int i = 0; i < (int)
|
|
94
|
-
const float val =
|
|
95
|
-
int ib = int(bucket_scale * val +
|
|
120
|
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
|
121
|
+
const float val = cur_p->data[i].logit;
|
|
122
|
+
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
|
96
123
|
ib = std::max(0, std::min(nbuckets-1, ib));
|
|
97
124
|
bucket_idx[i] = ib;
|
|
98
125
|
++histo[ib];
|
|
@@ -101,20 +128,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
|
101
128
|
int ib = nbuckets - 1;
|
|
102
129
|
for ( ; ib >= 0; --ib) {
|
|
103
130
|
nhave += histo[ib];
|
|
104
|
-
if (nhave >= k)
|
|
131
|
+
if (nhave >= k) {
|
|
132
|
+
break;
|
|
133
|
+
}
|
|
105
134
|
}
|
|
106
135
|
std::vector<llama_token_data> tmp_tokens(nhave);
|
|
107
|
-
auto ptr = tmp_tokens.data();
|
|
136
|
+
auto * ptr = tmp_tokens.data();
|
|
108
137
|
std::vector<llama_token_data*> bucket_ptrs;
|
|
109
138
|
bucket_ptrs.reserve(nbuckets - ib);
|
|
110
139
|
for (int j = nbuckets - 1; j >= ib; --j) {
|
|
111
140
|
bucket_ptrs.push_back(ptr);
|
|
112
141
|
ptr += histo[j];
|
|
113
142
|
}
|
|
114
|
-
for (int i = 0; i < (int)
|
|
143
|
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
|
115
144
|
int j = bucket_idx[i];
|
|
116
145
|
if (j >= ib) {
|
|
117
|
-
*bucket_ptrs[nbuckets-1-j]++ =
|
|
146
|
+
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
|
|
118
147
|
}
|
|
119
148
|
}
|
|
120
149
|
|
|
@@ -127,125 +156,582 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
|
127
156
|
}
|
|
128
157
|
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
|
129
158
|
|
|
130
|
-
std::memcpy(
|
|
159
|
+
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
|
131
160
|
|
|
132
161
|
}
|
|
133
|
-
|
|
162
|
+
cur_p->sorted = true;
|
|
134
163
|
}
|
|
135
|
-
|
|
164
|
+
cur_p->size = k;
|
|
165
|
+
}
|
|
136
166
|
|
|
137
|
-
|
|
138
|
-
|
|
167
|
+
static uint32_t get_rng_seed(uint32_t seed) {
|
|
168
|
+
if (seed == LLAMA_DEFAULT_SEED) {
|
|
169
|
+
// use system clock if std::random_device is not a true RNG
|
|
170
|
+
static bool is_rd_prng = std::random_device().entropy() == 0;
|
|
171
|
+
if (is_rd_prng) {
|
|
172
|
+
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
|
|
173
|
+
}
|
|
174
|
+
std::random_device rd;
|
|
175
|
+
return rd();
|
|
139
176
|
}
|
|
177
|
+
return seed;
|
|
140
178
|
}
|
|
141
179
|
|
|
142
|
-
|
|
143
|
-
|
|
180
|
+
// llama_sampler API
|
|
181
|
+
|
|
182
|
+
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
|
183
|
+
if (!smpl->iface) {
|
|
184
|
+
return "(null)";
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
return smpl->iface->name(smpl);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
|
191
|
+
if (smpl->iface->accept) {
|
|
192
|
+
smpl->iface->accept(smpl, token);
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
|
197
|
+
GGML_ASSERT(smpl->iface->apply);
|
|
198
|
+
smpl->iface->apply(smpl, cur_p);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
void llama_sampler_reset(struct llama_sampler * smpl) {
|
|
202
|
+
if (smpl->iface->reset) {
|
|
203
|
+
smpl->iface->reset(smpl);
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
|
208
|
+
if (smpl->iface->clone) {
|
|
209
|
+
return smpl->iface->clone(smpl);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
if (smpl->ctx == nullptr) {
|
|
213
|
+
return new llama_sampler {
|
|
214
|
+
/* .iface = */ smpl->iface,
|
|
215
|
+
/* .ctx = */ nullptr,
|
|
216
|
+
};
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
GGML_ABORT("the sampler does not support cloning");
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
void llama_sampler_free(struct llama_sampler * smpl) {
|
|
223
|
+
if (smpl == nullptr) {
|
|
144
224
|
return;
|
|
145
225
|
}
|
|
146
226
|
|
|
147
|
-
|
|
227
|
+
if (smpl->iface->free) {
|
|
228
|
+
smpl->iface->free(smpl);
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
delete smpl;
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
|
235
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
236
|
+
|
|
237
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
238
|
+
|
|
239
|
+
// TODO: do not allocate each time
|
|
240
|
+
std::vector<llama_token_data> cur;
|
|
241
|
+
cur.reserve(n_vocab);
|
|
242
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
243
|
+
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
llama_token_data_array cur_p = {
|
|
247
|
+
/* .data = */ cur.data(),
|
|
248
|
+
/* .size = */ cur.size(),
|
|
249
|
+
/* .selected = */ -1,
|
|
250
|
+
/* .sorted = */ false,
|
|
251
|
+
};
|
|
252
|
+
|
|
253
|
+
llama_sampler_apply(smpl, &cur_p);
|
|
254
|
+
|
|
255
|
+
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
|
256
|
+
|
|
257
|
+
auto token = cur_p.data[cur_p.selected].id;
|
|
258
|
+
|
|
259
|
+
llama_sampler_accept(smpl, token);
|
|
260
|
+
|
|
261
|
+
return token;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// sampler chain
|
|
265
|
+
|
|
266
|
+
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
|
267
|
+
return "chain";
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
|
|
271
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
272
|
+
|
|
273
|
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
274
|
+
|
|
275
|
+
for (auto * smpl : chain->samplers) {
|
|
276
|
+
llama_sampler_accept(smpl, token);
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
chain->n_sample++;
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
283
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
284
|
+
|
|
285
|
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
286
|
+
|
|
287
|
+
for (auto * smpl : chain->samplers) {
|
|
288
|
+
llama_sampler_apply(smpl, cur_p);
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
|
293
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
294
|
+
|
|
295
|
+
for (auto * smpl : chain->samplers) {
|
|
296
|
+
llama_sampler_reset(smpl);
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
chain->t_sample_us = 0;
|
|
300
|
+
chain->n_sample = 0;
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
|
304
|
+
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
|
305
|
+
|
|
306
|
+
auto * result = llama_sampler_chain_init(chain_src->params);
|
|
307
|
+
|
|
308
|
+
for (auto * smpl : chain_src->samplers) {
|
|
309
|
+
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
return result;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
|
316
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
317
|
+
|
|
318
|
+
for (auto * smpl : chain->samplers) {
|
|
319
|
+
llama_sampler_free(smpl);
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
delete chain;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
static struct llama_sampler_i llama_sampler_chain_i = {
|
|
326
|
+
/* .name = */ llama_sampler_chain_name,
|
|
327
|
+
/* .accept = */ llama_sampler_chain_accept,
|
|
328
|
+
/* .apply = */ llama_sampler_chain_apply,
|
|
329
|
+
/* .reset = */ llama_sampler_chain_reset,
|
|
330
|
+
/* .clone = */ llama_sampler_chain_clone,
|
|
331
|
+
/* .free = */ llama_sampler_chain_free,
|
|
332
|
+
};
|
|
333
|
+
|
|
334
|
+
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
|
335
|
+
return new llama_sampler {
|
|
336
|
+
/* .iface = */ &llama_sampler_chain_i,
|
|
337
|
+
/* .ctx = */ new llama_sampler_chain {
|
|
338
|
+
/* .params = */ params,
|
|
339
|
+
/* .samplers = */ {},
|
|
340
|
+
/* .t_sample_us = */ 0,
|
|
341
|
+
/* .n_sample = */ 0,
|
|
342
|
+
},
|
|
343
|
+
};
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
|
347
|
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
|
348
|
+
p->samplers.push_back(smpl);
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
|
352
|
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
|
353
|
+
|
|
354
|
+
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
|
355
|
+
return nullptr;
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
return p->samplers[i];
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
|
362
|
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
|
363
|
+
|
|
364
|
+
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
|
365
|
+
return nullptr;
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
auto * result = p->samplers[i];
|
|
369
|
+
p->samplers.erase(p->samplers.begin() + i);
|
|
370
|
+
|
|
371
|
+
return result;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
|
375
|
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
|
376
|
+
|
|
377
|
+
return p->samplers.size();
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
//
|
|
381
|
+
// samplers
|
|
382
|
+
//
|
|
383
|
+
|
|
384
|
+
// greedy
|
|
385
|
+
|
|
386
|
+
static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
|
|
387
|
+
return "greedy";
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
|
391
|
+
cur_p->selected = 0;
|
|
392
|
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
|
393
|
+
if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
|
|
394
|
+
cur_p->selected = i;
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
static struct llama_sampler_i llama_sampler_greedy_i = {
|
|
400
|
+
/* .name = */ llama_sampler_greedy_name,
|
|
401
|
+
/* .accept = */ nullptr,
|
|
402
|
+
/* .apply = */ llama_sampler_greedy_apply,
|
|
403
|
+
/* .reset = */ nullptr,
|
|
404
|
+
/* .clone = */ nullptr,
|
|
405
|
+
/* .free = */ nullptr,
|
|
406
|
+
};
|
|
407
|
+
|
|
408
|
+
struct llama_sampler * llama_sampler_init_greedy() {
|
|
409
|
+
return new llama_sampler {
|
|
410
|
+
/* .iface = */ &llama_sampler_greedy_i,
|
|
411
|
+
/* .ctx = */ nullptr,
|
|
412
|
+
};
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
// dist
|
|
416
|
+
|
|
417
|
+
struct llama_sampler_dist {
|
|
418
|
+
const uint32_t seed;
|
|
419
|
+
uint32_t seed_cur;
|
|
420
|
+
|
|
421
|
+
std::mt19937 rng;
|
|
422
|
+
};
|
|
423
|
+
|
|
424
|
+
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
|
|
425
|
+
return "dist";
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
429
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
430
|
+
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
|
434
|
+
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
|
435
|
+
auto * result = llama_sampler_init_dist(ctx->seed);
|
|
436
|
+
|
|
437
|
+
// copy the state
|
|
438
|
+
{
|
|
439
|
+
auto * result_ctx = (llama_sampler_dist *) result->ctx;
|
|
440
|
+
|
|
441
|
+
result_ctx->rng = ctx->rng;
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
return result;
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
448
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
449
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
450
|
+
ctx->rng.seed(ctx->seed_cur);
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|
454
|
+
delete (llama_sampler_dist *) smpl->ctx;
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
static struct llama_sampler_i llama_sampler_dist_i = {
|
|
458
|
+
/* .name = */ llama_sampler_dist_name,
|
|
459
|
+
/* .accept = */ nullptr,
|
|
460
|
+
/* .apply = */ llama_sampler_dist_apply,
|
|
461
|
+
/* .reset = */ llama_sampler_dist_reset,
|
|
462
|
+
/* .clone = */ llama_sampler_dist_clone,
|
|
463
|
+
/* .free = */ llama_sampler_dist_free,
|
|
464
|
+
};
|
|
465
|
+
|
|
466
|
+
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
467
|
+
auto seed_cur = get_rng_seed(seed);
|
|
468
|
+
return new llama_sampler {
|
|
469
|
+
/* .iface = */ &llama_sampler_dist_i,
|
|
470
|
+
/* .ctx = */ new llama_sampler_dist {
|
|
471
|
+
/* .seed = */ seed,
|
|
472
|
+
/* .seed_cur = */ seed_cur,
|
|
473
|
+
/* .rng = */ std::mt19937(seed_cur),
|
|
474
|
+
},
|
|
475
|
+
};
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
// softmax
|
|
479
|
+
|
|
480
|
+
static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
|
|
481
|
+
return "softmax";
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
|
485
|
+
llama_sampler_softmax_impl(cur_p);
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
static struct llama_sampler_i llama_sampler_softmax_i = {
|
|
489
|
+
/* .name = */ llama_sampler_softmax_name,
|
|
490
|
+
/* .accept = */ nullptr,
|
|
491
|
+
/* .apply = */ llama_sampler_softmax_apply,
|
|
492
|
+
/* .reset = */ nullptr,
|
|
493
|
+
/* .clone = */ nullptr,
|
|
494
|
+
/* .free = */ nullptr,
|
|
495
|
+
};
|
|
496
|
+
|
|
497
|
+
struct llama_sampler * llama_sampler_init_softmax() {
|
|
498
|
+
return new llama_sampler {
|
|
499
|
+
/* .iface = */ &llama_sampler_softmax_i,
|
|
500
|
+
/* .ctx = */ nullptr,
|
|
501
|
+
};
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
// top-k
|
|
505
|
+
|
|
506
|
+
struct llama_sampler_top_k {
|
|
507
|
+
const int32_t k;
|
|
508
|
+
};
|
|
509
|
+
|
|
510
|
+
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
|
|
511
|
+
return "top-k";
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
515
|
+
const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
|
|
516
|
+
llama_sampler_top_k_impl(cur_p, ctx->k);
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
|
|
520
|
+
const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
|
|
521
|
+
return llama_sampler_init_top_k(ctx->k);
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
|
525
|
+
delete (llama_sampler_top_k *) smpl->ctx;
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
static struct llama_sampler_i llama_sampler_top_k_i = {
|
|
529
|
+
/* .name = */ llama_sampler_top_k_name,
|
|
530
|
+
/* .accept = */ nullptr,
|
|
531
|
+
/* .apply = */ llama_sampler_top_k_apply,
|
|
532
|
+
/* .reset = */ nullptr,
|
|
533
|
+
/* .clone = */ llama_sampler_top_k_clone,
|
|
534
|
+
/* .free = */ llama_sampler_top_k_free,
|
|
535
|
+
};
|
|
536
|
+
|
|
537
|
+
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
|
538
|
+
return new llama_sampler {
|
|
539
|
+
/* .iface = */ &llama_sampler_top_k_i,
|
|
540
|
+
/* .ctx = */ new llama_sampler_top_k {
|
|
541
|
+
/* .k = */ k,
|
|
542
|
+
},
|
|
543
|
+
};
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
// top-p
|
|
547
|
+
|
|
548
|
+
struct llama_sampler_top_p {
|
|
549
|
+
const float p;
|
|
550
|
+
const size_t min_keep;
|
|
551
|
+
};
|
|
552
|
+
|
|
553
|
+
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
|
|
554
|
+
return "top-p";
|
|
555
|
+
}
|
|
148
556
|
|
|
149
|
-
|
|
557
|
+
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
558
|
+
const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
|
|
559
|
+
|
|
560
|
+
if (ctx->p >= 1.0f) {
|
|
561
|
+
return;
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
llama_sampler_softmax_impl(cur_p);
|
|
150
565
|
|
|
151
566
|
// Compute the cumulative probabilities
|
|
152
567
|
float cum_sum = 0.0f;
|
|
153
|
-
size_t last_idx =
|
|
568
|
+
size_t last_idx = cur_p->size;
|
|
154
569
|
|
|
155
|
-
for (size_t i = 0; i <
|
|
156
|
-
cum_sum +=
|
|
570
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
571
|
+
cum_sum += cur_p->data[i].p;
|
|
157
572
|
|
|
158
573
|
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
|
159
574
|
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
|
160
|
-
if (cum_sum >= p && i + 1 >= min_keep) {
|
|
575
|
+
if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
|
|
161
576
|
last_idx = i + 1;
|
|
162
577
|
break;
|
|
163
578
|
}
|
|
164
579
|
}
|
|
165
580
|
|
|
166
581
|
// Resize the output vector to keep only the top-p tokens
|
|
167
|
-
|
|
582
|
+
cur_p->size = last_idx;
|
|
583
|
+
}
|
|
168
584
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
585
|
+
static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
|
|
586
|
+
const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
|
|
587
|
+
return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
|
591
|
+
delete (llama_sampler_top_p *) smpl->ctx;
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
static struct llama_sampler_i llama_sampler_top_p_i = {
|
|
595
|
+
/* .name = */ llama_sampler_top_p_name,
|
|
596
|
+
/* .accept = */ nullptr,
|
|
597
|
+
/* .apply = */ llama_sampler_top_p_apply,
|
|
598
|
+
/* .reset = */ nullptr,
|
|
599
|
+
/* .clone = */ llama_sampler_top_p_clone,
|
|
600
|
+
/* .free = */ llama_sampler_top_p_free,
|
|
601
|
+
};
|
|
602
|
+
|
|
603
|
+
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
|
604
|
+
return new llama_sampler {
|
|
605
|
+
/* .iface = */ &llama_sampler_top_p_i,
|
|
606
|
+
/* .ctx = */ new llama_sampler_top_p {
|
|
607
|
+
/* .p = */ p,
|
|
608
|
+
/* .min_keep = */ min_keep,
|
|
609
|
+
},
|
|
610
|
+
};
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
// min-p
|
|
614
|
+
|
|
615
|
+
struct llama_sampler_min_p {
|
|
616
|
+
const float p;
|
|
617
|
+
const size_t min_keep;
|
|
618
|
+
};
|
|
619
|
+
|
|
620
|
+
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
|
|
621
|
+
return "min-p";
|
|
172
622
|
}
|
|
173
623
|
|
|
174
|
-
void
|
|
175
|
-
|
|
624
|
+
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
625
|
+
const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
|
|
626
|
+
|
|
627
|
+
if (ctx->p <= 0.0f || !cur_p->size) {
|
|
176
628
|
return;
|
|
177
629
|
}
|
|
178
630
|
|
|
179
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
|
180
|
-
|
|
181
631
|
bool min_p_applied = false;
|
|
182
632
|
|
|
183
|
-
// if the
|
|
184
|
-
if (!
|
|
633
|
+
// if the cur_p aren't sorted, try the unsorted implementation first
|
|
634
|
+
if (!cur_p->sorted) {
|
|
185
635
|
std::vector<llama_token_data> filtered_tokens;
|
|
186
636
|
|
|
187
637
|
float max_logit = -FLT_MAX;
|
|
188
|
-
for (size_t i = 0; i <
|
|
189
|
-
max_logit = std::max(max_logit,
|
|
638
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
639
|
+
max_logit = std::max(max_logit, cur_p->data[i].logit);
|
|
190
640
|
}
|
|
191
|
-
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
|
641
|
+
const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
|
192
642
|
|
|
193
|
-
for (size_t i = 0; i <
|
|
194
|
-
if (
|
|
195
|
-
filtered_tokens.push_back(
|
|
643
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
644
|
+
if (cur_p->data[i].logit >= min_logit) {
|
|
645
|
+
filtered_tokens.push_back(cur_p->data[i]);
|
|
196
646
|
}
|
|
197
647
|
}
|
|
198
648
|
|
|
199
649
|
// if we have enough values the operation was a success
|
|
200
|
-
if (filtered_tokens.size() >= min_keep) {
|
|
201
|
-
memcpy(
|
|
202
|
-
|
|
650
|
+
if (filtered_tokens.size() >= ctx->min_keep) {
|
|
651
|
+
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
|
652
|
+
cur_p->size = filtered_tokens.size();
|
|
203
653
|
min_p_applied = true;
|
|
204
654
|
}
|
|
205
655
|
}
|
|
206
656
|
|
|
207
|
-
// if the
|
|
657
|
+
// if the cur_p are sorted or the unsorted implementation failed, use this implementation
|
|
208
658
|
if (!min_p_applied) {
|
|
209
659
|
// Sort the logits in descending order
|
|
210
|
-
if (!
|
|
211
|
-
std::sort(
|
|
660
|
+
if (!cur_p->sorted) {
|
|
661
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
|
212
662
|
return a.logit > b.logit;
|
|
213
663
|
});
|
|
214
|
-
|
|
664
|
+
cur_p->sorted = true;
|
|
215
665
|
}
|
|
216
666
|
|
|
217
|
-
const float min_logit =
|
|
667
|
+
const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
|
218
668
|
size_t i = 1; // first token always matches
|
|
219
669
|
|
|
220
|
-
for (; i <
|
|
221
|
-
if (
|
|
670
|
+
for (; i < cur_p->size; ++i) {
|
|
671
|
+
if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
|
|
222
672
|
break; // prob too small
|
|
223
673
|
}
|
|
224
674
|
}
|
|
225
675
|
|
|
226
676
|
// Resize the output vector to keep only the matching tokens
|
|
227
|
-
|
|
677
|
+
cur_p->size = i;
|
|
228
678
|
}
|
|
679
|
+
}
|
|
229
680
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
681
|
+
static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
|
|
682
|
+
const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
|
|
683
|
+
return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
|
|
233
684
|
}
|
|
234
685
|
|
|
235
|
-
void
|
|
236
|
-
|
|
686
|
+
static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
|
687
|
+
delete (llama_sampler_min_p *) smpl->ctx;
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
static struct llama_sampler_i llama_sampler_min_p_i = {
|
|
691
|
+
/* .name = */ llama_sampler_min_p_name,
|
|
692
|
+
/* .accept = */ nullptr,
|
|
693
|
+
/* .apply = */ llama_sampler_min_p_apply,
|
|
694
|
+
/* .reset = */ nullptr,
|
|
695
|
+
/* .clone = */ llama_sampler_min_p_clone,
|
|
696
|
+
/* .free = */ llama_sampler_min_p_free,
|
|
697
|
+
};
|
|
698
|
+
|
|
699
|
+
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
|
700
|
+
return new llama_sampler {
|
|
701
|
+
/* .iface = */ &llama_sampler_min_p_i,
|
|
702
|
+
/* .ctx = */ new llama_sampler_min_p {
|
|
703
|
+
/* .p = */ p,
|
|
704
|
+
/* .min_keep = */ min_keep,
|
|
705
|
+
},
|
|
706
|
+
};
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
// tail-free
|
|
710
|
+
|
|
711
|
+
struct llama_sampler_tail_free {
|
|
712
|
+
const float z;
|
|
713
|
+
const size_t min_keep;
|
|
714
|
+
};
|
|
715
|
+
|
|
716
|
+
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
|
|
717
|
+
return "tail-free";
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
721
|
+
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
|
|
722
|
+
|
|
723
|
+
if (ctx->z >= 1.0f || cur_p->size <= 2) {
|
|
237
724
|
return;
|
|
238
725
|
}
|
|
239
726
|
|
|
240
|
-
|
|
241
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
|
727
|
+
llama_sampler_softmax_impl(cur_p);
|
|
242
728
|
|
|
243
729
|
// Compute the first and second derivatives
|
|
244
|
-
std::vector<float> first_derivatives(
|
|
245
|
-
std::vector<float> second_derivatives(
|
|
730
|
+
std::vector<float> first_derivatives(cur_p->size - 1);
|
|
731
|
+
std::vector<float> second_derivatives(cur_p->size - 2);
|
|
246
732
|
|
|
247
733
|
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
|
248
|
-
first_derivatives[i] =
|
|
734
|
+
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
|
|
249
735
|
}
|
|
250
736
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
251
737
|
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
|
@@ -272,51 +758,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
|
|
272
758
|
}
|
|
273
759
|
|
|
274
760
|
float cum_sum = 0.0f;
|
|
275
|
-
size_t last_idx =
|
|
761
|
+
size_t last_idx = cur_p->size;
|
|
276
762
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
277
763
|
cum_sum += second_derivatives[i];
|
|
278
764
|
|
|
279
765
|
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
|
280
|
-
if (cum_sum > z && i >= min_keep) {
|
|
766
|
+
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
|
281
767
|
last_idx = i;
|
|
282
768
|
break;
|
|
283
769
|
}
|
|
284
770
|
}
|
|
285
771
|
|
|
286
772
|
// Resize the output vector to keep only the tokens above the tail location
|
|
287
|
-
|
|
773
|
+
cur_p->size = last_idx;
|
|
774
|
+
}
|
|
288
775
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
776
|
+
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
|
777
|
+
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
|
|
778
|
+
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
|
782
|
+
delete (llama_sampler_tail_free *) smpl->ctx;
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
|
786
|
+
/* .name = */ llama_sampler_tail_free_name,
|
|
787
|
+
/* .accept = */ nullptr,
|
|
788
|
+
/* .apply = */ llama_sampler_tail_free_apply,
|
|
789
|
+
/* .reset = */ nullptr,
|
|
790
|
+
/* .clone = */ llama_sampler_tail_free_clone,
|
|
791
|
+
/* .free = */ llama_sampler_tail_free_free,
|
|
792
|
+
};
|
|
793
|
+
|
|
794
|
+
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
|
795
|
+
return new llama_sampler {
|
|
796
|
+
/* .iface = */ &llama_sampler_tail_free_i,
|
|
797
|
+
/* .ctx = */ new llama_sampler_tail_free {
|
|
798
|
+
/* .z = */ z,
|
|
799
|
+
/*. min_keep = */ min_keep,
|
|
800
|
+
},
|
|
801
|
+
};
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
// typical
|
|
805
|
+
|
|
806
|
+
struct llama_sampler_typical {
|
|
807
|
+
const float p;
|
|
808
|
+
const size_t min_keep;
|
|
809
|
+
};
|
|
810
|
+
|
|
811
|
+
static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
|
|
812
|
+
return "typical";
|
|
292
813
|
}
|
|
293
814
|
|
|
294
|
-
void
|
|
815
|
+
static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
816
|
+
const auto * ctx = (llama_sampler_typical *) smpl->ctx;
|
|
817
|
+
|
|
295
818
|
// Reference implementation:
|
|
296
819
|
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
|
297
|
-
if (p >= 1.0f) {
|
|
820
|
+
if (ctx->p >= 1.0f) {
|
|
298
821
|
return;
|
|
299
822
|
}
|
|
300
823
|
|
|
301
824
|
// Compute the softmax of logits and calculate entropy
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
const int64_t t_start_sample_us = ggml_time_us();
|
|
825
|
+
llama_sampler_softmax_impl(cur_p);
|
|
305
826
|
|
|
306
827
|
float entropy = 0.0f;
|
|
307
|
-
for (size_t i = 0; i <
|
|
308
|
-
entropy += -
|
|
828
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
829
|
+
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
|
|
309
830
|
}
|
|
310
831
|
|
|
311
832
|
// Compute the absolute difference between negative log probability and entropy for each candidate
|
|
312
833
|
std::vector<float> shifted_scores;
|
|
313
|
-
for (size_t i = 0; i <
|
|
314
|
-
float shifted_score = fabsf(-logf(
|
|
834
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
835
|
+
float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
|
|
315
836
|
shifted_scores.push_back(shifted_score);
|
|
316
837
|
}
|
|
317
838
|
|
|
318
839
|
// Sort tokens based on the shifted_scores and their corresponding indices
|
|
319
|
-
std::vector<size_t> indices(
|
|
840
|
+
std::vector<size_t> indices(cur_p->size);
|
|
320
841
|
std::iota(indices.begin(), indices.end(), 0);
|
|
321
842
|
|
|
322
843
|
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
|
@@ -329,134 +850,618 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|
|
329
850
|
|
|
330
851
|
for (size_t i = 0; i < indices.size(); ++i) {
|
|
331
852
|
size_t idx = indices[i];
|
|
332
|
-
cum_sum +=
|
|
853
|
+
cum_sum += cur_p->data[idx].p;
|
|
333
854
|
|
|
334
855
|
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
|
335
|
-
if (cum_sum > p && i >= min_keep - 1) {
|
|
856
|
+
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
|
|
336
857
|
last_idx = i + 1;
|
|
337
858
|
break;
|
|
338
859
|
}
|
|
339
860
|
}
|
|
340
861
|
|
|
341
862
|
// Resize the output vector to keep only the locally typical tokens
|
|
342
|
-
std::vector<llama_token_data>
|
|
863
|
+
std::vector<llama_token_data> cur_p_new;
|
|
343
864
|
for (size_t i = 0; i < last_idx; ++i) {
|
|
344
865
|
size_t idx = indices[i];
|
|
345
|
-
|
|
866
|
+
cur_p_new.push_back(cur_p->data[idx]);
|
|
346
867
|
}
|
|
347
868
|
|
|
348
|
-
// Replace the data in
|
|
349
|
-
std::copy(
|
|
350
|
-
|
|
351
|
-
|
|
869
|
+
// Replace the data in cur_p with the cur_p_new data
|
|
870
|
+
std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
|
|
871
|
+
cur_p->size = cur_p_new.size();
|
|
872
|
+
cur_p->sorted = false;
|
|
873
|
+
}
|
|
352
874
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
875
|
+
static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
|
|
876
|
+
const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
|
|
877
|
+
return llama_sampler_init_typical(ctx->p, ctx->min_keep);
|
|
356
878
|
}
|
|
357
879
|
|
|
358
|
-
void
|
|
359
|
-
|
|
880
|
+
static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
|
881
|
+
delete (llama_sampler_typical *) smpl->ctx;
|
|
882
|
+
}
|
|
360
883
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
884
|
+
static struct llama_sampler_i llama_sampler_typical_i = {
|
|
885
|
+
/* .name = */ llama_sampler_typical_name,
|
|
886
|
+
/* .accept = */ nullptr,
|
|
887
|
+
/* .apply = */ llama_sampler_typical_apply,
|
|
888
|
+
/* .reset = */ nullptr,
|
|
889
|
+
/* .clone = */ llama_sampler_typical_clone,
|
|
890
|
+
/* .free = */ llama_sampler_typical_free,
|
|
891
|
+
};
|
|
892
|
+
|
|
893
|
+
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
|
894
|
+
return new llama_sampler {
|
|
895
|
+
/* .iface = */ &llama_sampler_typical_i,
|
|
896
|
+
/* .ctx = */ new llama_sampler_typical {
|
|
897
|
+
/* .p = */ p,
|
|
898
|
+
/* .min_keep = */ min_keep,
|
|
899
|
+
},
|
|
900
|
+
};
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
// temp
|
|
904
|
+
|
|
905
|
+
struct llama_sampler_temp {
|
|
906
|
+
const float temp;
|
|
907
|
+
};
|
|
908
|
+
|
|
909
|
+
static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
|
|
910
|
+
return "temp";
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
914
|
+
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
|
915
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
916
|
+
cur_p->data[i].logit /= ctx->temp;
|
|
364
917
|
}
|
|
918
|
+
}
|
|
365
919
|
|
|
366
|
-
|
|
367
|
-
|
|
920
|
+
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
|
921
|
+
const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
|
|
922
|
+
return llama_sampler_init_temp(ctx->temp);
|
|
923
|
+
}
|
|
368
924
|
|
|
369
|
-
|
|
925
|
+
static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
|
926
|
+
delete (llama_sampler_temp *) smpl->ctx;
|
|
927
|
+
}
|
|
370
928
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
929
|
+
static struct llama_sampler_i llama_sampler_temp_i = {
|
|
930
|
+
/* .name = */ llama_sampler_temp_name,
|
|
931
|
+
/* .accept = */ nullptr,
|
|
932
|
+
/* .apply = */ llama_sampler_temp_apply,
|
|
933
|
+
/* .reset = */ nullptr,
|
|
934
|
+
/* .clone = */ llama_sampler_temp_clone,
|
|
935
|
+
/* .free = */ llama_sampler_temp_free,
|
|
936
|
+
};
|
|
937
|
+
|
|
938
|
+
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
|
939
|
+
return new llama_sampler {
|
|
940
|
+
/* .iface = */ &llama_sampler_temp_i,
|
|
941
|
+
/* .ctx = */ new llama_sampler_temp {
|
|
942
|
+
/*.temp = */ temp,
|
|
943
|
+
},
|
|
944
|
+
};
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
// temp-ext
|
|
948
|
+
|
|
949
|
+
struct llama_sampler_temp_ext {
|
|
950
|
+
const float temp;
|
|
951
|
+
const float delta;
|
|
952
|
+
const float exponent;
|
|
953
|
+
};
|
|
954
|
+
|
|
955
|
+
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
|
|
956
|
+
return "temp-ext";
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
960
|
+
const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
961
|
+
if (ctx->delta > 0) {
|
|
962
|
+
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
|
963
|
+
const float max_temp = ctx->temp + ctx->delta;
|
|
964
|
+
float exponent_val = ctx->exponent;
|
|
965
|
+
|
|
966
|
+
// no need to do anything if there is only one (or zero) candidates
|
|
967
|
+
if (cur_p->size <= 1) {
|
|
968
|
+
return;
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
// Calculate maximum possible entropy
|
|
972
|
+
float max_entropy = -logf(1.0f / cur_p->size);
|
|
973
|
+
|
|
974
|
+
llama_sampler_softmax_impl(cur_p);
|
|
975
|
+
|
|
976
|
+
// Calculate entropy of the softmax probabilities
|
|
977
|
+
float entropy = 0.0f;
|
|
978
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
979
|
+
float prob = cur_p->data[i].p;
|
|
980
|
+
if (prob > 0.0f) { // Ensure no log(0)
|
|
981
|
+
entropy -= prob * logf(prob);
|
|
982
|
+
}
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
// Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
|
|
986
|
+
float normalized_entropy = entropy / max_entropy;
|
|
987
|
+
|
|
988
|
+
// Map the normalized entropy to the desired temperature range using the power function
|
|
989
|
+
float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
|
|
990
|
+
|
|
991
|
+
#ifdef DEBUG
|
|
992
|
+
LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
|
|
993
|
+
LLAMA_LOG_INFO("Entropy: %f\n", entropy);
|
|
994
|
+
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
|
995
|
+
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
|
996
|
+
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
|
997
|
+
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
|
998
|
+
#endif
|
|
999
|
+
|
|
1000
|
+
// Apply the dynamically calculated temperature scaling
|
|
1001
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1002
|
+
cur_p->data[i].logit /= dyn_temp;
|
|
1003
|
+
}
|
|
1004
|
+
|
|
1005
|
+
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
|
1006
|
+
const double max_l_double = cur_p->data[0].logit;
|
|
1007
|
+
|
|
1008
|
+
double cum_sum_double = 0.0;
|
|
1009
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1010
|
+
double p = exp(cur_p->data[i].logit - max_l_double);
|
|
1011
|
+
cur_p->data[i].p = p; // Store the scaled probability
|
|
1012
|
+
cum_sum_double += p;
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1016
|
+
cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
|
1017
|
+
}
|
|
1018
|
+
|
|
1019
|
+
#ifdef DEBUG
|
|
1020
|
+
// Print the updated top 25 probabilities after temperature scaling
|
|
1021
|
+
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
|
|
1022
|
+
for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
|
|
1023
|
+
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
|
|
1024
|
+
}
|
|
1025
|
+
#endif
|
|
1026
|
+
} else {
|
|
1027
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1028
|
+
cur_p->data[i].logit /= ctx->temp;
|
|
377
1029
|
}
|
|
378
1030
|
}
|
|
1031
|
+
}
|
|
379
1032
|
|
|
380
|
-
|
|
381
|
-
|
|
1033
|
+
static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
|
|
1034
|
+
const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
|
|
1035
|
+
return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
|
|
1036
|
+
}
|
|
382
1037
|
|
|
383
|
-
|
|
384
|
-
|
|
1038
|
+
static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|
1039
|
+
delete (llama_sampler_temp_ext *) smpl->ctx;
|
|
1040
|
+
}
|
|
385
1041
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
1042
|
+
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|
1043
|
+
/* .name = */ llama_sampler_temp_ext_name,
|
|
1044
|
+
/* .accept = */ nullptr,
|
|
1045
|
+
/* .apply = */ llama_sampler_temp_ext_apply,
|
|
1046
|
+
/* .reset = */ nullptr,
|
|
1047
|
+
/* .clone = */ llama_sampler_temp_ext_clone,
|
|
1048
|
+
/* .free = */ llama_sampler_temp_ext_free,
|
|
1049
|
+
};
|
|
1050
|
+
|
|
1051
|
+
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
|
1052
|
+
return new llama_sampler {
|
|
1053
|
+
/* .iface = */ &llama_sampler_temp_ext_i,
|
|
1054
|
+
/* .ctx = */ new llama_sampler_temp_ext {
|
|
1055
|
+
/* .temp = */ temp,
|
|
1056
|
+
/* .delta = */ delta,
|
|
1057
|
+
/* .exponent = */ exponent,
|
|
1058
|
+
},
|
|
1059
|
+
};
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
// mirostat
|
|
1063
|
+
|
|
1064
|
+
struct llama_sampler_mirostat {
|
|
1065
|
+
const int32_t n_vocab;
|
|
1066
|
+
|
|
1067
|
+
const uint32_t seed;
|
|
1068
|
+
uint32_t seed_cur;
|
|
1069
|
+
|
|
1070
|
+
const float tau;
|
|
1071
|
+
const float eta;
|
|
1072
|
+
|
|
1073
|
+
const int32_t m;
|
|
1074
|
+
|
|
1075
|
+
float mu;
|
|
394
1076
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
1077
|
+
std::mt19937 rng;
|
|
1078
|
+
};
|
|
1079
|
+
|
|
1080
|
+
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
|
|
1081
|
+
return "mirostat";
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1085
|
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
|
1086
|
+
|
|
1087
|
+
llama_sampler_softmax_impl(cur_p);
|
|
1088
|
+
|
|
1089
|
+
// Estimate s_hat using the most probable m tokens
|
|
1090
|
+
float s_hat = 0.0;
|
|
1091
|
+
float sum_ti_bi = 0.0;
|
|
1092
|
+
float sum_ti_sq = 0.0;
|
|
1093
|
+
for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
|
|
1094
|
+
float t_i = logf(float(i + 2) / float(i + 1));
|
|
1095
|
+
float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
|
|
1096
|
+
sum_ti_bi += t_i * b_i;
|
|
1097
|
+
sum_ti_sq += t_i * t_i;
|
|
398
1098
|
}
|
|
1099
|
+
s_hat = sum_ti_bi / sum_ti_sq;
|
|
1100
|
+
|
|
1101
|
+
// Compute k from the estimated s_hat and target surprise value
|
|
1102
|
+
float epsilon_hat = s_hat - 1;
|
|
1103
|
+
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
|
|
1104
|
+
|
|
1105
|
+
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
|
1106
|
+
llama_sampler_softmax_impl(cur_p);
|
|
1107
|
+
|
|
1108
|
+
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
|
1109
|
+
|
|
1110
|
+
cur_p->selected = idx;
|
|
1111
|
+
|
|
1112
|
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
|
1113
|
+
float e = observed_surprise - ctx->tau;
|
|
399
1114
|
|
|
400
|
-
//
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
1115
|
+
// Update mu using the learning rate and error
|
|
1116
|
+
ctx->mu = ctx->mu - ctx->eta * e;
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
|
|
1120
|
+
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
|
|
1121
|
+
auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
|
|
1122
|
+
|
|
1123
|
+
// copy the state
|
|
1124
|
+
{
|
|
1125
|
+
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
|
|
1126
|
+
|
|
1127
|
+
result_ctx->mu = ctx->mu;
|
|
1128
|
+
result_ctx->rng = ctx->rng;
|
|
407
1129
|
}
|
|
408
|
-
|
|
409
|
-
|
|
1130
|
+
|
|
1131
|
+
return result;
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
|
|
1135
|
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
|
1136
|
+
ctx->mu = 2.0f*ctx->tau;
|
|
1137
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
1138
|
+
ctx->rng.seed(ctx->seed_cur);
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
|
1142
|
+
delete (llama_sampler_mirostat *) smpl->ctx;
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|
1146
|
+
/* .name = */ llama_sampler_mirostat_name,
|
|
1147
|
+
/* .accept = */ nullptr,
|
|
1148
|
+
/* .apply = */ llama_sampler_mirostat_apply,
|
|
1149
|
+
/* .reset = */ llama_sampler_mirostat_reset,
|
|
1150
|
+
/* .clone = */ llama_sampler_mirostat_clone,
|
|
1151
|
+
/* .free = */ llama_sampler_mirostat_free,
|
|
1152
|
+
};
|
|
1153
|
+
|
|
1154
|
+
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
|
1155
|
+
auto seed_cur = get_rng_seed(seed);
|
|
1156
|
+
return new llama_sampler {
|
|
1157
|
+
/* .iface = */ &llama_sampler_mirostat_i,
|
|
1158
|
+
/* .ctx = */ new llama_sampler_mirostat {
|
|
1159
|
+
/* .n_vocab = */ n_vocab,
|
|
1160
|
+
/* .seed = */ seed,
|
|
1161
|
+
/* .seed_cur = */ seed_cur,
|
|
1162
|
+
/* .tau = */ tau,
|
|
1163
|
+
/* .eta = */ eta,
|
|
1164
|
+
/* .m = */ m,
|
|
1165
|
+
/* .mu = */ 2.0f*tau,
|
|
1166
|
+
/* .rng = */ std::mt19937(seed_cur),
|
|
1167
|
+
},
|
|
1168
|
+
};
|
|
1169
|
+
}
|
|
1170
|
+
|
|
1171
|
+
// mirostat v2
|
|
1172
|
+
|
|
1173
|
+
struct llama_sampler_mirostat_v2 {
|
|
1174
|
+
const uint32_t seed;
|
|
1175
|
+
uint32_t seed_cur;
|
|
1176
|
+
|
|
1177
|
+
const float tau;
|
|
1178
|
+
const float eta;
|
|
1179
|
+
|
|
1180
|
+
float mu;
|
|
1181
|
+
|
|
1182
|
+
std::mt19937 rng;
|
|
1183
|
+
};
|
|
1184
|
+
|
|
1185
|
+
static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
|
|
1186
|
+
return "mirostat-v2";
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1190
|
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
|
1191
|
+
|
|
1192
|
+
llama_sampler_softmax_impl(cur_p);
|
|
1193
|
+
|
|
1194
|
+
// Truncate the words with surprise values greater than mu
|
|
1195
|
+
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
|
1196
|
+
return -log2f(candidate.p) > ctx->mu;
|
|
1197
|
+
}));
|
|
1198
|
+
|
|
1199
|
+
if (cur_p->size == 0) {
|
|
1200
|
+
cur_p->size = 1;
|
|
410
1201
|
}
|
|
411
1202
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
1203
|
+
// Normalize the probabilities of the remaining words
|
|
1204
|
+
llama_sampler_softmax_impl(cur_p);
|
|
1205
|
+
|
|
1206
|
+
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
|
1207
|
+
|
|
1208
|
+
cur_p->selected = idx;
|
|
1209
|
+
|
|
1210
|
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
|
1211
|
+
float e = observed_surprise - ctx->tau;
|
|
1212
|
+
|
|
1213
|
+
// Update mu using the learning rate and error
|
|
1214
|
+
ctx->mu = ctx->mu - ctx->eta * e;
|
|
1215
|
+
}
|
|
1216
|
+
|
|
1217
|
+
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
|
|
1218
|
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
|
1219
|
+
ctx->mu = 2.0f*ctx->tau;
|
|
1220
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
1221
|
+
ctx->rng.seed(ctx->seed_cur);
|
|
1222
|
+
}
|
|
1223
|
+
|
|
1224
|
+
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
|
|
1225
|
+
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
|
|
1226
|
+
|
|
1227
|
+
auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
|
|
1228
|
+
|
|
1229
|
+
// copy the state
|
|
1230
|
+
{
|
|
1231
|
+
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
|
|
1232
|
+
|
|
1233
|
+
result_ctx->mu = ctx->mu;
|
|
1234
|
+
result_ctx->rng = ctx->rng;
|
|
417
1235
|
}
|
|
418
|
-
#endif
|
|
419
1236
|
|
|
420
|
-
|
|
421
|
-
|
|
1237
|
+
return result;
|
|
1238
|
+
}
|
|
1239
|
+
|
|
1240
|
+
static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
|
1241
|
+
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
|
|
1242
|
+
}
|
|
1243
|
+
|
|
1244
|
+
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|
1245
|
+
/* .name = */ llama_sampler_mirostat_v2_name,
|
|
1246
|
+
/* .accept = */ nullptr,
|
|
1247
|
+
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
|
1248
|
+
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
|
1249
|
+
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
|
1250
|
+
/* .free = */ llama_sampler_mirostat_v2_free,
|
|
1251
|
+
};
|
|
1252
|
+
|
|
1253
|
+
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
|
1254
|
+
auto seed_cur = get_rng_seed(seed);
|
|
1255
|
+
return new llama_sampler {
|
|
1256
|
+
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
|
1257
|
+
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
|
1258
|
+
/* .seed = */ seed,
|
|
1259
|
+
/* .seed_cur = */ seed_cur,
|
|
1260
|
+
/* .tau = */ tau,
|
|
1261
|
+
/* .eta = */ eta,
|
|
1262
|
+
/* .mu = */ 2.0f*tau,
|
|
1263
|
+
/* .rng = */ std::mt19937(seed_cur),
|
|
1264
|
+
},
|
|
1265
|
+
};
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
// grammar
|
|
1269
|
+
|
|
1270
|
+
struct llama_sampler_grammar {
|
|
1271
|
+
const struct llama_vocab * vocab;
|
|
1272
|
+
|
|
1273
|
+
std::string grammar_str;
|
|
1274
|
+
std::string grammar_root;
|
|
1275
|
+
|
|
1276
|
+
struct llama_grammar * grammar;
|
|
1277
|
+
};
|
|
1278
|
+
|
|
1279
|
+
static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
|
|
1280
|
+
return "grammar";
|
|
1281
|
+
}
|
|
1282
|
+
|
|
1283
|
+
static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
|
|
1284
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
|
1285
|
+
if (ctx->grammar) {
|
|
1286
|
+
llama_grammar_accept_impl(*ctx->grammar, token);
|
|
1287
|
+
}
|
|
1288
|
+
}
|
|
1289
|
+
|
|
1290
|
+
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1291
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
|
1292
|
+
if (ctx->grammar) {
|
|
1293
|
+
llama_grammar_apply_impl(*ctx->grammar, cur_p);
|
|
1294
|
+
}
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|
1298
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
|
1299
|
+
if (!ctx->grammar) {
|
|
1300
|
+
return;
|
|
422
1301
|
}
|
|
1302
|
+
|
|
1303
|
+
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
|
|
1304
|
+
|
|
1305
|
+
llama_grammar_free_impl(ctx->grammar);
|
|
1306
|
+
ctx->grammar = grammar_new;
|
|
423
1307
|
}
|
|
424
1308
|
|
|
425
|
-
|
|
426
|
-
const
|
|
1309
|
+
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
|
1310
|
+
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
|
1311
|
+
|
|
1312
|
+
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
|
|
1313
|
+
|
|
1314
|
+
// copy the state
|
|
1315
|
+
{
|
|
1316
|
+
auto * result_ctx = (llama_sampler_grammar *) result->ctx;
|
|
427
1317
|
|
|
428
|
-
|
|
429
|
-
|
|
1318
|
+
if (ctx->grammar) {
|
|
1319
|
+
result_ctx->grammar_str = ctx->grammar_str;
|
|
1320
|
+
result_ctx->grammar_root = ctx->grammar_root;
|
|
1321
|
+
|
|
1322
|
+
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
|
|
1323
|
+
}
|
|
1324
|
+
}
|
|
1325
|
+
|
|
1326
|
+
return result;
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
|
1330
|
+
const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
|
1331
|
+
|
|
1332
|
+
if (ctx->grammar) {
|
|
1333
|
+
llama_grammar_free_impl(ctx->grammar);
|
|
1334
|
+
}
|
|
1335
|
+
|
|
1336
|
+
delete ctx;
|
|
1337
|
+
}
|
|
1338
|
+
|
|
1339
|
+
static struct llama_sampler_i llama_sampler_grammar_i = {
|
|
1340
|
+
/* .name = */ llama_sampler_grammar_name,
|
|
1341
|
+
/* .accept = */ llama_sampler_grammar_accept_impl,
|
|
1342
|
+
/* .apply = */ llama_sampler_grammar_apply,
|
|
1343
|
+
/* .reset = */ llama_sampler_grammar_reset,
|
|
1344
|
+
/* .clone = */ llama_sampler_grammar_clone,
|
|
1345
|
+
/* .free = */ llama_sampler_grammar_free,
|
|
1346
|
+
};
|
|
1347
|
+
|
|
1348
|
+
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
|
1349
|
+
auto * ctx = new llama_sampler_grammar;
|
|
1350
|
+
|
|
1351
|
+
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
1352
|
+
*ctx = {
|
|
1353
|
+
/* .vocab = */ &vocab,
|
|
1354
|
+
/* .grammar_str = */ grammar_str,
|
|
1355
|
+
/* .grammar_root = */ grammar_root,
|
|
1356
|
+
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
|
1357
|
+
};
|
|
1358
|
+
} else {
|
|
1359
|
+
*ctx = {
|
|
1360
|
+
/* .vocab = */ &vocab,
|
|
1361
|
+
/* .grammar_str = */ {},
|
|
1362
|
+
/* .grammar_root = */ {},
|
|
1363
|
+
/* .grammar = */ nullptr,
|
|
1364
|
+
};
|
|
430
1365
|
}
|
|
431
1366
|
|
|
432
|
-
|
|
433
|
-
|
|
1367
|
+
return new llama_sampler {
|
|
1368
|
+
/* .iface = */ &llama_sampler_grammar_i,
|
|
1369
|
+
/* .ctx = */ ctx,
|
|
1370
|
+
};
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1373
|
+
// penalties
|
|
1374
|
+
|
|
1375
|
+
struct llama_sampler_penalties {
|
|
1376
|
+
const int32_t n_vocab;
|
|
1377
|
+
const llama_token special_eos_id;
|
|
1378
|
+
const llama_token linefeed_id;
|
|
1379
|
+
|
|
1380
|
+
const int32_t penalty_last_n;
|
|
1381
|
+
const float penalty_repeat;
|
|
1382
|
+
const float penalty_freq;
|
|
1383
|
+
const float penalty_present;
|
|
1384
|
+
|
|
1385
|
+
const bool penalize_nl;
|
|
1386
|
+
const bool ignore_eos;
|
|
1387
|
+
|
|
1388
|
+
ring_buffer<llama_token> prev;
|
|
1389
|
+
};
|
|
1390
|
+
|
|
1391
|
+
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
|
1392
|
+
return "penalties";
|
|
1393
|
+
}
|
|
1394
|
+
|
|
1395
|
+
static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
|
|
1396
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
|
1397
|
+
if (ctx->penalty_last_n == 0) {
|
|
1398
|
+
return;
|
|
434
1399
|
}
|
|
1400
|
+
|
|
1401
|
+
ctx->prev.push_back(token);
|
|
435
1402
|
}
|
|
436
1403
|
|
|
437
|
-
void
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
1404
|
+
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1405
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
|
1406
|
+
|
|
1407
|
+
if (ctx->ignore_eos) {
|
|
1408
|
+
assert(ctx->special_eos_id >= 0);
|
|
1409
|
+
|
|
1410
|
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
|
1411
|
+
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
|
1412
|
+
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
|
1413
|
+
} else {
|
|
1414
|
+
// else, search for the special EOS token
|
|
1415
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1416
|
+
if (cur_p->data[i].id == ctx->special_eos_id) {
|
|
1417
|
+
cur_p->data[i].logit = -INFINITY;
|
|
1418
|
+
break;
|
|
1419
|
+
}
|
|
1420
|
+
}
|
|
1421
|
+
}
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
if ((ctx->penalty_last_n == 0) ||
|
|
1425
|
+
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
|
446
1426
|
return;
|
|
447
1427
|
}
|
|
448
1428
|
|
|
449
|
-
|
|
1429
|
+
bool nl_found = false;
|
|
1430
|
+
size_t nl_idx = 0;
|
|
1431
|
+
float nl_logit = -INFINITY;
|
|
1432
|
+
if (!ctx->penalize_nl) {
|
|
1433
|
+
assert(ctx->linefeed_id >= 0);
|
|
1434
|
+
|
|
1435
|
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
|
1436
|
+
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
|
1437
|
+
nl_found = true;
|
|
1438
|
+
nl_idx = ctx->linefeed_id;
|
|
1439
|
+
nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
|
1440
|
+
} else {
|
|
1441
|
+
// else, search for the linefeed token
|
|
1442
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1443
|
+
if (cur_p->data[i].id == ctx->linefeed_id) {
|
|
1444
|
+
nl_found = true;
|
|
1445
|
+
nl_idx = i;
|
|
1446
|
+
nl_logit = cur_p->data[i].logit;
|
|
1447
|
+
break;
|
|
1448
|
+
}
|
|
1449
|
+
}
|
|
1450
|
+
}
|
|
1451
|
+
}
|
|
450
1452
|
|
|
451
1453
|
// Create a frequency map to count occurrences of each token in last_tokens
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
1454
|
+
// TODO: optimize this by maintaining the token count in the sampler context
|
|
1455
|
+
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
|
1456
|
+
llama_token_cnt token_count;
|
|
1457
|
+
|
|
1458
|
+
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
|
1459
|
+
token_count[ctx->prev.rat(i)]++;
|
|
455
1460
|
}
|
|
456
1461
|
|
|
457
|
-
// Apply frequency and presence penalties to the
|
|
458
|
-
for (size_t i = 0; i <
|
|
459
|
-
const auto token_iter = token_count.find(
|
|
1462
|
+
// Apply frequency and presence penalties to the cur_p
|
|
1463
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1464
|
+
const auto token_iter = token_count.find(cur_p->data[i].id);
|
|
460
1465
|
if (token_iter == token_count.end()) {
|
|
461
1466
|
continue;
|
|
462
1467
|
}
|
|
@@ -465,171 +1470,238 @@ void llama_sample_repetition_penalties_impl(
|
|
|
465
1470
|
|
|
466
1471
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
|
467
1472
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
|
468
|
-
if (
|
|
469
|
-
|
|
1473
|
+
if (cur_p->data[i].logit <= 0) {
|
|
1474
|
+
cur_p->data[i].logit *= ctx->penalty_repeat;
|
|
470
1475
|
} else {
|
|
471
|
-
|
|
1476
|
+
cur_p->data[i].logit /= ctx->penalty_repeat;
|
|
472
1477
|
}
|
|
473
1478
|
|
|
474
|
-
|
|
1479
|
+
cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
|
|
475
1480
|
}
|
|
476
1481
|
|
|
477
|
-
|
|
1482
|
+
cur_p->sorted = false;
|
|
478
1483
|
|
|
479
|
-
if (
|
|
480
|
-
|
|
1484
|
+
if (!ctx->penalize_nl && nl_found) {
|
|
1485
|
+
// restore the logit of the newline token if it was penalized
|
|
1486
|
+
cur_p->data[nl_idx].logit = nl_logit;
|
|
481
1487
|
}
|
|
482
1488
|
}
|
|
483
1489
|
|
|
484
|
-
void
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
float scale) {
|
|
489
|
-
GGML_ASSERT(smpl);
|
|
490
|
-
|
|
491
|
-
const auto t_start_sample_us = ggml_time_us();
|
|
492
|
-
const auto n_vocab = smpl->n_vocab;
|
|
493
|
-
|
|
494
|
-
llama_log_softmax(logits, n_vocab);
|
|
495
|
-
llama_log_softmax(logits_guidance, n_vocab);
|
|
1490
|
+
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
|
1491
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
|
1492
|
+
ctx->prev.clear();
|
|
1493
|
+
}
|
|
496
1494
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
1495
|
+
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
|
1496
|
+
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
|
1497
|
+
auto * result = llama_sampler_init_penalties(
|
|
1498
|
+
ctx->n_vocab,
|
|
1499
|
+
ctx->special_eos_id,
|
|
1500
|
+
ctx->linefeed_id,
|
|
1501
|
+
ctx->penalty_last_n,
|
|
1502
|
+
ctx->penalty_repeat,
|
|
1503
|
+
ctx->penalty_freq,
|
|
1504
|
+
ctx->penalty_present,
|
|
1505
|
+
ctx->penalize_nl,
|
|
1506
|
+
ctx->ignore_eos);
|
|
1507
|
+
|
|
1508
|
+
// copy the state
|
|
1509
|
+
{
|
|
1510
|
+
auto * result_ctx = (llama_sampler_penalties *) result->ctx;
|
|
500
1511
|
|
|
501
|
-
|
|
1512
|
+
result_ctx->prev = ctx->prev;
|
|
502
1513
|
}
|
|
503
1514
|
|
|
504
|
-
|
|
1515
|
+
return result;
|
|
505
1516
|
}
|
|
506
1517
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
const int32_t n_vocab = float(smpl->n_vocab);
|
|
511
|
-
|
|
512
|
-
int64_t t_start_sample_us = ggml_time_us();
|
|
1518
|
+
static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
|
1519
|
+
delete (llama_sampler_penalties *) smpl->ctx;
|
|
1520
|
+
}
|
|
513
1521
|
|
|
514
|
-
|
|
1522
|
+
static struct llama_sampler_i llama_sampler_penalties_i = {
|
|
1523
|
+
/* .name = */ llama_sampler_penalties_name,
|
|
1524
|
+
/* .accept = */ llama_sampler_penalties_accept,
|
|
1525
|
+
/* .apply = */ llama_sampler_penalties_apply,
|
|
1526
|
+
/* .reset = */ llama_sampler_penalties_reset,
|
|
1527
|
+
/* .clone = */ llama_sampler_penalties_clone,
|
|
1528
|
+
/* .free = */ llama_sampler_penalties_free,
|
|
1529
|
+
};
|
|
1530
|
+
|
|
1531
|
+
struct llama_sampler * llama_sampler_init_penalties(
|
|
1532
|
+
int32_t n_vocab,
|
|
1533
|
+
llama_token special_eos_id,
|
|
1534
|
+
llama_token linefeed_id,
|
|
1535
|
+
int32_t penalty_last_n,
|
|
1536
|
+
float penalty_repeat,
|
|
1537
|
+
float penalty_freq,
|
|
1538
|
+
float penalty_present,
|
|
1539
|
+
bool penalize_nl,
|
|
1540
|
+
bool ignore_eos) {
|
|
1541
|
+
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
|
1542
|
+
penalize_nl = true;
|
|
1543
|
+
}
|
|
515
1544
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
float sum_ti_bi = 0.0;
|
|
519
|
-
float sum_ti_sq = 0.0;
|
|
520
|
-
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
|
521
|
-
float t_i = logf(float(i + 2) / float(i + 1));
|
|
522
|
-
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
|
523
|
-
sum_ti_bi += t_i * b_i;
|
|
524
|
-
sum_ti_sq += t_i * t_i;
|
|
1545
|
+
if (special_eos_id == LLAMA_TOKEN_NULL) {
|
|
1546
|
+
ignore_eos = false;
|
|
525
1547
|
}
|
|
526
|
-
s_hat = sum_ti_bi / sum_ti_sq;
|
|
527
1548
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
1549
|
+
penalty_last_n = std::max(penalty_last_n, 0);
|
|
1550
|
+
|
|
1551
|
+
return new llama_sampler {
|
|
1552
|
+
/* .iface = */ &llama_sampler_penalties_i,
|
|
1553
|
+
/* .ctx = */ new llama_sampler_penalties {
|
|
1554
|
+
/* .n_vocab = */ n_vocab,
|
|
1555
|
+
/* .special_eos_id = */ special_eos_id,
|
|
1556
|
+
/* .linefeed_id = */ linefeed_id,
|
|
1557
|
+
/* .penalty_last_n = */ penalty_last_n,
|
|
1558
|
+
/* .penalty_repeat = */ penalty_repeat,
|
|
1559
|
+
/* .penalty_freq = */ penalty_freq,
|
|
1560
|
+
/* .penalty_present = */ penalty_present,
|
|
1561
|
+
/* .penalize_nl = */ penalize_nl,
|
|
1562
|
+
/* .ignore_eos = */ ignore_eos,
|
|
1563
|
+
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
|
1564
|
+
},
|
|
1565
|
+
};
|
|
1566
|
+
}
|
|
531
1567
|
|
|
532
|
-
|
|
533
|
-
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
|
534
|
-
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
535
|
-
llama_token X = llama_sample_token_impl(smpl, candidates);
|
|
536
|
-
t_start_sample_us = ggml_time_us();
|
|
1568
|
+
// logit-bias
|
|
537
1569
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
return candidate.id == X;
|
|
541
|
-
}));
|
|
542
|
-
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
|
543
|
-
float e = observed_surprise - tau;
|
|
1570
|
+
struct llama_sampler_logit_bias {
|
|
1571
|
+
const int32_t n_vocab;
|
|
544
1572
|
|
|
545
|
-
|
|
546
|
-
*mu = *mu - eta * e;
|
|
1573
|
+
const std::vector<llama_logit_bias> logit_bias;
|
|
547
1574
|
|
|
548
|
-
|
|
549
|
-
|
|
1575
|
+
std::vector<llama_logit_bias> to_search;
|
|
1576
|
+
};
|
|
1577
|
+
|
|
1578
|
+
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
|
|
1579
|
+
return "logit-bias";
|
|
550
1580
|
}
|
|
551
1581
|
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
1582
|
+
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1583
|
+
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
1584
|
+
|
|
1585
|
+
if (ctx->logit_bias.empty()) {
|
|
1586
|
+
return;
|
|
1587
|
+
}
|
|
555
1588
|
|
|
556
|
-
|
|
1589
|
+
ctx->to_search.clear();
|
|
557
1590
|
|
|
558
|
-
//
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
1591
|
+
// update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
|
|
1592
|
+
for (const auto & lb : ctx->logit_bias) {
|
|
1593
|
+
if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
|
|
1594
|
+
cur_p->data[lb.token].logit += lb.bias;
|
|
1595
|
+
} else {
|
|
1596
|
+
ctx->to_search.push_back(lb);
|
|
1597
|
+
}
|
|
1598
|
+
}
|
|
562
1599
|
|
|
563
|
-
if (
|
|
564
|
-
|
|
1600
|
+
if (ctx->to_search.empty()) {
|
|
1601
|
+
return;
|
|
565
1602
|
}
|
|
566
1603
|
|
|
567
|
-
|
|
568
|
-
|
|
1604
|
+
// search for the remaining candidates that were not found in the previous step
|
|
1605
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1606
|
+
for (const auto & lb : ctx->to_search) {
|
|
1607
|
+
if (cur_p->data[i].id == lb.token) {
|
|
1608
|
+
cur_p->data[i].logit += lb.bias;
|
|
1609
|
+
break;
|
|
1610
|
+
}
|
|
1611
|
+
}
|
|
569
1612
|
}
|
|
1613
|
+
}
|
|
570
1614
|
|
|
571
|
-
|
|
572
|
-
|
|
1615
|
+
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
|
|
1616
|
+
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
|
|
1617
|
+
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
|
|
1618
|
+
}
|
|
573
1619
|
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
1620
|
+
static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
|
1621
|
+
delete (llama_sampler_logit_bias *) smpl->ctx;
|
|
1622
|
+
}
|
|
577
1623
|
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
1624
|
+
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
|
1625
|
+
/* .name = */ llama_sampler_logit_bias_name,
|
|
1626
|
+
/* .accept = */ nullptr,
|
|
1627
|
+
/* .apply = */ llama_sampler_logit_bias_apply,
|
|
1628
|
+
/* .reset = */ nullptr,
|
|
1629
|
+
/* .clone = */ llama_sampler_logit_bias_clone,
|
|
1630
|
+
/* .free = */ llama_sampler_logit_bias_free,
|
|
1631
|
+
};
|
|
1632
|
+
|
|
1633
|
+
struct llama_sampler * llama_sampler_init_logit_bias(
|
|
1634
|
+
int32_t n_vocab,
|
|
1635
|
+
int32_t n_logit_bias,
|
|
1636
|
+
const llama_logit_bias * logit_bias) {
|
|
1637
|
+
return new llama_sampler {
|
|
1638
|
+
/* .iface = */ &llama_sampler_logit_bias_i,
|
|
1639
|
+
/* .ctx = */ new llama_sampler_logit_bias {
|
|
1640
|
+
/* .n_vocab = */ n_vocab,
|
|
1641
|
+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
|
1642
|
+
/* .to_search = */ {},
|
|
1643
|
+
},
|
|
1644
|
+
};
|
|
1645
|
+
}
|
|
584
1646
|
|
|
585
|
-
|
|
586
|
-
*mu = *mu - eta * e;
|
|
1647
|
+
// utils
|
|
587
1648
|
|
|
588
|
-
|
|
589
|
-
|
|
1649
|
+
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|
1650
|
+
if (smpl->iface == &llama_sampler_dist_i) {
|
|
1651
|
+
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
|
|
590
1652
|
}
|
|
591
|
-
return X;
|
|
592
|
-
}
|
|
593
1653
|
|
|
594
|
-
|
|
595
|
-
|
|
1654
|
+
if (smpl->iface == &llama_sampler_mirostat_i) {
|
|
1655
|
+
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
|
|
1656
|
+
}
|
|
596
1657
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
});
|
|
1658
|
+
if (smpl->iface == &llama_sampler_mirostat_v2_i) {
|
|
1659
|
+
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
|
|
1660
|
+
}
|
|
601
1661
|
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
1662
|
+
if (smpl->iface == &llama_sampler_chain_i) {
|
|
1663
|
+
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
|
1664
|
+
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
|
1665
|
+
const uint32_t seed = llama_sampler_get_seed(*it);
|
|
1666
|
+
if (seed != LLAMA_DEFAULT_SEED) {
|
|
1667
|
+
return seed;
|
|
1668
|
+
}
|
|
1669
|
+
}
|
|
606
1670
|
}
|
|
607
|
-
|
|
1671
|
+
|
|
1672
|
+
return LLAMA_DEFAULT_SEED;
|
|
608
1673
|
}
|
|
609
1674
|
|
|
610
|
-
|
|
611
|
-
GGML_ASSERT(smpl);
|
|
1675
|
+
// perf
|
|
612
1676
|
|
|
613
|
-
|
|
614
|
-
|
|
1677
|
+
struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
|
|
1678
|
+
struct llama_perf_sampler_data data = {};
|
|
615
1679
|
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
|
619
|
-
probs.push_back(candidates->data[i].p);
|
|
1680
|
+
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
|
1681
|
+
GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
|
620
1682
|
}
|
|
621
1683
|
|
|
622
|
-
|
|
623
|
-
int idx = dist(rng);
|
|
1684
|
+
const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
|
|
624
1685
|
|
|
625
|
-
|
|
1686
|
+
data.t_sample_ms = 1e-3 * ctx->t_sample_us;
|
|
1687
|
+
data.n_sample = std::max(0, ctx->n_sample);
|
|
626
1688
|
|
|
627
|
-
|
|
628
|
-
|
|
1689
|
+
return data;
|
|
1690
|
+
}
|
|
629
1691
|
|
|
630
|
-
|
|
1692
|
+
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
|
1693
|
+
const auto data = llama_perf_sampler(chain);
|
|
1694
|
+
|
|
1695
|
+
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
1696
|
+
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
|
631
1697
|
}
|
|
632
1698
|
|
|
633
|
-
|
|
634
|
-
|
|
1699
|
+
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|
1700
|
+
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
|
1701
|
+
GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
|
1702
|
+
}
|
|
1703
|
+
|
|
1704
|
+
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
|
1705
|
+
|
|
1706
|
+
ctx->t_sample_us = ctx->n_sample = 0;
|
|
635
1707
|
}
|