cui-llama.rn 1.1.2 → 1.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,56 @@
1
1
  #include "llama-sampling.h"
2
2
 
3
+ #include "llama-vocab.h"
4
+ #include "llama-grammar.h"
5
+
6
+ #include <cassert>
3
7
  #include <algorithm>
4
8
  #include <cstring>
5
9
  #include <ctime>
6
10
  #include <cfloat>
7
11
  #include <numeric>
12
+ #include <random>
8
13
  #include <unordered_map>
9
14
 
15
+ static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
16
+ #if 1
17
+ probs.resize(cur_p->size);
18
+ for (size_t i = 0; i < cur_p->size; ++i) {
19
+ probs[i] = cur_p->data[i].p;
20
+ }
21
+
22
+ std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
23
+ #else
24
+ // avoid the copy with a custom iterator
25
+ #pragma GCC diagnostic push
26
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
27
+
28
+ struct probs_iterator {
29
+ typedef std::input_iterator_tag iterator_category;
30
+ typedef float value_type;
31
+ typedef float * pointer;
32
+ typedef float & reference;
33
+ typedef size_t difference_type;
34
+
35
+ const llama_token_data_array * data;
36
+ size_t i;
37
+
38
+ bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; }
39
+ bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; }
40
+ float operator*() const { return data->data[i].p; }
41
+ probs_iterator & operator++() { ++i; return *this; }
42
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; }
43
+ };
44
+ #pragma GCC diagnostic pop
45
+
46
+ std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size});
47
+
48
+ LM_GGML_UNUSED(probs);
49
+ #endif
50
+
51
+ return dist(rng);
52
+ }
53
+
10
54
  static void llama_log_softmax(float * array, size_t size) {
11
55
  float max_l = *std::max_element(array, array + size);
12
56
  float sum = 0.f;
@@ -21,65 +65,50 @@ static void llama_log_softmax(float * array, size_t size) {
21
65
  }
22
66
  }
23
67
 
24
- void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
25
- if (seed == LLAMA_DEFAULT_SEED) {
26
- seed = time(NULL);
27
- }
28
-
29
- smpl->rng.seed(seed);
30
- }
31
-
32
- void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
33
- LM_GGML_ASSERT(candidates->size > 0);
34
-
35
- const int64_t t_start_sample_us = lm_ggml_time_us();
68
+ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
69
+ LM_GGML_ASSERT(cur_p->size > 0);
36
70
 
37
71
  // Sort the logits in descending order
38
- if (!candidates->sorted) {
39
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
72
+ if (!cur_p->sorted) {
73
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
40
74
  return a.logit > b.logit;
41
75
  });
42
- candidates->sorted = true;
76
+ cur_p->sorted = true;
43
77
  }
44
78
 
45
- float max_l = candidates->data[0].logit;
79
+ float max_l = cur_p->data[0].logit;
46
80
  float cum_sum = 0.0f;
