cui-llama.rn 1.1.2 → 1.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +1 -2
- package/android/src/main/jni.cpp +26 -21
- package/cpp/common.cpp +181 -1584
- package/cpp/common.h +131 -52
- package/cpp/ggml-aarch64.c +612 -0
- package/cpp/ggml-alloc.h +2 -2
- package/cpp/ggml-backend.c +33 -6
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-common.h +20 -0
- package/cpp/ggml-impl.h +36 -7
- package/cpp/ggml-metal.m +68 -8
- package/cpp/ggml-quants.c +932 -50
- package/cpp/ggml-quants.h +15 -0
- package/cpp/ggml.c +1712 -325
- package/cpp/ggml.h +169 -100
- package/cpp/llama-grammar.cpp +721 -122
- package/cpp/llama-grammar.h +120 -15
- package/cpp/llama-impl.h +132 -1
- package/cpp/llama-sampling.cpp +1483 -354
- package/cpp/llama-sampling.h +20 -48
- package/cpp/llama-vocab.cpp +140 -7
- package/cpp/llama-vocab.h +3 -2
- package/cpp/llama.cpp +824 -327
- package/cpp/llama.h +235 -256
- package/cpp/rn-llama.hpp +18 -14
- package/cpp/sampling.cpp +353 -354
- package/cpp/sampling.h +62 -143
- package/cpp/sgemm.cpp +153 -0
- package/package.json +1 -1
- package/cpp/grammar-parser.cpp +0 -539
- package/cpp/grammar-parser.h +0 -29
package/cpp/llama-sampling.cpp
CHANGED
@@ -1,12 +1,52 @@
|
|
1
1
|
#include "llama-sampling.h"
|
2
2
|
|
3
|
+
#include "llama-vocab.h"
|
4
|
+
#include "llama-grammar.h"
|
5
|
+
|
6
|
+
#include <cassert>
|
3
7
|
#include <algorithm>
|
4
8
|
#include <cstring>
|
5
9
|
#include <ctime>
|
6
10
|
#include <cfloat>
|
11
|
+
#include <chrono>
|
12
|
+
#include <cmath>
|
7
13
|
#include <numeric>
|
14
|
+
#include <random>
|
8
15
|
#include <unordered_map>
|
9
16
|
|
17
|
+
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
18
|
+
// iterator for the probabilities
|
19
|
+
#ifdef __GNUC__
|
20
|
+
#pragma GCC diagnostic push
|
21
|
+
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
22
|
+
#endif
|
23
|
+
|
24
|
+
struct probs_iterator {
|
25
|
+
typedef std::input_iterator_tag iterator_category;
|
26
|
+
typedef float value_type;
|
27
|
+
typedef float * pointer;
|
28
|
+
typedef float & reference;
|
29
|
+
typedef ptrdiff_t difference_type;
|
30
|
+
|
31
|
+
const llama_token_data * data;
|
32
|
+
|
33
|
+
bool operator==(const probs_iterator & other) const { return data == other.data; }
|
34
|
+
bool operator!=(const probs_iterator & other) const { return data != other.data; }
|
35
|
+
const float & operator*() const { return data->p; }
|
36
|
+
probs_iterator & operator++() { ++data; return *this; }
|
37
|
+
probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
|
38
|
+
};
|
39
|
+
|
40
|
+
#ifdef __GNUC__
|
41
|
+
#pragma GCC diagnostic pop
|
42
|
+
#endif
|
43
|
+
|
44
|
+
std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
|
45
|
+
|
46
|
+
return dist(rng);
|
47
|
+
}
|
48
|
+
|
49
|
+
/*
|
10
50
|
static void llama_log_softmax(float * array, size_t size) {
|
11
51
|
float max_l = *std::max_element(array, array + size);
|
12
52
|
float sum = 0.f;
|
@@ -20,66 +60,52 @@ static void llama_log_softmax(float * array, size_t size) {
|
|
20
60
|
array[i] = logf(array[i] / sum);
|
21
61
|
}
|
22
62
|
}
|
63
|
+
*/
|
23
64
|
|
24
|
-
void
|
25
|
-
|
26
|
-
seed = time(NULL);
|
27
|
-
}
|
28
|
-
|
29
|
-
smpl->rng.seed(seed);
|
30
|
-
}
|
31
|
-
|
32
|
-
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
33
|
-
LM_GGML_ASSERT(candidates->size > 0);
|
34
|
-
|
35
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
65
|
+
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
66
|
+
LM_GGML_ASSERT(cur_p->size > 0);
|
36
67
|
|
37
68
|
// Sort the logits in descending order
|
38
|
-
if (!
|
39
|
-
std::sort(
|
69
|
+
if (!cur_p->sorted) {
|
70
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
40
71
|
return a.logit > b.logit;
|
41
72
|
});
|
42
|
-
|
73
|
+
cur_p->sorted = true;
|
43
74
|
}
|
44
75
|
|
45
|
-
float max_l =
|
76
|
+
float max_l = cur_p->data[0].logit;
|
46
77
|
float cum_sum = 0.0f;
|
47
|
-
|
48
|
-
|
49
|
-
|
78
|
+
|
79
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
80
|
+
float p = expf(cur_p->data[i].logit - max_l);
|
81
|
+
cur_p->data[i].p = p;
|
50
82
|
cum_sum += p;
|
51
83
|
}
|
52
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
53
|
-
candidates->data[i].p /= cum_sum;
|
54
|
-
}
|
55
84
|
|
56
|
-
|
57
|
-
|
85
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
86
|
+
cur_p->data[i].p /= cum_sum;
|
58
87
|
}
|
59
88
|
}
|
60
89
|
|
61
|
-
void
|
90
|
+
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
62
91
|
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
63
|
-
// if (k >= (int32_t)
|
92
|
+
// if (k >= (int32_t)cur_p->size) {
|
64
93
|
// return;
|
65
94
|
// }
|
66
95
|
|
67
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
68
|
-
|
69
96
|
if (k <= 0) {
|
70
|
-
k =
|
97
|
+
k = cur_p->size;
|
71
98
|
}
|
72
99
|
|
73
|
-
k = std::
|
74
|
-
k = std::min(k, (int) candidates->size);
|
100
|
+
k = std::min(k, (int) cur_p->size);
|
75
101
|
|
76
102
|
// Sort scores in descending order
|
77
|
-
if (!
|
103
|
+
if (!cur_p->sorted) {
|
78
104
|
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
79
105
|
return a.logit > b.logit;
|
80
106
|
};
|
81
107
|
if (k <= 128) {
|
82
|
-
std::partial_sort(
|
108
|
+
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
|
83
109
|
} else {
|
84
110
|
constexpr int nbuckets = 128;
|
85
111
|
constexpr float bucket_low = -10.0f;
|
@@ -87,11 +113,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
87
113
|
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
88
114
|
constexpr float bucket_inter = -bucket_low * bucket_scale;
|
89
115
|
|
90
|
-
std::vector<int> bucket_idx(
|
116
|
+
std::vector<int> bucket_idx(cur_p->size);
|
91
117
|
std::vector<int> histo(nbuckets, 0);
|
92
118
|
|
93
|
-
for (int i = 0; i < (int)
|
94
|
-
const float val =
|
119
|
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
120
|
+
const float val = cur_p->data[i].logit;
|
95
121
|
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
96
122
|
ib = std::max(0, std::min(nbuckets-1, ib));
|
97
123
|
bucket_idx[i] = ib;
|
@@ -101,20 +127,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
101
127
|
int ib = nbuckets - 1;
|
102
128
|
for ( ; ib >= 0; --ib) {
|
103
129
|
nhave += histo[ib];
|
104
|
-
if (nhave >= k)
|
130
|
+
if (nhave >= k) {
|
131
|
+
break;
|
132
|
+
}
|
105
133
|
}
|
106
134
|
std::vector<llama_token_data> tmp_tokens(nhave);
|
107
|
-
auto ptr = tmp_tokens.data();
|
135
|
+
auto * ptr = tmp_tokens.data();
|
108
136
|
std::vector<llama_token_data*> bucket_ptrs;
|
109
137
|
bucket_ptrs.reserve(nbuckets - ib);
|
110
138
|
for (int j = nbuckets - 1; j >= ib; --j) {
|
111
139
|
bucket_ptrs.push_back(ptr);
|
112
140
|
ptr += histo[j];
|
113
141
|
}
|
114
|
-
for (int i = 0; i < (int)
|
142
|
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
115
143
|
int j = bucket_idx[i];
|
116
144
|
if (j >= ib) {
|
117
|
-
*bucket_ptrs[nbuckets-1-j]++ =
|
145
|
+
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
|
118
146
|
}
|
119
147
|
}
|
120
148
|
|
@@ -127,69 +155,606 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
127
155
|
}
|
128
156
|
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
129
157
|
|
130
|
-
std::memcpy(
|
158
|
+
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
159
|
+
|
160
|
+
}
|
161
|
+
cur_p->sorted = true;
|
162
|
+
}
|
163
|
+
cur_p->size = k;
|
164
|
+
}
|
131
165
|
|
166
|
+
static uint32_t get_rng_seed(uint32_t seed) {
|
167
|
+
if (seed == LLAMA_DEFAULT_SEED) {
|
168
|
+
// use system clock if std::random_device is not a true RNG
|
169
|
+
static bool is_rd_prng = std::random_device().entropy() == 0;
|
170
|
+
if (is_rd_prng) {
|
171
|
+
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
|
132
172
|
}
|
133
|
-
|
173
|
+
std::random_device rd;
|
174
|
+
return rd();
|
175
|
+
}
|
176
|
+
return seed;
|
177
|
+
}
|
178
|
+
|
179
|
+
// llama_sampler API
|
180
|
+
|
181
|
+
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
182
|
+
if (!smpl->iface) {
|
183
|
+
return "(null)";
|
134
184
|
}
|
135
|
-
candidates->size = k;
|
136
185
|
|
137
|
-
|
138
|
-
|
186
|
+
return smpl->iface->name(smpl);
|
187
|
+
}
|
188
|
+
|
189
|
+
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
190
|
+
if (smpl->iface->accept) {
|
191
|
+
smpl->iface->accept(smpl, token);
|
192
|
+
}
|
193
|
+
}
|
194
|
+
|
195
|
+
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
196
|
+
LM_GGML_ASSERT(smpl->iface->apply);
|
197
|
+
smpl->iface->apply(smpl, cur_p);
|
198
|
+
}
|
199
|
+
|
200
|
+
void llama_sampler_reset(struct llama_sampler * smpl) {
|
201
|
+
if (smpl->iface->reset) {
|
202
|
+
smpl->iface->reset(smpl);
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
207
|
+
if (smpl->iface->clone) {
|
208
|
+
return smpl->iface->clone(smpl);
|
209
|
+
}
|
210
|
+
|
211
|
+
if (smpl->ctx == nullptr) {
|
212
|
+
return new llama_sampler {
|
213
|
+
/* .iface = */ smpl->iface,
|
214
|
+
/* .ctx = */ nullptr,
|
215
|
+
};
|
139
216
|
}
|
217
|
+
|
218
|
+
LM_GGML_ABORT("the sampler does not support cloning");
|
140
219
|
}
|
141
220
|
|
142
|
-
void
|
143
|
-
if (
|
221
|
+
void llama_sampler_free(struct llama_sampler * smpl) {
|
222
|
+
if (smpl == nullptr) {
|
144
223
|
return;
|
145
224
|
}
|
146
225
|
|
147
|
-
|
226
|
+
if (smpl->iface->free) {
|
227
|
+
smpl->iface->free(smpl);
|
228
|
+
}
|
148
229
|
|
149
|
-
|
230
|
+
delete smpl;
|
231
|
+
}
|
232
|
+
|
233
|
+
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
234
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
235
|
+
|
236
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
237
|
+
|
238
|
+
// TODO: do not allocate each time
|
239
|
+
std::vector<llama_token_data> cur(n_vocab);
|
240
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
241
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
242
|
+
}
|
243
|
+
|
244
|
+
llama_token_data_array cur_p = {
|
245
|
+
/* .data = */ cur.data(),
|
246
|
+
/* .size = */ cur.size(),
|
247
|
+
/* .selected = */ -1,
|
248
|
+
/* .sorted = */ false,
|
249
|
+
};
|
250
|
+
|
251
|
+
llama_sampler_apply(smpl, &cur_p);
|
252
|
+
|
253
|
+
LM_GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
254
|
+
|
255
|
+
auto token = cur_p.data[cur_p.selected].id;
|
256
|
+
|
257
|
+
llama_sampler_accept(smpl, token);
|
258
|
+
|
259
|
+
return token;
|
260
|
+
}
|
261
|
+
|
262
|
+
// sampler chain
|
263
|
+
|
264
|
+
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
265
|
+
return "chain";
|
266
|
+
}
|
267
|
+
|
268
|
+
static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
|
269
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
270
|
+
|
271
|
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
272
|
+
|
273
|
+
for (auto * smpl : chain->samplers) {
|
274
|
+
llama_sampler_accept(smpl, token);
|
275
|
+
}
|
276
|
+
|
277
|
+
chain->n_sample++;
|
278
|
+
}
|
279
|
+
|
280
|
+
static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
281
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
282
|
+
|
283
|
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
284
|
+
|
285
|
+
for (auto * smpl : chain->samplers) {
|
286
|
+
llama_sampler_apply(smpl, cur_p);
|
287
|
+
}
|
288
|
+
}
|
289
|
+
|
290
|
+
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
291
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
292
|
+
|
293
|
+
for (auto * smpl : chain->samplers) {
|
294
|
+
llama_sampler_reset(smpl);
|
295
|
+
}
|
296
|
+
|
297
|
+
chain->t_sample_us = 0;
|
298
|
+
chain->n_sample = 0;
|
299
|
+
}
|
300
|
+
|
301
|
+
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
302
|
+
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
303
|
+
|
304
|
+
auto * result = llama_sampler_chain_init(chain_src->params);
|
305
|
+
|
306
|
+
for (auto * smpl : chain_src->samplers) {
|
307
|
+
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
308
|
+
}
|
309
|
+
|
310
|
+
return result;
|
311
|
+
}
|
312
|
+
|
313
|
+
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
314
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
315
|
+
|
316
|
+
for (auto * smpl : chain->samplers) {
|
317
|
+
llama_sampler_free(smpl);
|
318
|
+
}
|
319
|
+
|
320
|
+
delete chain;
|
321
|
+
}
|
322
|
+
|
323
|
+
static struct llama_sampler_i llama_sampler_chain_i = {
|
324
|
+
/* .name = */ llama_sampler_chain_name,
|
325
|
+
/* .accept = */ llama_sampler_chain_accept,
|
326
|
+
/* .apply = */ llama_sampler_chain_apply,
|
327
|
+
/* .reset = */ llama_sampler_chain_reset,
|
328
|
+
/* .clone = */ llama_sampler_chain_clone,
|
329
|
+
/* .free = */ llama_sampler_chain_free,
|
330
|
+
};
|
331
|
+
|
332
|
+
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
333
|
+
return new llama_sampler {
|
334
|
+
/* .iface = */ &llama_sampler_chain_i,
|
335
|
+
/* .ctx = */ new llama_sampler_chain {
|
336
|
+
/* .params = */ params,
|
337
|
+
/* .samplers = */ {},
|
338
|
+
/* .t_sample_us = */ 0,
|
339
|
+
/* .n_sample = */ 0,
|
340
|
+
},
|
341
|
+
};
|
342
|
+
}
|
343
|
+
|
344
|
+
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
345
|
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
346
|
+
p->samplers.push_back(smpl);
|
347
|
+
}
|
348
|
+
|
349
|
+
llama_sampler_timings llama_sampler_chain_timings(struct llama_sampler * chain) {
|
350
|
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
351
|
+
struct llama_sampler_timings result = {
|
352
|
+
p -> t_sample_us,
|
353
|
+
p -> n_sample
|
354
|
+
};
|
355
|
+
return result;
|
356
|
+
}
|
357
|
+
|
358
|
+
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
359
|
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
360
|
+
|
361
|
+
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
362
|
+
return nullptr;
|
363
|
+
}
|
364
|
+
|
365
|
+
return p->samplers[i];
|
366
|
+
}
|
367
|
+
|
368
|
+
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
369
|
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
370
|
+
|
371
|
+
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
372
|
+
return nullptr;
|
373
|
+
}
|
374
|
+
|
375
|
+
auto * result = p->samplers[i];
|
376
|
+
p->samplers.erase(p->samplers.begin() + i);
|
377
|
+
|
378
|
+
return result;
|
379
|
+
}
|
380
|
+
|
381
|
+
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
382
|
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
383
|
+
|
384
|
+
return p->samplers.size();
|
385
|
+
}
|
386
|
+
|
387
|
+
//
|
388
|
+
// samplers
|
389
|
+
//
|
390
|
+
|
391
|
+
// greedy
|
392
|
+
|
393
|
+
static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
|
394
|
+
return "greedy";
|
395
|
+
}
|
396
|
+
|
397
|
+
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
398
|
+
cur_p->selected = 0;
|
399
|
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
400
|
+
if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
|
401
|
+
cur_p->selected = i;
|
402
|
+
}
|
403
|
+
}
|
404
|
+
}
|
405
|
+
|
406
|
+
static struct llama_sampler_i llama_sampler_greedy_i = {
|
407
|
+
/* .name = */ llama_sampler_greedy_name,
|
408
|
+
/* .accept = */ nullptr,
|
409
|
+
/* .apply = */ llama_sampler_greedy_apply,
|
410
|
+
/* .reset = */ nullptr,
|
411
|
+
/* .clone = */ nullptr,
|
412
|
+
/* .free = */ nullptr,
|
413
|
+
};
|
414
|
+
|
415
|
+
struct llama_sampler * llama_sampler_init_greedy() {
|
416
|
+
return new llama_sampler {
|
417
|
+
/* .iface = */ &llama_sampler_greedy_i,
|
418
|
+
/* .ctx = */ nullptr,
|
419
|
+
};
|
420
|
+
}
|
421
|
+
|
422
|
+
// dist
|
423
|
+
|
424
|
+
struct llama_sampler_dist {
|
425
|
+
const uint32_t seed;
|
426
|
+
uint32_t seed_cur;
|
427
|
+
|
428
|
+
std::mt19937 rng;
|
429
|
+
};
|
430
|
+
|
431
|
+
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
|
432
|
+
return "dist";
|
433
|
+
}
|
434
|
+
|
435
|
+
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
436
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
437
|
+
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
438
|
+
}
|
439
|
+
|
440
|
+
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
441
|
+
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
442
|
+
auto * result = llama_sampler_init_dist(ctx->seed);
|
443
|
+
|
444
|
+
// copy the state
|
445
|
+
{
|
446
|
+
auto * result_ctx = (llama_sampler_dist *) result->ctx;
|
447
|
+
|
448
|
+
result_ctx->rng = ctx->rng;
|
449
|
+
}
|
450
|
+
|
451
|
+
return result;
|
452
|
+
}
|
453
|
+
|
454
|
+
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
455
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
456
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
457
|
+
ctx->rng.seed(ctx->seed_cur);
|
458
|
+
}
|
459
|
+
|
460
|
+
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
461
|
+
delete (llama_sampler_dist *) smpl->ctx;
|
462
|
+
}
|
463
|
+
|
464
|
+
static struct llama_sampler_i llama_sampler_dist_i = {
|
465
|
+
/* .name = */ llama_sampler_dist_name,
|
466
|
+
/* .accept = */ nullptr,
|
467
|
+
/* .apply = */ llama_sampler_dist_apply,
|
468
|
+
/* .reset = */ llama_sampler_dist_reset,
|
469
|
+
/* .clone = */ llama_sampler_dist_clone,
|
470
|
+
/* .free = */ llama_sampler_dist_free,
|
471
|
+
};
|
472
|
+
|
473
|
+
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
474
|
+
auto seed_cur = get_rng_seed(seed);
|
475
|
+
return new llama_sampler {
|
476
|
+
/* .iface = */ &llama_sampler_dist_i,
|
477
|
+
/* .ctx = */ new llama_sampler_dist {
|
478
|
+
/* .seed = */ seed,
|
479
|
+
/* .seed_cur = */ seed_cur,
|
480
|
+
/* .rng = */ std::mt19937(seed_cur),
|
481
|
+
},
|
482
|
+
};
|
483
|
+
}
|
484
|
+
|
485
|
+
// softmax
|
486
|
+
|
487
|
+
static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
|
488
|
+
return "softmax";
|
489
|
+
}
|
490
|
+
|
491
|
+
static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
492
|
+
llama_sampler_softmax_impl(cur_p);
|
493
|
+
}
|
494
|
+
|
495
|
+
static struct llama_sampler_i llama_sampler_softmax_i = {
|
496
|
+
/* .name = */ llama_sampler_softmax_name,
|
497
|
+
/* .accept = */ nullptr,
|
498
|
+
/* .apply = */ llama_sampler_softmax_apply,
|
499
|
+
/* .reset = */ nullptr,
|
500
|
+
/* .clone = */ nullptr,
|
501
|
+
/* .free = */ nullptr,
|
502
|
+
};
|
503
|
+
|
504
|
+
struct llama_sampler * llama_sampler_init_softmax() {
|
505
|
+
return new llama_sampler {
|
506
|
+
/* .iface = */ &llama_sampler_softmax_i,
|
507
|
+
/* .ctx = */ nullptr,
|
508
|
+
};
|
509
|
+
}
|
510
|
+
|
511
|
+
// top-k
|
512
|
+
|
513
|
+
struct llama_sampler_top_k {
|
514
|
+
const int32_t k;
|
515
|
+
};
|
516
|
+
|
517
|
+
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
|
518
|
+
return "top-k";
|
519
|
+
}
|
520
|
+
|
521
|
+
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
522
|
+
const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
|
523
|
+
llama_sampler_top_k_impl(cur_p, ctx->k);
|
524
|
+
}
|
525
|
+
|
526
|
+
static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
|
527
|
+
const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
|
528
|
+
return llama_sampler_init_top_k(ctx->k);
|
529
|
+
}
|
530
|
+
|
531
|
+
static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
532
|
+
delete (llama_sampler_top_k *) smpl->ctx;
|
533
|
+
}
|
534
|
+
|
535
|
+
static struct llama_sampler_i llama_sampler_top_k_i = {
|
536
|
+
/* .name = */ llama_sampler_top_k_name,
|
537
|
+
/* .accept = */ nullptr,
|
538
|
+
/* .apply = */ llama_sampler_top_k_apply,
|
539
|
+
/* .reset = */ nullptr,
|
540
|
+
/* .clone = */ llama_sampler_top_k_clone,
|
541
|
+
/* .free = */ llama_sampler_top_k_free,
|
542
|
+
};
|
543
|
+
|
544
|
+
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
545
|
+
return new llama_sampler {
|
546
|
+
/* .iface = */ &llama_sampler_top_k_i,
|
547
|
+
/* .ctx = */ new llama_sampler_top_k {
|
548
|
+
/* .k = */ k,
|
549
|
+
},
|
550
|
+
};
|
551
|
+
}
|
552
|
+
|
553
|
+
// top-p
|
554
|
+
|
555
|
+
struct llama_sampler_top_p {
|
556
|
+
const float p;
|
557
|
+
const size_t min_keep;
|
558
|
+
};
|
559
|
+
|
560
|
+
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
|
561
|
+
return "top-p";
|
562
|
+
}
|
563
|
+
|
564
|
+
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
565
|
+
const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
|
566
|
+
|
567
|
+
if (ctx->p >= 1.0f) {
|
568
|
+
return;
|
569
|
+
}
|
570
|
+
|
571
|
+
llama_sampler_softmax_impl(cur_p);
|
150
572
|
|
151
573
|
// Compute the cumulative probabilities
|
152
574
|
float cum_sum = 0.0f;
|
153
|
-
size_t last_idx =
|
575
|
+
size_t last_idx = cur_p->size;
|
154
576
|
|
155
|
-
for (size_t i = 0; i <
|
156
|
-
cum_sum +=
|
577
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
578
|
+
cum_sum += cur_p->data[i].p;
|
157
579
|
|
158
580
|
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
159
581
|
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
160
|
-
if (cum_sum >= p && i + 1 >= min_keep) {
|
582
|
+
if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
|
161
583
|
last_idx = i + 1;
|
162
584
|
break;
|
163
585
|
}
|
164
586
|
}
|
165
587
|
|
166
588
|
// Resize the output vector to keep only the top-p tokens
|
167
|
-
|
589
|
+
cur_p->size = last_idx;
|
590
|
+
}
|
591
|
+
|
592
|
+
static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
|
593
|
+
const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
|
594
|
+
return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
|
595
|
+
}
|
596
|
+
|
597
|
+
static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
598
|
+
delete (llama_sampler_top_p *) smpl->ctx;
|
599
|
+
}
|
600
|
+
|
601
|
+
static struct llama_sampler_i llama_sampler_top_p_i = {
|
602
|
+
/* .name = */ llama_sampler_top_p_name,
|
603
|
+
/* .accept = */ nullptr,
|
604
|
+
/* .apply = */ llama_sampler_top_p_apply,
|
605
|
+
/* .reset = */ nullptr,
|
606
|
+
/* .clone = */ llama_sampler_top_p_clone,
|
607
|
+
/* .free = */ llama_sampler_top_p_free,
|
608
|
+
};
|
609
|
+
|
610
|
+
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
611
|
+
return new llama_sampler {
|
612
|
+
/* .iface = */ &llama_sampler_top_p_i,
|
613
|
+
/* .ctx = */ new llama_sampler_top_p {
|
614
|
+
/* .p = */ p,
|
615
|
+
/* .min_keep = */ min_keep,
|
616
|
+
},
|
617
|
+
};
|
618
|
+
}
|
619
|
+
|
620
|
+
// min-p
|
621
|
+
|
622
|
+
struct llama_sampler_min_p {
|
623
|
+
const float p;
|
624
|
+
const size_t min_keep;
|
625
|
+
};
|
626
|
+
|
627
|
+
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
|
628
|
+
return "min-p";
|
629
|
+
}
|
630
|
+
|
631
|
+
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
632
|
+
const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
|
633
|
+
|
634
|
+
if (ctx->p <= 0.0f || !cur_p->size) {
|
635
|
+
return;
|
636
|
+
}
|
637
|
+
|
638
|
+
bool min_p_applied = false;
|
639
|
+
|
640
|
+
// if the cur_p aren't sorted, try the unsorted implementation first
|
641
|
+
if (!cur_p->sorted) {
|
642
|
+
std::vector<llama_token_data> filtered_tokens;
|
643
|
+
|
644
|
+
float max_logit = -FLT_MAX;
|
645
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
646
|
+
max_logit = std::max(max_logit, cur_p->data[i].logit);
|
647
|
+
}
|
648
|
+
const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
649
|
+
|
650
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
651
|
+
if (cur_p->data[i].logit >= min_logit) {
|
652
|
+
filtered_tokens.push_back(cur_p->data[i]);
|
653
|
+
}
|
654
|
+
}
|
655
|
+
|
656
|
+
// if we have enough values the operation was a success
|
657
|
+
if (filtered_tokens.size() >= ctx->min_keep) {
|
658
|
+
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
659
|
+
cur_p->size = filtered_tokens.size();
|
660
|
+
min_p_applied = true;
|
661
|
+
}
|
662
|
+
}
|
663
|
+
|
664
|
+
// if the cur_p are sorted or the unsorted implementation failed, use this implementation
|
665
|
+
if (!min_p_applied) {
|
666
|
+
// Sort the logits in descending order
|
667
|
+
if (!cur_p->sorted) {
|
668
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
669
|
+
return a.logit > b.logit;
|
670
|
+
});
|
671
|
+
cur_p->sorted = true;
|
672
|
+
}
|
168
673
|
|
169
|
-
|
170
|
-
|
674
|
+
const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
675
|
+
size_t i = 1; // first token always matches
|
676
|
+
|
677
|
+
for (; i < cur_p->size; ++i) {
|
678
|
+
if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
|
679
|
+
break; // prob too small
|
680
|
+
}
|
681
|
+
}
|
682
|
+
|
683
|
+
// Resize the output vector to keep only the matching tokens
|
684
|
+
cur_p->size = i;
|
171
685
|
}
|
172
686
|
}
|
173
687
|
|
174
|
-
|
175
|
-
|
688
|
+
static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
|
689
|
+
const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
|
690
|
+
return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
|
691
|
+
}
|
692
|
+
|
693
|
+
static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
694
|
+
delete (llama_sampler_min_p *) smpl->ctx;
|
695
|
+
}
|
696
|
+
|
697
|
+
static struct llama_sampler_i llama_sampler_min_p_i = {
|
698
|
+
/* .name = */ llama_sampler_min_p_name,
|
699
|
+
/* .accept = */ nullptr,
|
700
|
+
/* .apply = */ llama_sampler_min_p_apply,
|
701
|
+
/* .reset = */ nullptr,
|
702
|
+
/* .clone = */ llama_sampler_min_p_clone,
|
703
|
+
/* .free = */ llama_sampler_min_p_free,
|
704
|
+
};
|
705
|
+
|
706
|
+
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
707
|
+
return new llama_sampler {
|
708
|
+
/* .iface = */ &llama_sampler_min_p_i,
|
709
|
+
/* .ctx = */ new llama_sampler_min_p {
|
710
|
+
/* .p = */ p,
|
711
|
+
/* .min_keep = */ min_keep,
|
712
|
+
},
|
713
|
+
};
|
714
|
+
}
|
715
|
+
|
716
|
+
// xtc
|
717
|
+
|
718
|
+
struct llama_sampler_xtc {
|
719
|
+
const uint32_t seed;
|
720
|
+
std::mt19937 rng;
|
721
|
+
const float xtc_p;
|
722
|
+
const float xtc_t;
|
723
|
+
const size_t min_keep;
|
724
|
+
};
|
725
|
+
|
726
|
+
static const char * llama_sampler_xtc_name(const struct llama_sampler * /* smpl */) {
|
727
|
+
return "xtc";
|
728
|
+
}
|
729
|
+
|
730
|
+
static void llama_sampler_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
731
|
+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
732
|
+
|
733
|
+
size_t min_keep = ctx -> min_keep;
|
734
|
+
std::mt19937 rng = ctx -> rng;
|
735
|
+
|
736
|
+
float xtc_threshold = ctx -> xtc_t;
|
737
|
+
float xtc_probability = ctx -> xtc_p;
|
738
|
+
|
739
|
+
|
740
|
+
if(xtc_threshold <= 0.0f || !cur_p-> size) {
|
176
741
|
return;
|
177
742
|
}
|
178
743
|
|
179
744
|
bool xtc_applied = false;
|
180
745
|
const int64_t t_start_sample_us = lm_ggml_time_us();
|
181
|
-
|
746
|
+
llama_sampler_softmax_impl(cur_p);
|
182
747
|
|
183
748
|
// unsorted iteration
|
184
|
-
if (!
|
749
|
+
if (!cur_p->sorted) {
|
185
750
|
std::vector<llama_token_data> top_tokens, low_tokens;
|
186
751
|
|
187
752
|
// split candidates into two arrays for low and high tokens
|
188
|
-
for (size_t i = 0; i <
|
189
|
-
if (
|
190
|
-
top_tokens.push_back(
|
753
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
754
|
+
if (cur_p->data[i].logit >= xtc_threshold) {
|
755
|
+
top_tokens.push_back(cur_p->data[i]);
|
191
756
|
} else {
|
192
|
-
low_tokens.push_back(
|
757
|
+
low_tokens.push_back(cur_p-> data[i]);
|
193
758
|
}
|
194
759
|
}
|
195
760
|
// if there is only one or no top_tokens, do not truncate
|
@@ -212,28 +777,28 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
|
|
212
777
|
}
|
213
778
|
}
|
214
779
|
if(low_tokens.size() >= min_keep) {
|
215
|
-
memcpy(
|
216
|
-
|
780
|
+
memcpy(cur_p->data, low_tokens.data(), low_tokens.size()*sizeof(llama_token_data));
|
781
|
+
cur_p->size = low_tokens.size();
|
217
782
|
xtc_applied = true;
|
218
783
|
}
|
219
784
|
}
|
220
785
|
// sorted iteration
|
221
|
-
|
786
|
+
|
222
787
|
if (!xtc_applied) {
|
223
788
|
// Sort the logits in descending order
|
224
|
-
if (!
|
225
|
-
std::sort(
|
789
|
+
if (!cur_p->sorted) {
|
790
|
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
226
791
|
return a.logit > b.logit;
|
227
792
|
});
|
228
|
-
|
793
|
+
cur_p->sorted = true;
|
229
794
|
}
|
230
795
|
|
231
796
|
// find last token over threshold
|
232
797
|
|
233
798
|
size_t last_index = 0;
|
234
799
|
|
235
|
-
for (; last_index <
|
236
|
-
if(
|
800
|
+
for (; last_index < cur_p -> size; ++last_index) {
|
801
|
+
if(cur_p -> data[last_index].p < xtc_threshold) {
|
237
802
|
break;
|
238
803
|
}
|
239
804
|
}
|
@@ -244,102 +809,80 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
|
|
244
809
|
}
|
245
810
|
last_index--;
|
246
811
|
// items beyond safe index will be ignored
|
247
|
-
size_t safe_index =
|
812
|
+
size_t safe_index = cur_p -> size;
|
248
813
|
|
249
814
|
// remove tokens until last threshold item
|
250
815
|
std::uniform_real_distribution<float> random_float(0.0 , 1.0);
|
251
816
|
for (size_t i = 0; i < last_index; i++) {
|
252
817
|
if(random_float(rng) < xtc_probability) {
|
253
|
-
std::swap(
|
818
|
+
std::swap(cur_p-> data[i], cur_p->data[safe_index - 1]);
|
254
819
|
safe_index--;
|
255
|
-
if (
|
256
|
-
|
820
|
+
if (cur_p-> sorted) {
|
821
|
+
cur_p -> sorted = false;
|
257
822
|
}
|
258
823
|
}
|
259
824
|
}
|
260
|
-
|
825
|
+
cur_p -> size = safe_index;
|
261
826
|
}
|
827
|
+
}
|
262
828
|
|
263
|
-
|
264
|
-
|
265
|
-
|
829
|
+
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
830
|
+
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
831
|
+
return llama_sampler_init_xtc(ctx->xtc_p, ctx->xtc_t, ctx->min_keep, ctx->seed);
|
266
832
|
}
|
267
833
|
|
268
|
-
void
|
269
|
-
|
270
|
-
|
271
|
-
}
|
834
|
+
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
835
|
+
delete (const llama_sampler_xtc *) smpl->ctx;
|
836
|
+
}
|
272
837
|
|
273
|
-
|
838
|
+
static struct llama_sampler_i llama_sampler_xtc_i = {
|
839
|
+
/* .name = */ llama_sampler_xtc_name,
|
840
|
+
/* .accept = */ nullptr,
|
841
|
+
/* .apply = */ llama_sampler_xtc_apply,
|
842
|
+
/* .reset = */ nullptr,
|
843
|
+
/* .clone = */ llama_sampler_xtc_clone,
|
844
|
+
/* .free = */ llama_sampler_xtc_free,
|
845
|
+
};
|
846
|
+
|
847
|
+
struct llama_sampler * llama_sampler_init_xtc(float xtc_p, float xtc_t, size_t min_keep, uint32_t seed) {
|
848
|
+
return new llama_sampler {
|
849
|
+
/* .iface = */ &llama_sampler_xtc_i,
|
850
|
+
/* .ctx = */ new llama_sampler_xtc {
|
851
|
+
/* .seed = */ seed,
|
852
|
+
/* .rng = */ std::mt19937(seed),
|
853
|
+
/* .xtc_p = */ xtc_p,
|
854
|
+
/* .xtc_t = */ xtc_t,
|
855
|
+
/* .min_keep = */ min_keep
|
856
|
+
},
|
857
|
+
};
|
858
|
+
}
|
274
859
|
|
275
|
-
|
860
|
+
// tail-free
|
276
861
|
|
277
|
-
|
278
|
-
|
279
|
-
|
862
|
+
struct llama_sampler_tail_free {
|
863
|
+
const float z;
|
864
|
+
const size_t min_keep;
|
865
|
+
};
|
280
866
|
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
}
|
285
|
-
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
867
|
+
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
|
868
|
+
return "tail-free";
|
869
|
+
}
|
286
870
|
|
287
|
-
|
288
|
-
|
289
|
-
filtered_tokens.push_back(candidates->data[i]);
|
290
|
-
}
|
291
|
-
}
|
871
|
+
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
872
|
+
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
|
292
873
|
|
293
|
-
|
294
|
-
|
295
|
-
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
296
|
-
candidates->size = filtered_tokens.size();
|
297
|
-
min_p_applied = true;
|
298
|
-
}
|
874
|
+
if (ctx->z >= 1.0f || cur_p->size <= 2) {
|
875
|
+
return;
|
299
876
|
}
|
300
877
|
|
301
|
-
|
302
|
-
if (!min_p_applied) {
|
303
|
-
// Sort the logits in descending order
|
304
|
-
if (!candidates->sorted) {
|
305
|
-
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
306
|
-
return a.logit > b.logit;
|
307
|
-
});
|
308
|
-
candidates->sorted = true;
|
309
|
-
}
|
878
|
+
llama_sampler_softmax_impl(cur_p);
|
310
879
|
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
for (; i < candidates->size; ++i) {
|
315
|
-
if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
316
|
-
break; // prob too small
|
317
|
-
}
|
318
|
-
}
|
319
|
-
|
320
|
-
// Resize the output vector to keep only the matching tokens
|
321
|
-
candidates->size = i;
|
322
|
-
}
|
323
|
-
|
324
|
-
if (smpl) {
|
325
|
-
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
326
|
-
}
|
327
|
-
}
|
328
|
-
|
329
|
-
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
330
|
-
if (z >= 1.0f || candidates->size <= 2) {
|
331
|
-
return;
|
332
|
-
}
|
333
|
-
|
334
|
-
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
335
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
336
|
-
|
337
|
-
// Compute the first and second derivatives
|
338
|
-
std::vector<float> first_derivatives(candidates->size - 1);
|
339
|
-
std::vector<float> second_derivatives(candidates->size - 2);
|
880
|
+
// Compute the first and second derivatives
|
881
|
+
std::vector<float> first_derivatives(cur_p->size - 1);
|
882
|
+
std::vector<float> second_derivatives(cur_p->size - 2);
|
340
883
|
|
341
884
|
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
342
|
-
first_derivatives[i] =
|
885
|
+
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
|
343
886
|
}
|
344
887
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
345
888
|
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
@@ -366,51 +909,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
|
366
909
|
}
|
367
910
|
|
368
911
|
float cum_sum = 0.0f;
|
369
|
-
size_t last_idx =
|
912
|
+
size_t last_idx = cur_p->size;
|
370
913
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
371
914
|
cum_sum += second_derivatives[i];
|
372
915
|
|
373
916
|
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
374
|
-
if (cum_sum > z && i >= min_keep) {
|
917
|
+
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
375
918
|
last_idx = i;
|
376
919
|
break;
|
377
920
|
}
|
378
921
|
}
|
379
922
|
|
380
923
|
// Resize the output vector to keep only the tokens above the tail location
|
381
|
-
|
924
|
+
cur_p->size = last_idx;
|
925
|
+
}
|
382
926
|
|
383
|
-
|
384
|
-
|
385
|
-
|
927
|
+
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
928
|
+
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
|
929
|
+
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
930
|
+
}
|
931
|
+
|
932
|
+
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
933
|
+
delete (llama_sampler_tail_free *) smpl->ctx;
|
934
|
+
}
|
935
|
+
|
936
|
+
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
937
|
+
/* .name = */ llama_sampler_tail_free_name,
|
938
|
+
/* .accept = */ nullptr,
|
939
|
+
/* .apply = */ llama_sampler_tail_free_apply,
|
940
|
+
/* .reset = */ nullptr,
|
941
|
+
/* .clone = */ llama_sampler_tail_free_clone,
|
942
|
+
/* .free = */ llama_sampler_tail_free_free,
|
943
|
+
};
|
944
|
+
|
945
|
+
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
946
|
+
return new llama_sampler {
|
947
|
+
/* .iface = */ &llama_sampler_tail_free_i,
|
948
|
+
/* .ctx = */ new llama_sampler_tail_free {
|
949
|
+
/* .z = */ z,
|
950
|
+
/*. min_keep = */ min_keep,
|
951
|
+
},
|
952
|
+
};
|
953
|
+
}
|
954
|
+
|
955
|
+
// typical
|
956
|
+
|
957
|
+
struct llama_sampler_typical {
|
958
|
+
const float p;
|
959
|
+
const size_t min_keep;
|
960
|
+
};
|
961
|
+
|
962
|
+
static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
|
963
|
+
return "typical";
|
386
964
|
}
|
387
965
|
|
388
|
-
void
|
966
|
+
static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
967
|
+
const auto * ctx = (llama_sampler_typical *) smpl->ctx;
|
968
|
+
|
389
969
|
// Reference implementation:
|
390
970
|
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
391
|
-
if (p >= 1.0f) {
|
971
|
+
if (ctx->p >= 1.0f) {
|
392
972
|
return;
|
393
973
|
}
|
394
974
|
|
395
975
|
// Compute the softmax of logits and calculate entropy
|
396
|
-
|
397
|
-
|
398
|
-
const int64_t t_start_sample_us = lm_ggml_time_us();
|
976
|
+
llama_sampler_softmax_impl(cur_p);
|
399
977
|
|
400
978
|
float entropy = 0.0f;
|
401
|
-
for (size_t i = 0; i <
|
402
|
-
entropy += -
|
979
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
980
|
+
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
|
403
981
|
}
|
404
982
|
|
405
983
|
// Compute the absolute difference between negative log probability and entropy for each candidate
|
406
984
|
std::vector<float> shifted_scores;
|
407
|
-
for (size_t i = 0; i <
|
408
|
-
float shifted_score = fabsf(-logf(
|
985
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
986
|
+
float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
|
409
987
|
shifted_scores.push_back(shifted_score);
|
410
988
|
}
|
411
989
|
|
412
990
|
// Sort tokens based on the shifted_scores and their corresponding indices
|
413
|
-
std::vector<size_t> indices(
|
991
|
+
std::vector<size_t> indices(cur_p->size);
|
414
992
|
std::iota(indices.begin(), indices.end(), 0);
|
415
993
|
|
416
994
|
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
@@ -423,134 +1001,618 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|
423
1001
|
|
424
1002
|
for (size_t i = 0; i < indices.size(); ++i) {
|
425
1003
|
size_t idx = indices[i];
|
426
|
-
cum_sum +=
|
1004
|
+
cum_sum += cur_p->data[idx].p;
|
427
1005
|
|
428
1006
|
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
429
|
-
if (cum_sum > p && i >= min_keep - 1) {
|
1007
|
+
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
|
430
1008
|
last_idx = i + 1;
|
431
1009
|
break;
|
432
1010
|
}
|
433
1011
|
}
|
434
1012
|
|
435
1013
|
// Resize the output vector to keep only the locally typical tokens
|
436
|
-
std::vector<llama_token_data>
|
1014
|
+
std::vector<llama_token_data> cur_p_new;
|
437
1015
|
for (size_t i = 0; i < last_idx; ++i) {
|
438
1016
|
size_t idx = indices[i];
|
439
|
-
|
1017
|
+
cur_p_new.push_back(cur_p->data[idx]);
|
440
1018
|
}
|
441
1019
|
|
442
|
-
// Replace the data in
|
443
|
-
std::copy(
|
444
|
-
|
445
|
-
|
1020
|
+
// Replace the data in cur_p with the cur_p_new data
|
1021
|
+
std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
|
1022
|
+
cur_p->size = cur_p_new.size();
|
1023
|
+
cur_p->sorted = false;
|
1024
|
+
}
|
446
1025
|
|
447
|
-
|
448
|
-
|
449
|
-
|
1026
|
+
static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
|
1027
|
+
const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
|
1028
|
+
return llama_sampler_init_typical(ctx->p, ctx->min_keep);
|
450
1029
|
}
|
451
1030
|
|
452
|
-
void
|
453
|
-
|
1031
|
+
static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
1032
|
+
delete (llama_sampler_typical *) smpl->ctx;
|
1033
|
+
}
|
454
1034
|
|
455
|
-
|
456
|
-
|
457
|
-
|
1035
|
+
static struct llama_sampler_i llama_sampler_typical_i = {
|
1036
|
+
/* .name = */ llama_sampler_typical_name,
|
1037
|
+
/* .accept = */ nullptr,
|
1038
|
+
/* .apply = */ llama_sampler_typical_apply,
|
1039
|
+
/* .reset = */ nullptr,
|
1040
|
+
/* .clone = */ llama_sampler_typical_clone,
|
1041
|
+
/* .free = */ llama_sampler_typical_free,
|
1042
|
+
};
|
1043
|
+
|
1044
|
+
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
1045
|
+
return new llama_sampler {
|
1046
|
+
/* .iface = */ &llama_sampler_typical_i,
|
1047
|
+
/* .ctx = */ new llama_sampler_typical {
|
1048
|
+
/* .p = */ p,
|
1049
|
+
/* .min_keep = */ min_keep,
|
1050
|
+
},
|
1051
|
+
};
|
1052
|
+
}
|
1053
|
+
|
1054
|
+
// temp
|
1055
|
+
|
1056
|
+
struct llama_sampler_temp {
|
1057
|
+
const float temp;
|
1058
|
+
};
|
1059
|
+
|
1060
|
+
static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
|
1061
|
+
return "temp";
|
1062
|
+
}
|
1063
|
+
|
1064
|
+
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1065
|
+
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
1066
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1067
|
+
cur_p->data[i].logit /= ctx->temp;
|
458
1068
|
}
|
1069
|
+
}
|
459
1070
|
|
460
|
-
|
461
|
-
|
1071
|
+
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
1072
|
+
const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
|
1073
|
+
return llama_sampler_init_temp(ctx->temp);
|
1074
|
+
}
|
462
1075
|
|
463
|
-
|
1076
|
+
static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
1077
|
+
delete (llama_sampler_temp *) smpl->ctx;
|
1078
|
+
}
|
464
1079
|
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
1080
|
+
static struct llama_sampler_i llama_sampler_temp_i = {
|
1081
|
+
/* .name = */ llama_sampler_temp_name,
|
1082
|
+
/* .accept = */ nullptr,
|
1083
|
+
/* .apply = */ llama_sampler_temp_apply,
|
1084
|
+
/* .reset = */ nullptr,
|
1085
|
+
/* .clone = */ llama_sampler_temp_clone,
|
1086
|
+
/* .free = */ llama_sampler_temp_free,
|
1087
|
+
};
|
1088
|
+
|
1089
|
+
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
1090
|
+
return new llama_sampler {
|
1091
|
+
/* .iface = */ &llama_sampler_temp_i,
|
1092
|
+
/* .ctx = */ new llama_sampler_temp {
|
1093
|
+
/*.temp = */ temp,
|
1094
|
+
},
|
1095
|
+
};
|
1096
|
+
}
|
1097
|
+
|
1098
|
+
// temp-ext
|
1099
|
+
|
1100
|
+
struct llama_sampler_temp_ext {
|
1101
|
+
const float temp;
|
1102
|
+
const float delta;
|
1103
|
+
const float exponent;
|
1104
|
+
};
|
1105
|
+
|
1106
|
+
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
|
1107
|
+
return "temp-ext";
|
1108
|
+
}
|
1109
|
+
|
1110
|
+
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1111
|
+
const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
|
1112
|
+
if (ctx->delta > 0) {
|
1113
|
+
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
1114
|
+
const float max_temp = ctx->temp + ctx->delta;
|
1115
|
+
float exponent_val = ctx->exponent;
|
1116
|
+
|
1117
|
+
// no need to do anything if there is only one (or zero) candidates
|
1118
|
+
if (cur_p->size <= 1) {
|
1119
|
+
return;
|
1120
|
+
}
|
1121
|
+
|
1122
|
+
// Calculate maximum possible entropy
|
1123
|
+
float max_entropy = -logf(1.0f / cur_p->size);
|
1124
|
+
|
1125
|
+
llama_sampler_softmax_impl(cur_p);
|
1126
|
+
|
1127
|
+
// Calculate entropy of the softmax probabilities
|
1128
|
+
float entropy = 0.0f;
|
1129
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1130
|
+
float prob = cur_p->data[i].p;
|
1131
|
+
if (prob > 0.0f) { // Ensure no log(0)
|
1132
|
+
entropy -= prob * logf(prob);
|
1133
|
+
}
|
1134
|
+
}
|
1135
|
+
|
1136
|
+
// Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
|
1137
|
+
float normalized_entropy = entropy / max_entropy;
|
1138
|
+
|
1139
|
+
// Map the normalized entropy to the desired temperature range using the power function
|
1140
|
+
float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
|
1141
|
+
|
1142
|
+
#ifdef DEBUG
|
1143
|
+
LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
|
1144
|
+
LLAMA_LOG_INFO("Entropy: %f\n", entropy);
|
1145
|
+
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
1146
|
+
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
1147
|
+
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
1148
|
+
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
1149
|
+
#endif
|
1150
|
+
|
1151
|
+
// Apply the dynamically calculated temperature scaling
|
1152
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1153
|
+
cur_p->data[i].logit /= dyn_temp;
|
1154
|
+
}
|
1155
|
+
|
1156
|
+
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
1157
|
+
const double max_l_double = cur_p->data[0].logit;
|
1158
|
+
|
1159
|
+
double cum_sum_double = 0.0;
|
1160
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1161
|
+
double p = exp(cur_p->data[i].logit - max_l_double);
|
1162
|
+
cur_p->data[i].p = p; // Store the scaled probability
|
1163
|
+
cum_sum_double += p;
|
1164
|
+
}
|
1165
|
+
|
1166
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1167
|
+
cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
1168
|
+
}
|
1169
|
+
|
1170
|
+
#ifdef DEBUG
|
1171
|
+
// Print the updated top 25 probabilities after temperature scaling
|
1172
|
+
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
|
1173
|
+
for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
|
1174
|
+
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
|
1175
|
+
}
|
1176
|
+
#endif
|
1177
|
+
} else {
|
1178
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1179
|
+
cur_p->data[i].logit /= ctx->temp;
|
471
1180
|
}
|
472
1181
|
}
|
1182
|
+
}
|
473
1183
|
|
474
|
-
|
475
|
-
|
1184
|
+
static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
|
1185
|
+
const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
|
1186
|
+
return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
|
1187
|
+
}
|
476
1188
|
|
477
|
-
|
478
|
-
|
1189
|
+
static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
1190
|
+
delete (llama_sampler_temp_ext *) smpl->ctx;
|
1191
|
+
}
|
479
1192
|
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
1193
|
+
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
1194
|
+
/* .name = */ llama_sampler_temp_ext_name,
|
1195
|
+
/* .accept = */ nullptr,
|
1196
|
+
/* .apply = */ llama_sampler_temp_ext_apply,
|
1197
|
+
/* .reset = */ nullptr,
|
1198
|
+
/* .clone = */ llama_sampler_temp_ext_clone,
|
1199
|
+
/* .free = */ llama_sampler_temp_ext_free,
|
1200
|
+
};
|
1201
|
+
|
1202
|
+
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
1203
|
+
return new llama_sampler {
|
1204
|
+
/* .iface = */ &llama_sampler_temp_ext_i,
|
1205
|
+
/* .ctx = */ new llama_sampler_temp_ext {
|
1206
|
+
/* .temp = */ temp,
|
1207
|
+
/* .delta = */ delta,
|
1208
|
+
/* .exponent = */ exponent,
|
1209
|
+
},
|
1210
|
+
};
|
1211
|
+
}
|
1212
|
+
|
1213
|
+
// mirostat
|
1214
|
+
|
1215
|
+
struct llama_sampler_mirostat {
|
1216
|
+
const int32_t n_vocab;
|
1217
|
+
|
1218
|
+
const uint32_t seed;
|
1219
|
+
uint32_t seed_cur;
|
1220
|
+
|
1221
|
+
const float tau;
|
1222
|
+
const float eta;
|
1223
|
+
|
1224
|
+
const int32_t m;
|
1225
|
+
|
1226
|
+
float mu;
|
1227
|
+
|
1228
|
+
std::mt19937 rng;
|
1229
|
+
};
|
1230
|
+
|
1231
|
+
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
|
1232
|
+
return "mirostat";
|
1233
|
+
}
|
1234
|
+
|
1235
|
+
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1236
|
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
1237
|
+
|
1238
|
+
llama_sampler_softmax_impl(cur_p);
|
1239
|
+
|
1240
|
+
// Estimate s_hat using the most probable m tokens
|
1241
|
+
float s_hat = 0.0;
|
1242
|
+
float sum_ti_bi = 0.0;
|
1243
|
+
float sum_ti_sq = 0.0;
|
1244
|
+
for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
|
1245
|
+
float t_i = logf(float(i + 2) / float(i + 1));
|
1246
|
+
float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
|
1247
|
+
sum_ti_bi += t_i * b_i;
|
1248
|
+
sum_ti_sq += t_i * t_i;
|
1249
|
+
}
|
1250
|
+
s_hat = sum_ti_bi / sum_ti_sq;
|
1251
|
+
|
1252
|
+
// Compute k from the estimated s_hat and target surprise value
|
1253
|
+
float epsilon_hat = s_hat - 1;
|
1254
|
+
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
|
1255
|
+
|
1256
|
+
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
1257
|
+
llama_sampler_softmax_impl(cur_p);
|
1258
|
+
|
1259
|
+
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
1260
|
+
|
1261
|
+
cur_p->selected = idx;
|
1262
|
+
|
1263
|
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
1264
|
+
float e = observed_surprise - ctx->tau;
|
1265
|
+
|
1266
|
+
// Update mu using the learning rate and error
|
1267
|
+
ctx->mu = ctx->mu - ctx->eta * e;
|
1268
|
+
}
|
1269
|
+
|
1270
|
+
static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
|
1271
|
+
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
|
1272
|
+
auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
|
1273
|
+
|
1274
|
+
// copy the state
|
1275
|
+
{
|
1276
|
+
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
|
1277
|
+
|
1278
|
+
result_ctx->mu = ctx->mu;
|
1279
|
+
result_ctx->rng = ctx->rng;
|
1280
|
+
}
|
1281
|
+
|
1282
|
+
return result;
|
1283
|
+
}
|
1284
|
+
|
1285
|
+
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
|
1286
|
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
1287
|
+
ctx->mu = 2.0f*ctx->tau;
|
1288
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
1289
|
+
ctx->rng.seed(ctx->seed_cur);
|
1290
|
+
}
|
1291
|
+
|
1292
|
+
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
1293
|
+
delete (llama_sampler_mirostat *) smpl->ctx;
|
1294
|
+
}
|
1295
|
+
|
1296
|
+
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
1297
|
+
/* .name = */ llama_sampler_mirostat_name,
|
1298
|
+
/* .accept = */ nullptr,
|
1299
|
+
/* .apply = */ llama_sampler_mirostat_apply,
|
1300
|
+
/* .reset = */ llama_sampler_mirostat_reset,
|
1301
|
+
/* .clone = */ llama_sampler_mirostat_clone,
|
1302
|
+
/* .free = */ llama_sampler_mirostat_free,
|
1303
|
+
};
|
1304
|
+
|
1305
|
+
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
1306
|
+
auto seed_cur = get_rng_seed(seed);
|
1307
|
+
return new llama_sampler {
|
1308
|
+
/* .iface = */ &llama_sampler_mirostat_i,
|
1309
|
+
/* .ctx = */ new llama_sampler_mirostat {
|
1310
|
+
/* .n_vocab = */ n_vocab,
|
1311
|
+
/* .seed = */ seed,
|
1312
|
+
/* .seed_cur = */ seed_cur,
|
1313
|
+
/* .tau = */ tau,
|
1314
|
+
/* .eta = */ eta,
|
1315
|
+
/* .m = */ m,
|
1316
|
+
/* .mu = */ 2.0f*tau,
|
1317
|
+
/* .rng = */ std::mt19937(seed_cur),
|
1318
|
+
},
|
1319
|
+
};
|
1320
|
+
}
|
1321
|
+
|
1322
|
+
// mirostat v2
|
1323
|
+
|
1324
|
+
struct llama_sampler_mirostat_v2 {
|
1325
|
+
const uint32_t seed;
|
1326
|
+
uint32_t seed_cur;
|
1327
|
+
|
1328
|
+
const float tau;
|
1329
|
+
const float eta;
|
1330
|
+
|
1331
|
+
float mu;
|
1332
|
+
|
1333
|
+
std::mt19937 rng;
|
1334
|
+
};
|
1335
|
+
|
1336
|
+
static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
|
1337
|
+
return "mirostat-v2";
|
1338
|
+
}
|
1339
|
+
|
1340
|
+
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1341
|
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
1342
|
+
|
1343
|
+
llama_sampler_softmax_impl(cur_p);
|
1344
|
+
|
1345
|
+
// Truncate the words with surprise values greater than mu
|
1346
|
+
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
1347
|
+
return -log2f(candidate.p) > ctx->mu;
|
1348
|
+
}));
|
1349
|
+
|
1350
|
+
if (cur_p->size == 0) {
|
1351
|
+
cur_p->size = 1;
|
1352
|
+
}
|
488
1353
|
|
489
|
-
//
|
490
|
-
|
491
|
-
|
1354
|
+
// Normalize the probabilities of the remaining words
|
1355
|
+
llama_sampler_softmax_impl(cur_p);
|
1356
|
+
|
1357
|
+
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
1358
|
+
|
1359
|
+
cur_p->selected = idx;
|
1360
|
+
|
1361
|
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
1362
|
+
float e = observed_surprise - ctx->tau;
|
1363
|
+
|
1364
|
+
// Update mu using the learning rate and error
|
1365
|
+
ctx->mu = ctx->mu - ctx->eta * e;
|
1366
|
+
}
|
1367
|
+
|
1368
|
+
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
|
1369
|
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
1370
|
+
ctx->mu = 2.0f*ctx->tau;
|
1371
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
1372
|
+
ctx->rng.seed(ctx->seed_cur);
|
1373
|
+
}
|
1374
|
+
|
1375
|
+
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
|
1376
|
+
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
|
1377
|
+
|
1378
|
+
auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
|
1379
|
+
|
1380
|
+
// copy the state
|
1381
|
+
{
|
1382
|
+
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
|
1383
|
+
|
1384
|
+
result_ctx->mu = ctx->mu;
|
1385
|
+
result_ctx->rng = ctx->rng;
|
492
1386
|
}
|
493
1387
|
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
1388
|
+
return result;
|
1389
|
+
}
|
1390
|
+
|
1391
|
+
static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
1392
|
+
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
|
1393
|
+
}
|
1394
|
+
|
1395
|
+
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
1396
|
+
/* .name = */ llama_sampler_mirostat_v2_name,
|
1397
|
+
/* .accept = */ nullptr,
|
1398
|
+
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
1399
|
+
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
1400
|
+
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
1401
|
+
/* .free = */ llama_sampler_mirostat_v2_free,
|
1402
|
+
};
|
1403
|
+
|
1404
|
+
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
1405
|
+
auto seed_cur = get_rng_seed(seed);
|
1406
|
+
return new llama_sampler {
|
1407
|
+
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
1408
|
+
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
1409
|
+
/* .seed = */ seed,
|
1410
|
+
/* .seed_cur = */ seed_cur,
|
1411
|
+
/* .tau = */ tau,
|
1412
|
+
/* .eta = */ eta,
|
1413
|
+
/* .mu = */ 2.0f*tau,
|
1414
|
+
/* .rng = */ std::mt19937(seed_cur),
|
1415
|
+
},
|
1416
|
+
};
|
1417
|
+
}
|
1418
|
+
|
1419
|
+
// grammar
|
1420
|
+
|
1421
|
+
struct llama_sampler_grammar {
|
1422
|
+
const struct llama_vocab * vocab;
|
1423
|
+
|
1424
|
+
std::string grammar_str;
|
1425
|
+
std::string grammar_root;
|
1426
|
+
|
1427
|
+
struct llama_grammar * grammar;
|
1428
|
+
};
|
1429
|
+
|
1430
|
+
static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
|
1431
|
+
return "grammar";
|
1432
|
+
}
|
1433
|
+
|
1434
|
+
static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
|
1435
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
1436
|
+
if (ctx->grammar) {
|
1437
|
+
llama_grammar_accept_impl(*ctx->grammar, token);
|
501
1438
|
}
|
502
|
-
|
503
|
-
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1442
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
1443
|
+
if (ctx->grammar) {
|
1444
|
+
llama_grammar_apply_impl(*ctx->grammar, cur_p);
|
504
1445
|
}
|
1446
|
+
}
|
505
1447
|
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
|
1448
|
+
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
1449
|
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
1450
|
+
if (!ctx->grammar) {
|
1451
|
+
return;
|
511
1452
|
}
|
512
|
-
#endif
|
513
1453
|
|
514
|
-
|
515
|
-
|
1454
|
+
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
|
1455
|
+
|
1456
|
+
llama_grammar_free_impl(ctx->grammar);
|
1457
|
+
ctx->grammar = grammar_new;
|
1458
|
+
}
|
1459
|
+
|
1460
|
+
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
1461
|
+
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
1462
|
+
|
1463
|
+
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
|
1464
|
+
|
1465
|
+
// copy the state
|
1466
|
+
{
|
1467
|
+
auto * result_ctx = (llama_sampler_grammar *) result->ctx;
|
1468
|
+
|
1469
|
+
if (ctx->grammar) {
|
1470
|
+
result_ctx->grammar_str = ctx->grammar_str;
|
1471
|
+
result_ctx->grammar_root = ctx->grammar_root;
|
1472
|
+
|
1473
|
+
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
|
1474
|
+
}
|
516
1475
|
}
|
1476
|
+
|
1477
|
+
return result;
|
517
1478
|
}
|
518
1479
|
|
519
|
-
void
|
520
|
-
const
|
1480
|
+
static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
1481
|
+
const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
521
1482
|
|
522
|
-
|
523
|
-
|
1483
|
+
if (ctx->grammar) {
|
1484
|
+
llama_grammar_free_impl(ctx->grammar);
|
524
1485
|
}
|
525
1486
|
|
526
|
-
|
527
|
-
|
1487
|
+
delete ctx;
|
1488
|
+
}
|
1489
|
+
|
1490
|
+
static struct llama_sampler_i llama_sampler_grammar_i = {
|
1491
|
+
/* .name = */ llama_sampler_grammar_name,
|
1492
|
+
/* .accept = */ llama_sampler_grammar_accept_impl,
|
1493
|
+
/* .apply = */ llama_sampler_grammar_apply,
|
1494
|
+
/* .reset = */ llama_sampler_grammar_reset,
|
1495
|
+
/* .clone = */ llama_sampler_grammar_clone,
|
1496
|
+
/* .free = */ llama_sampler_grammar_free,
|
1497
|
+
};
|
1498
|
+
|
1499
|
+
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
1500
|
+
auto * ctx = new llama_sampler_grammar;
|
1501
|
+
|
1502
|
+
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
1503
|
+
*ctx = {
|
1504
|
+
/* .vocab = */ &vocab,
|
1505
|
+
/* .grammar_str = */ grammar_str,
|
1506
|
+
/* .grammar_root = */ grammar_root,
|
1507
|
+
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
1508
|
+
};
|
1509
|
+
} else {
|
1510
|
+
*ctx = {
|
1511
|
+
/* .vocab = */ &vocab,
|
1512
|
+
/* .grammar_str = */ {},
|
1513
|
+
/* .grammar_root = */ {},
|
1514
|
+
/* .grammar = */ nullptr,
|
1515
|
+
};
|
528
1516
|
}
|
1517
|
+
|
1518
|
+
return new llama_sampler {
|
1519
|
+
/* .iface = */ &llama_sampler_grammar_i,
|
1520
|
+
/* .ctx = */ ctx,
|
1521
|
+
};
|
529
1522
|
}
|
530
1523
|
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
1524
|
+
// penalties
|
1525
|
+
|
1526
|
+
struct llama_sampler_penalties {
|
1527
|
+
const int32_t n_vocab;
|
1528
|
+
const llama_token special_eos_id;
|
1529
|
+
const llama_token linefeed_id;
|
1530
|
+
|
1531
|
+
const int32_t penalty_last_n;
|
1532
|
+
const float penalty_repeat;
|
1533
|
+
const float penalty_freq;
|
1534
|
+
const float penalty_present;
|
1535
|
+
|
1536
|
+
const bool penalize_nl;
|
1537
|
+
const bool ignore_eos;
|
1538
|
+
|
1539
|
+
ring_buffer<llama_token> prev;
|
1540
|
+
};
|
1541
|
+
|
1542
|
+
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
1543
|
+
return "penalties";
|
1544
|
+
}
|
1545
|
+
|
1546
|
+
static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
|
1547
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
1548
|
+
if (ctx->penalty_last_n == 0) {
|
540
1549
|
return;
|
541
1550
|
}
|
542
1551
|
|
543
|
-
|
1552
|
+
ctx->prev.push_back(token);
|
1553
|
+
}
|
1554
|
+
|
1555
|
+
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1556
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
1557
|
+
|
1558
|
+
if (ctx->ignore_eos) {
|
1559
|
+
assert(ctx->special_eos_id >= 0);
|
1560
|
+
|
1561
|
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
1562
|
+
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
1563
|
+
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
1564
|
+
} else {
|
1565
|
+
// else, search for the special EOS token
|
1566
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1567
|
+
if (cur_p->data[i].id == ctx->special_eos_id) {
|
1568
|
+
cur_p->data[i].logit = -INFINITY;
|
1569
|
+
break;
|
1570
|
+
}
|
1571
|
+
}
|
1572
|
+
}
|
1573
|
+
}
|
1574
|
+
|
1575
|
+
if ((ctx->penalty_last_n == 0) ||
|
1576
|
+
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
1577
|
+
return;
|
1578
|
+
}
|
1579
|
+
|
1580
|
+
bool nl_found = false;
|
1581
|
+
size_t nl_idx = 0;
|
1582
|
+
float nl_logit = -INFINITY;
|
1583
|
+
if (!ctx->penalize_nl) {
|
1584
|
+
assert(ctx->linefeed_id >= 0);
|
1585
|
+
|
1586
|
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
1587
|
+
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
1588
|
+
nl_found = true;
|
1589
|
+
nl_idx = ctx->linefeed_id;
|
1590
|
+
nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
1591
|
+
} else {
|
1592
|
+
// else, search for the linefeed token
|
1593
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1594
|
+
if (cur_p->data[i].id == ctx->linefeed_id) {
|
1595
|
+
nl_found = true;
|
1596
|
+
nl_idx = i;
|
1597
|
+
nl_logit = cur_p->data[i].logit;
|
1598
|
+
break;
|
1599
|
+
}
|
1600
|
+
}
|
1601
|
+
}
|
1602
|
+
}
|
544
1603
|
|
545
1604
|
// Create a frequency map to count occurrences of each token in last_tokens
|
546
|
-
|
547
|
-
|
548
|
-
|
1605
|
+
// TODO: optimize this by maintaining the token count in the sampler context
|
1606
|
+
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
1607
|
+
llama_token_cnt token_count;
|
1608
|
+
|
1609
|
+
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
1610
|
+
token_count[ctx->prev.rat(i)]++;
|
549
1611
|
}
|
550
1612
|
|
551
|
-
// Apply frequency and presence penalties to the
|
552
|
-
for (size_t i = 0; i <
|
553
|
-
const auto token_iter = token_count.find(
|
1613
|
+
// Apply frequency and presence penalties to the cur_p
|
1614
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1615
|
+
const auto token_iter = token_count.find(cur_p->data[i].id);
|
554
1616
|
if (token_iter == token_count.end()) {
|
555
1617
|
continue;
|
556
1618
|
}
|
@@ -559,171 +1621,238 @@ void llama_sample_repetition_penalties_impl(
|
|
559
1621
|
|
560
1622
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
561
1623
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
562
|
-
if (
|
563
|
-
|
1624
|
+
if (cur_p->data[i].logit <= 0) {
|
1625
|
+
cur_p->data[i].logit *= ctx->penalty_repeat;
|
564
1626
|
} else {
|
565
|
-
|
1627
|
+
cur_p->data[i].logit /= ctx->penalty_repeat;
|
566
1628
|
}
|
567
1629
|
|
568
|
-
|
1630
|
+
cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
|
569
1631
|
}
|
570
1632
|
|
571
|
-
|
1633
|
+
cur_p->sorted = false;
|
572
1634
|
|
573
|
-
if (
|
574
|
-
|
1635
|
+
if (!ctx->penalize_nl && nl_found) {
|
1636
|
+
// restore the logit of the newline token if it was penalized
|
1637
|
+
cur_p->data[nl_idx].logit = nl_logit;
|
575
1638
|
}
|
576
1639
|
}
|
577
1640
|
|
578
|
-
void
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
float scale) {
|
583
|
-
LM_GGML_ASSERT(smpl);
|
584
|
-
|
585
|
-
const auto t_start_sample_us = lm_ggml_time_us();
|
586
|
-
const auto n_vocab = smpl->n_vocab;
|
587
|
-
|
588
|
-
llama_log_softmax(logits, n_vocab);
|
589
|
-
llama_log_softmax(logits_guidance, n_vocab);
|
1641
|
+
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
1642
|
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
1643
|
+
ctx->prev.clear();
|
1644
|
+
}
|
590
1645
|
|
591
|
-
|
592
|
-
|
593
|
-
|
1646
|
+
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
1647
|
+
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
1648
|
+
auto * result = llama_sampler_init_penalties(
|
1649
|
+
ctx->n_vocab,
|
1650
|
+
ctx->special_eos_id,
|
1651
|
+
ctx->linefeed_id,
|
1652
|
+
ctx->penalty_last_n,
|
1653
|
+
ctx->penalty_repeat,
|
1654
|
+
ctx->penalty_freq,
|
1655
|
+
ctx->penalty_present,
|
1656
|
+
ctx->penalize_nl,
|
1657
|
+
ctx->ignore_eos);
|
1658
|
+
|
1659
|
+
// copy the state
|
1660
|
+
{
|
1661
|
+
auto * result_ctx = (llama_sampler_penalties *) result->ctx;
|
594
1662
|
|
595
|
-
|
1663
|
+
result_ctx->prev = ctx->prev;
|
596
1664
|
}
|
597
1665
|
|
598
|
-
|
1666
|
+
return result;
|
599
1667
|
}
|
600
1668
|
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
const int32_t n_vocab = float(smpl->n_vocab);
|
605
|
-
|
606
|
-
int64_t t_start_sample_us = lm_ggml_time_us();
|
1669
|
+
static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
1670
|
+
delete (llama_sampler_penalties *) smpl->ctx;
|
1671
|
+
}
|
607
1672
|
|
608
|
-
|
1673
|
+
static struct llama_sampler_i llama_sampler_penalties_i = {
|
1674
|
+
/* .name = */ llama_sampler_penalties_name,
|
1675
|
+
/* .accept = */ llama_sampler_penalties_accept,
|
1676
|
+
/* .apply = */ llama_sampler_penalties_apply,
|
1677
|
+
/* .reset = */ llama_sampler_penalties_reset,
|
1678
|
+
/* .clone = */ llama_sampler_penalties_clone,
|
1679
|
+
/* .free = */ llama_sampler_penalties_free,
|
1680
|
+
};
|
1681
|
+
|
1682
|
+
struct llama_sampler * llama_sampler_init_penalties(
|
1683
|
+
int32_t n_vocab,
|
1684
|
+
llama_token special_eos_id,
|
1685
|
+
llama_token linefeed_id,
|
1686
|
+
int32_t penalty_last_n,
|
1687
|
+
float penalty_repeat,
|
1688
|
+
float penalty_freq,
|
1689
|
+
float penalty_present,
|
1690
|
+
bool penalize_nl,
|
1691
|
+
bool ignore_eos) {
|
1692
|
+
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
1693
|
+
penalize_nl = true;
|
1694
|
+
}
|
609
1695
|
|
610
|
-
|
611
|
-
|
612
|
-
float sum_ti_bi = 0.0;
|
613
|
-
float sum_ti_sq = 0.0;
|
614
|
-
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
615
|
-
float t_i = logf(float(i + 2) / float(i + 1));
|
616
|
-
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
617
|
-
sum_ti_bi += t_i * b_i;
|
618
|
-
sum_ti_sq += t_i * t_i;
|
1696
|
+
if (special_eos_id == LLAMA_TOKEN_NULL) {
|
1697
|
+
ignore_eos = false;
|
619
1698
|
}
|
620
|
-
s_hat = sum_ti_bi / sum_ti_sq;
|
621
1699
|
|
622
|
-
|
623
|
-
|
624
|
-
|
1700
|
+
penalty_last_n = std::max(penalty_last_n, 0);
|
1701
|
+
|
1702
|
+
return new llama_sampler {
|
1703
|
+
/* .iface = */ &llama_sampler_penalties_i,
|
1704
|
+
/* .ctx = */ new llama_sampler_penalties {
|
1705
|
+
/* .n_vocab = */ n_vocab,
|
1706
|
+
/* .special_eos_id = */ special_eos_id,
|
1707
|
+
/* .linefeed_id = */ linefeed_id,
|
1708
|
+
/* .penalty_last_n = */ penalty_last_n,
|
1709
|
+
/* .penalty_repeat = */ penalty_repeat,
|
1710
|
+
/* .penalty_freq = */ penalty_freq,
|
1711
|
+
/* .penalty_present = */ penalty_present,
|
1712
|
+
/* .penalize_nl = */ penalize_nl,
|
1713
|
+
/* .ignore_eos = */ ignore_eos,
|
1714
|
+
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
1715
|
+
},
|
1716
|
+
};
|
1717
|
+
}
|
625
1718
|
|
626
|
-
|
627
|
-
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
628
|
-
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
629
|
-
llama_token X = llama_sample_token_impl(smpl, candidates);
|
630
|
-
t_start_sample_us = lm_ggml_time_us();
|
1719
|
+
// logit-bias
|
631
1720
|
|
632
|
-
|
633
|
-
|
634
|
-
return candidate.id == X;
|
635
|
-
}));
|
636
|
-
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
637
|
-
float e = observed_surprise - tau;
|
1721
|
+
struct llama_sampler_logit_bias {
|
1722
|
+
const int32_t n_vocab;
|
638
1723
|
|
639
|
-
|
640
|
-
|
1724
|
+
const std::vector<llama_logit_bias> logit_bias;
|
1725
|
+
|
1726
|
+
std::vector<llama_logit_bias> to_search;
|
1727
|
+
};
|
641
1728
|
|
642
|
-
|
643
|
-
return
|
1729
|
+
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
|
1730
|
+
return "logit-bias";
|
644
1731
|
}
|
645
1732
|
|
646
|
-
|
647
|
-
|
648
|
-
t_start_sample_us = lm_ggml_time_us();
|
1733
|
+
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1734
|
+
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
649
1735
|
|
650
|
-
|
1736
|
+
if (ctx->logit_bias.empty()) {
|
1737
|
+
return;
|
1738
|
+
}
|
651
1739
|
|
652
|
-
|
653
|
-
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
654
|
-
return -log2f(candidate.p) > *mu;
|
655
|
-
}));
|
1740
|
+
ctx->to_search.clear();
|
656
1741
|
|
657
|
-
|
658
|
-
|
1742
|
+
// update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
|
1743
|
+
for (const auto & lb : ctx->logit_bias) {
|
1744
|
+
if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
|
1745
|
+
cur_p->data[lb.token].logit += lb.bias;
|
1746
|
+
} else {
|
1747
|
+
ctx->to_search.push_back(lb);
|
1748
|
+
}
|
659
1749
|
}
|
660
1750
|
|
661
|
-
if (
|
662
|
-
|
1751
|
+
if (ctx->to_search.empty()) {
|
1752
|
+
return;
|
663
1753
|
}
|
664
1754
|
|
665
|
-
//
|
666
|
-
|
1755
|
+
// search for the remaining candidates that were not found in the previous step
|
1756
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1757
|
+
for (const auto & lb : ctx->to_search) {
|
1758
|
+
if (cur_p->data[i].id == lb.token) {
|
1759
|
+
cur_p->data[i].logit += lb.bias;
|
1760
|
+
break;
|
1761
|
+
}
|
1762
|
+
}
|
1763
|
+
}
|
1764
|
+
}
|
667
1765
|
|
668
|
-
|
669
|
-
|
670
|
-
|
1766
|
+
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
|
1767
|
+
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
|
1768
|
+
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
|
1769
|
+
}
|
671
1770
|
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
}));
|
676
|
-
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
677
|
-
float e = observed_surprise - tau;
|
1771
|
+
static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
1772
|
+
delete (llama_sampler_logit_bias *) smpl->ctx;
|
1773
|
+
}
|
678
1774
|
|
679
|
-
|
680
|
-
|
1775
|
+
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
1776
|
+
/* .name = */ llama_sampler_logit_bias_name,
|
1777
|
+
/* .accept = */ nullptr,
|
1778
|
+
/* .apply = */ llama_sampler_logit_bias_apply,
|
1779
|
+
/* .reset = */ nullptr,
|
1780
|
+
/* .clone = */ llama_sampler_logit_bias_clone,
|
1781
|
+
/* .free = */ llama_sampler_logit_bias_free,
|
1782
|
+
};
|
1783
|
+
|
1784
|
+
struct llama_sampler * llama_sampler_init_logit_bias(
|
1785
|
+
int32_t n_vocab,
|
1786
|
+
int32_t n_logit_bias,
|
1787
|
+
const llama_logit_bias * logit_bias) {
|
1788
|
+
return new llama_sampler {
|
1789
|
+
/* .iface = */ &llama_sampler_logit_bias_i,
|
1790
|
+
/* .ctx = */ new llama_sampler_logit_bias {
|
1791
|
+
/* .n_vocab = */ n_vocab,
|
1792
|
+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
1793
|
+
/* .to_search = */ {},
|
1794
|
+
},
|
1795
|
+
};
|
1796
|
+
}
|
1797
|
+
|
1798
|
+
// utils
|
681
1799
|
|
682
|
-
|
683
|
-
|
1800
|
+
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
1801
|
+
if (smpl->iface == &llama_sampler_dist_i) {
|
1802
|
+
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
|
684
1803
|
}
|
685
|
-
return X;
|
686
|
-
}
|
687
1804
|
|
688
|
-
|
689
|
-
|
1805
|
+
if (smpl->iface == &llama_sampler_mirostat_i) {
|
1806
|
+
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
|
1807
|
+
}
|
690
1808
|
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
});
|
1809
|
+
if (smpl->iface == &llama_sampler_mirostat_v2_i) {
|
1810
|
+
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
|
1811
|
+
}
|
695
1812
|
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
1813
|
+
if (smpl->iface == &llama_sampler_chain_i) {
|
1814
|
+
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
1815
|
+
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
1816
|
+
const uint32_t seed = llama_sampler_get_seed(*it);
|
1817
|
+
if (seed != LLAMA_DEFAULT_SEED) {
|
1818
|
+
return seed;
|
1819
|
+
}
|
1820
|
+
}
|
700
1821
|
}
|
701
|
-
|
1822
|
+
|
1823
|
+
return LLAMA_DEFAULT_SEED;
|
702
1824
|
}
|
703
1825
|
|
704
|
-
|
705
|
-
LM_GGML_ASSERT(smpl);
|
1826
|
+
// perf
|
706
1827
|
|
707
|
-
|
708
|
-
|
1828
|
+
struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
|
1829
|
+
struct llama_perf_sampler_data data = {};
|
709
1830
|
|
710
|
-
|
711
|
-
|
712
|
-
for (size_t i = 0; i < candidates->size; ++i) {
|
713
|
-
probs.push_back(candidates->data[i].p);
|
1831
|
+
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
1832
|
+
LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
714
1833
|
}
|
715
1834
|
|
716
|
-
|
717
|
-
int idx = dist(rng);
|
1835
|
+
const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
|
718
1836
|
|
719
|
-
|
1837
|
+
data.t_sample_ms = 1e-3 * ctx->t_sample_us;
|
1838
|
+
data.n_sample = std::max(0, ctx->n_sample);
|
720
1839
|
|
721
|
-
|
722
|
-
|
1840
|
+
return data;
|
1841
|
+
}
|
723
1842
|
|
724
|
-
|
1843
|
+
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
1844
|
+
const auto data = llama_perf_sampler(chain);
|
1845
|
+
|
1846
|
+
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
1847
|
+
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
725
1848
|
}
|
726
1849
|
|
727
|
-
|
728
|
-
|
1850
|
+
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
1851
|
+
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
1852
|
+
LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
1853
|
+
}
|
1854
|
+
|
1855
|
+
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
1856
|
+
|
1857
|
+
ctx->t_sample_us = ctx->n_sample = 0;
|
729
1858
|
}
|