@fugood/llama.node 0.3.6 → 0.3.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -2
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +3 -1
- package/lib/index.js +16 -1
- package/lib/index.ts +16 -0
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +4 -3
- package/src/LlamaCompletionWorker.cpp +4 -2
- package/src/LlamaContext.cpp +61 -6
- package/src/LlamaContext.h +1 -0
- package/src/common.hpp +6 -11
- package/src/llama.cpp/.github/workflows/build.yml +19 -17
- package/src/llama.cpp/.github/workflows/docker.yml +77 -30
- package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +22 -3
- package/src/llama.cpp/CMakeLists.txt +49 -24
- package/src/llama.cpp/common/arg.cpp +82 -26
- package/src/llama.cpp/common/arg.h +3 -0
- package/src/llama.cpp/common/common.cpp +192 -72
- package/src/llama.cpp/common/common.h +51 -18
- package/src/llama.cpp/common/ngram-cache.cpp +12 -12
- package/src/llama.cpp/common/ngram-cache.h +2 -2
- package/src/llama.cpp/common/sampling.cpp +11 -6
- package/src/llama.cpp/common/speculative.cpp +18 -15
- package/src/llama.cpp/docs/build.md +2 -0
- package/src/llama.cpp/examples/batched/batched.cpp +9 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
- package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
- package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
- package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
- package/src/llama.cpp/examples/infill/infill.cpp +23 -24
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
- package/src/llama.cpp/examples/llava/clip.cpp +4 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
- package/src/llama.cpp/examples/llava/llava.cpp +2 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
- package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
- package/src/llama.cpp/examples/main/main.cpp +51 -29
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
- package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
- package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
- package/src/llama.cpp/examples/run/run.cpp +175 -61
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
- package/src/llama.cpp/examples/server/httplib.h +1295 -409
- package/src/llama.cpp/examples/server/server.cpp +387 -181
- package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
- package/src/llama.cpp/examples/server/utils.hpp +170 -58
- package/src/llama.cpp/examples/simple/simple.cpp +9 -8
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
- package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
- package/src/llama.cpp/examples/tts/tts.cpp +64 -23
- package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
- package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +36 -145
- package/src/llama.cpp/ggml/include/gguf.h +202 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
- package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml.c +117 -1327
- package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
- package/src/llama.cpp/include/llama-cpp.h +6 -1
- package/src/llama.cpp/include/llama.h +138 -75
- package/src/llama.cpp/src/CMakeLists.txt +13 -1
- package/src/llama.cpp/src/llama-adapter.cpp +347 -0
- package/src/llama.cpp/src/llama-adapter.h +74 -0
- package/src/llama.cpp/src/llama-arch.cpp +1487 -0
- package/src/llama.cpp/src/llama-arch.h +400 -0
- package/src/llama.cpp/src/llama-batch.cpp +368 -0
- package/src/llama.cpp/src/llama-batch.h +88 -0
- package/src/llama.cpp/src/llama-chat.cpp +578 -0
- package/src/llama.cpp/src/llama-chat.h +52 -0
- package/src/llama.cpp/src/llama-context.cpp +1775 -0
- package/src/llama.cpp/src/llama-context.h +128 -0
- package/src/llama.cpp/src/llama-cparams.cpp +1 -0
- package/src/llama.cpp/src/llama-cparams.h +37 -0
- package/src/llama.cpp/src/llama-grammar.cpp +5 -4
- package/src/llama.cpp/src/llama-grammar.h +3 -1
- package/src/llama.cpp/src/llama-hparams.cpp +71 -0
- package/src/llama.cpp/src/llama-hparams.h +139 -0
- package/src/llama.cpp/src/llama-impl.cpp +167 -0
- package/src/llama.cpp/src/llama-impl.h +16 -136
- package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
- package/src/llama.cpp/src/llama-kv-cache.h +218 -0
- package/src/llama.cpp/src/llama-mmap.cpp +589 -0
- package/src/llama.cpp/src/llama-mmap.h +67 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
- package/src/llama.cpp/src/llama-model-loader.h +167 -0
- package/src/llama.cpp/src/llama-model.cpp +3953 -0
- package/src/llama.cpp/src/llama-model.h +370 -0
- package/src/llama.cpp/src/llama-quant.cpp +934 -0
- package/src/llama.cpp/src/llama-quant.h +1 -0
- package/src/llama.cpp/src/llama-sampling.cpp +147 -32
- package/src/llama.cpp/src/llama-sampling.h +3 -19
- package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
- package/src/llama.cpp/src/llama-vocab.h +97 -142
- package/src/llama.cpp/src/llama.cpp +7160 -20314
- package/src/llama.cpp/src/unicode.cpp +8 -3
- package/src/llama.cpp/tests/CMakeLists.txt +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
- package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
- package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
- package/src/llama.cpp/tests/test-gguf.cpp +222 -187
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +0 -1
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
|
@@ -0,0 +1 @@
|
|
|
1
|
+
#pragma once
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
#include "llama-sampling.h"
|
|
2
2
|
|
|
3
|
+
#include "llama-impl.h"
|
|
3
4
|
#include "llama-vocab.h"
|
|
4
5
|
#include "llama-grammar.h"
|
|
5
6
|
|
|
@@ -14,6 +15,118 @@
|
|
|
14
15
|
#include <numeric>
|
|
15
16
|
#include <random>
|
|
16
17
|
#include <unordered_map>
|
|
18
|
+
#include <stdexcept>
|
|
19
|
+
|
|
20
|
+
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
|
21
|
+
template<typename T>
|
|
22
|
+
struct ring_buffer {
|
|
23
|
+
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
|
24
|
+
|
|
25
|
+
T & front() {
|
|
26
|
+
if (sz == 0) {
|
|
27
|
+
throw std::runtime_error("ring buffer is empty");
|
|
28
|
+
}
|
|
29
|
+
return data[first];
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
const T & front() const {
|
|
33
|
+
if (sz == 0) {
|
|
34
|
+
throw std::runtime_error("ring buffer is empty");
|
|
35
|
+
}
|
|
36
|
+
return data[first];
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
T & back() {
|
|
40
|
+
if (sz == 0) {
|
|
41
|
+
throw std::runtime_error("ring buffer is empty");
|
|
42
|
+
}
|
|
43
|
+
return data[pos];
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
const T & back() const {
|
|
47
|
+
if (sz == 0) {
|
|
48
|
+
throw std::runtime_error("ring buffer is empty");
|
|
49
|
+
}
|
|
50
|
+
return data[pos];
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
void push_back(const T & value) {
|
|
54
|
+
if (capacity == 0) {
|
|
55
|
+
throw std::runtime_error("ring buffer: capacity is zero");
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if (sz == capacity) {
|
|
59
|
+
// advance the start when buffer is full
|
|
60
|
+
first = (first + 1) % capacity;
|
|
61
|
+
} else {
|
|
62
|
+
sz++;
|
|
63
|
+
}
|
|
64
|
+
data[pos] = value;
|
|
65
|
+
pos = (pos + 1) % capacity;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
T pop_front() {
|
|
69
|
+
if (sz == 0) {
|
|
70
|
+
throw std::runtime_error("ring buffer is empty");
|
|
71
|
+
}
|
|
72
|
+
T value = data[first];
|
|
73
|
+
first = (first + 1) % capacity;
|
|
74
|
+
sz--;
|
|
75
|
+
return value;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
//T & operator[](size_t i) {
|
|
79
|
+
// if (i >= sz) {
|
|
80
|
+
// throw std::runtime_error("ring buffer: index out of bounds");
|
|
81
|
+
// }
|
|
82
|
+
// return data[(first + i) % capacity];
|
|
83
|
+
//}
|
|
84
|
+
|
|
85
|
+
//const T & at(size_t i) const {
|
|
86
|
+
// if (i >= sz) {
|
|
87
|
+
// throw std::runtime_error("ring buffer: index out of bounds");
|
|
88
|
+
// }
|
|
89
|
+
// return data[(first + i) % capacity];
|
|
90
|
+
//}
|
|
91
|
+
|
|
92
|
+
const T & rat(size_t i) const {
|
|
93
|
+
if (i >= sz) {
|
|
94
|
+
throw std::runtime_error("ring buffer: index out of bounds");
|
|
95
|
+
}
|
|
96
|
+
return data[(first + sz - i - 1) % capacity];
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
std::vector<T> to_vector() const {
|
|
100
|
+
std::vector<T> result;
|
|
101
|
+
result.reserve(sz);
|
|
102
|
+
for (size_t i = 0; i < sz; i++) {
|
|
103
|
+
result.push_back(data[(first + i) % capacity]);
|
|
104
|
+
}
|
|
105
|
+
return result;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
void clear() {
|
|
109
|
+
// here only reset the status of the buffer
|
|
110
|
+
sz = 0;
|
|
111
|
+
first = 0;
|
|
112
|
+
pos = 0;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
bool empty() const {
|
|
116
|
+
return sz == 0;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
size_t size() const {
|
|
120
|
+
return sz;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
size_t capacity = 0;
|
|
124
|
+
size_t sz = 0;
|
|
125
|
+
size_t first = 0;
|
|
126
|
+
size_t pos = 0;
|
|
127
|
+
|
|
128
|
+
std::vector<T> data;
|
|
129
|
+
};
|
|
17
130
|
|
|
18
131
|
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
|
19
132
|
// iterator for the probabilities
|
|
@@ -144,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|
|
144
257
|
for (int i = 0; i < (int)cur_p->size; ++i) {
|
|
145
258
|
const float val = cur_p->data[i].logit;
|
|
146
259
|
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
|
147
|
-
ib = std::max(0, std::min(nbuckets-1, ib));
|
|
260
|
+
ib = std::max(0, std::min(nbuckets - 1, ib));
|
|
148
261
|
bucket_idx[i] = ib;
|
|
149
262
|
++histo[ib];
|
|
150
263
|
}
|
|
@@ -167,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|
|
167
280
|
for (int i = 0; i < (int)cur_p->size; ++i) {
|
|
168
281
|
int j = bucket_idx[i];
|
|
169
282
|
if (j >= ib) {
|
|
170
|
-
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
|
|
283
|
+
*bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
|
|
171
284
|
}
|
|
172
285
|
}
|
|
173
286
|
|
|
174
287
|
ptr = tmp_tokens.data();
|
|
175
288
|
int ndone = 0;
|
|
176
|
-
for (int j = nbuckets-1; j > ib; --j) {
|
|
289
|
+
for (int j = nbuckets - 1; j > ib; --j) {
|
|
177
290
|
std::sort(ptr, ptr + histo[j], comp);
|
|
178
291
|
ptr += histo[j];
|
|
179
292
|
ndone += histo[j];
|
|
@@ -258,7 +371,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|
|
258
371
|
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
|
259
372
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
260
373
|
|
|
261
|
-
const
|
|
374
|
+
const llama_model * model = llama_get_model(ctx);
|
|
375
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
376
|
+
|
|
377
|
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
262
378
|
|
|
263
379
|
// TODO: do not allocate each time
|
|
264
380
|
std::vector<llama_token_data> cur;
|
|
@@ -1332,7 +1448,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|
|
1332
1448
|
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
|
1333
1449
|
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
|
1334
1450
|
|
|
1335
|
-
auto * result =
|
|
1451
|
+
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
|
|
1336
1452
|
|
|
1337
1453
|
// copy the state
|
|
1338
1454
|
{
|
|
@@ -1368,19 +1484,19 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
|
|
|
1368
1484
|
/* .free = */ llama_sampler_grammar_free,
|
|
1369
1485
|
};
|
|
1370
1486
|
|
|
1371
|
-
struct llama_sampler *
|
|
1487
|
+
struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
|
1372
1488
|
auto * ctx = new llama_sampler_grammar;
|
|
1373
1489
|
|
|
1374
1490
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
1375
1491
|
*ctx = {
|
|
1376
|
-
/* .vocab = */
|
|
1492
|
+
/* .vocab = */ vocab,
|
|
1377
1493
|
/* .grammar_str = */ grammar_str,
|
|
1378
1494
|
/* .grammar_root = */ grammar_root,
|
|
1379
|
-
/* .grammar = */ llama_grammar_init_impl(
|
|
1495
|
+
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
|
|
1380
1496
|
};
|
|
1381
1497
|
} else {
|
|
1382
1498
|
*ctx = {
|
|
1383
|
-
/* .vocab = */
|
|
1499
|
+
/* .vocab = */ vocab,
|
|
1384
1500
|
/* .grammar_str = */ {},
|
|
1385
1501
|
/* .grammar_root = */ {},
|
|
1386
1502
|
/* .grammar = */ nullptr,
|
|
@@ -1550,8 +1666,8 @@ struct llama_sampler_dry {
|
|
|
1550
1666
|
|
|
1551
1667
|
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
|
1552
1668
|
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
|
1553
|
-
for (llama_token token_id = 0; token_id < (llama_token)vocab.
|
|
1554
|
-
std::string word =
|
|
1669
|
+
for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
|
|
1670
|
+
std::string word = vocab.detokenize({token_id}, true);
|
|
1555
1671
|
if (word.find(str) != std::string::npos) {
|
|
1556
1672
|
token_sequences.emplace(token_id, std::vector<llama_token>());
|
|
1557
1673
|
} else {
|
|
@@ -1568,7 +1684,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
|
|
|
1568
1684
|
}
|
|
1569
1685
|
}
|
|
1570
1686
|
if (match) {
|
|
1571
|
-
std::vector<llama_token> tokenization =
|
|
1687
|
+
std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
|
|
1572
1688
|
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
|
1573
1689
|
tokenization.resize(max_tail_len);
|
|
1574
1690
|
}
|
|
@@ -1719,7 +1835,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
|
|
|
1719
1835
|
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
|
1720
1836
|
if (n > 0) {
|
|
1721
1837
|
lt = k;
|
|
1722
|
-
rt = k+n-1;
|
|
1838
|
+
rt = k + n - 1;
|
|
1723
1839
|
}
|
|
1724
1840
|
} else {
|
|
1725
1841
|
// If k is inside the current Z-box, consider two cases.
|
|
@@ -1824,7 +1940,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
|
|
|
1824
1940
|
llama_vocab dummy_vocab;
|
|
1825
1941
|
|
|
1826
1942
|
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
|
|
1827
|
-
auto * result =
|
|
1943
|
+
auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
|
1828
1944
|
|
|
1829
1945
|
// Copy the state, including the processed breakers
|
|
1830
1946
|
{
|
|
@@ -1851,7 +1967,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
|
|
|
1851
1967
|
/* .free = */ llama_sampler_dry_free,
|
|
1852
1968
|
};
|
|
1853
1969
|
|
|
1854
|
-
struct llama_sampler *
|
|
1970
|
+
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
|
1855
1971
|
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
|
1856
1972
|
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
|
1857
1973
|
const int MAX_CHAR_LEN = 40;
|
|
@@ -1878,7 +1994,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
|
|
1878
1994
|
sequence_break.resize(MAX_CHAR_LEN);
|
|
1879
1995
|
}
|
|
1880
1996
|
|
|
1881
|
-
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
|
1997
|
+
get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
|
1882
1998
|
}
|
|
1883
1999
|
}
|
|
1884
2000
|
|
|
@@ -1901,7 +2017,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
|
|
1901
2017
|
// wrapper for test-sampling.cpp
|
|
1902
2018
|
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
|
1903
2019
|
llama_vocab dummy_vocab;
|
|
1904
|
-
auto * result =
|
|
2020
|
+
auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
|
1905
2021
|
auto * ctx = (llama_sampler_dry *) result->ctx;
|
|
1906
2022
|
|
|
1907
2023
|
// Process the token-based sequence breakers
|
|
@@ -2040,7 +2156,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2040
2156
|
float p_eog_sum = 0.0f;
|
|
2041
2157
|
|
|
2042
2158
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
2043
|
-
if (
|
|
2159
|
+
if (ctx->vocab->is_eog(cur_p->data[i].id)) {
|
|
2044
2160
|
p_eog_sum += cur_p->data[i].p;
|
|
2045
2161
|
} else {
|
|
2046
2162
|
p_txt_sum += cur_p->data[i].p;
|
|
@@ -2062,7 +2178,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2062
2178
|
float p_sum = 0.0f;
|
|
2063
2179
|
|
|
2064
2180
|
for (size_t i = 0; i < size_org; ++i) {
|
|
2065
|
-
if (
|
|
2181
|
+
if (ctx->vocab->is_eog(cur_p->data[i].id)) {
|
|
2066
2182
|
p_sum += cur_p->data[i].p;
|
|
2067
2183
|
|
|
2068
2184
|
cur_p->data[cur_p->size++] = cur_p->data[i];
|
|
@@ -2090,17 +2206,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2090
2206
|
continue;
|
|
2091
2207
|
}
|
|
2092
2208
|
|
|
2093
|
-
int len0 =
|
|
2209
|
+
int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
|
2094
2210
|
if (len0 < 0) {
|
|
2095
2211
|
ctx->buf0.resize(len0);
|
|
2096
|
-
len0 =
|
|
2212
|
+
len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
|
2097
2213
|
assert(len0 > 0);
|
|
2098
2214
|
}
|
|
2099
2215
|
|
|
2100
|
-
int len1 =
|
|
2216
|
+
int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
|
2101
2217
|
if (len1 < 0) {
|
|
2102
2218
|
ctx->buf1.resize(len1);
|
|
2103
|
-
len1 =
|
|
2219
|
+
len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
|
2104
2220
|
assert(len1 > 0);
|
|
2105
2221
|
}
|
|
2106
2222
|
|
|
@@ -2135,7 +2251,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2135
2251
|
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
|
2136
2252
|
|
|
2137
2253
|
for (size_t i = 0; i < size_org; ++i) {
|
|
2138
|
-
const bool is_eog =
|
|
2254
|
+
const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
|
|
2139
2255
|
|
|
2140
2256
|
if (cur_p->data[i].p < thold && !is_eog) {
|
|
2141
2257
|
continue;
|
|
@@ -2156,7 +2272,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2156
2272
|
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
|
2157
2273
|
if (n_non_eog == 0) {
|
|
2158
2274
|
cur_p->size = 1;
|
|
2159
|
-
cur_p->data[0].id =
|
|
2275
|
+
cur_p->data[0].id = ctx->vocab->token_eot();
|
|
2160
2276
|
cur_p->data[0].logit = 1.0f;
|
|
2161
2277
|
|
|
2162
2278
|
return;
|
|
@@ -2178,7 +2294,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2178
2294
|
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
|
2179
2295
|
|
|
2180
2296
|
for (size_t i = 0; i < size_org; ++i) {
|
|
2181
|
-
const bool is_eog =
|
|
2297
|
+
const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
|
|
2182
2298
|
|
|
2183
2299
|
if (cur_p->data[i].p < thold && !is_eog) {
|
|
2184
2300
|
continue;
|
|
@@ -2201,7 +2317,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2201
2317
|
|
|
2202
2318
|
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
|
2203
2319
|
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
|
2204
|
-
return
|
|
2320
|
+
return llama_sampler_init_infill(ctx->vocab);
|
|
2205
2321
|
}
|
|
2206
2322
|
|
|
2207
2323
|
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
|
@@ -2217,14 +2333,13 @@ static struct llama_sampler_i llama_sampler_infill_i = {
|
|
|
2217
2333
|
/* .free = */ llama_sampler_infill_free,
|
|
2218
2334
|
};
|
|
2219
2335
|
|
|
2220
|
-
struct llama_sampler *
|
|
2221
|
-
const struct llama_vocab & vocab) {
|
|
2336
|
+
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
|
2222
2337
|
return new llama_sampler {
|
|
2223
2338
|
/* .iface = */ &llama_sampler_infill_i,
|
|
2224
2339
|
/* .ctx = */ new llama_sampler_infill {
|
|
2225
|
-
/* .vocab = */
|
|
2226
|
-
/* .buf0
|
|
2227
|
-
/* .buf1
|
|
2340
|
+
/* .vocab = */ vocab,
|
|
2341
|
+
/* .buf0 = */ std::vector<char>(512),
|
|
2342
|
+
/* .buf1 = */ std::vector<char>(512),
|
|
2228
2343
|
},
|
|
2229
2344
|
};
|
|
2230
2345
|
}
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
|
|
4
4
|
|
|
5
|
-
#include "llama
|
|
5
|
+
#include "llama.h"
|
|
6
|
+
|
|
7
|
+
#include <vector>
|
|
6
8
|
|
|
7
9
|
struct llama_vocab;
|
|
8
10
|
struct llama_grammar;
|
|
@@ -21,24 +23,6 @@ struct llama_sampler_chain {
|
|
|
21
23
|
mutable int32_t n_sample;
|
|
22
24
|
};
|
|
23
25
|
|
|
24
|
-
struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
25
|
-
const struct llama_vocab & vocab,
|
|
26
|
-
const char * grammar_str,
|
|
27
|
-
const char * grammar_root);
|
|
28
|
-
|
|
29
|
-
struct llama_sampler * llama_sampler_init_infill_impl(
|
|
30
|
-
const struct llama_vocab & vocab);
|
|
31
|
-
|
|
32
|
-
struct llama_sampler * llama_sampler_init_dry_impl(
|
|
33
|
-
const struct llama_vocab & vocab,
|
|
34
|
-
int32_t context_size,
|
|
35
|
-
float dry_multiplier,
|
|
36
|
-
float dry_base,
|
|
37
|
-
int32_t dry_allowed_length,
|
|
38
|
-
int32_t dry_penalty_last_n,
|
|
39
|
-
const char ** seq_breakers,
|
|
40
|
-
size_t num_breakers);
|
|
41
|
-
|
|
42
26
|
struct llama_sampler * llama_sampler_init_dry_testing(
|
|
43
27
|
int32_t context_size,
|
|
44
28
|
float dry_multiplier,
|