47
- for (size_t i = 0; i < candidates->size; ++i) {
48
- float p = expf(candidates->data[i].logit - max_l);
49
- candidates->data[i].p = p;
81
+
82
+ for (size_t i = 0; i < cur_p->size; ++i) {
83
+ float p = expf(cur_p->data[i].logit - max_l);
84
+ cur_p->data[i].p = p;
50
85
  cum_sum += p;
51
86
  }
52
- for (size_t i = 0; i < candidates->size; ++i) {
53
- candidates->data[i].p /= cum_sum;
54
- }
55
87
 
56
- if (smpl) {
57
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
88
+ for (size_t i = 0; i < cur_p->size; ++i) {
89
+ cur_p->data[i].p /= cum_sum;
58
90
  }
59
91
  }
60
92
 
61
- void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
93
+ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
62
94
  // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
63
- // if (k >= (int32_t)candidates->size) {
95
+ // if (k >= (int32_t)cur_p->size) {
64
96
  // return;
65
97
  // }
66
98
 
67
- const int64_t t_start_sample_us = lm_ggml_time_us();
68
-
69
99
  if (k <= 0) {
70
- k = candidates->size;
100
+ k = cur_p->size;
71
101
  }
72
102
 
73
- k = std::max(k, (int) min_keep);
74
- k = std::min(k, (int) candidates->size);
103
+ k = std::min(k, (int) cur_p->size);
75
104
 
76
105
  // Sort scores in descending order
77
- if (!candidates->sorted) {
106
+ if (!cur_p->sorted) {
78
107
  auto comp = [](const llama_token_data & a, const llama_token_data & b) {
79
108
  return a.logit > b.logit;
80
109
  };
81
110
  if (k <= 128) {
82
- std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
111
+ std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
83
112
  } else {
84
113
  constexpr int nbuckets = 128;
85
114
  constexpr float bucket_low = -10.0f;
@@ -87,11 +116,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
87
116
  constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
88
117
  constexpr float bucket_inter = -bucket_low * bucket_scale;
89
118
 
90
- std::vector<int> bucket_idx(candidates->size);
119
+ std::vector<int> bucket_idx(cur_p->size);
91
120
  std::vector<int> histo(nbuckets, 0);
92
121
 
93
- for (int i = 0; i < (int)candidates->size; ++i) {
94
- const float val = candidates->data[i].logit;
122
+ for (int i = 0; i < (int)cur_p->size; ++i) {
123
+ const float val = cur_p->data[i].logit;
95
124
  int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
96
125
  ib = std::max(0, std::min(nbuckets-1, ib));
97
126
  bucket_idx[i] = ib;
@@ -101,20 +130,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
101
130
  int ib = nbuckets - 1;
102
131
  for ( ; ib >= 0; --ib) {
103
132
  nhave += histo[ib];
104
- if (nhave >= k) break;
133
+ if (nhave >= k) {
134
+ break;
135
+ }
105
136
  }
106
137
  std::vector<llama_token_data> tmp_tokens(nhave);
107
- auto ptr = tmp_tokens.data();
138
+ auto * ptr = tmp_tokens.data();
108
139
  std::vector<llama_token_data*> bucket_ptrs;
109
140
  bucket_ptrs.reserve(nbuckets - ib);
110
141
  for (int j = nbuckets - 1; j >= ib; --j) {
111
142
  bucket_ptrs.push_back(ptr);
112
143
  ptr += histo[j];
113
144
  }
114
- for (int i = 0; i < (int)candidates->size; ++i) {
145
+ for (int i = 0; i < (int)cur_p->size; ++i) {
115
146
  int j = bucket_idx[i];
116
147
  if (j >= ib) {
117
- *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
148
+ *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
118
149
  }
119
150
  }
120
151
 
@@ -127,69 +158,554 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
127
158
  }
128
159
  std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
129
160
 
130
- std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
161
+ std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
131
162
 
132
163
  }
133
- candidates->sorted = true;
164
+ cur_p->sorted = true;
134
165
  }
135
- candidates->size = k;
166
+ cur_p->size = k;
167
+ }
136
168
 
137
- if (smpl) {
138
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
169
+ // llama_sampler API
170
+
171
+ const char * llama_sampler_name(const struct llama_sampler * smpl) {
172
+ if (!smpl->iface) {
173
+ return "(null)";
139
174
  }
175
+
176
+ return smpl->iface->name(smpl);
140
177
  }
141
178
 
142
- void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143
- if (p >= 1.0f) {
179
+ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
180
+ if (smpl->iface->accept) {
181
+ smpl->iface->accept(smpl, token);
182
+ }
183
+ }
184
+
185
+ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
186
+ LM_GGML_ASSERT(smpl->iface->apply);
187
+ smpl->iface->apply(smpl, cur_p);
188
+ }
189
+
190
+ void llama_sampler_reset(struct llama_sampler * smpl) {
191
+ if (smpl->iface->reset) {
192
+ smpl->iface->reset(smpl);
193
+ }
194
+ }
195
+
196
+ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
197
+ if (smpl->iface->clone) {
198
+ return smpl->iface->clone(smpl);
199
+ }
200
+
201
+ if (smpl->ctx == nullptr) {
202
+ return new llama_sampler {
203
+ /* .iface = */ smpl->iface,
204
+ /* .ctx = */ nullptr,
205
+ };
206
+ }
207
+
208
+ LM_GGML_ABORT("the sampler does not support cloning");
209
+ }
210
+
211
+ void llama_sampler_free(struct llama_sampler * smpl) {
212
+ if (smpl == nullptr) {
144
213
  return;
145
214
  }
146
215
 
147
- llama_sample_softmax_impl(smpl, candidates);
216
+ if (smpl->iface->free) {
217
+ smpl->iface->free(smpl);
218
+ }
148
219
 
149
- const int64_t t_start_sample_us = lm_ggml_time_us();
220
+ delete smpl;
221
+ }
222
+
223
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
224
+ const auto * logits = llama_get_logits_ith(ctx, idx);
225
+
226
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
227
+
228
+ // TODO: do not allocate each time
229
+ std::vector<llama_token_data> cur(n_vocab);
230
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
231
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
232
+ }
233
+
234
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
235
+
236
+ llama_sampler_apply(smpl, &cur_p);
237
+
238
+ return cur_p.data[cur_p.selected].id;
239
+ }
240
+
241
+ // sampler chain
242
+
243
+ static struct llama_sampler_i llama_sampler_chain_i = {
244
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
245
+ /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
246
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
247
+
248
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
249
+
250
+ for (auto * smpl : chain->samplers) {
251
+ llama_sampler_accept(smpl, token);
252
+ }
253
+
254
+ chain->n_sample++;
255
+ },
256
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
257
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
258
+
259
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
260
+
261
+ for (auto * smpl : chain->samplers) {
262
+ llama_sampler_apply(smpl, cur_p);
263
+ }
264
+ },
265
+ /* .reset = */ [](struct llama_sampler * smpl) {
266
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
267
+
268
+ for (auto * smpl : chain->samplers) {
269
+ llama_sampler_reset(smpl);
270
+ }
271
+
272
+ chain->t_sample_us = 0;
273
+ chain->n_sample = 0;
274
+ },
275
+ /* .clone = */ [](const struct llama_sampler * smpl) {
276
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
277
+
278
+ auto * result = llama_sampler_chain_init(chain_src->params);
279
+
280
+ for (auto * smpl : chain_src->samplers) {
281
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
282
+ }
283
+
284
+ return result;
285
+ },
286
+ /* .free = */ [](struct llama_sampler * smpl) {
287
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
288
+
289
+ for (auto * smpl : chain->samplers) {
290
+ llama_sampler_free(smpl);
291
+ }
292
+
293
+ delete chain;
294
+ },
295
+ };
296
+
297
+ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
298
+ return new llama_sampler {
299
+ /* .iface = */ &llama_sampler_chain_i,
300
+ /* .ctx = */ new llama_sampler_chain {
301
+ /* .params = */ params,
302
+ /* .samplers = */ {},
303
+ /* .t_sample_us = */ 0,
304
+ /* .n_sample = */ 0,
305
+ },
306
+ };
307
+ }
308
+
309
+ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
310
+ auto * p = (llama_sampler_chain *) chain->ctx;
311
+ p->samplers.push_back(smpl);
312
+ }
313
+
314
+ llama_sampler_timings llama_sampler_chain_timings(struct llama_sampler * chain) {
315
+ auto * p = (llama_sampler_chain *) chain->ctx;
316
+ struct llama_sampler_timings result = {
317
+ p -> t_sample_us,
318
+ p -> n_sample
319
+ };
320
+ return result;
321
+ }
322
+
323
+ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
324
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
325
+
326
+ if (i < 0 || i >= (int32_t) p->samplers.size()) {
327
+ return nullptr;
328
+ }
329
+
330
+ return p->samplers[i];
331
+ }
332
+
333
+ int llama_sampler_chain_n(const struct llama_sampler * chain) {
334
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
335
+
336
+ return p->samplers.size();
337
+ }
338
+
339
+ //
340
+ // samplers
341
+ //
342
+
343
+ // greedy
344
+
345
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
346
+ return "greedy";
347
+ }
348
+
349
+ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
350
+ cur_p->selected = 0;
351
+ for (size_t i = 1; i < cur_p->size; ++i) {
352
+ if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
353
+ cur_p->selected = i;
354
+ }
355
+ }
356
+ }
357
+
358
+ static struct llama_sampler_i llama_sampler_greedy_i = {
359
+ /* .name = */ llama_sampler_greedy_name,
360
+ /* .accept = */ nullptr,
361
+ /* .apply = */ llama_sampler_greedy_apply,
362
+ /* .reset = */ nullptr,
363
+ /* .clone = */ nullptr,
364
+ /* .free = */ nullptr,
365
+ };
366
+
367
+ struct llama_sampler * llama_sampler_init_greedy() {
368
+ return new llama_sampler {
369
+ /* .iface = */ &llama_sampler_greedy_i,
370
+ /* .ctx = */ nullptr,
371
+ };
372
+ }
373
+
374
+ // dist
375
+
376
+ struct llama_sampler_dist {
377
+ const uint32_t seed;
378
+
379
+ std::mt19937 rng;
380
+
381
+ std::vector<float> probs; // work array
382
+ };
383
+
384
+ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
385
+ return "dist";
386
+ }
387
+
388
+ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
389
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
390
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
391
+ }
392
+
393
+ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
394
+ const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
395
+ auto * result = llama_sampler_init_dist(ctx->seed);
396
+
397
+ // copy the state
398
+ {
399
+ auto * result_ctx = (llama_sampler_dist *) result->ctx;
400
+
401
+ result_ctx->rng = ctx->rng;
402
+ }
403
+
404
+ return result;
405
+ }
406
+
407
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
408
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
409
+ ctx->rng = std::mt19937(ctx->seed);
410
+ }
411
+
412
+ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
413
+ delete (llama_sampler_dist *) smpl->ctx;
414
+ }
415
+
416
+ static struct llama_sampler_i llama_sampler_dist_i = {
417
+ /* .name = */ llama_sampler_dist_name,
418
+ /* .accept = */ nullptr,
419
+ /* .apply = */ llama_sampler_dist_apply,
420
+ /* .reset = */ llama_sampler_dist_reset,
421
+ /* .clone = */ llama_sampler_dist_clone,
422
+ /* .free = */ llama_sampler_dist_free,
423
+ };
424
+
425
+ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
426
+ return new llama_sampler {
427
+ /* .iface = */ &llama_sampler_dist_i,
428
+ /* .ctx = */ new llama_sampler_dist {
429
+ /* .seed = */ seed,
430
+ /* .rng = */ std::mt19937(seed),
431
+ /* .probs = */ {},
432
+ },
433
+ };
434
+ }
435
+
436
+ // softmax
437
+
438
+ static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
439
+ return "softmax";
440
+ }
441
+
442
+ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
443
+ llama_sampler_softmax_impl(cur_p);
444
+ }
445
+
446
+ static struct llama_sampler_i llama_sampler_softmax_i = {
447
+ /* .name = */ llama_sampler_softmax_name,
448
+ /* .accept = */ nullptr,
449
+ /* .apply = */ llama_sampler_softmax_apply,
450
+ /* .reset = */ nullptr,
451
+ /* .clone = */ nullptr,
452
+ /* .free = */ nullptr,
453
+ };
454
+
455
+ struct llama_sampler * llama_sampler_init_softmax() {
456
+ return new llama_sampler {
457
+ /* .iface = */ &llama_sampler_softmax_i,
458
+ /* .ctx = */ nullptr,
459
+ };
460
+ }
461
+
462
+ // top-k
463
+
464
+ struct llama_sampler_top_k {
465
+ const int32_t k;
466
+ };
467
+
468
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
469
+ return "top-k";
470
+ }
471
+
472
+ static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
473
+ const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
474
+ llama_sampler_top_k_impl(cur_p, ctx->k);
475
+ }
476
+
477
+ static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
478
+ const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
479
+ return llama_sampler_init_top_k(ctx->k);
480
+ }
481
+
482
+ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
483
+ delete (llama_sampler_top_k *) smpl->ctx;
484
+ }
485
+
486
+ static struct llama_sampler_i llama_sampler_top_k_i = {
487
+ /* .name = */ llama_sampler_top_k_name,
488
+ /* .accept = */ nullptr,
489
+ /* .apply = */ llama_sampler_top_k_apply,
490
+ /* .reset = */ nullptr,
491
+ /* .clone = */ llama_sampler_top_k_clone,
492
+ /* .free = */ llama_sampler_top_k_free,
493
+ };
494
+
495
+ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
496
+ return new llama_sampler {
497
+ /* .iface = */ &llama_sampler_top_k_i,
498
+ /* .ctx = */ new llama_sampler_top_k {
499
+ /* .k = */ k,
500
+ },
501
+ };
502
+ }
503
+
504
+ // top-p
505
+
506
+ struct llama_sampler_top_p {
507
+ const float p;
508
+ const size_t min_keep;
509
+ };
510
+
511
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
512
+ return "top-p";
513
+ }
514
+
515
+ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
516
+ const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
517
+
518
+ if (ctx->p >= 1.0f) {
519
+ return;
520
+ }
521
+
522
+ llama_sampler_softmax_impl(cur_p);
150
523
 
