cui-llama.rn 1.1.2 → 1.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +1 -2
- package/android/src/main/jni.cpp +26 -21
- package/cpp/common.cpp +2028 -1520
- package/cpp/common.h +134 -18
- package/cpp/ggml-aarch64.c +612 -0
- package/cpp/ggml-alloc.h +2 -2
- package/cpp/ggml-backend.c +33 -6
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-common.h +20 -0
- package/cpp/ggml-impl.h +4 -7
- package/cpp/ggml-metal.m +63 -2
- package/cpp/ggml-quants.c +690 -2
- package/cpp/ggml-quants.h +15 -0
- package/cpp/ggml.c +1650 -317
- package/cpp/ggml.h +155 -48
- package/cpp/llama-grammar.cpp +721 -122
- package/cpp/llama-grammar.h +120 -15
- package/cpp/llama-impl.h +132 -1
- package/cpp/llama-sampling.cpp +1361 -356
- package/cpp/llama-sampling.h +20 -48
- package/cpp/llama-vocab.cpp +140 -7
- package/cpp/llama-vocab.h +3 -2
- package/cpp/llama.cpp +810 -307
- package/cpp/llama.h +213 -259
- package/cpp/rn-llama.hpp +17 -14
- package/cpp/sampling.cpp +347 -355
- package/cpp/sampling.h +106 -135
- package/cpp/sgemm.cpp +153 -0
- package/package.json +1 -1
- package/cpp/grammar-parser.cpp +0 -539
- package/cpp/grammar-parser.h +0 -29
package/cpp/llama-sampling.cpp
CHANGED
@@ -1,12 +1,56 @@
|
|
1
1
|
#include "llama-sampling.h"
|
2
2
|
|
3
|
+
#include "llama-vocab.h"
|
4
|
+
#include "llama-grammar.h"
|
5
|
+
|
6
|
+
#include <cassert>
|
3
7
|
#include <algorithm>
|
4
8
|
#include <cstring>
|
5
9
|
#include <ctime>
|
6
10
|
#include <cfloat>
|
7
11
|
#include <numeric>
|
12
|
+
#include <random>
|
8
13
|
#include <unordered_map>
|
9
14
|
|
15
|
+
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
|
16
|
+
#if 1
|
17
|
+
probs.resize(cur_p->size);
|
18
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
19
|
+
probs[i] = cur_p->data[i].p;
|
20
|
+
}
|
21
|
+
|
22
|
+
std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
|
23
|
+
#else
|
24
|
+
// avoid the copy with a custom iterator
|
25
|
+
#pragma GCC diagnostic push
|
26
|
+
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
27
|
+
|
28
|
+
struct probs_iterator {
|
29
|
+
typedef std::input_iterator_tag iterator_category;
|
30
|
+
typedef float value_type;
|
31
|
+
typedef float * pointer;
|
32
|
+
typedef float & reference;
|
33
|
+
typedef size_t difference_type;
|
34
|
+
|
35
|
+
const llama_token_data_array * data;
|
36
|
+
size_t i;
|
37
|
+
|
38
|
+
bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; }
|
39
|
+
bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; }
|
40
|
+
float operator*() const { return data->data[i].p; }
|
41
|
+
probs_iterator & operator++() { ++i; return *this; }
|
42
|
+
probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; }
|
43
|
+
};
|
44
|
+
#pragma GCC diagnostic pop
|
45
|
+
|
46
|
+
std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size});
|
47
|
+
|
48
|
+
LM_GGML_UNUSED(probs);
|
49
|
+
#endif
|
50
|
+
|
51
|
+
return dist(rng);
|
52
|
+
}
|
53
|
+
|
10
54
|
static void llama_log_softmax(float * array, size_t size) {
|
11
55
|
float max_l = *std::max_element(array, array + size);
|
12
56
|
float sum = 0.f;
|
@@ -21,65 +65,50 @@ static void llama_log_softmax(float * array, size_t size) {
|
|
21
65
|
}
|
22
66
|
}
|
23
67
|
|
24
|
-
void
|
25
|
-
|
26
|
-
seed = time(NULL);
|
27
|
-
}
|
28
|
-
|
29
|
-
smpl->rng.seed(seed);
|
30
|
-
}
|
31
|
-
|
32
|
-
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
33
|
-
LM_GGML_ASSERT(candidates->size > 0);
|
34
|
-
|
35
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
68
|
+
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
69
|
+
LM_GGML_ASSERT(cur_p->size > 0);
|
36
70
|
|
37
71
|
// Sort the logits in descending order
|
38
|
-
if (!
|
39
|
-
std::sort(
|
72
|
+
if (!cur_p->sorted) {
|
73
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
40
74
|
return a.logit > b.logit;
|
41
75
|
});
|
42
|
-
|
76
|
+
cur_p->sorted = true;
|
43
77
|
}
|
44
78
|
|
45
|
-
float max_l =
|
79
|
+
float max_l = cur_p->data[0].logit;
|
46
80
|
float cum_sum = 0.0f;
|
47
|
-
|
48
|
-
|
49
|
-
|
81
|
+
|
82
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
83
|
+
float p = expf(cur_p->data[i].logit - max_l);
|
84
|
+
cur_p->data[i].p = p;
|
50
85
|
cum_sum += p;
|
51
86
|
}
|
52
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
53
|
-
candidates->data[i].p /= cum_sum;
|
54
|
-
}
|
55
87
|
|
56
|
-
|
57
|
-
|
88
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
89
|
+
cur_p->data[i].p /= cum_sum;
|
58
90
|
}
|
59
91
|
}
|
60
92
|
|
61
|
-
void
|
93
|
+
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
62
94
|
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
63
|
-
// if (k >= (int32_t)
|
95
|
+
// if (k >= (int32_t)cur_p->size) {
|
64
96
|
// return;
|
65
97
|
// }
|
66
98
|
|
67
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
68
|
-
|
69
99
|
if (k <= 0) {
|
70
|
-
k =
|
100
|
+
k = cur_p->size;
|
71
101
|
}
|
72
102
|
|
73
|
-
k = std::
|
74
|
-
k = std::min(k, (int) candidates->size);
|
103
|
+
k = std::min(k, (int) cur_p->size);
|
75
104
|
|
76
105
|
// Sort scores in descending order
|
77
|
-
if (!
|
106
|
+
if (!cur_p->sorted) {
|
78
107
|
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
79
108
|
return a.logit > b.logit;
|
80
109
|
};
|
81
110
|
if (k <= 128) {
|
82
|
-
std::partial_sort(
|
111
|
+
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
|
83
112
|
} else {
|
84
113
|
constexpr int nbuckets = 128;
|
85
114
|
constexpr float bucket_low = -10.0f;
|
@@ -87,11 +116,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
87
116
|
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
88
117
|
constexpr float bucket_inter = -bucket_low * bucket_scale;
|
89
118
|
|
90
|
-
std::vector<int> bucket_idx(
|
119
|
+
std::vector<int> bucket_idx(cur_p->size);
|
91
120
|
std::vector<int> histo(nbuckets, 0);
|
92
121
|
|
93
|
-
for (int i = 0; i < (int)
|
94
|
-
const float val =
|
122
|
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
123
|
+
const float val = cur_p->data[i].logit;
|
95
124
|
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
96
125
|
ib = std::max(0, std::min(nbuckets-1, ib));
|
97
126
|
bucket_idx[i] = ib;
|
@@ -101,20 +130,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
101
130
|
int ib = nbuckets - 1;
|
102
131
|
for ( ; ib >= 0; --ib) {
|
103
132
|
nhave += histo[ib];
|
104
|
-
if (nhave >= k)
|
133
|
+
if (nhave >= k) {
|
134
|
+
break;
|
135
|
+
}
|
105
136
|
}
|
106
137
|
std::vector<llama_token_data> tmp_tokens(nhave);
|
107
|
-
auto ptr = tmp_tokens.data();
|
138
|
+
auto * ptr = tmp_tokens.data();
|
108
139
|
std::vector<llama_token_data*> bucket_ptrs;
|
109
140
|
bucket_ptrs.reserve(nbuckets - ib);
|
110
141
|
for (int j = nbuckets - 1; j >= ib; --j) {
|
111
142
|
bucket_ptrs.push_back(ptr);
|
112
143
|
ptr += histo[j];
|
113
144
|
}
|
114
|
-
for (int i = 0; i < (int)
|
145
|
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
115
146
|
int j = bucket_idx[i];
|
116
147
|
if (j >= ib) {
|
117
|
-
*bucket_ptrs[nbuckets-1-j]++ =
|
148
|
+
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
|
118
149
|
}
|
119
150
|
}
|
120
151
|
|
@@ -127,69 +158,554 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
127
158
|
}
|
128
159
|
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
129
160
|
|
130
|
-
std::memcpy(
|
161
|
+
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
131
162
|
|
132
163
|
}
|
133
|
-
|
164
|
+
cur_p->sorted = true;
|
134
165
|
}
|
135
|
-
|
166
|
+
cur_p->size = k;
|
167
|
+
}
|
136
168
|
|
137
|
-
|
138
|
-
|
169
|
+
// llama_sampler API
|
170
|
+
|
171
|
+
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
172
|
+
if (!smpl->iface) {
|
173
|
+
return "(null)";
|
139
174
|
}
|
175
|
+
|
176
|
+
return smpl->iface->name(smpl);
|
140
177
|
}
|
141
178
|
|
142
|
-
void
|
143
|
-
if (
|
179
|
+
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
180
|
+
if (smpl->iface->accept) {
|
181
|
+
smpl->iface->accept(smpl, token);
|
182
|
+
}
|
183
|
+
}
|
184
|
+
|
185
|
+
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
186
|
+
LM_GGML_ASSERT(smpl->iface->apply);
|
187
|
+
smpl->iface->apply(smpl, cur_p);
|
188
|
+
}
|
189
|
+
|
190
|
+
void llama_sampler_reset(struct llama_sampler * smpl) {
|
191
|
+
if (smpl->iface->reset) {
|
192
|
+
smpl->iface->reset(smpl);
|
193
|
+
}
|
194
|
+
}
|
195
|
+
|
196
|
+
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
197
|
+
if (smpl->iface->clone) {
|
198
|
+
return smpl->iface->clone(smpl);
|
199
|
+
}
|
200
|
+
|
201
|
+
if (smpl->ctx == nullptr) {
|
202
|
+
return new llama_sampler {
|
203
|
+
/* .iface = */ smpl->iface,
|
204
|
+
/* .ctx = */ nullptr,
|
205
|
+
};
|
206
|
+
}
|
207
|
+
|
208
|
+
LM_GGML_ABORT("the sampler does not support cloning");
|
209
|
+
}
|
210
|
+
|
211
|
+
void llama_sampler_free(struct llama_sampler * smpl) {
|
212
|
+
if (smpl == nullptr) {
|
144
213
|
return;
|
145
214
|
}
|
146
215
|
|
147
|
-
|
216
|
+
if (smpl->iface->free) {
|
217
|
+
smpl->iface->free(smpl);
|
218
|
+
}
|
148
219
|
|
149
|
-
|
220
|
+
delete smpl;
|
221
|
+
}
|
222
|
+
|
223
|
+
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
224
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
225
|
+
|
226
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
227
|
+
|
228
|
+
// TODO: do not allocate each time
|
229
|
+
std::vector<llama_token_data> cur(n_vocab);
|
230
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
231
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
232
|
+
}
|
233
|
+
|
234
|
+
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
235
|
+
|
236
|
+
llama_sampler_apply(smpl, &cur_p);
|
237
|
+
|
238
|
+
return cur_p.data[cur_p.selected].id;
|
239
|
+
}
|
240
|
+
|
241
|
+
// sampler chain
|
242
|
+
|
243
|
+
static struct llama_sampler_i llama_sampler_chain_i = {
|
244
|
+
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
|
245
|
+
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
|
246
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
247
|
+
|
248
|
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
249
|
+
|
250
|
+
for (auto * smpl : chain->samplers) {
|
251
|
+
llama_sampler_accept(smpl, token);
|
252
|
+
}
|
253
|
+
|
254
|
+
chain->n_sample++;
|
255
|
+
},
|
256
|
+
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
257
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
258
|
+
|
259
|
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
260
|
+
|
261
|
+
for (auto * smpl : chain->samplers) {
|
262
|
+
llama_sampler_apply(smpl, cur_p);
|
263
|
+
}
|
264
|
+
},
|
265
|
+
/* .reset = */ [](struct llama_sampler * smpl) {
|
266
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
267
|
+
|
268
|
+
for (auto * smpl : chain->samplers) {
|
269
|
+
llama_sampler_reset(smpl);
|
270
|
+
}
|
271
|
+
|
272
|
+
chain->t_sample_us = 0;
|
273
|
+
chain->n_sample = 0;
|
274
|
+
},
|
275
|
+
/* .clone = */ [](const struct llama_sampler * smpl) {
|
276
|
+
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
277
|
+
|
278
|
+
auto * result = llama_sampler_chain_init(chain_src->params);
|
279
|
+
|
280
|
+
for (auto * smpl : chain_src->samplers) {
|
281
|
+
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
282
|
+
}
|
283
|
+
|
284
|
+
return result;
|
285
|
+
},
|
286
|
+
/* .free = */ [](struct llama_sampler * smpl) {
|
287
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
288
|
+
|
289
|
+
for (auto * smpl : chain->samplers) {
|
290
|
+
llama_sampler_free(smpl);
|
291
|
+
}
|
292
|
+
|
293
|
+
delete chain;
|
294
|
+
},
|
295
|
+
};
|
296
|
+
|
297
|
+
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
298
|
+
return new llama_sampler {
|
299
|
+
/* .iface = */ &llama_sampler_chain_i,
|
300
|
+
/* .ctx = */ new llama_sampler_chain {
|
301
|
+
/* .params = */ params,
|
302
|
+
/* .samplers = */ {},
|
303
|
+
/* .t_sample_us = */ 0,
|
304
|
+
/* .n_sample = */ 0,
|
305
|
+
},
|
306
|
+
};
|
307
|
+
}
|
308
|
+
|
309
|
+
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
310
|
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
311
|
+
p->samplers.push_back(smpl);
|
312
|
+
}
|
313
|
+
|
314
|
+
llama_sampler_timings llama_sampler_chain_timings(struct llama_sampler * chain) {
|
315
|
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
316
|
+
struct llama_sampler_timings result = {
|
317
|
+
p -> t_sample_us,
|
318
|
+
p -> n_sample
|
319
|
+
};
|
320
|
+
return result;
|
321
|
+
}
|
322
|
+
|
323
|
+
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
324
|
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
325
|
+
|
326
|
+
if (i < 0 || i >= (int32_t) p->samplers.size()) {
|
327
|
+
return nullptr;
|
328
|
+
}
|
329
|
+
|
330
|
+
return p->samplers[i];
|
331
|
+
}
|
332
|
+
|
333
|
+
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
334
|
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
335
|
+
|
336
|
+
return p->samplers.size();
|
337
|
+
}
|
338
|
+
|
339
|
+
//
|
340
|
+
// samplers
|
341
|
+
//
|
342
|
+
|
343
|
+
// greedy
|
344
|
+
|
345
|
+
static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
|
346
|
+
return "greedy";
|
347
|
+
}
|
348
|
+
|
349
|
+
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
350
|
+
cur_p->selected = 0;
|
351
|
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
352
|
+
if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
|
353
|
+
cur_p->selected = i;
|
354
|
+
}
|
355
|
+
}
|
356
|
+
}
|
357
|
+
|
358
|
+
static struct llama_sampler_i llama_sampler_greedy_i = {
|
359
|
+
/* .name = */ llama_sampler_greedy_name,
|
360
|
+
/* .accept = */ nullptr,
|
361
|
+
/* .apply = */ llama_sampler_greedy_apply,
|
362
|
+
/* .reset = */ nullptr,
|
363
|
+
/* .clone = */ nullptr,
|
364
|
+
/* .free = */ nullptr,
|
365
|
+
};
|
366
|
+
|
367
|
+
struct llama_sampler * llama_sampler_init_greedy() {
|
368
|
+
return new llama_sampler {
|
369
|
+
/* .iface = */ &llama_sampler_greedy_i,
|
370
|
+
/* .ctx = */ nullptr,
|
371
|
+
};
|
372
|
+
}
|
373
|
+
|
374
|
+
// dist
|
375
|
+
|
376
|
+
struct llama_sampler_dist {
|
377
|
+
const uint32_t seed;
|
378
|
+
|
379
|
+
std::mt19937 rng;
|
380
|
+
|
381
|
+
std::vector<float> probs; // work array
|
382
|
+
};
|
383
|
+
|
384
|
+
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
|
385
|
+
return "dist";
|
386
|
+
}
|
387
|
+
|
388
|
+
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
389
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
390
|
+
cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
391
|
+
}
|
392
|
+
|
393
|
+
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
394
|
+
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
395
|
+
auto * result = llama_sampler_init_dist(ctx->seed);
|
396
|
+
|
397
|
+
// copy the state
|
398
|
+
{
|
399
|
+
auto * result_ctx = (llama_sampler_dist *) result->ctx;
|
400
|
+
|
401
|
+
result_ctx->rng = ctx->rng;
|
402
|
+
}
|
403
|
+
|
404
|
+
return result;
|
405
|
+
}
|
406
|
+
|
407
|
+
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
408
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
409
|
+
ctx->rng = std::mt19937(ctx->seed);
|
410
|
+
}
|
411
|
+
|
412
|
+
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
413
|
+
delete (llama_sampler_dist *) smpl->ctx;
|
414
|
+
}
|
415
|
+
|
416
|
+
static struct llama_sampler_i llama_sampler_dist_i = {
|
417
|
+
/* .name = */ llama_sampler_dist_name,
|
418
|
+
/* .accept = */ nullptr,
|
419
|
+
/* .apply = */ llama_sampler_dist_apply,
|
420
|
+
/* .reset = */ llama_sampler_dist_reset,
|
421
|
+
/* .clone = */ llama_sampler_dist_clone,
|
422
|
+
/* .free = */ llama_sampler_dist_free,
|
423
|
+
};
|
424
|
+
|
425
|
+
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
426
|
+
return new llama_sampler {
|
427
|
+
/* .iface = */ &llama_sampler_dist_i,
|
428
|
+
/* .ctx = */ new llama_sampler_dist {
|
429
|
+
/* .seed = */ seed,
|
430
|
+
/* .rng = */ std::mt19937(seed),
|
431
|
+
/* .probs = */ {},
|
432
|
+
},
|
433
|
+
};
|
434
|
+
}
|
435
|
+
|
436
|
+
// softmax
|
437
|
+
|
438
|
+
static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
|
439
|
+
return "softmax";
|
440
|
+
}
|
441
|
+
|
442
|
+
static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
443
|
+
llama_sampler_softmax_impl(cur_p);
|
444
|
+
}
|
445
|
+
|
446
|
+
static struct llama_sampler_i llama_sampler_softmax_i = {
|
447
|
+
/* .name = */ llama_sampler_softmax_name,
|
448
|
+
/* .accept = */ nullptr,
|
449
|
+
/* .apply = */ llama_sampler_softmax_apply,
|
450
|
+
/* .reset = */ nullptr,
|
451
|
+
/* .clone = */ nullptr,
|
452
|
+
/* .free = */ nullptr,
|
453
|
+
};
|
454
|
+
|
455
|
+
struct llama_sampler * llama_sampler_init_softmax() {
|
456
|
+
return new llama_sampler {
|
457
|
+
/* .iface = */ &llama_sampler_softmax_i,
|
458
|
+
/* .ctx = */ nullptr,
|
459
|
+
};
|
460
|
+
}
|
461
|
+
|
462
|
+
// top-k
|
463
|
+
|
464
|
+
struct llama_sampler_top_k {
|
465
|
+
const int32_t k;
|
466
|
+
};
|
467
|
+
|
468
|
+
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
|
469
|
+
return "top-k";
|
470
|
+
}
|
471
|
+
|
472
|
+
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
473
|
+
const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
|
474
|
+
llama_sampler_top_k_impl(cur_p, ctx->k);
|
475
|
+
}
|
476
|
+
|
477
|
+
static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
|
478
|
+
const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
|
479
|
+
return llama_sampler_init_top_k(ctx->k);
|
480
|
+
}
|
481
|
+
|
482
|
+
static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
483
|
+
delete (llama_sampler_top_k *) smpl->ctx;
|
484
|
+
}
|
485
|
+
|
486
|
+
static struct llama_sampler_i llama_sampler_top_k_i = {
|
487
|
+
/* .name = */ llama_sampler_top_k_name,
|
488
|
+
/* .accept = */ nullptr,
|
489
|
+
/* .apply = */ llama_sampler_top_k_apply,
|
490
|
+
/* .reset = */ nullptr,
|
491
|
+
/* .clone = */ llama_sampler_top_k_clone,
|
492
|
+
/* .free = */ llama_sampler_top_k_free,
|
493
|
+
};
|
494
|
+
|
495
|
+
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
496
|
+
return new llama_sampler {
|
497
|
+
/* .iface = */ &llama_sampler_top_k_i,
|
498
|
+
/* .ctx = */ new llama_sampler_top_k {
|
499
|
+
/* .k = */ k,
|
500
|
+
},
|
501
|
+
};
|
502
|
+
}
|
503
|
+
|
504
|
+
// top-p
|
505
|
+
|
506
|
+
struct llama_sampler_top_p {
|
507
|
+
const float p;
|
508
|
+
const size_t min_keep;
|
509
|
+
};
|
510
|
+
|
511
|
+
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
|
512
|
+
return "top-p";
|
513
|
+
}
|
514
|
+
|
515
|
+
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
516
|
+
const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
|
517
|
+
|
518
|
+
if (ctx->p >= 1.0f) {
|
519
|
+
return;
|
520
|
+
}
|
521
|
+
|
522
|
+
llama_sampler_softmax_impl(cur_p);
|
150
523
|
|
151
524
|
// Compute the cumulative probabilities
|
152
525
|
float cum_sum = 0.0f;
|
153
|
-
size_t last_idx =
|
526
|
+
size_t last_idx = cur_p->size;
|
154
527
|
|
155
|
-
for (size_t i = 0; i <
|
156
|
-
cum_sum +=
|
528
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
529
|
+
cum_sum += cur_p->data[i].p;
|
157
530
|
|
158
531
|
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
159
532
|
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
160
|
-
if (cum_sum >= p && i + 1 >= min_keep) {
|
533
|
+
if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
|
161
534
|
last_idx = i + 1;
|
162
535
|
break;
|
163
536
|
}
|
164
537
|
}
|
165
538
|
|
166
539
|
// Resize the output vector to keep only the top-p tokens
|
167
|
-
|
540
|
+
cur_p->size = last_idx;
|
541
|
+
}
|
542
|
+
|
543
|
+
static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
|
544
|
+
const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
|
545
|
+
return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
|
546
|
+
}
|
547
|
+
|
548
|
+
static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
549
|
+
delete (llama_sampler_top_p *) smpl->ctx;
|
550
|
+
}
|
551
|
+
|
552
|
+
static struct llama_sampler_i llama_sampler_top_p_i = {
|
553
|
+
/* .name = */ llama_sampler_top_p_name,
|
554
|
+
/* .accept = */ nullptr,
|
555
|
+
/* .apply = */ llama_sampler_top_p_apply,
|
556
|
+
/* .reset = */ nullptr,
|
557
|
+
/* .clone = */ llama_sampler_top_p_clone,
|
558
|
+
/* .free = */ llama_sampler_top_p_free,
|
559
|
+
};
|
560
|
+
|
561
|
+
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
562
|
+
return new llama_sampler {
|
563
|
+
/* .iface = */ &llama_sampler_top_p_i,
|
564
|
+
/* .ctx = */ new llama_sampler_top_p {
|
565
|
+
/* .p = */ p,
|
566
|
+
/* .min_keep = */ min_keep,
|
567
|
+
},
|
568
|
+
};
|
569
|
+
}
|
570
|
+
|
571
|
+
// min-p
|
572
|
+
|
573
|
+
struct llama_sampler_min_p {
|
574
|
+
const float p;
|
575
|
+
const size_t min_keep;
|
576
|
+
};
|
577
|
+
|
578
|
+
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
|
579
|
+
return "min-p";
|
580
|
+
}
|
581
|
+
|
582
|
+
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
583
|
+
const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
|
584
|
+
|
585
|
+
if (ctx->p <= 0.0f || !cur_p->size) {
|
586
|
+
return;
|
587
|
+
}
|
588
|
+
|
589
|
+
bool min_p_applied = false;
|
590
|
+
|
591
|
+
// if the cur_p aren't sorted, try the unsorted implementation first
|
592
|
+
if (!cur_p->sorted) {
|
593
|
+
std::vector<llama_token_data> filtered_tokens;
|
594
|
+
|
595
|
+
float max_logit = -FLT_MAX;
|
596
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
597
|
+
max_logit = std::max(max_logit, cur_p->data[i].logit);
|
598
|
+
}
|
599
|
+
const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
600
|
+
|
601
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
602
|
+
if (cur_p->data[i].logit >= min_logit) {
|
603
|
+
filtered_tokens.push_back(cur_p->data[i]);
|
604
|
+
}
|
605
|
+
}
|
606
|
+
|
607
|
+
// if we have enough values the operation was a success
|
608
|
+
if (filtered_tokens.size() >= ctx->min_keep) {
|
609
|
+
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
610
|
+
cur_p->size = filtered_tokens.size();
|
611
|
+
min_p_applied = true;
|
612
|
+
}
|
613
|
+
}
|
614
|
+
|
615
|
+
// if the cur_p are sorted or the unsorted implementation failed, use this implementation
|
616
|
+
if (!min_p_applied) {
|
617
|
+
// Sort the logits in descending order
|
618
|
+
if (!cur_p->sorted) {
|
619
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
620
|
+
return a.logit > b.logit;
|
621
|
+
});
|
622
|
+
cur_p->sorted = true;
|
623
|
+
}
|
624
|
+
|
625
|
+
const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
626
|
+
size_t i = 1; // first token always matches
|
627
|
+
|
628
|
+
for (; i < cur_p->size; ++i) {
|
629
|
+
if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
|
630
|
+
break; // prob too small
|
631
|
+
}
|
632
|
+
}
|
168
633
|
|
169
|
-
|
170
|
-
|
634
|
+
// Resize the output vector to keep only the matching tokens
|
635
|
+
cur_p->size = i;
|
171
636
|
}
|
172
637
|
}
|
173
638
|
|
174
|
-
|
175
|
-
|
639
|
+
static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
|
640
|
+
const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
|
641
|
+
return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
|
642
|
+
}
|
643
|
+
|
644
|
+
static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
645
|
+
delete (llama_sampler_min_p *) smpl->ctx;
|
646
|
+
}
|
647
|
+
|
648
|
+
static struct llama_sampler_i llama_sampler_min_p_i = {
|
649
|
+
/* .name = */ llama_sampler_min_p_name,
|
650
|
+
/* .accept = */ nullptr,
|
651
|
+
/* .apply = */ llama_sampler_min_p_apply,
|
652
|
+
/* .reset = */ nullptr,
|
653
|
+
/* .clone = */ llama_sampler_min_p_clone,
|
654
|
+
/* .free = */ llama_sampler_min_p_free,
|
655
|
+
};
|
656
|
+
|
657
|
+
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
658
|
+
return new llama_sampler {
|
659
|
+
/* .iface = */ &llama_sampler_min_p_i,
|
660
|
+
/* .ctx = */ new llama_sampler_min_p {
|
661
|
+
/* .p = */ p,
|
662
|
+
/* .min_keep = */ min_keep,
|
663
|
+
},
|
664
|
+
};
|
665
|
+
}
|
666
|
+
|
667
|
+
// xtc
|
668
|
+
|
669
|
+
struct llama_sampler_xtc {
|
670
|
+
const uint32_t seed;
|
671
|
+
std::mt19937 rng;
|
672
|
+
const float xtc_p;
|
673
|
+
const float xtc_t;
|
674
|
+
const size_t min_keep;
|
675
|
+
};
|
676
|
+
|
677
|
+
static const char * llama_sampler_xtc_name(const struct llama_sampler * /* smpl */) {
|
678
|
+
return "xtc";
|
679
|
+
}
|
680
|
+
|
681
|
+
static void llama_sampler_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
682
|
+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
683
|
+
|
684
|
+
size_t min_keep = ctx -> min_keep;
|
685
|
+
std::mt19937 rng = ctx -> rng;
|
686
|
+
|
687
|
+
float xtc_threshold = ctx -> xtc_t;
|
688
|
+
float xtc_probability = ctx -> xtc_p;
|
689
|
+
|
690
|
+
|
691
|
+
if(xtc_threshold <= 0.0f || !cur_p-> size) {
|
176
692
|
return;
|
177
693
|
}
|
178
694
|
|
179
695
|
bool xtc_applied = false;
|
180
696
|
const int64_t t_start_sample_us = lm_ggml_time_us();
|
181
|
-
|
697
|
+
llama_sampler_softmax_impl(cur_p);
|
182
698
|
|
183
699
|
// unsorted iteration
|
184
|
-
if (!
|
700
|
+
if (!cur_p->sorted) {
|
185
701
|
std::vector<llama_token_data> top_tokens, low_tokens;
|
186
702
|
|
187
703
|
// split candidates into two arrays for low and high tokens
|
188
|
-
for (size_t i = 0; i <
|
189
|
-
if (
|
190
|
-
top_tokens.push_back(
|
704
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
705
|
+
if (cur_p->data[i].logit >= xtc_threshold) {
|
706
|
+
top_tokens.push_back(cur_p->data[i]);
|
191
707
|
} else {
|
192
|
-
low_tokens.push_back(
|
708
|
+
low_tokens.push_back(cur_p-> data[i]);
|
193
709
|
}
|
194
710
|
}
|
195
711
|
// if there is only one or no top_tokens, do not truncate
|
@@ -212,28 +728,28 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
|
|
212
728
|
}
|
213
729
|
}
|
214
730
|
if(low_tokens.size() >= min_keep) {
|
215
|
-
memcpy(
|
216
|
-
|
731
|
+
memcpy(cur_p->data, low_tokens.data(), low_tokens.size()*sizeof(llama_token_data));
|
732
|
+
cur_p->size = low_tokens.size();
|
217
733
|
xtc_applied = true;
|
218
734
|
}
|
219
735
|
}
|
220
736
|
// sorted iteration
|
221
|
-
|
737
|
+
|
222
738
|
if (!xtc_applied) {
|
223
739
|
// Sort the logits in descending order
|
224
|
-
if (!
|
225
|
-
std::sort(
|
740
|
+
if (!cur_p->sorted) {
|
741
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
226
742
|
return a.logit > b.logit;
|
227
743
|
});
|
228
|
-
|
744
|
+
cur_p->sorted = true;
|
229
745
|
}
|
230
746
|
|
231
747
|
// find last token over threshold
|
232
748
|
|
233
749
|
size_t last_index = 0;
|
234
750
|
|
235
|
-
for (; last_index <
|
236
|
-
if(
|
751
|
+
for (; last_index < cur_p -> size; ++last_index) {
|
752
|
+
if(cur_p -> data[last_index].p < xtc_threshold) {
|
237
753
|
break;
|
238
754
|
}
|
239
755
|
}
|
@@ -244,102 +760,80 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
|
|
244
760
|
}
|
245
761
|
last_index--;
|
246
762
|
// items beyond safe index will be ignored
|
247
|
-
size_t safe_index =
|
763
|
+
size_t safe_index = cur_p -> size;
|
248
764
|
|
249
765
|
// remove tokens until last threshold item
|
250
766
|
std::uniform_real_distribution<float> random_float(0.0 , 1.0);
|
251
767
|
for (size_t i = 0; i < last_index; i++) {
|
252
768
|
if(random_float(rng) < xtc_probability) {
|
253
|
-
std::swap(
|
769
|
+
std::swap(cur_p-> data[i], cur_p->data[safe_index - 1]);
|
254
770
|
safe_index--;
|
255
|
-
if (
|
256
|
-
|
771
|
+
if (cur_p-> sorted) {
|
772
|
+
cur_p -> sorted = false;
|
257
773
|
}
|
258
774
|
}
|
259
775
|
}
|
260
|
-
|
261
|
-
}
|
262
|
-
|
263
|
-
if (smpl) {
|
264
|
-
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
776
|
+
cur_p -> size = safe_index;
|
265
777
|
}
|
266
778
|
}
|
267
779
|
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
274
|
-
|
275
|
-
bool min_p_applied = false;
|
276
|
-
|
277
|
-
// if the candidates aren't sorted, try the unsorted implementation first
|
278
|
-
if (!candidates->sorted) {
|
279
|
-
std::vector<llama_token_data> filtered_tokens;
|
280
|
-
|
281
|
-
float max_logit = -FLT_MAX;
|
282
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
283
|
-
max_logit = std::max(max_logit, candidates->data[i].logit);
|
284
|
-
}
|
285
|
-
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
286
|
-
|
287
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
288
|
-
if (candidates->data[i].logit >= min_logit) {
|
289
|
-
filtered_tokens.push_back(candidates->data[i]);
|
290
|
-
}
|
291
|
-
}
|
292
|
-
|
293
|
-
// if we have enough values the operation was a success
|
294
|
-
if (filtered_tokens.size() >= min_keep) {
|
295
|
-
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
296
|
-
candidates->size = filtered_tokens.size();
|
297
|
-
min_p_applied = true;
|
298
|
-
}
|
299
|
-
}
|
780
|
+
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
781
|
+
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
782
|
+
return llama_sampler_init_xtc(ctx->xtc_p, ctx->xtc_t, ctx->min_keep, ctx->seed);
|
783
|
+
}
|
300
784
|
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
if (!candidates->sorted) {
|
305
|
-
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
306
|
-
return a.logit > b.logit;
|
307
|
-
});
|
308
|
-
candidates->sorted = true;
|
309
|
-
}
|
785
|
+
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
786
|
+
delete (const llama_sampler_xtc *) smpl->ctx;
|
787
|
+
}
|
310
788
|
|
311
|
-
|
312
|
-
|
789
|
+
static struct llama_sampler_i llama_sampler_xtc_i = {
|
790
|
+
/* .name = */ llama_sampler_xtc_name,
|
791
|
+
/* .accept = */ nullptr,
|
792
|
+
/* .apply = */ llama_sampler_xtc_apply,
|
793
|
+
/* .reset = */ nullptr,
|
794
|
+
/* .clone = */ llama_sampler_xtc_clone,
|
795
|
+
/* .free = */ llama_sampler_xtc_free,
|
796
|
+
};
|
797
|
+
|
798
|
+
struct llama_sampler * llama_sampler_init_xtc(float xtc_p, float xtc_t, size_t min_keep, uint32_t seed) {
|
799
|
+
return new llama_sampler {
|
800
|
+
/* .iface = */ &llama_sampler_xtc_i,
|
801
|
+
/* .ctx = */ new llama_sampler_xtc {
|
802
|
+
/* .seed = */ seed,
|
803
|
+
/* .rng = */ std::mt19937(seed),
|
804
|
+
/* .xtc_p = */ xtc_p,
|
805
|
+
/* .xtc_t = */ xtc_t,
|
806
|
+
/* .min_keep = */ min_keep
|
807
|
+
},
|
808
|
+
};
|
809
|
+
}
|
313
810
|
|
314
|
-
|
315
|
-
if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
316
|
-
break; // prob too small
|
317
|
-
}
|
318
|
-
}
|
811
|
+
// tail-free
|
319
812
|
|
320
|
-
|
321
|
-
|
322
|
-
|
813
|
+
struct llama_sampler_tail_free {
|
814
|
+
const float z;
|
815
|
+
const size_t min_keep;
|
816
|
+
};
|
323
817
|
|
324
|
-
|
325
|
-
|
326
|
-
}
|
818
|
+
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
|
819
|
+
return "tail-free";
|
327
820
|
}
|
328
821
|
|
329
|
-
void
|
330
|
-
|
822
|
+
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
823
|
+
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
|
824
|
+
|
825
|
+
if (ctx->z >= 1.0f || cur_p->size <= 2) {
|
331
826
|
return;
|
332
827
|
}
|
333
828
|
|
334
|
-
|
335
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
829
|
+
llama_sampler_softmax_impl(cur_p);
|
336
830
|
|
337
831
|
// Compute the first and second derivatives
|
338
|
-
std::vector<float> first_derivatives(
|
339
|
-
std::vector<float> second_derivatives(
|
832
|
+
std::vector<float> first_derivatives(cur_p->size - 1);
|
833
|
+
std::vector<float> second_derivatives(cur_p->size - 2);
|
340
834
|
|
341
835
|
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
342
|
-
first_derivatives[i] =
|
836
|
+
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
|
343
837
|
}
|
344
838
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
345
839
|
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
@@ -366,51 +860,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
|
366
860
|
}
|
367
861
|
|
368
862
|
float cum_sum = 0.0f;
|
369
|
-
size_t last_idx =
|
863
|
+
size_t last_idx = cur_p->size;
|
370
864
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
371
865
|
cum_sum += second_derivatives[i];
|
372
866
|
|
373
867
|
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
374
|
-
if (cum_sum > z && i >= min_keep) {
|
868
|
+
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
375
869
|
last_idx = i;
|
376
870
|
break;
|
377
871
|
}
|
378
872
|
}
|
379
873
|
|
380
874
|
// Resize the output vector to keep only the tokens above the tail location
|
381
|
-
|
875
|
+
cur_p->size = last_idx;
|
876
|
+
}
|
382
877
|
|
383
|
-
|
384
|
-
|
385
|
-
|
878
|
+
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
879
|
+
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
|
880
|
+
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
881
|
+
}
|
882
|
+
|
883
|
+
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
884
|
+
delete (llama_sampler_tail_free *) smpl->ctx;
|
386
885
|
}
|
387
886
|
|
388
|
-
|
887
|
+
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
888
|
+
/* .name = */ llama_sampler_tail_free_name,
|
889
|
+
/* .accept = */ nullptr,
|
890
|
+
/* .apply = */ llama_sampler_tail_free_apply,
|
891
|
+
/* .reset = */ nullptr,
|
892
|
+
/* .clone = */ llama_sampler_tail_free_clone,
|
893
|
+
/* .free = */ llama_sampler_tail_free_free,
|
894
|
+
};
|
895
|
+
|
896
|
+
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
897
|
+
return new llama_sampler {
|
898
|
+
/* .iface = */ &llama_sampler_tail_free_i,
|
899
|
+
/* .ctx = */ new llama_sampler_tail_free {
|
900
|
+
/* .z = */ z,
|
901
|
+
/*. min_keep = */ min_keep,
|
902
|
+
},
|
903
|
+
};
|
904
|
+
}
|
905
|
+
|
906
|
+
// typical
|
907
|
+
|
908
|
+
struct llama_sampler_typical {
|
909
|
+
const float p;
|
910
|
+
const size_t min_keep;
|
911
|
+
};
|
912
|
+
|
913
|
+
static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
|
914
|
+
return "typical";
|
915
|
+
}
|
916
|
+
|
917
|
+
static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
918
|
+
const auto * ctx = (llama_sampler_typical *) smpl->ctx;
|
919
|
+
|
389
920
|
// Reference implementation:
|
390
921
|
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
391
|
-
if (p >= 1.0f) {
|
922
|
+
if (ctx->p >= 1.0f) {
|
392
923
|
return;
|
393
924
|
}
|
394
925
|
|
395
926
|
// Compute the softmax of logits and calculate entropy
|
396
|
-
|
397
|
-
|
398
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
927
|
+
llama_sampler_softmax_impl(cur_p);
|
399
928
|
|
400
929
|
float entropy = 0.0f;
|
401
|
-
for (size_t i = 0; i <
|
402
|
-
entropy += -
|
930
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
931
|
+
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
|
403
932
|
}
|
404
933
|
|
405
934
|
// Compute the absolute difference between negative log probability and entropy for each candidate
|
406
935
|
std::vector<float> shifted_scores;
|
407
|
-
for (size_t i = 0; i <
|
408
|
-
float shifted_score = fabsf(-logf(
|
936
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
937
|
+
float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
|
409
938
|
shifted_scores.push_back(shifted_score);
|
410
939
|
}
|
411
940
|
|
412
941
|
// Sort tokens based on the shifted_scores and their corresponding indices
|
413
|
-
std::vector<size_t> indices(
|
942
|
+
std::vector<size_t> indices(cur_p->size);
|
414
943
|
std::iota(indices.begin(), indices.end(), 0);
|
415
944
|
|
416
945
|
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
@@ -423,197 +952,250 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|
423
952
|
|
424
953
|
for (size_t i = 0; i < indices.size(); ++i) {
|
425
954
|
size_t idx = indices[i];
|
426
|
-
cum_sum +=
|
955
|
+
cum_sum += cur_p->data[idx].p;
|
427
956
|
|
428
957
|
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
429
|
-
if (cum_sum > p && i >= min_keep - 1) {
|
958
|
+
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
|
430
959
|
last_idx = i + 1;
|
431
960
|
break;
|
432
961
|
}
|
433
962
|
}
|
434
963
|
|
435
964
|
// Resize the output vector to keep only the locally typical tokens
|
436
|
-
std::vector<llama_token_data>
|
965
|
+
std::vector<llama_token_data> cur_p_new;
|
437
966
|
for (size_t i = 0; i < last_idx; ++i) {
|
438
967
|
size_t idx = indices[i];
|
439
|
-
|
968
|
+
cur_p_new.push_back(cur_p->data[idx]);
|
440
969
|
}
|
441
970
|
|
442
|
-
// Replace the data in
|
443
|
-
std::copy(
|
444
|
-
|
445
|
-
|
971
|
+
// Replace the data in cur_p with the cur_p_new data
|
972
|
+
std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
|
973
|
+
cur_p->size = cur_p_new.size();
|
974
|
+
cur_p->sorted = false;
|
975
|
+
}
|
446
976
|
|
447
|
-
|
448
|
-
|
449
|
-
|
977
|
+
static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
|
978
|
+
const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
|
979
|
+
return llama_sampler_init_typical(ctx->p, ctx->min_keep);
|
450
980
|
}
|
451
981
|
|
452
|
-
void
|
453
|
-
|
982
|
+
static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
983
|
+
delete (llama_sampler_typical *) smpl->ctx;
|
984
|
+
}
|
454
985
|
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
986
|
+
static struct llama_sampler_i llama_sampler_typical_i = {
|
987
|
+
/* .name = */ llama_sampler_typical_name,
|
988
|
+
/* .accept = */ nullptr,
|
989
|
+
/* .apply = */ llama_sampler_typical_apply,
|
990
|
+
/* .reset = */ nullptr,
|
991
|
+
/* .clone = */ llama_sampler_typical_clone,
|
992
|
+
/* .free = */ llama_sampler_typical_free,
|
993
|
+
};
|
994
|
+
|
995
|
+
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
996
|
+
return new llama_sampler {
|
997
|
+
/* .iface = */ &llama_sampler_typical_i,
|
998
|
+
/* .ctx = */ new llama_sampler_typical {
|
999
|
+
/* .p = */ p,
|
1000
|
+
/* .min_keep = */ min_keep,
|
1001
|
+
},
|
1002
|
+
};
|
1003
|
+
}
|
459
1004
|
|
460
|
-
|
461
|
-
float max_entropy = -logf(1.0f / candidates->size);
|
1005
|
+
// temp
|
462
1006
|
|
463
|
-
|
1007
|
+
struct llama_sampler_temp {
|
1008
|
+
const float temp;
|
1009
|
+
};
|
464
1010
|
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
float prob = candidates->data[i].p;
|
469
|
-
if (prob > 0.0f) { // Ensure no log(0)
|
470
|
-
entropy -= prob * logf(prob);
|
471
|
-
}
|
472
|
-
}
|
1011
|
+
static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
|
1012
|
+
return "temp";
|
1013
|
+
}
|
473
1014
|
|
474
|
-
|
475
|
-
|
1015
|
+
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1016
|
+
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
1017
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1018
|
+
cur_p->data[i].logit /= ctx->temp;
|
1019
|
+
}
|
1020
|
+
}
|
476
1021
|
|
477
|
-
|
478
|
-
|
1022
|
+
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
1023
|
+
const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
|
1024
|
+
return llama_sampler_init_temp(ctx->temp);
|
1025
|
+
}
|
479
1026
|
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
484
|
-
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
485
|
-
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
486
|
-
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
487
|
-
#endif
|
1027
|
+
static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
1028
|
+
delete (llama_sampler_temp *) smpl->ctx;
|
1029
|
+
}
|
488
1030
|
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
1031
|
+
static struct llama_sampler_i llama_sampler_temp_i = {
|
1032
|
+
/* .name = */ llama_sampler_temp_name,
|
1033
|
+
/* .accept = */ nullptr,
|
1034
|
+
/* .apply = */ llama_sampler_temp_apply,
|
1035
|
+
/* .reset = */ nullptr,
|
1036
|
+
/* .clone = */ llama_sampler_temp_clone,
|
1037
|
+
/* .free = */ llama_sampler_temp_free,
|
1038
|
+
};
|
1039
|
+
|
1040
|
+
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
1041
|
+
return new llama_sampler {
|
1042
|
+
/* .iface = */ &llama_sampler_temp_i,
|
1043
|
+
/* .ctx = */ new llama_sampler_temp {
|
1044
|
+
/*.temp = */ temp,
|
1045
|
+
},
|
1046
|
+
};
|
1047
|
+
}
|
493
1048
|
|
494
|
-
|
495
|
-
double max_l_double = candidates->data[0].logit;
|
496
|
-
double cum_sum_double = 0.0;
|
497
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
498
|
-
double p = exp(candidates->data[i].logit - max_l_double);
|
499
|
-
candidates->data[i].p = p; // Store the scaled probability
|
500
|
-
cum_sum_double += p;
|
501
|
-
}
|
502
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
503
|
-
candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
504
|
-
}
|
1049
|
+
// temp-ext
|
505
1050
|
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
}
|
512
|
-
#endif
|
1051
|
+
struct llama_sampler_temp_ext {
|
1052
|
+
const float temp;
|
1053
|
+
const float delta;
|
1054
|
+
const float exponent;
|
1055
|
+
};
|
513
1056
|
|
514
|
-
|
515
|
-
|
516
|
-
}
|
1057
|
+
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
|
1058
|
+
return "temp-ext";
|
517
1059
|
}
|
518
1060
|
|
519
|
-
void
|
520
|
-
const
|
1061
|
+
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1062
|
+
const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
|
1063
|
+
if (ctx->delta > 0) {
|
1064
|
+
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
1065
|
+
const float max_temp = ctx->temp + ctx->delta;
|
1066
|
+
float exponent_val = ctx->exponent;
|
521
1067
|
|
522
|
-
|
523
|
-
|
524
|
-
|
1068
|
+
// no need to do anything if there is only one (or zero) candidates
|
1069
|
+
if (cur_p->size <= 1) {
|
1070
|
+
return;
|
1071
|
+
}
|
525
1072
|
|
526
|
-
|
527
|
-
|
528
|
-
}
|
529
|
-
}
|
1073
|
+
// Calculate maximum possible entropy
|
1074
|
+
float max_entropy = -logf(1.0f / cur_p->size);
|
530
1075
|
|
531
|
-
|
532
|
-
struct llama_sampling * smpl,
|
533
|
-
llama_token_data_array * candidates,
|
534
|
-
const llama_token * last_tokens,
|
535
|
-
size_t penalty_last_n,
|
536
|
-
float penalty_repeat,
|
537
|
-
float penalty_freq,
|
538
|
-
float penalty_present) {
|
539
|
-
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
540
|
-
return;
|
541
|
-
}
|
1076
|
+
llama_sampler_softmax_impl(cur_p);
|
542
1077
|
|
543
|
-
|
1078
|
+
// Calculate entropy of the softmax probabilities
|
1079
|
+
float entropy = 0.0f;
|
1080
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1081
|
+
float prob = cur_p->data[i].p;
|
1082
|
+
if (prob > 0.0f) { // Ensure no log(0)
|
1083
|
+
entropy -= prob * logf(prob);
|
1084
|
+
}
|
1085
|
+
}
|
544
1086
|
|
545
|
-
|
546
|
-
|
547
|
-
for (size_t i = 0; i < penalty_last_n; ++i) {
|
548
|
-
token_count[last_tokens[i]]++;
|
549
|
-
}
|
1087
|
+
// Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
|
1088
|
+
float normalized_entropy = entropy / max_entropy;
|
550
1089
|
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
1090
|
+
// Map the normalized entropy to the desired temperature range using the power function
|
1091
|
+
float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
|
1092
|
+
|
1093
|
+
#ifdef DEBUG
|
1094
|
+
LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
|
1095
|
+
LLAMA_LOG_INFO("Entropy: %f\n", entropy);
|
1096
|
+
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
1097
|
+
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
1098
|
+
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
1099
|
+
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
1100
|
+
#endif
|
1101
|
+
|
1102
|
+
// Apply the dynamically calculated temperature scaling
|
1103
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1104
|
+
cur_p->data[i].logit /= dyn_temp;
|
556
1105
|
}
|
557
1106
|
|
558
|
-
|
1107
|
+
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
1108
|
+
const double max_l_double = cur_p->data[0].logit;
|
559
1109
|
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
1110
|
+
double cum_sum_double = 0.0;
|
1111
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1112
|
+
double p = exp(cur_p->data[i].logit - max_l_double);
|
1113
|
+
cur_p->data[i].p = p; // Store the scaled probability
|
1114
|
+
cum_sum_double += p;
|
1115
|
+
}
|
1116
|
+
|
1117
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1118
|
+
cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
566
1119
|
}
|
567
1120
|
|
568
|
-
|
1121
|
+
#ifdef DEBUG
|
1122
|
+
// Print the updated top 25 probabilities after temperature scaling
|
1123
|
+
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
|
1124
|
+
for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
|
1125
|
+
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
|
1126
|
+
}
|
1127
|
+
#endif
|
1128
|
+
} else {
|
1129
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1130
|
+
cur_p->data[i].logit /= ctx->temp;
|
1131
|
+
}
|
569
1132
|
}
|
1133
|
+
}
|
570
1134
|
|
571
|
-
|
1135
|
+
static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
|
1136
|
+
const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
|
1137
|
+
return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
|
1138
|
+
}
|
572
1139
|
|
573
|
-
|
574
|
-
|
575
|
-
}
|
1140
|
+
static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
1141
|
+
delete (llama_sampler_temp_ext *) smpl->ctx;
|
576
1142
|
}
|
577
1143
|
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
1144
|
+
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
1145
|
+
/* .name = */ llama_sampler_temp_ext_name,
|
1146
|
+
/* .accept = */ nullptr,
|
1147
|
+
/* .apply = */ llama_sampler_temp_ext_apply,
|
1148
|
+
/* .reset = */ nullptr,
|
1149
|
+
/* .clone = */ llama_sampler_temp_ext_clone,
|
1150
|
+
/* .free = */ llama_sampler_temp_ext_free,
|
1151
|
+
};
|
1152
|
+
|
1153
|
+
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
1154
|
+
return new llama_sampler {
|
1155
|
+
/* .iface = */ &llama_sampler_temp_ext_i,
|
1156
|
+
/* .ctx = */ new llama_sampler_temp_ext {
|
1157
|
+
/* .temp = */ temp,
|
1158
|
+
/* .delta = */ delta,
|
1159
|
+
/* .exponent = */ exponent,
|
1160
|
+
},
|
1161
|
+
};
|
1162
|
+
}
|
584
1163
|
|
585
|
-
|
586
|
-
const auto n_vocab = smpl->n_vocab;
|
1164
|
+
// mirostat
|
587
1165
|
|
588
|
-
|
589
|
-
|
1166
|
+
struct llama_sampler_mirostat {
|
1167
|
+
const int32_t n_vocab;
|
590
1168
|
|
591
|
-
|
592
|
-
auto & l = logits[i];
|
593
|
-
const auto & g = logits_guidance[i];
|
1169
|
+
const uint32_t seed;
|
594
1170
|
|
595
|
-
|
596
|
-
|
1171
|
+
const float tau;
|
1172
|
+
const float eta;
|
597
1173
|
|
598
|
-
|
599
|
-
}
|
1174
|
+
const int32_t m;
|
600
1175
|
|
601
|
-
|
602
|
-
LM_GGML_ASSERT(smpl);
|
1176
|
+
float mu;
|
603
1177
|
|
604
|
-
|
1178
|
+
std::mt19937 rng;
|
605
1179
|
|
606
|
-
|
1180
|
+
std::vector<float> probs;
|
1181
|
+
};
|
1182
|
+
|
1183
|
+
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
|
1184
|
+
return "mirostat";
|
1185
|
+
}
|
607
1186
|
|
608
|
-
|
1187
|
+
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1188
|
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
1189
|
+
|
1190
|
+
llama_sampler_softmax_impl(cur_p);
|
609
1191
|
|
610
1192
|
// Estimate s_hat using the most probable m tokens
|
611
1193
|
float s_hat = 0.0;
|
612
1194
|
float sum_ti_bi = 0.0;
|
613
1195
|
float sum_ti_sq = 0.0;
|
614
|
-
for (size_t i = 0; i < size_t(m - 1) && i <
|
1196
|
+
for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
|
615
1197
|
float t_i = logf(float(i + 2) / float(i + 1));
|
616
|
-
float b_i = logf(
|
1198
|
+
float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
|
617
1199
|
sum_ti_bi += t_i * b_i;
|
618
1200
|
sum_ti_sq += t_i * t_i;
|
619
1201
|
}
|
@@ -621,109 +1203,532 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
|
621
1203
|
|
622
1204
|
// Compute k from the estimated s_hat and target surprise value
|
623
1205
|
float epsilon_hat = s_hat - 1;
|
624
|
-
float k = powf((epsilon_hat * powf(2,
|
1206
|
+
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
|
625
1207
|
|
626
|
-
|
627
|
-
|
628
|
-
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
629
|
-
llama_token X = llama_sample_token_impl(smpl, candidates);
|
630
|
-
t_start_sample_us = lm_ggml_time_us();
|
1208
|
+
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
1209
|
+
llama_sampler_softmax_impl(cur_p);
|
631
1210
|
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
float observed_surprise = -log2f(
|
637
|
-
float e = observed_surprise - tau;
|
1211
|
+
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
1212
|
+
|
1213
|
+
cur_p->selected = idx;
|
1214
|
+
|
1215
|
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
1216
|
+
float e = observed_surprise - ctx->tau;
|
638
1217
|
|
639
1218
|
// Update mu using the learning rate and error
|
640
|
-
|
1219
|
+
ctx->mu = ctx->mu - ctx->eta * e;
|
1220
|
+
}
|
1221
|
+
|
1222
|
+
static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
|
1223
|
+
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
|
1224
|
+
auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
|
1225
|
+
|
1226
|
+
// copy the state
|
1227
|
+
{
|
1228
|
+
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
|
1229
|
+
|
1230
|
+
result_ctx->mu = ctx->mu;
|
1231
|
+
result_ctx->rng = ctx->rng;
|
1232
|
+
}
|
1233
|
+
|
1234
|
+
return result;
|
1235
|
+
}
|
1236
|
+
|
1237
|
+
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
|
1238
|
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
1239
|
+
ctx->mu = 2.0f*ctx->tau;
|
1240
|
+
ctx->rng = std::mt19937(ctx->seed);
|
1241
|
+
}
|
1242
|
+
|
1243
|
+
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
1244
|
+
delete (llama_sampler_mirostat *) smpl->ctx;
|
1245
|
+
}
|
641
1246
|
|
642
|
-
|
643
|
-
|
1247
|
+
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
1248
|
+
/* .name = */ llama_sampler_mirostat_name,
|
1249
|
+
/* .accept = */ nullptr,
|
1250
|
+
/* .apply = */ llama_sampler_mirostat_apply,
|
1251
|
+
/* .reset = */ llama_sampler_mirostat_reset,
|
1252
|
+
/* .clone = */ llama_sampler_mirostat_clone,
|
1253
|
+
/* .free = */ llama_sampler_mirostat_free,
|
1254
|
+
};
|
1255
|
+
|
1256
|
+
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
1257
|
+
return new llama_sampler {
|
1258
|
+
/* .iface = */ &llama_sampler_mirostat_i,
|
1259
|
+
/* .ctx = */ new llama_sampler_mirostat {
|
1260
|
+
/* .n_vocab = */ n_vocab,
|
1261
|
+
/* .seed = */ seed,
|
1262
|
+
/* .tau = */ tau,
|
1263
|
+
/* .eta = */ eta,
|
1264
|
+
/* .m = */ m,
|
1265
|
+
/* .mu = */ 2.0f*tau,
|
1266
|
+
/* .rng = */ std::mt19937(seed),
|
1267
|
+
/* .probs = */ {},
|
1268
|
+
},
|
1269
|
+
};
|
644
1270
|
}
|
645
1271
|
|
646
|
-
|
647
|
-
|
648
|
-
|
1272
|
+
// mirostat v2
|
1273
|
+
|
1274
|
+
struct llama_sampler_mirostat_v2 {
|
1275
|
+
const uint32_t seed;
|
649
1276
|
|
650
|
-
|
1277
|
+
const float tau;
|
1278
|
+
const float eta;
|
1279
|
+
|
1280
|
+
float mu;
|
1281
|
+
|
1282
|
+
std::mt19937 rng;
|
1283
|
+
|
1284
|
+
std::vector<float> probs;
|
1285
|
+
};
|
1286
|
+
|
1287
|
+
static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
|
1288
|
+
return "mirostat-v2";
|
1289
|
+
}
|
1290
|
+
|
1291
|
+
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1292
|
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
1293
|
+
|
1294
|
+
llama_sampler_softmax_impl(cur_p);
|
651
1295
|
|
652
1296
|
// Truncate the words with surprise values greater than mu
|
653
|
-
|
654
|
-
return -log2f(candidate.p) >
|
1297
|
+
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
1298
|
+
return -log2f(candidate.p) > ctx->mu;
|
655
1299
|
}));
|
656
1300
|
|
657
|
-
if (
|
658
|
-
|
659
|
-
}
|
660
|
-
|
661
|
-
if (smpl) {
|
662
|
-
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
1301
|
+
if (cur_p->size == 0) {
|
1302
|
+
cur_p->size = 1;
|
663
1303
|
}
|
664
1304
|
|
665
1305
|
// Normalize the probabilities of the remaining words
|
666
|
-
|
1306
|
+
llama_sampler_softmax_impl(cur_p);
|
667
1307
|
|
668
|
-
|
669
|
-
llama_token X = llama_sample_token_impl(smpl, candidates);
|
670
|
-
t_start_sample_us = lm_ggml_time_us();
|
1308
|
+
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
671
1309
|
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
677
|
-
float e = observed_surprise - tau;
|
1310
|
+
cur_p->selected = idx;
|
1311
|
+
|
1312
|
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
1313
|
+
float e = observed_surprise - ctx->tau;
|
678
1314
|
|
679
1315
|
// Update mu using the learning rate and error
|
680
|
-
|
1316
|
+
ctx->mu = ctx->mu - ctx->eta * e;
|
1317
|
+
}
|
1318
|
+
|
1319
|
+
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
|
1320
|
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
1321
|
+
ctx->mu = 2.0f*ctx->tau;
|
1322
|
+
ctx->rng = std::mt19937(ctx->seed);
|
1323
|
+
}
|
1324
|
+
|
1325
|
+
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
|
1326
|
+
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
|
1327
|
+
|
1328
|
+
auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
|
1329
|
+
|
1330
|
+
// copy the state
|
1331
|
+
{
|
1332
|
+
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
|
681
1333
|
|
682
|
-
|
683
|
-
|
1334
|
+
result_ctx->mu = ctx->mu;
|
1335
|
+
result_ctx->rng = ctx->rng;
|
684
1336
|
}
|
685
|
-
|
1337
|
+
|
1338
|
+
return result;
|
686
1339
|
}
|
687
1340
|
|
688
|
-
|
689
|
-
|
1341
|
+
static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
1342
|
+
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
|
1343
|
+
}
|
690
1344
|
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
1345
|
+
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
1346
|
+
/* .name = */ llama_sampler_mirostat_v2_name,
|
1347
|
+
/* .accept = */ nullptr,
|
1348
|
+
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
1349
|
+
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
1350
|
+
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
1351
|
+
/* .free = */ llama_sampler_mirostat_v2_free,
|
1352
|
+
};
|
1353
|
+
|
1354
|
+
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
1355
|
+
return new llama_sampler {
|
1356
|
+
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
1357
|
+
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
1358
|
+
/* .seed = */ seed,
|
1359
|
+
/* .tau = */ tau,
|
1360
|
+
/* .eta = */ eta,
|
1361
|
+
/* .mu = */ 2.0f*tau,
|
1362
|
+
/* .rng = */ std::mt19937(seed),
|
1363
|
+
/* .probs = */ {},
|
1364
|
+
},
|
1365
|
+
};
|
1366
|
+
}
|
695
1367
|
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
1368
|
+
// grammar
|
1369
|
+
|
1370
|
+
struct llama_sampler_grammar {
|
1371
|
+
const struct llama_vocab * vocab;
|
1372
|
+
|
1373
|
+
std::string grammar_str;
|
1374
|
+
std::string grammar_root;
|
1375
|
+
|
1376
|
+
struct llama_grammar * grammar;
|
1377
|
+
};
|
1378
|
+
|
1379
|
+
static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
|
1380
|
+
return "grammar";
|
1381
|
+
}
|
1382
|
+
|
1383
|
+
static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
|
1384
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
1385
|
+
if (ctx->grammar) {
|
1386
|
+
llama_grammar_accept_impl(*ctx->grammar, token);
|
700
1387
|
}
|
1388
|
+
}
|
1389
|
+
|
1390
|
+
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1391
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
1392
|
+
if (ctx->grammar) {
|
1393
|
+
llama_grammar_apply_impl(*ctx->grammar, cur_p);
|
1394
|
+
}
|
1395
|
+
}
|
1396
|
+
|
1397
|
+
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
1398
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
1399
|
+
if (!ctx->grammar) {
|
1400
|
+
return;
|
1401
|
+
}
|
1402
|
+
|
1403
|
+
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
|
1404
|
+
|
1405
|
+
llama_grammar_free_impl(ctx->grammar);
|
1406
|
+
ctx->grammar = grammar_new;
|
1407
|
+
}
|
1408
|
+
|
1409
|
+
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
1410
|
+
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
1411
|
+
|
1412
|
+
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
|
1413
|
+
|
1414
|
+
// copy the state
|
1415
|
+
{
|
1416
|
+
auto * result_ctx = (llama_sampler_grammar *) result->ctx;
|
1417
|
+
|
1418
|
+
if (ctx->grammar) {
|
1419
|
+
result_ctx->grammar_str = ctx->grammar_str;
|
1420
|
+
result_ctx->grammar_root = ctx->grammar_root;
|
1421
|
+
|
1422
|
+
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
|
1423
|
+
}
|
1424
|
+
}
|
1425
|
+
|
701
1426
|
return result;
|
702
1427
|
}
|
703
1428
|
|
704
|
-
|
705
|
-
|
1429
|
+
static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
1430
|
+
const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
706
1431
|
|
707
|
-
|
708
|
-
|
1432
|
+
if (ctx->grammar) {
|
1433
|
+
llama_grammar_free_impl(ctx->grammar);
|
1434
|
+
}
|
709
1435
|
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
1436
|
+
delete ctx;
|
1437
|
+
}
|
1438
|
+
|
1439
|
+
static struct llama_sampler_i llama_sampler_grammar_i = {
|
1440
|
+
/* .name = */ llama_sampler_grammar_name,
|
1441
|
+
/* .accept = */ llama_sampler_grammar_accept_impl,
|
1442
|
+
/* .apply = */ llama_sampler_grammar_apply,
|
1443
|
+
/* .reset = */ llama_sampler_grammar_reset,
|
1444
|
+
/* .clone = */ llama_sampler_grammar_clone,
|
1445
|
+
/* .free = */ llama_sampler_grammar_free,
|
1446
|
+
};
|
1447
|
+
|
1448
|
+
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
1449
|
+
auto * ctx = new llama_sampler_grammar;
|
1450
|
+
|
1451
|
+
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
1452
|
+
*ctx = {
|
1453
|
+
/* .vocab = */ &vocab,
|
1454
|
+
/* .grammar_str = */ grammar_str,
|
1455
|
+
/* .grammar_root = */ grammar_root,
|
1456
|
+
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
1457
|
+
};
|
1458
|
+
} else {
|
1459
|
+
*ctx = {
|
1460
|
+
/* .vocab = */ &vocab,
|
1461
|
+
/* .grammar_str = */ {},
|
1462
|
+
/* .grammar_root = */ {},
|
1463
|
+
/* .grammar = */ nullptr,
|
1464
|
+
};
|
1465
|
+
}
|
1466
|
+
|
1467
|
+
return new llama_sampler {
|
1468
|
+
/* .iface = */ &llama_sampler_grammar_i,
|
1469
|
+
/* .ctx = */ ctx,
|
1470
|
+
};
|
1471
|
+
}
|
1472
|
+
|
1473
|
+
// penalties
|
1474
|
+
|
1475
|
+
struct llama_sampler_penalties {
|
1476
|
+
const int32_t n_vocab;
|
1477
|
+
const llama_token special_eos_id;
|
1478
|
+
const llama_token linefeed_id;
|
1479
|
+
|
1480
|
+
const int32_t penalty_last_n;
|
1481
|
+
const float penalty_repeat;
|
1482
|
+
const float penalty_freq;
|
1483
|
+
const float penalty_present;
|
1484
|
+
|
1485
|
+
const bool penalize_nl;
|
1486
|
+
const bool ignore_eos;
|
1487
|
+
|
1488
|
+
ring_buffer<llama_token> prev;
|
1489
|
+
};
|
1490
|
+
|
1491
|
+
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
1492
|
+
return "penalties";
|
1493
|
+
}
|
1494
|
+
|
1495
|
+
static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
|
1496
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
1497
|
+
if (ctx->penalty_last_n == 0) {
|
1498
|
+
return;
|
1499
|
+
}
|
1500
|
+
|
1501
|
+
ctx->prev.push_back(token);
|
1502
|
+
}
|
1503
|
+
|
1504
|
+
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1505
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
1506
|
+
|
1507
|
+
if (ctx->ignore_eos) {
|
1508
|
+
assert(ctx->special_eos_id >= 0);
|
1509
|
+
|
1510
|
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
1511
|
+
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
1512
|
+
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
1513
|
+
} else {
|
1514
|
+
// else, search for the special EOS token
|
1515
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1516
|
+
if (cur_p->data[i].id == ctx->special_eos_id) {
|
1517
|
+
cur_p->data[i].logit = -INFINITY;
|
1518
|
+
break;
|
1519
|
+
}
|
1520
|
+
}
|
1521
|
+
}
|
1522
|
+
}
|
1523
|
+
|
1524
|
+
if ((ctx->penalty_last_n == 0) ||
|
1525
|
+
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
1526
|
+
return;
|
1527
|
+
}
|
1528
|
+
|
1529
|
+
bool nl_found = false;
|
1530
|
+
size_t nl_idx = 0;
|
1531
|
+
float nl_logit = -INFINITY;
|
1532
|
+
if (!ctx->penalize_nl) {
|
1533
|
+
assert(ctx->linefeed_id >= 0);
|
1534
|
+
|
1535
|
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
1536
|
+
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
1537
|
+
nl_found = true;
|
1538
|
+
nl_idx = ctx->linefeed_id;
|
1539
|
+
nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
1540
|
+
} else {
|
1541
|
+
// else, search for the linefeed token
|
1542
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1543
|
+
if (cur_p->data[i].id == ctx->linefeed_id) {
|
1544
|
+
nl_found = true;
|
1545
|
+
nl_idx = i;
|
1546
|
+
nl_logit = cur_p->data[i].logit;
|
1547
|
+
break;
|
1548
|
+
}
|
1549
|
+
}
|
1550
|
+
}
|
714
1551
|
}
|
715
1552
|
|
716
|
-
|
717
|
-
|
1553
|
+
// Create a frequency map to count occurrences of each token in last_tokens
|
1554
|
+
// TODO: optimize this by maintaining the token count in the sampler context
|
1555
|
+
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
1556
|
+
llama_token_cnt token_count;
|
718
1557
|
|
719
|
-
|
1558
|
+
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
1559
|
+
token_count[ctx->prev.rat(i)]++;
|
1560
|
+
}
|
720
1561
|
|
721
|
-
|
722
|
-
|
1562
|
+
// Apply frequency and presence penalties to the cur_p
|
1563
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1564
|
+
const auto token_iter = token_count.find(cur_p->data[i].id);
|
1565
|
+
if (token_iter == token_count.end()) {
|
1566
|
+
continue;
|
1567
|
+
}
|
1568
|
+
|
1569
|
+
const int count = token_iter->second;
|
1570
|
+
|
1571
|
+
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
1572
|
+
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
1573
|
+
if (cur_p->data[i].logit <= 0) {
|
1574
|
+
cur_p->data[i].logit *= ctx->penalty_repeat;
|
1575
|
+
} else {
|
1576
|
+
cur_p->data[i].logit /= ctx->penalty_repeat;
|
1577
|
+
}
|
1578
|
+
|
1579
|
+
cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
|
1580
|
+
}
|
1581
|
+
|
1582
|
+
cur_p->sorted = false;
|
1583
|
+
|
1584
|
+
if (!ctx->penalize_nl && nl_found) {
|
1585
|
+
// restore the logit of the newline token if it was penalized
|
1586
|
+
cur_p->data[nl_idx].logit = nl_logit;
|
1587
|
+
}
|
1588
|
+
}
|
1589
|
+
|
1590
|
+
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
1591
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
1592
|
+
ctx->prev.clear();
|
1593
|
+
}
|
1594
|
+
|
1595
|
+
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
1596
|
+
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
1597
|
+
auto * result = llama_sampler_init_penalties(
|
1598
|
+
ctx->n_vocab,
|
1599
|
+
ctx->special_eos_id,
|
1600
|
+
ctx->linefeed_id,
|
1601
|
+
ctx->penalty_last_n,
|
1602
|
+
ctx->penalty_repeat,
|
1603
|
+
ctx->penalty_freq,
|
1604
|
+
ctx->penalty_present,
|
1605
|
+
ctx->penalize_nl,
|
1606
|
+
ctx->ignore_eos);
|
1607
|
+
|
1608
|
+
// copy the state
|
1609
|
+
{
|
1610
|
+
auto * result_ctx = (llama_sampler_penalties *) result->ctx;
|
1611
|
+
|
1612
|
+
result_ctx->prev = ctx->prev;
|
1613
|
+
}
|
723
1614
|
|
724
1615
|
return result;
|
725
1616
|
}
|
726
1617
|
|
727
|
-
|
728
|
-
|
1618
|
+
static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
1619
|
+
delete (llama_sampler_penalties *) smpl->ctx;
|
1620
|
+
}
|
1621
|
+
|
1622
|
+
static struct llama_sampler_i llama_sampler_penalties_i = {
|
1623
|
+
/* .name = */ llama_sampler_penalties_name,
|
1624
|
+
/* .accept = */ llama_sampler_penalties_accept,
|
1625
|
+
/* .apply = */ llama_sampler_penalties_apply,
|
1626
|
+
/* .reset = */ llama_sampler_penalties_reset,
|
1627
|
+
/* .clone = */ llama_sampler_penalties_clone,
|
1628
|
+
/* .free = */ llama_sampler_penalties_free,
|
1629
|
+
};
|
1630
|
+
|
1631
|
+
struct llama_sampler * llama_sampler_init_penalties(
|
1632
|
+
int32_t n_vocab,
|
1633
|
+
llama_token special_eos_id,
|
1634
|
+
llama_token linefeed_id,
|
1635
|
+
int32_t penalty_last_n,
|
1636
|
+
float penalty_repeat,
|
1637
|
+
float penalty_freq,
|
1638
|
+
float penalty_present,
|
1639
|
+
bool penalize_nl,
|
1640
|
+
bool ignore_eos) {
|
1641
|
+
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
1642
|
+
penalize_nl = true;
|
1643
|
+
}
|
1644
|
+
|
1645
|
+
if (special_eos_id == LLAMA_TOKEN_NULL) {
|
1646
|
+
ignore_eos = false;
|
1647
|
+
}
|
1648
|
+
|
1649
|
+
return new llama_sampler {
|
1650
|
+
/* .iface = */ &llama_sampler_penalties_i,
|
1651
|
+
/* .ctx = */ new llama_sampler_penalties {
|
1652
|
+
/* .n_vocab = */ n_vocab,
|
1653
|
+
/* .special_eos_id = */ special_eos_id,
|
1654
|
+
/* .linefeed_id = */ linefeed_id,
|
1655
|
+
/* .penalty_last_n = */ penalty_last_n,
|
1656
|
+
/* .penalty_repeat = */ penalty_repeat,
|
1657
|
+
/* .penalty_freq = */ penalty_freq,
|
1658
|
+
/* .penalty_present = */ penalty_present,
|
1659
|
+
/* .penalize_nl = */ penalize_nl,
|
1660
|
+
/* .ignore_eos = */ ignore_eos,
|
1661
|
+
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
1662
|
+
},
|
1663
|
+
};
|
1664
|
+
}
|
1665
|
+
|
1666
|
+
// logit-bias
|
1667
|
+
|
1668
|
+
struct llama_sampler_logit_bias {
|
1669
|
+
const int32_t n_vocab;
|
1670
|
+
|
1671
|
+
const std::vector<llama_logit_bias> logit_bias;
|
1672
|
+
|
1673
|
+
std::vector<llama_logit_bias> to_search;
|
1674
|
+
};
|
1675
|
+
|
1676
|
+
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
|
1677
|
+
return "logit-bias";
|
1678
|
+
}
|
1679
|
+
|
1680
|
+
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1681
|
+
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
1682
|
+
|
1683
|
+
ctx->to_search.clear();
|
1684
|
+
|
1685
|
+
// update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
|
1686
|
+
for (const auto & lb : ctx->logit_bias) {
|
1687
|
+
if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
|
1688
|
+
cur_p->data[lb.token].logit += lb.bias;
|
1689
|
+
} else {
|
1690
|
+
ctx->to_search.push_back(lb);
|
1691
|
+
}
|
1692
|
+
}
|
1693
|
+
|
1694
|
+
// search for the remaining candidates that were not found in the previous step
|
1695
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1696
|
+
for (const auto & lb : ctx->to_search) {
|
1697
|
+
if (cur_p->data[i].id == lb.token) {
|
1698
|
+
cur_p->data[i].logit += lb.bias;
|
1699
|
+
break;
|
1700
|
+
}
|
1701
|
+
}
|
1702
|
+
}
|
1703
|
+
}
|
1704
|
+
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
|
1705
|
+
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
|
1706
|
+
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
|
1707
|
+
}
|
1708
|
+
|
1709
|
+
static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
1710
|
+
delete (llama_sampler_logit_bias *) smpl->ctx;
|
1711
|
+
}
|
1712
|
+
|
1713
|
+
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
1714
|
+
/* .name = */ llama_sampler_logit_bias_name,
|
1715
|
+
/* .accept = */ nullptr,
|
1716
|
+
/* .apply = */ llama_sampler_logit_bias_apply,
|
1717
|
+
/* .reset = */ nullptr,
|
1718
|
+
/* .clone = */ llama_sampler_logit_bias_clone,
|
1719
|
+
/* .free = */ llama_sampler_logit_bias_free,
|
1720
|
+
};
|
1721
|
+
|
1722
|
+
struct llama_sampler * llama_sampler_init_logit_bias(
|
1723
|
+
int32_t n_vocab,
|
1724
|
+
int32_t n_logit_bias,
|
1725
|
+
const llama_logit_bias * logit_bias) {
|
1726
|
+
return new llama_sampler {
|
1727
|
+
/* .iface = */ &llama_sampler_logit_bias_i,
|
1728
|
+
/* .ctx = */ new llama_sampler_logit_bias {
|
1729
|
+
/* .n_vocab = */ n_vocab,
|
1730
|
+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
1731
|
+
/* .to_search = */ {},
|
1732
|
+
},
|
1733
|
+
};
|
729
1734
|
}
|