151
524
  // Compute the cumulative probabilities
152
525
  float cum_sum = 0.0f;
153
- size_t last_idx = candidates->size;
526
+ size_t last_idx = cur_p->size;
154
527
 
155
- for (size_t i = 0; i < candidates->size; ++i) {
156
- cum_sum += candidates->data[i].p;
528
+ for (size_t i = 0; i < cur_p->size; ++i) {
529
+ cum_sum += cur_p->data[i].p;
157
530
 
158
531
  // Check if the running sum is at least p or if we have kept at least min_keep tokens
159
532
  // we set the last index to i+1 to indicate that the current iterate should be included in the set
160
- if (cum_sum >= p && i + 1 >= min_keep) {
533
+ if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
161
534
  last_idx = i + 1;
162
535
  break;
163
536
  }
164
537
  }
165
538
 
166
539
  // Resize the output vector to keep only the top-p tokens
167
- candidates->size = last_idx;
540
+ cur_p->size = last_idx;
541
+ }
542
+
543
+ static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
544
+ const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
545
+ return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
546
+ }
547
+
548
+ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
549
+ delete (llama_sampler_top_p *) smpl->ctx;
550
+ }
551
+
552
+ static struct llama_sampler_i llama_sampler_top_p_i = {
553
+ /* .name = */ llama_sampler_top_p_name,
554
+ /* .accept = */ nullptr,
555
+ /* .apply = */ llama_sampler_top_p_apply,
556
+ /* .reset = */ nullptr,
557
+ /* .clone = */ llama_sampler_top_p_clone,
558
+ /* .free = */ llama_sampler_top_p_free,
559
+ };
560
+
561
+ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
562
+ return new llama_sampler {
563
+ /* .iface = */ &llama_sampler_top_p_i,
564
+ /* .ctx = */ new llama_sampler_top_p {
565
+ /* .p = */ p,
566
+ /* .min_keep = */ min_keep,
567
+ },
568
+ };
569
+ }
570
+
571
+ // min-p
572
+
573
+ struct llama_sampler_min_p {
574
+ const float p;
575
+ const size_t min_keep;
576
+ };
577
+
578
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
579
+ return "min-p";
580
+ }
581
+
582
+ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
583
+ const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
584
+
585
+ if (ctx->p <= 0.0f || !cur_p->size) {
586
+ return;
587
+ }
588
+
589
+ bool min_p_applied = false;
590
+
591
+ // if the cur_p aren't sorted, try the unsorted implementation first
592
+ if (!cur_p->sorted) {
593
+ std::vector<llama_token_data> filtered_tokens;
594
+
595
+ float max_logit = -FLT_MAX;
596
+ for (size_t i = 0; i < cur_p->size; ++i) {
597
+ max_logit = std::max(max_logit, cur_p->data[i].logit);
598
+ }
599
+ const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
600
+
601
+ for (size_t i = 0; i < cur_p->size; ++i) {
602
+ if (cur_p->data[i].logit >= min_logit) {
603
+ filtered_tokens.push_back(cur_p->data[i]);
604
+ }
605
+ }
606
+
607
+ // if we have enough values the operation was a success
608
+ if (filtered_tokens.size() >= ctx->min_keep) {
609
+ memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
610
+ cur_p->size = filtered_tokens.size();
611
+ min_p_applied = true;
612
+ }
613
+ }
614
+
615
+ // if the cur_p are sorted or the unsorted implementation failed, use this implementation
616
+ if (!min_p_applied) {
617
+ // Sort the logits in descending order
618
+ if (!cur_p->sorted) {
619
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
620
+ return a.logit > b.logit;
621
+ });
622
+ cur_p->sorted = true;
623
+ }
624
+
625
+ const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
626
+ size_t i = 1; // first token always matches
627
+
628
+ for (; i < cur_p->size; ++i) {
629
+ if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
630
+ break; // prob too small
631
+ }
632
+ }
168
633
 
169
- if (smpl) {
170
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
634
+ // Resize the output vector to keep only the matching tokens
635
+ cur_p->size = i;
171
636
  }
172
637
  }
173
638
 
174
- void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, size_t min_keep, std::mt19937 & rng) {
175
- if(xtc_threshold <= 0.0f || !candidates-> size) {
639
+ static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
640
+ const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
641
+ return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
642
+ }
643
+
644
+ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
645
+ delete (llama_sampler_min_p *) smpl->ctx;
646
+ }
647
+
648
+ static struct llama_sampler_i llama_sampler_min_p_i = {
649
+ /* .name = */ llama_sampler_min_p_name,
650
+ /* .accept = */ nullptr,
651
+ /* .apply = */ llama_sampler_min_p_apply,
652
+ /* .reset = */ nullptr,
653
+ /* .clone = */ llama_sampler_min_p_clone,
654
+ /* .free = */ llama_sampler_min_p_free,
655
+ };
656
+
657
+ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
658
+ return new llama_sampler {
659
+ /* .iface = */ &llama_sampler_min_p_i,
660
+ /* .ctx = */ new llama_sampler_min_p {
661
+ /* .p = */ p,
662
+ /* .min_keep = */ min_keep,
663
+ },
664
+ };
665
+ }
666
+
667
+ // xtc
668
+
669
+ struct llama_sampler_xtc {
670
+ const uint32_t seed;
671
+ std::mt19937 rng;
672
+ const float xtc_p;
673
+ const float xtc_t;
674
+ const size_t min_keep;
675
+ };
676
+
677
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /* smpl */) {
678
+ return "xtc";
679
+ }
680
+
681
+ static void llama_sampler_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
682
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
683
+
684
+ size_t min_keep = ctx -> min_keep;
685
+ std::mt19937 rng = ctx -> rng;
686
+
687
+ float xtc_threshold = ctx -> xtc_t;
688
+ float xtc_probability = ctx -> xtc_p;
689
+
690
+
691
+ if(xtc_threshold <= 0.0f || !cur_p-> size) {
176
692
  return;
177
693
  }
178
694
 
179
695
  bool xtc_applied = false;
180
696
  const int64_t t_start_sample_us = lm_ggml_time_us();
181
- llama_sample_softmax(nullptr, candidates);
697
+ llama_sampler_softmax_impl(cur_p);
182
698
 
183
699
  // unsorted iteration
184
- if (!candidates->sorted) {
700
+ if (!cur_p->sorted) {
185
701
  std::vector<llama_token_data> top_tokens, low_tokens;
186
702
 
187
703
  // split candidates into two arrays for low and high tokens
188
- for (size_t i = 0; i < candidates->size; ++i) {
189
- if (candidates->data[i].logit >= xtc_threshold) {
190
- top_tokens.push_back(candidates->data[i]);
704
+ for (size_t i = 0; i < cur_p->size; ++i) {
705
+ if (cur_p->data[i].logit >= xtc_threshold) {
706
+ top_tokens.push_back(cur_p->data[i]);
191
707
  } else {
192
- low_tokens.push_back(candidates-> data[i]);
708
+ low_tokens.push_back(cur_p-> data[i]);
193
709
  }
194
710
  }
195
711
  // if there is only one or no top_tokens, do not truncate
@@ -212,28 +728,28 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
212
728
  }
213
729
  }
214
730
  if(low_tokens.size() >= min_keep) {
215
- memcpy(candidates->data, low_tokens.data(), low_tokens.size()*sizeof(llama_token_data));
216
- candidates->size = low_tokens.size();
731
+ memcpy(cur_p->data, low_tokens.data(), low_tokens.size()*sizeof(llama_token_data));
732
+ cur_p->size = low_tokens.size();
217
733
  xtc_applied = true;
218
734
  }
219
735
  }
220
736
  // sorted iteration
221
-
737
+
222
738
  if (!xtc_applied) {
223
739
  // Sort the logits in descending order
224
- if (!candidates->sorted) {
225
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
740
+ if (!cur_p->sorted) {
741
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
226
742
  return a.logit > b.logit;
227
743
  });
228
- candidates->sorted = true;
744
+ cur_p->sorted = true;
229
745
  }
230
746
 
231
747
  // find last token over threshold
232
748
 
233
749
  size_t last_index = 0;
234
750
 
235
- for (; last_index < candidates -> size; ++last_index) {
236
- if(candidates -> data[last_index].p < xtc_threshold) {
751
+ for (; last_index < cur_p -> size; ++last_index) {
752
+ if(cur_p -> data[last_index].p < xtc_threshold) {
237
753
  break;
238
754
  }
239
755
  }
@@ -244,102 +760,80 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
244
760
  }
245
761
  last_index--;
246
762
  // items beyond safe index will be ignored
247
- size_t safe_index = candidates -> size;
763
+ size_t safe_index = cur_p -> size;
248
764
 
249
765
  // remove tokens until last threshold item
250
766
  std::uniform_real_distribution<float> random_float(0.0 , 1.0);
251
767
  for (size_t i = 0; i < last_index; i++) {
252
768
  if(random_float(rng) < xtc_probability) {
253
- std::swap(candidates-> data[i], candidates->data[safe_index - 1]);
769
+ std::swap(cur_p-> data[i], cur_p->data[safe_index - 1]);
254
770
  safe_index--;
255
- if (candidates-> sorted) {
256
- candidates -> sorted = false;
771
+ if (cur_p-> sorted) {
772
+ cur_p -> sorted = false;
257
773
  }
258
774
  }
259
775
  }
260
- candidates -> size = safe_index;
261
- }
262
-
263
- if (smpl) {
264
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
776
+ cur_p -> size = safe_index;
265
777
  }
266
778
  }
267
779
 
268
- void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
269
- if (p <= 0.0f || !candidates->size) {
270
- return;
271
- }
272
-
273
- const int64_t t_start_sample_us = lm_ggml_time_us();
274
-
275
- bool min_p_applied = false;
276
-
277
- // if the candidates aren't sorted, try the unsorted implementation first
278
- if (!candidates->sorted) {
279
- std::vector<llama_token_data> filtered_tokens;
280
-
281
- float max_logit = -FLT_MAX;
282
- for (size_t i = 0; i < candidates->size; ++i) {
283
- max_logit = std::max(max_logit, candidates->data[i].logit);
284
- }
285
- const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
286
-
287
- for (size_t i = 0; i < candidates->size; ++i) {
288
- if (candidates->data[i].logit >= min_logit) {
289
- filtered_tokens.push_back(candidates->data[i]);
290
- }
291
- }
292
-
293
- // if we have enough values the operation was a success
294
- if (filtered_tokens.size() >= min_keep) {
295
- memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
296
- candidates->size = filtered_tokens.size();
297
- min_p_applied = true;
298
- }
299
- }
780
+ static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
781
+ const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
782
+ return llama_sampler_init_xtc(ctx->xtc_p, ctx->xtc_t, ctx->min_keep, ctx->seed);
783
+ }
300
784
 
301
- // if the candidates are sorted or the unsorted implementation failed, use this implementation
302
- if (!min_p_applied) {
303
- // Sort the logits in descending order
304
- if (!candidates->sorted) {
305
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
306
- return a.logit > b.logit;
307
- });
308
- candidates->sorted = true;
309
- }
785
+ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
786
+ delete (const llama_sampler_xtc *) smpl->ctx;
787
+ }
310
788
 
311
- const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
312
- size_t i = 1; // first token always matches
789
+ static struct llama_sampler_i llama_sampler_xtc_i = {
790
+ /* .name = */ llama_sampler_xtc_name,
791
+ /* .accept = */ nullptr,
792
+ /* .apply = */ llama_sampler_xtc_apply,
793
+ /* .reset = */ nullptr,
794
+ /* .clone = */ llama_sampler_xtc_clone,
795
+ /* .free = */ llama_sampler_xtc_free,
796
+ };
797
+
798
+ struct llama_sampler * llama_sampler_init_xtc(float xtc_p, float xtc_t, size_t min_keep, uint32_t seed) {
799
+ return new llama_sampler {
800
+ /* .iface = */ &llama_sampler_xtc_i,
801
+ /* .ctx = */ new llama_sampler_xtc {
802
+ /* .seed = */ seed,
803
+ /* .rng = */ std::mt19937(seed),
804
+ /* .xtc_p = */ xtc_p,
805
+ /* .xtc_t = */ xtc_t,
806
+ /* .min_keep = */ min_keep
807
+ },
808
+ };
809
+ }
313
810
 
314
- for (; i < candidates->size; ++i) {
315
- if (candidates->data[i].logit < min_logit && i >= min_keep) {
316
- break; // prob too small
317
- }
318
- }
811
+ // tail-free
319
812
 
320
- // Resize the output vector to keep only the matching tokens
321
- candidates->size = i;
322
- }
813
+ struct llama_sampler_tail_free {
814
+ const float z;
815
+ const size_t min_keep;
816
+ };
323
817
 
324
- if (smpl) {
325
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
326
- }
818
+ static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
819
+ return "tail-free";
327
820
  }
328
821
 
329
- void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
330
- if (z >= 1.0f || candidates->size <= 2) {
822
+ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
823
+ const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
824
+
825
+ if (ctx->z >= 1.0f || cur_p->size <= 2) {
331
826
  return;
332
827
  }
333
828
 
334
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
335
- const int64_t t_start_sample_us = lm_ggml_time_us();
829
+ llama_sampler_softmax_impl(cur_p);
336
830
 
337
831
  // Compute the first and second derivatives
338
- std::vector<float> first_derivatives(candidates->size - 1);
339
- std::vector<float> second_derivatives(candidates->size - 2);
832
+ std::vector<float> first_derivatives(cur_p->size - 1);
833
+ std::vector<float> second_derivatives(cur_p->size - 2);
340
834
 
341
835
  for (size_t i = 0; i < first_derivatives.size(); ++i) {
342
- first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
836
+ first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
343
837
  }
344
838
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
345
839
  second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
@@ -366,51 +860,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
366
860
  }
367
861
 
368
862
  float cum_sum = 0.0f;
369
- size_t last_idx = candidates->size;
863
+ size_t last_idx = cur_p->size;
370
864
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
371
865
  cum_sum += second_derivatives[i];
372
866
 
373
867
  // Check if the running sum is greater than z or if we have kept at least min_keep tokens
374
- if (cum_sum > z && i >= min_keep) {
868
+ if (cum_sum > ctx->z && i >= ctx->min_keep) {
375
869
  last_idx = i;
376
870
  break;
377
871
  }
378
872
  }
379
873
 
380
874
  // Resize the output vector to keep only the tokens above the tail location
381
- candidates->size = last_idx;
875
+ cur_p->size = last_idx;
876
+ }
382
877
 
383
- if (smpl) {
384
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
385
- }
878
+ static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
879
+ const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
880
+ return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
881
+ }
882
+
883
+ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
884
+ delete (llama_sampler_tail_free *) smpl->ctx;
386
885
  }
387
886
 
388
- void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
887
+ static struct llama_sampler_i llama_sampler_tail_free_i = {
888
+ /* .name = */ llama_sampler_tail_free_name,
889
+ /* .accept = */ nullptr,
890
+ /* .apply = */ llama_sampler_tail_free_apply,
891
+ /* .reset = */ nullptr,
892
+ /* .clone = */ llama_sampler_tail_free_clone,
893
+ /* .free = */ llama_sampler_tail_free_free,
894
+ };
895
+
896
+ struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
897
+ return new llama_sampler {
898
+ /* .iface = */ &llama_sampler_tail_free_i,
899
+ /* .ctx = */ new llama_sampler_tail_free {
900
+ /* .z = */ z,
901
+ /*. min_keep = */ min_keep,
902
+ },
903
+ };
904
+ }
905
+
906
+ // typical
907
+
908
+ struct llama_sampler_typical {
909
+ const float p;
910
+ const size_t min_keep;
911
+ };
912
+
913
+ static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
914
+ return "typical";
915
+ }
916
+
917
+ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
918
+ const auto * ctx = (llama_sampler_typical *) smpl->ctx;
919
+
389
920
  // Reference implementation:
390
921
  // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
391
- if (p >= 1.0f) {
922
+ if (ctx->p >= 1.0f) {
392
923
  return;
393
924
  }
394
925
 
395
926
  // Compute the softmax of logits and calculate entropy
396
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
397
-
398
- const int64_t t_start_sample_us = lm_ggml_time_us();
927
+ llama_sampler_softmax_impl(cur_p);
399
928
 
400
929
  float entropy = 0.0f;
401
- for (size_t i = 0; i < candidates->size; ++i) {
402
- entropy += -candidates->data[i].p * logf(candidates->data[i].p);
930
+ for (size_t i = 0; i < cur_p->size; ++i) {
931
+ entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
403
932
  }
404
933
 
405
934
  // Compute the absolute difference between negative log probability and entropy for each candidate
406
935
  std::vector<float> shifted_scores;
407
- for (size_t i = 0; i < candidates->size; ++i) {
408
- float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
936
+ for (size_t i = 0; i < cur_p->size; ++i) {
937
+ float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
409
938
  shifted_scores.push_back(shifted_score);
410
939
  }
411
940
 
412
941
  // Sort tokens based on the shifted_scores and their corresponding indices
413
- std::vector<size_t> indices(candidates->size);
942
+ std::vector<size_t> indices(cur_p->size);
414
943
  std::iota(indices.begin(), indices.end(), 0);
415
944
 
416
945
  std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
@@ -423,197 +952,250 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
423
952
 
424
953
  for (size_t i = 0; i < indices.size(); ++i) {
425
954
  size_t idx = indices[i];
426
- cum_sum += candidates->data[idx].p;
955
+ cum_sum += cur_p->data[idx].p;
427
956
 
428
957
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
429
- if (cum_sum > p && i >= min_keep - 1) {
958
+ if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
430
959
  last_idx = i + 1;
431
960
  break;
432
961
  }
433
962
  }
434
963
 
435
964
  // Resize the output vector to keep only the locally typical tokens
436
- std::vector<llama_token_data> new_candidates;
965
+ std::vector<llama_token_data> cur_p_new;
437
966
  for (size_t i = 0; i < last_idx; ++i) {
438
967
  size_t idx = indices[i];
439
- new_candidates.push_back(candidates->data[idx]);
968
+ cur_p_new.push_back(cur_p->data[idx]);
440
969
  }
441
970
 
442
- // Replace the data in candidates with the new_candidates data
443
- std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
444
- candidates->size = new_candidates.size();
445
- candidates->sorted = false;
971
+ // Replace the data in cur_p with the cur_p_new data
972
+ std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
973
+ cur_p->size = cur_p_new.size();
974
+ cur_p->sorted = false;
975
+ }
446
976
 
447
- if (smpl) {
448
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
449
- }
977
+ static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
978
+ const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
979
+ return llama_sampler_init_typical(ctx->p, ctx->min_keep);
450
980
  }
451
981
 
452
- void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
453
- const int64_t t_start_sample_us = lm_ggml_time_us();
982
+ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
983
+ delete (llama_sampler_typical *) smpl->ctx;
984
+ }
454
985
 
455
- // no need to do anything if there is only one (or zero) candidates
456
- if(candidates->size <= 1) {
457
- return;
458
- }
986
+ static struct llama_sampler_i llama_sampler_typical_i = {
987
+ /* .name = */ llama_sampler_typical_name,
988
+ /* .accept = */ nullptr,
989
+ /* .apply = */ llama_sampler_typical_apply,
990
+ /* .reset = */ nullptr,
991
+ /* .clone = */ llama_sampler_typical_clone,
992
+ /* .free = */ llama_sampler_typical_free,
993
+ };
994
+
995
+ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
996
+ return new llama_sampler {
997
+ /* .iface = */ &llama_sampler_typical_i,
998
+ /* .ctx = */ new llama_sampler_typical {
999
+ /* .p = */ p,
1000
+ /* .min_keep = */ min_keep,
1001
+ },
1002
+ };
1003
+ }
459
1004
 
460
- // Calculate maximum possible entropy
461
- float max_entropy = -logf(1.0f / candidates->size);
1005
+ // temp
462
1006
 
463
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1007
+ struct llama_sampler_temp {
1008
+ const float temp;
1009
+ };
464
1010
 
465
- // Calculate entropy of the softmax probabilities
466
- float entropy = 0.0f;
467
- for (size_t i = 0; i < candidates->size; ++i) {
468
- float prob = candidates->data[i].p;
469
- if (prob > 0.0f) { // Ensure no log(0)
470
- entropy -= prob * logf(prob);
471
- }
472
- }
1011
+ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
1012
+ return "temp";
1013
+ }
473
1014
 
474
- // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
475
- float normalized_entropy = entropy / max_entropy;
1015
+ static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1016
+ const auto * ctx = (llama_sampler_temp *) smpl->ctx;
1017
+ for (size_t i = 0; i < cur_p->size; ++i) {
1018
+ cur_p->data[i].logit /= ctx->temp;
1019
+ }
1020
+ }
476
1021
 
477
- // Map the normalized entropy to the desired temperature range using the power function
478
- float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1022
+ static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
1023
+ const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
1024
+ return llama_sampler_init_temp(ctx->temp);
1025
+ }
479
1026
 
480
- #ifdef DEBUG
481
- LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
482
- LLAMA_LOG_INFO("Entropy: %f\n", entropy);
483
- LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
484
- LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
485
- LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
486
- LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
487
- #endif
1027
+ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1028
+ delete (llama_sampler_temp *) smpl->ctx;
1029
+ }
488
1030
 
489
- // Apply the dynamically calculated temperature scaling
490
- for (size_t i = 0; i < candidates->size; ++i) {
491
- candidates->data[i].logit /= dyn_temp;
492
- }
1031
+ static struct llama_sampler_i llama_sampler_temp_i = {
1032
+ /* .name = */ llama_sampler_temp_name,
1033
+ /* .accept = */ nullptr,
1034
+ /* .apply = */ llama_sampler_temp_apply,
1035
+ /* .reset = */ nullptr,
1036
+ /* .clone = */ llama_sampler_temp_clone,
1037
+ /* .free = */ llama_sampler_temp_free,
1038
+ };
1039
+
1040
+ struct llama_sampler * llama_sampler_init_temp(float temp) {
1041
+ return new llama_sampler {
1042
+ /* .iface = */ &llama_sampler_temp_i,
1043
+ /* .ctx = */ new llama_sampler_temp {
1044
+ /*.temp = */ temp,
1045
+ },
1046
+ };
1047
+ }
493
1048
 
494
- // Re-compute softmax probabilities after scaling logits with dynamic temperature
495
- double max_l_double = candidates->data[0].logit;
496
- double cum_sum_double = 0.0;
497
- for (size_t i = 0; i < candidates->size; ++i) {
498
- double p = exp(candidates->data[i].logit - max_l_double);
499
- candidates->data[i].p = p; // Store the scaled probability
500
- cum_sum_double += p;
501
- }
502
- for (size_t i = 0; i < candidates->size; ++i) {
503
- candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
504
- }
1049
+ // temp-ext
505
1050
 
506
- #ifdef DEBUG
507
- // Print the updated top 25 probabilities after temperature scaling
508
- LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
509
- for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
510
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
511
- }
512
- #endif
1051
+ struct llama_sampler_temp_ext {
1052
+ const float temp;
1053
+ const float delta;
1054
+ const float exponent;
1055
+ };
513
1056
 
514
- if (smpl) {
515
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
516
- }
1057
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1058
+ return "temp-ext";
517
1059
  }
518
1060
 
519
- void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
520
- const int64_t t_start_sample_us = lm_ggml_time_us();
1061
+ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1062
+ const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1063
+ if (ctx->delta > 0) {
1064
+ const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1065
+ const float max_temp = ctx->temp + ctx->delta;
1066
+ float exponent_val = ctx->exponent;
521
1067
 
522
- for (size_t i = 0; i < candidates->size; ++i) {
523
- candidates->data[i].logit /= temp;
524
- }
1068
+ // no need to do anything if there is only one (or zero) candidates
1069
+ if (cur_p->size <= 1) {
1070
+ return;
1071
+ }
525
1072
 
526
- if (smpl) {
527
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
528
- }
529
- }
1073
+ // Calculate maximum possible entropy
1074
+ float max_entropy = -logf(1.0f / cur_p->size);
530
1075
 
531
- void llama_sample_repetition_penalties_impl(
532
- struct llama_sampling * smpl,
533
- llama_token_data_array * candidates,
534
- const llama_token * last_tokens,
535
- size_t penalty_last_n,
536
- float penalty_repeat,
537
- float penalty_freq,
538
- float penalty_present) {
539
- if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
540
- return;
541
- }
1076
+ llama_sampler_softmax_impl(cur_p);
542
1077
 
543
- const int64_t t_start_sample_us = lm_ggml_time_us();
1078
+ // Calculate entropy of the softmax probabilities
1079
+ float entropy = 0.0f;
1080
+ for (size_t i = 0; i < cur_p->size; ++i) {
1081
+ float prob = cur_p->data[i].p;
1082
+ if (prob > 0.0f) { // Ensure no log(0)
1083
+ entropy -= prob * logf(prob);
1084
+ }
1085
+ }
544
1086
 
545
- // Create a frequency map to count occurrences of each token in last_tokens
546
- std::unordered_map<llama_token, int> token_count;
547
- for (size_t i = 0; i < penalty_last_n; ++i) {
548
- token_count[last_tokens[i]]++;
549
- }
1087
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1088
+ float normalized_entropy = entropy / max_entropy;
550
1089
 
551
- // Apply frequency and presence penalties to the candidates
552
- for (size_t i = 0; i < candidates->size; ++i) {
553
- const auto token_iter = token_count.find(candidates->data[i].id);
554
- if (token_iter == token_count.end()) {
555
- continue;
1090
+ // Map the normalized entropy to the desired temperature range using the power function
1091
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1092
+
1093
+ #ifdef DEBUG
1094
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1095
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1096
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1097
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1098
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1099
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1100
+ #endif
1101
+
1102
+ // Apply the dynamically calculated temperature scaling
1103
+ for (size_t i = 0; i < cur_p->size; ++i) {
1104
+ cur_p->data[i].logit /= dyn_temp;
556
1105
  }
557
1106
 
558
- const int count = token_iter->second;
1107
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
1108
+ const double max_l_double = cur_p->data[0].logit;
559
1109
 
560
- // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
561
- // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
562
- if (candidates->data[i].logit <= 0) {
563
- candidates->data[i].logit *= penalty_repeat;
564
- } else {
565
- candidates->data[i].logit /= penalty_repeat;
1110
+ double cum_sum_double = 0.0;
1111
+ for (size_t i = 0; i < cur_p->size; ++i) {
1112
+ double p = exp(cur_p->data[i].logit - max_l_double);
1113
+ cur_p->data[i].p = p; // Store the scaled probability
1114
+ cum_sum_double += p;
1115
+ }
1116
+
1117
+ for (size_t i = 0; i < cur_p->size; ++i) {
1118
+ cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
566
1119
  }
567
1120
 
568
- candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
1121
+ #ifdef DEBUG
1122
+ // Print the updated top 25 probabilities after temperature scaling
1123
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1124
+ for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1125
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1126
+ }
1127
+ #endif
1128
+ } else {
1129
+ for (size_t i = 0; i < cur_p->size; ++i) {
1130
+ cur_p->data[i].logit /= ctx->temp;
1131
+ }
569
1132
  }
1133
+ }
570
1134
 
571
- candidates->sorted = false;
1135
+ static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1136
+ const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1137
+ return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1138
+ }
572
1139
 
573
- if (smpl) {
574
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
575
- }
1140
+ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1141
+ delete (llama_sampler_temp_ext *) smpl->ctx;
576
1142
  }
577
1143
 
578
- void llama_sample_apply_guidance_impl(
579
- struct llama_sampling * smpl,
580
- float * logits,
581
- float * logits_guidance,
582
- float scale) {
583
- LM_GGML_ASSERT(smpl);
1144
+ static struct llama_sampler_i llama_sampler_temp_ext_i = {
1145
+ /* .name = */ llama_sampler_temp_ext_name,
1146
+ /* .accept = */ nullptr,
1147
+ /* .apply = */ llama_sampler_temp_ext_apply,
1148
+ /* .reset = */ nullptr,
1149
+ /* .clone = */ llama_sampler_temp_ext_clone,
1150
+ /* .free = */ llama_sampler_temp_ext_free,
1151
+ };
1152
+
1153
+ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1154
+ return new llama_sampler {
1155
+ /* .iface = */ &llama_sampler_temp_ext_i,
1156
+ /* .ctx = */ new llama_sampler_temp_ext {
1157
+ /* .temp = */ temp,
1158
+ /* .delta = */ delta,
1159
+ /* .exponent = */ exponent,
1160
+ },
1161
+ };
1162
+ }
584
1163
 
585
- const auto t_start_sample_us = lm_ggml_time_us();
586
- const auto n_vocab = smpl->n_vocab;
1164
+ // mirostat
587
1165
 
588
- llama_log_softmax(logits, n_vocab);
589
- llama_log_softmax(logits_guidance, n_vocab);
1166
+ struct llama_sampler_mirostat {
1167
+ const int32_t n_vocab;
590
1168
 
591
- for (int i = 0; i < n_vocab; ++i) {
592
- auto & l = logits[i];
593
- const auto & g = logits_guidance[i];
1169
+ const uint32_t seed;
594
1170
 
595
- l = scale * (l - g) + g;
596
- }
1171
+ const float tau;
1172
+ const float eta;
597
1173
 
598
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
599
- }
1174
+ const int32_t m;
600
1175
 
601
- llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
602
- LM_GGML_ASSERT(smpl);
1176
+ float mu;
603
1177
 
604
- const int32_t n_vocab = float(smpl->n_vocab);
1178
+ std::mt19937 rng;
605
1179
 
606
- int64_t t_start_sample_us = lm_ggml_time_us();
1180
+ std::vector<float> probs;
1181
+ };
1182
+
1183
+ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1184
+ return "mirostat";
1185
+ }
607
1186
 
608
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1187
+ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1188
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1189
+
1190
+ llama_sampler_softmax_impl(cur_p);
609
1191
 
610
1192
  // Estimate s_hat using the most probable m tokens
611
1193
  float s_hat = 0.0;
612
1194
  float sum_ti_bi = 0.0;
613
1195
  float sum_ti_sq = 0.0;
614
- for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
1196
+ for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
615
1197
  float t_i = logf(float(i + 2) / float(i + 1));
616
- float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
1198
+ float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
617
1199
  sum_ti_bi += t_i * b_i;
618
1200
  sum_ti_sq += t_i * t_i;
619
1201
  }
@@ -621,109 +1203,532 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
621
1203
 
622
1204
  // Compute k from the estimated s_hat and target surprise value
623
1205
  float epsilon_hat = s_hat - 1;
624
- float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
1206
+ float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
625
1207
 
626
- // Sample the next word X using top-k sampling
627
- llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
628
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
629
- llama_token X = llama_sample_token_impl(smpl, candidates);
630
- t_start_sample_us = lm_ggml_time_us();
1208
+ llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1209
+ llama_sampler_softmax_impl(cur_p);
631
1210
 
632
- // Compute error as the difference between observed surprise and target surprise value
633
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
634
- return candidate.id == X;
635
- }));
636
- float observed_surprise = -log2f(candidates->data[X_idx].p);
637
- float e = observed_surprise - tau;
1211
+ const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
1212
+
1213
+ cur_p->selected = idx;
1214
+
1215
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1216
+ float e = observed_surprise - ctx->tau;
638
1217
 
639
1218
  // Update mu using the learning rate and error
640
- *mu = *mu - eta * e;
1219
+ ctx->mu = ctx->mu - ctx->eta * e;
1220
+ }
1221
+
1222
+ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1223
+ const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1224
+ auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
1225
+
1226
+ // copy the state
1227
+ {
1228
+ auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1229
+
1230
+ result_ctx->mu = ctx->mu;
1231
+ result_ctx->rng = ctx->rng;
1232
+ }
1233
+
1234
+ return result;
1235
+ }
1236
+
1237
+ static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1238
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1239
+ ctx->mu = 2.0f*ctx->tau;
1240
+ ctx->rng = std::mt19937(ctx->seed);
1241
+ }
1242
+
1243
+ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1244
+ delete (llama_sampler_mirostat *) smpl->ctx;
1245
+ }
641
1246
 
642
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
643
- return X;
1247
+ static struct llama_sampler_i llama_sampler_mirostat_i = {
1248
+ /* .name = */ llama_sampler_mirostat_name,
1249
+ /* .accept = */ nullptr,
1250
+ /* .apply = */ llama_sampler_mirostat_apply,
1251
+ /* .reset = */ llama_sampler_mirostat_reset,
1252
+ /* .clone = */ llama_sampler_mirostat_clone,
1253
+ /* .free = */ llama_sampler_mirostat_free,
1254
+ };
1255
+
1256
+ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1257
+ return new llama_sampler {
1258
+ /* .iface = */ &llama_sampler_mirostat_i,
1259
+ /* .ctx = */ new llama_sampler_mirostat {
1260
+ /* .n_vocab = */ n_vocab,
1261
+ /* .seed = */ seed,
1262
+ /* .tau = */ tau,
1263
+ /* .eta = */ eta,
1264
+ /* .m = */ m,
1265
+ /* .mu = */ 2.0f*tau,
1266
+ /* .rng = */ std::mt19937(seed),
1267
+ /* .probs = */ {},
1268
+ },
1269
+ };
644
1270
  }
645
1271
 
646
- llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
647
- int64_t t_start_sample_us;
648
- t_start_sample_us = lm_ggml_time_us();
1272
+ // mirostat v2
1273
+
1274
+ struct llama_sampler_mirostat_v2 {
1275
+ const uint32_t seed;
649
1276
 
650
- llama_sample_softmax_impl(smpl, candidates);
1277
+ const float tau;
1278
+ const float eta;
1279
+
1280
+ float mu;
1281
+
1282
+ std::mt19937 rng;
1283
+
1284
+ std::vector<float> probs;
1285
+ };
1286
+
1287
+ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1288
+ return "mirostat-v2";
1289
+ }
1290
+
1291
+ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1292
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1293
+
1294
+ llama_sampler_softmax_impl(cur_p);
651
1295
 
652
1296
  // Truncate the words with surprise values greater than mu
653
- candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
654
- return -log2f(candidate.p) > *mu;
1297
+ cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1298
+ return -log2f(candidate.p) > ctx->mu;
655
1299
  }));
656
1300
 
657
- if (candidates->size == 0) {
658
- candidates->size = 1;
659
- }
660
-
661
- if (smpl) {
662
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1301
+ if (cur_p->size == 0) {
1302
+ cur_p->size = 1;
663
1303
  }
664
1304
 
665
1305
  // Normalize the probabilities of the remaining words
666
- llama_sample_softmax_impl(smpl, candidates);
1306
+ llama_sampler_softmax_impl(cur_p);
667
1307
 
668
- // Sample the next word X from the remaining words
669
- llama_token X = llama_sample_token_impl(smpl, candidates);
670
- t_start_sample_us = lm_ggml_time_us();
1308
+ const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
671
1309
 
672
- // Compute error as the difference between observed surprise and target surprise value
673
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
674
- return candidate.id == X;
675
- }));
676
- float observed_surprise = -log2f(candidates->data[X_idx].p);
677
- float e = observed_surprise - tau;
1310
+ cur_p->selected = idx;
1311
+
1312
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1313
+ float e = observed_surprise - ctx->tau;
678
1314
 
679
1315
  // Update mu using the learning rate and error
680
- *mu = *mu - eta * e;
1316
+ ctx->mu = ctx->mu - ctx->eta * e;
1317
+ }
1318
+
1319
+ static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1320
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1321
+ ctx->mu = 2.0f*ctx->tau;
1322
+ ctx->rng = std::mt19937(ctx->seed);
1323
+ }
1324
+
1325
+ static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1326
+ const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1327
+
1328
+ auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
1329
+
1330
+ // copy the state
1331
+ {
1332
+ auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
681
1333
 
682
- if (smpl) {
683
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1334
+ result_ctx->mu = ctx->mu;
1335
+ result_ctx->rng = ctx->rng;
684
1336
  }
685
- return X;
1337
+
1338
+ return result;
686
1339
  }
687
1340
 
688
- llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
689
- const int64_t t_start_sample_us = lm_ggml_time_us();
1341
+ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1342
+ delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1343
+ }
690
1344
 
691
- // Find max element
692
- auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
693
- return a.logit < b.logit;
694
- });
1345
+ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1346
+ /* .name = */ llama_sampler_mirostat_v2_name,
1347
+ /* .accept = */ nullptr,
1348
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
1349
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
1350
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
1351
+ /* .free = */ llama_sampler_mirostat_v2_free,
1352
+ };
1353
+
1354
+ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1355
+ return new llama_sampler {
1356
+ /* .iface = */ &llama_sampler_mirostat_v2_i,
1357
+ /* .ctx = */ new llama_sampler_mirostat_v2 {
1358
+ /* .seed = */ seed,
1359
+ /* .tau = */ tau,
1360
+ /* .eta = */ eta,
1361
+ /* .mu = */ 2.0f*tau,
1362
+ /* .rng = */ std::mt19937(seed),
1363
+ /* .probs = */ {},
1364
+ },
1365
+ };
1366
+ }
695
1367
 
696
- llama_token result = max_iter->id;
697
- if (smpl) {
698
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
699
- smpl->n_sample++;
1368
+ // grammar
1369
+
1370
+ struct llama_sampler_grammar {
1371
+ const struct llama_vocab * vocab;
1372
+
1373
+ std::string grammar_str;
1374
+ std::string grammar_root;
1375
+
1376
+ struct llama_grammar * grammar;
1377
+ };
1378
+
1379
+ static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1380
+ return "grammar";
1381
+ }
1382
+
1383
+ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1384
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1385
+ if (ctx->grammar) {
1386
+ llama_grammar_accept_impl(*ctx->grammar, token);
700
1387
  }
1388
+ }
1389
+
1390
+ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1391
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1392
+ if (ctx->grammar) {
1393
+ llama_grammar_apply_impl(*ctx->grammar, cur_p);
1394
+ }
1395
+ }
1396
+
1397
+ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1398
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1399
+ if (!ctx->grammar) {
1400
+ return;
1401
+ }
1402
+
1403
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
1404
+
1405
+ llama_grammar_free_impl(ctx->grammar);
1406
+ ctx->grammar = grammar_new;
1407
+ }
1408
+
1409
+ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1410
+ const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1411
+
1412
+ auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
1413
+
1414
+ // copy the state
1415
+ {
1416
+ auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1417
+
1418
+ if (ctx->grammar) {
1419
+ result_ctx->grammar_str = ctx->grammar_str;
1420
+ result_ctx->grammar_root = ctx->grammar_root;
1421
+
1422
+ result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1423
+ }
1424
+ }
1425
+
701
1426
  return result;
702
1427
  }
703
1428
 
704
- llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
705
- LM_GGML_ASSERT(smpl);
1429
+ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1430
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
706
1431
 
707
- const int64_t t_start_sample_us = lm_ggml_time_us();
708
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1432
+ if (ctx->grammar) {
1433
+ llama_grammar_free_impl(ctx->grammar);
1434
+ }
709
1435
 
710
- std::vector<float> probs;
711
- probs.reserve(candidates->size);
712
- for (size_t i = 0; i < candidates->size; ++i) {
713
- probs.push_back(candidates->data[i].p);
1436
+ delete ctx;
1437
+ }
1438
+
1439
+ static struct llama_sampler_i llama_sampler_grammar_i = {
1440
+ /* .name = */ llama_sampler_grammar_name,
1441
+ /* .accept = */ llama_sampler_grammar_accept_impl,
1442
+ /* .apply = */ llama_sampler_grammar_apply,
1443
+ /* .reset = */ llama_sampler_grammar_reset,
1444
+ /* .clone = */ llama_sampler_grammar_clone,
1445
+ /* .free = */ llama_sampler_grammar_free,
1446
+ };
1447
+
1448
+ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
1449
+ auto * ctx = new llama_sampler_grammar;
1450
+
1451
+ if (grammar_str != nullptr && grammar_str[0] != '\0') {
1452
+ *ctx = {
1453
+ /* .vocab = */ &vocab,
1454
+ /* .grammar_str = */ grammar_str,
1455
+ /* .grammar_root = */ grammar_root,
1456
+ /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
1457
+ };
1458
+ } else {
1459
+ *ctx = {
1460
+ /* .vocab = */ &vocab,
1461
+ /* .grammar_str = */ {},
1462
+ /* .grammar_root = */ {},
1463
+ /* .grammar = */ nullptr,
1464
+ };
1465
+ }
1466
+
1467
+ return new llama_sampler {
1468
+ /* .iface = */ &llama_sampler_grammar_i,
1469
+ /* .ctx = */ ctx,
1470
+ };
1471
+ }
1472
+
1473
+ // penalties
1474
+
1475
+ struct llama_sampler_penalties {
1476
+ const int32_t n_vocab;
1477
+ const llama_token special_eos_id;
1478
+ const llama_token linefeed_id;
1479
+
1480
+ const int32_t penalty_last_n;
1481
+ const float penalty_repeat;
1482
+ const float penalty_freq;
1483
+ const float penalty_present;
1484
+
1485
+ const bool penalize_nl;
1486
+ const bool ignore_eos;
1487
+
1488
+ ring_buffer<llama_token> prev;
1489
+ };
1490
+
1491
+ static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1492
+ return "penalties";
1493
+ }
1494
+
1495
+ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1496
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1497
+ if (ctx->penalty_last_n == 0) {
1498
+ return;
1499
+ }
1500
+
1501
+ ctx->prev.push_back(token);
1502
+ }
1503
+
1504
+ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1505
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1506
+
1507
+ if (ctx->ignore_eos) {
1508
+ assert(ctx->special_eos_id >= 0);
1509
+
1510
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1511
+ if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
1512
+ cur_p->data[ctx->special_eos_id].logit = -INFINITY;
1513
+ } else {
1514
+ // else, search for the special EOS token
1515
+ for (size_t i = 0; i < cur_p->size; ++i) {
1516
+ if (cur_p->data[i].id == ctx->special_eos_id) {
1517
+ cur_p->data[i].logit = -INFINITY;
1518
+ break;
1519
+ }
1520
+ }
1521
+ }
1522
+ }
1523
+
1524
+ if ((ctx->penalty_last_n == 0) ||
1525
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1526
+ return;
1527
+ }
1528
+
1529
+ bool nl_found = false;
1530
+ size_t nl_idx = 0;
1531
+ float nl_logit = -INFINITY;
1532
+ if (!ctx->penalize_nl) {
1533
+ assert(ctx->linefeed_id >= 0);
1534
+
1535
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1536
+ if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
1537
+ nl_found = true;
1538
+ nl_idx = ctx->linefeed_id;
1539
+ nl_logit = cur_p->data[ctx->linefeed_id].logit;
1540
+ } else {
1541
+ // else, search for the linefeed token
1542
+ for (size_t i = 0; i < cur_p->size; ++i) {
1543
+ if (cur_p->data[i].id == ctx->linefeed_id) {
1544
+ nl_found = true;
1545
+ nl_idx = i;
1546
+ nl_logit = cur_p->data[i].logit;
1547
+ break;
1548
+ }
1549
+ }
1550
+ }
714
1551
  }
715
1552
 
716
- std::discrete_distribution<> dist(probs.begin(), probs.end());
717
- int idx = dist(rng);
1553
+ // Create a frequency map to count occurrences of each token in last_tokens
1554
+ // TODO: optimize this by maintaining the token count in the sampler context
1555
+ using llama_token_cnt = std::unordered_map<llama_token, int>;
1556
+ llama_token_cnt token_count;
718
1557
 
719
- llama_token result = candidates->data[idx].id;
1558
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1559
+ token_count[ctx->prev.rat(i)]++;
1560
+ }
720
1561
 
721
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
722
- smpl->n_sample++;
1562
+ // Apply frequency and presence penalties to the cur_p
1563
+ for (size_t i = 0; i < cur_p->size; ++i) {
1564
+ const auto token_iter = token_count.find(cur_p->data[i].id);
1565
+ if (token_iter == token_count.end()) {
1566
+ continue;
1567
+ }
1568
+
1569
+ const int count = token_iter->second;
1570
+
1571
+ // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1572
+ // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1573
+ if (cur_p->data[i].logit <= 0) {
1574
+ cur_p->data[i].logit *= ctx->penalty_repeat;
1575
+ } else {
1576
+ cur_p->data[i].logit /= ctx->penalty_repeat;
1577
+ }
1578
+
1579
+ cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
1580
+ }
1581
+
1582
+ cur_p->sorted = false;
1583
+
1584
+ if (!ctx->penalize_nl && nl_found) {
1585
+ // restore the logit of the newline token if it was penalized
1586
+ cur_p->data[nl_idx].logit = nl_logit;
1587
+ }
1588
+ }
1589
+
1590
+ static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1591
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1592
+ ctx->prev.clear();
1593
+ }
1594
+
1595
+ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1596
+ const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1597
+ auto * result = llama_sampler_init_penalties(
1598
+ ctx->n_vocab,
1599
+ ctx->special_eos_id,
1600
+ ctx->linefeed_id,
1601
+ ctx->penalty_last_n,
1602
+ ctx->penalty_repeat,
1603
+ ctx->penalty_freq,
1604
+ ctx->penalty_present,
1605
+ ctx->penalize_nl,
1606
+ ctx->ignore_eos);
1607
+
1608
+ // copy the state
1609
+ {
1610
+ auto * result_ctx = (llama_sampler_penalties *) result->ctx;
1611
+
1612
+ result_ctx->prev = ctx->prev;
1613
+ }
723
1614
 
724
1615
  return result;
725
1616
  }
726
1617
 
727
- llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
728
- return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
1618
+ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1619
+ delete (llama_sampler_penalties *) smpl->ctx;
1620
+ }
1621
+
1622
+ static struct llama_sampler_i llama_sampler_penalties_i = {
1623
+ /* .name = */ llama_sampler_penalties_name,
1624
+ /* .accept = */ llama_sampler_penalties_accept,
1625
+ /* .apply = */ llama_sampler_penalties_apply,
1626
+ /* .reset = */ llama_sampler_penalties_reset,
1627
+ /* .clone = */ llama_sampler_penalties_clone,
1628
+ /* .free = */ llama_sampler_penalties_free,
1629
+ };
1630
+
1631
+ struct llama_sampler * llama_sampler_init_penalties(
1632
+ int32_t n_vocab,
1633
+ llama_token special_eos_id,
1634
+ llama_token linefeed_id,
1635
+ int32_t penalty_last_n,
1636
+ float penalty_repeat,
1637
+ float penalty_freq,
1638
+ float penalty_present,
1639
+ bool penalize_nl,
1640
+ bool ignore_eos) {
1641
+ if (linefeed_id == LLAMA_TOKEN_NULL) {
1642
+ penalize_nl = true;
1643
+ }
1644
+
1645
+ if (special_eos_id == LLAMA_TOKEN_NULL) {
1646
+ ignore_eos = false;
1647
+ }
1648
+
1649
+ return new llama_sampler {
1650
+ /* .iface = */ &llama_sampler_penalties_i,
1651
+ /* .ctx = */ new llama_sampler_penalties {
1652
+ /* .n_vocab = */ n_vocab,
1653
+ /* .special_eos_id = */ special_eos_id,
1654
+ /* .linefeed_id = */ linefeed_id,
1655
+ /* .penalty_last_n = */ penalty_last_n,
1656
+ /* .penalty_repeat = */ penalty_repeat,
1657
+ /* .penalty_freq = */ penalty_freq,
1658
+ /* .penalty_present = */ penalty_present,
1659
+ /* .penalize_nl = */ penalize_nl,
1660
+ /* .ignore_eos = */ ignore_eos,
1661
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1662
+ },
1663
+ };
1664
+ }
1665
+
1666
+ // logit-bias
1667
+
1668
+ struct llama_sampler_logit_bias {
1669
+ const int32_t n_vocab;
1670
+
1671
+ const std::vector<llama_logit_bias> logit_bias;
1672
+
1673
+ std::vector<llama_logit_bias> to_search;
1674
+ };
1675
+
1676
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
1677
+ return "logit-bias";
1678
+ }
1679
+
1680
+ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1681
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
1682
+
1683
+ ctx->to_search.clear();
1684
+
1685
+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
1686
+ for (const auto & lb : ctx->logit_bias) {
1687
+ if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
1688
+ cur_p->data[lb.token].logit += lb.bias;
1689
+ } else {
1690
+ ctx->to_search.push_back(lb);
1691
+ }
1692
+ }
1693
+
1694
+ // search for the remaining candidates that were not found in the previous step
1695
+ for (size_t i = 0; i < cur_p->size; ++i) {
1696
+ for (const auto & lb : ctx->to_search) {
1697
+ if (cur_p->data[i].id == lb.token) {
1698
+ cur_p->data[i].logit += lb.bias;
1699
+ break;
1700
+ }
1701
+ }
1702
+ }
1703
+ }
1704
+ static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
1705
+ const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
1706
+ return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
1707
+ }
1708
+
1709
+ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
1710
+ delete (llama_sampler_logit_bias *) smpl->ctx;
1711
+ }
1712
+
1713
+ static struct llama_sampler_i llama_sampler_logit_bias_i = {
1714
+ /* .name = */ llama_sampler_logit_bias_name,
1715
+ /* .accept = */ nullptr,
1716
+ /* .apply = */ llama_sampler_logit_bias_apply,
1717
+ /* .reset = */ nullptr,
1718
+ /* .clone = */ llama_sampler_logit_bias_clone,
1719
+ /* .free = */ llama_sampler_logit_bias_free,
1720
+ };
1721
+
1722
+ struct llama_sampler * llama_sampler_init_logit_bias(
1723
+ int32_t n_vocab,
1724
+ int32_t n_logit_bias,
1725
+ const llama_logit_bias * logit_bias) {
1726
+ return new llama_sampler {
1727
+ /* .iface = */ &llama_sampler_logit_bias_i,
1728
+ /* .ctx = */ new llama_sampler_logit_bias {
1729
+ /* .n_vocab = */ n_vocab,
1730
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
1731
+ /* .to_search = */ {},
1732
+ },
1733
+ };
729
1734
  }