cui-llama.rn 1.1.2 → 1.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,52 @@
1
1
  #include "llama-sampling.h"
2
2
 
3
+ #include "llama-vocab.h"
4
+ #include "llama-grammar.h"
5
+
6
+ #include <cassert>
3
7
  #include <algorithm>
4
8
  #include <cstring>
5
9
  #include <ctime>
6
10
  #include <cfloat>
11
+ #include <chrono>
12
+ #include <cmath>
7
13
  #include <numeric>
14
+ #include <random>
8
15
  #include <unordered_map>
9
16
 
17
+ static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
18
+ // iterator for the probabilities
19
+ #ifdef __GNUC__
20
+ #pragma GCC diagnostic push
21
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
22
+ #endif
23
+
24
+ struct probs_iterator {
25
+ typedef std::input_iterator_tag iterator_category;
26
+ typedef float value_type;
27
+ typedef float * pointer;
28
+ typedef float & reference;
29
+ typedef ptrdiff_t difference_type;
30
+
31
+ const llama_token_data * data;
32
+
33
+ bool operator==(const probs_iterator & other) const { return data == other.data; }
34
+ bool operator!=(const probs_iterator & other) const { return data != other.data; }
35
+ const float & operator*() const { return data->p; }
36
+ probs_iterator & operator++() { ++data; return *this; }
37
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
38
+ };
39
+
40
+ #ifdef __GNUC__
41
+ #pragma GCC diagnostic pop
42
+ #endif
43
+
44
+ std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
45
+
46
+ return dist(rng);
47
+ }
48
+
49
+ /*
10
50
  static void llama_log_softmax(float * array, size_t size) {
11
51
  float max_l = *std::max_element(array, array + size);
12
52
  float sum = 0.f;
@@ -20,66 +60,52 @@ static void llama_log_softmax(float * array, size_t size) {
20
60
  array[i] = logf(array[i] / sum);
21
61
  }
22
62
  }
63
+ */
23
64
 
24
- void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
25
- if (seed == LLAMA_DEFAULT_SEED) {
26
- seed = time(NULL);
27
- }
28
-
29
- smpl->rng.seed(seed);
30
- }
31
-
32
- void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
33
- LM_GGML_ASSERT(candidates->size > 0);
34
-
35
- const int64_t t_start_sample_us = lm_ggml_time_us();
65
+ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
66
+ LM_GGML_ASSERT(cur_p->size > 0);
36
67
 
37
68
  // Sort the logits in descending order
38
- if (!candidates->sorted) {
39
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
69
+ if (!cur_p->sorted) {
70
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
40
71
  return a.logit > b.logit;
41
72
  });
42
- candidates->sorted = true;
73
+ cur_p->sorted = true;
43
74
  }
44
75
 
45
- float max_l = candidates->data[0].logit;
76
+ float max_l = cur_p->data[0].logit;
46
77
  float cum_sum = 0.0f;
47
- for (size_t i = 0; i < candidates->size; ++i) {
48
- float p = expf(candidates->data[i].logit - max_l);
49
- candidates->data[i].p = p;
78
+
79
+ for (size_t i = 0; i < cur_p->size; ++i) {
80
+ float p = expf(cur_p->data[i].logit - max_l);
81
+ cur_p->data[i].p = p;
50
82
  cum_sum += p;
51
83
  }
52
- for (size_t i = 0; i < candidates->size; ++i) {
53
- candidates->data[i].p /= cum_sum;
54
- }
55
84
 
56
- if (smpl) {
57
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
85
+ for (size_t i = 0; i < cur_p->size; ++i) {
86
+ cur_p->data[i].p /= cum_sum;
58
87
  }
59
88
  }
60
89
 
61
- void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
90
+ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
62
91
  // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
63
- // if (k >= (int32_t)candidates->size) {
92
+ // if (k >= (int32_t)cur_p->size) {
64
93
  // return;
65
94
  // }
66
95
 
67
- const int64_t t_start_sample_us = lm_ggml_time_us();
68
-
69
96
  if (k <= 0) {
70
- k = candidates->size;
97
+ k = cur_p->size;
71
98
  }
72
99
 
73
- k = std::max(k, (int) min_keep);
74
- k = std::min(k, (int) candidates->size);
100
+ k = std::min(k, (int) cur_p->size);
75
101
 
76
102
  // Sort scores in descending order
77
- if (!candidates->sorted) {
103
+ if (!cur_p->sorted) {
78
104
  auto comp = [](const llama_token_data & a, const llama_token_data & b) {
79
105
  return a.logit > b.logit;
80
106
  };
81
107
  if (k <= 128) {
82
- std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
108
+ std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
83
109
  } else {
84
110
  constexpr int nbuckets = 128;
85
111
  constexpr float bucket_low = -10.0f;
@@ -87,11 +113,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
87
113
  constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
88
114
  constexpr float bucket_inter = -bucket_low * bucket_scale;
89
115
 
90
- std::vector<int> bucket_idx(candidates->size);
116
+ std::vector<int> bucket_idx(cur_p->size);
91
117
  std::vector<int> histo(nbuckets, 0);
92
118
 
93
- for (int i = 0; i < (int)candidates->size; ++i) {
94
- const float val = candidates->data[i].logit;
119
+ for (int i = 0; i < (int)cur_p->size; ++i) {
120
+ const float val = cur_p->data[i].logit;
95
121
  int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
96
122
  ib = std::max(0, std::min(nbuckets-1, ib));
97
123
  bucket_idx[i] = ib;
@@ -101,20 +127,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
101
127
  int ib = nbuckets - 1;
102
128
  for ( ; ib >= 0; --ib) {
103
129
  nhave += histo[ib];
104
- if (nhave >= k) break;
130
+ if (nhave >= k) {
131
+ break;
132
+ }
105
133
  }
106
134
  std::vector<llama_token_data> tmp_tokens(nhave);
107
- auto ptr = tmp_tokens.data();
135
+ auto * ptr = tmp_tokens.data();
108
136
  std::vector<llama_token_data*> bucket_ptrs;
109
137
  bucket_ptrs.reserve(nbuckets - ib);
110
138
  for (int j = nbuckets - 1; j >= ib; --j) {
111
139
  bucket_ptrs.push_back(ptr);
112
140
  ptr += histo[j];
113
141
  }
114
- for (int i = 0; i < (int)candidates->size; ++i) {
142
+ for (int i = 0; i < (int)cur_p->size; ++i) {
115
143
  int j = bucket_idx[i];
116
144
  if (j >= ib) {
117
- *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
145
+ *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
118
146
  }
119
147
  }
120
148
 
@@ -127,69 +155,606 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
127
155
  }
128
156
  std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
129
157
 
130
- std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
158
+ std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
159
+
160
+ }
161
+ cur_p->sorted = true;
162
+ }
163
+ cur_p->size = k;
164
+ }
131
165
 
166
+ static uint32_t get_rng_seed(uint32_t seed) {
167
+ if (seed == LLAMA_DEFAULT_SEED) {
168
+ // use system clock if std::random_device is not a true RNG
169
+ static bool is_rd_prng = std::random_device().entropy() == 0;
170
+ if (is_rd_prng) {
171
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
132
172
  }
133
- candidates->sorted = true;
173
+ std::random_device rd;
174
+ return rd();
175
+ }
176
+ return seed;
177
+ }
178
+
179
+ // llama_sampler API
180
+
181
+ const char * llama_sampler_name(const struct llama_sampler * smpl) {
182
+ if (!smpl->iface) {
183
+ return "(null)";
134
184
  }
135
- candidates->size = k;
136
185
 
137
- if (smpl) {
138
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
186
+ return smpl->iface->name(smpl);
187
+ }
188
+
189
+ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
190
+ if (smpl->iface->accept) {
191
+ smpl->iface->accept(smpl, token);
192
+ }
193
+ }
194
+
195
+ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
196
+ LM_GGML_ASSERT(smpl->iface->apply);
197
+ smpl->iface->apply(smpl, cur_p);
198
+ }
199
+
200
+ void llama_sampler_reset(struct llama_sampler * smpl) {
201
+ if (smpl->iface->reset) {
202
+ smpl->iface->reset(smpl);
203
+ }
204
+ }
205
+
206
+ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
207
+ if (smpl->iface->clone) {
208
+ return smpl->iface->clone(smpl);
209
+ }
210
+
211
+ if (smpl->ctx == nullptr) {
212
+ return new llama_sampler {
213
+ /* .iface = */ smpl->iface,
214
+ /* .ctx = */ nullptr,
215
+ };
139
216
  }
217
+
218
+ LM_GGML_ABORT("the sampler does not support cloning");
140
219
  }
141
220
 
142
- void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143
- if (p >= 1.0f) {
221
+ void llama_sampler_free(struct llama_sampler * smpl) {
222
+ if (smpl == nullptr) {
144
223
  return;
145
224
  }
146
225
 
147
- llama_sample_softmax_impl(smpl, candidates);
226
+ if (smpl->iface->free) {
227
+ smpl->iface->free(smpl);
228
+ }
148
229
 
149
- const int64_t t_start_sample_us = lm_ggml_time_us();
230
+ delete smpl;
231
+ }
232
+
233
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
234
+ const auto * logits = llama_get_logits_ith(ctx, idx);
235
+
236
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
237
+
238
+ // TODO: do not allocate each time
239
+ std::vector<llama_token_data> cur(n_vocab);
240
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
241
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
242
+ }
243
+
244
+ llama_token_data_array cur_p = {
245
+ /* .data = */ cur.data(),
246
+ /* .size = */ cur.size(),
247
+ /* .selected = */ -1,
248
+ /* .sorted = */ false,
249
+ };
250
+
251
+ llama_sampler_apply(smpl, &cur_p);
252
+
253
+ LM_GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
254
+
255
+ auto token = cur_p.data[cur_p.selected].id;
256
+
257
+ llama_sampler_accept(smpl, token);
258
+
259
+ return token;
260
+ }
261
+
262
+ // sampler chain
263
+
264
+ static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
265
+ return "chain";
266
+ }
267
+
268
+ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
269
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
270
+
271
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
272
+
273
+ for (auto * smpl : chain->samplers) {
274
+ llama_sampler_accept(smpl, token);
275
+ }
276
+
277
+ chain->n_sample++;
278
+ }
279
+
280
+ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
281
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
282
+
283
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
284
+
285
+ for (auto * smpl : chain->samplers) {
286
+ llama_sampler_apply(smpl, cur_p);
287
+ }
288
+ }
289
+
290
+ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
291
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
292
+
293
+ for (auto * smpl : chain->samplers) {
294
+ llama_sampler_reset(smpl);
295
+ }
296
+
297
+ chain->t_sample_us = 0;
298
+ chain->n_sample = 0;
299
+ }
300
+
301
+ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
302
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
303
+
304
+ auto * result = llama_sampler_chain_init(chain_src->params);
305
+
306
+ for (auto * smpl : chain_src->samplers) {
307
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
308
+ }
309
+
310
+ return result;
311
+ }
312
+
313
+ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
314
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
315
+
316
+ for (auto * smpl : chain->samplers) {
317
+ llama_sampler_free(smpl);
318
+ }
319
+
320
+ delete chain;
321
+ }
322
+
323
+ static struct llama_sampler_i llama_sampler_chain_i = {
324
+ /* .name = */ llama_sampler_chain_name,
325
+ /* .accept = */ llama_sampler_chain_accept,
326
+ /* .apply = */ llama_sampler_chain_apply,
327
+ /* .reset = */ llama_sampler_chain_reset,
328
+ /* .clone = */ llama_sampler_chain_clone,
329
+ /* .free = */ llama_sampler_chain_free,
330
+ };
331
+
332
+ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
333
+ return new llama_sampler {
334
+ /* .iface = */ &llama_sampler_chain_i,
335
+ /* .ctx = */ new llama_sampler_chain {
336
+ /* .params = */ params,
337
+ /* .samplers = */ {},
338
+ /* .t_sample_us = */ 0,
339
+ /* .n_sample = */ 0,
340
+ },
341
+ };
342
+ }
343
+
344
+ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
345
+ auto * p = (llama_sampler_chain *) chain->ctx;
346
+ p->samplers.push_back(smpl);
347
+ }
348
+
349
+ llama_sampler_timings llama_sampler_chain_timings(struct llama_sampler * chain) {
350
+ auto * p = (llama_sampler_chain *) chain->ctx;
351
+ struct llama_sampler_timings result = {
352
+ p -> t_sample_us,
353
+ p -> n_sample
354
+ };
355
+ return result;
356
+ }
357
+
358
+ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
359
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
360
+
361
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
362
+ return nullptr;
363
+ }
364
+
365
+ return p->samplers[i];
366
+ }
367
+
368
+ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
369
+ auto * p = (llama_sampler_chain *) chain->ctx;
370
+
371
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
372
+ return nullptr;
373
+ }
374
+
375
+ auto * result = p->samplers[i];
376
+ p->samplers.erase(p->samplers.begin() + i);
377
+
378
+ return result;
379
+ }
380
+
381
+ int llama_sampler_chain_n(const struct llama_sampler * chain) {
382
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
383
+
384
+ return p->samplers.size();
385
+ }
386
+
387
+ //
388
+ // samplers
389
+ //
390
+
391
+ // greedy
392
+
393
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
394
+ return "greedy";
395
+ }
396
+
397
+ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
398
+ cur_p->selected = 0;
399
+ for (size_t i = 1; i < cur_p->size; ++i) {
400
+ if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
401
+ cur_p->selected = i;
402
+ }
403
+ }
404
+ }
405
+
406
+ static struct llama_sampler_i llama_sampler_greedy_i = {
407
+ /* .name = */ llama_sampler_greedy_name,
408
+ /* .accept = */ nullptr,
409
+ /* .apply = */ llama_sampler_greedy_apply,
410
+ /* .reset = */ nullptr,
411
+ /* .clone = */ nullptr,
412
+ /* .free = */ nullptr,
413
+ };
414
+
415
+ struct llama_sampler * llama_sampler_init_greedy() {
416
+ return new llama_sampler {
417
+ /* .iface = */ &llama_sampler_greedy_i,
418
+ /* .ctx = */ nullptr,
419
+ };
420
+ }
421
+
422
+ // dist
423
+
424
+ struct llama_sampler_dist {
425
+ const uint32_t seed;
426
+ uint32_t seed_cur;
427
+
428
+ std::mt19937 rng;
429
+ };
430
+
431
+ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
432
+ return "dist";
433
+ }
434
+
435
+ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
436
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
437
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
438
+ }
439
+
440
+ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
441
+ const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
442
+ auto * result = llama_sampler_init_dist(ctx->seed);
443
+
444
+ // copy the state
445
+ {
446
+ auto * result_ctx = (llama_sampler_dist *) result->ctx;
447
+
448
+ result_ctx->rng = ctx->rng;
449
+ }
450
+
451
+ return result;
452
+ }
453
+
454
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
455
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
456
+ ctx->seed_cur = get_rng_seed(ctx->seed);
457
+ ctx->rng.seed(ctx->seed_cur);
458
+ }
459
+
460
+ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
461
+ delete (llama_sampler_dist *) smpl->ctx;
462
+ }
463
+
464
+ static struct llama_sampler_i llama_sampler_dist_i = {
465
+ /* .name = */ llama_sampler_dist_name,
466
+ /* .accept = */ nullptr,
467
+ /* .apply = */ llama_sampler_dist_apply,
468
+ /* .reset = */ llama_sampler_dist_reset,
469
+ /* .clone = */ llama_sampler_dist_clone,
470
+ /* .free = */ llama_sampler_dist_free,
471
+ };
472
+
473
+ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
474
+ auto seed_cur = get_rng_seed(seed);
475
+ return new llama_sampler {
476
+ /* .iface = */ &llama_sampler_dist_i,
477
+ /* .ctx = */ new llama_sampler_dist {
478
+ /* .seed = */ seed,
479
+ /* .seed_cur = */ seed_cur,
480
+ /* .rng = */ std::mt19937(seed_cur),
481
+ },
482
+ };
483
+ }
484
+
485
+ // softmax
486
+
487
+ static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
488
+ return "softmax";
489
+ }
490
+
491
+ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
492
+ llama_sampler_softmax_impl(cur_p);
493
+ }
494
+
495
+ static struct llama_sampler_i llama_sampler_softmax_i = {
496
+ /* .name = */ llama_sampler_softmax_name,
497
+ /* .accept = */ nullptr,
498
+ /* .apply = */ llama_sampler_softmax_apply,
499
+ /* .reset = */ nullptr,
500
+ /* .clone = */ nullptr,
501
+ /* .free = */ nullptr,
502
+ };
503
+
504
+ struct llama_sampler * llama_sampler_init_softmax() {
505
+ return new llama_sampler {
506
+ /* .iface = */ &llama_sampler_softmax_i,
507
+ /* .ctx = */ nullptr,
508
+ };
509
+ }
510
+
511
+ // top-k
512
+
513
+ struct llama_sampler_top_k {
514
+ const int32_t k;
515
+ };
516
+
517
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
518
+ return "top-k";
519
+ }
520
+
521
+ static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
522
+ const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
523
+ llama_sampler_top_k_impl(cur_p, ctx->k);
524
+ }
525
+
526
+ static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
527
+ const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
528
+ return llama_sampler_init_top_k(ctx->k);
529
+ }
530
+
531
+ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
532
+ delete (llama_sampler_top_k *) smpl->ctx;
533
+ }
534
+
535
+ static struct llama_sampler_i llama_sampler_top_k_i = {
536
+ /* .name = */ llama_sampler_top_k_name,
537
+ /* .accept = */ nullptr,
538
+ /* .apply = */ llama_sampler_top_k_apply,
539
+ /* .reset = */ nullptr,
540
+ /* .clone = */ llama_sampler_top_k_clone,
541
+ /* .free = */ llama_sampler_top_k_free,
542
+ };
543
+
544
+ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
545
+ return new llama_sampler {
546
+ /* .iface = */ &llama_sampler_top_k_i,
547
+ /* .ctx = */ new llama_sampler_top_k {
548
+ /* .k = */ k,
549
+ },
550
+ };
551
+ }
552
+
553
+ // top-p
554
+
555
+ struct llama_sampler_top_p {
556
+ const float p;
557
+ const size_t min_keep;
558
+ };
559
+
560
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
561
+ return "top-p";
562
+ }
563
+
564
+ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
565
+ const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
566
+
567
+ if (ctx->p >= 1.0f) {
568
+ return;
569
+ }
570
+
571
+ llama_sampler_softmax_impl(cur_p);
150
572
 
151
573
  // Compute the cumulative probabilities
152
574
  float cum_sum = 0.0f;
153
- size_t last_idx = candidates->size;
575
+ size_t last_idx = cur_p->size;
154
576
 
155
- for (size_t i = 0; i < candidates->size; ++i) {
156
- cum_sum += candidates->data[i].p;
577
+ for (size_t i = 0; i < cur_p->size; ++i) {
578
+ cum_sum += cur_p->data[i].p;
157
579
 
158
580
  // Check if the running sum is at least p or if we have kept at least min_keep tokens
159
581
  // we set the last index to i+1 to indicate that the current iterate should be included in the set
160
- if (cum_sum >= p && i + 1 >= min_keep) {
582
+ if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
161
583
  last_idx = i + 1;
162
584
  break;
163
585
  }
164
586
  }
165
587
 
166
588
  // Resize the output vector to keep only the top-p tokens
167
- candidates->size = last_idx;
589
+ cur_p->size = last_idx;
590
+ }
591
+
592
+ static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
593
+ const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
594
+ return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
595
+ }
596
+
597
+ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
598
+ delete (llama_sampler_top_p *) smpl->ctx;
599
+ }
600
+
601
+ static struct llama_sampler_i llama_sampler_top_p_i = {
602
+ /* .name = */ llama_sampler_top_p_name,
603
+ /* .accept = */ nullptr,
604
+ /* .apply = */ llama_sampler_top_p_apply,
605
+ /* .reset = */ nullptr,
606
+ /* .clone = */ llama_sampler_top_p_clone,
607
+ /* .free = */ llama_sampler_top_p_free,
608
+ };
609
+
610
+ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
611
+ return new llama_sampler {
612
+ /* .iface = */ &llama_sampler_top_p_i,
613
+ /* .ctx = */ new llama_sampler_top_p {
614
+ /* .p = */ p,
615
+ /* .min_keep = */ min_keep,
616
+ },
617
+ };
618
+ }
619
+
620
+ // min-p
621
+
622
+ struct llama_sampler_min_p {
623
+ const float p;
624
+ const size_t min_keep;
625
+ };
626
+
627
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
628
+ return "min-p";
629
+ }
630
+
631
+ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
632
+ const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
633
+
634
+ if (ctx->p <= 0.0f || !cur_p->size) {
635
+ return;
636
+ }
637
+
638
+ bool min_p_applied = false;
639
+
640
+ // if the cur_p aren't sorted, try the unsorted implementation first
641
+ if (!cur_p->sorted) {
642
+ std::vector<llama_token_data> filtered_tokens;
643
+
644
+ float max_logit = -FLT_MAX;
645
+ for (size_t i = 0; i < cur_p->size; ++i) {
646
+ max_logit = std::max(max_logit, cur_p->data[i].logit);
647
+ }
648
+ const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
649
+
650
+ for (size_t i = 0; i < cur_p->size; ++i) {
651
+ if (cur_p->data[i].logit >= min_logit) {
652
+ filtered_tokens.push_back(cur_p->data[i]);
653
+ }
654
+ }
655
+
656
+ // if we have enough values the operation was a success
657
+ if (filtered_tokens.size() >= ctx->min_keep) {
658
+ memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
659
+ cur_p->size = filtered_tokens.size();
660
+ min_p_applied = true;
661
+ }
662
+ }
663
+
664
+ // if the cur_p are sorted or the unsorted implementation failed, use this implementation
665
+ if (!min_p_applied) {
666
+ // Sort the logits in descending order
667
+ if (!cur_p->sorted) {
668
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
669
+ return a.logit > b.logit;
670
+ });
671
+ cur_p->sorted = true;
672
+ }
168
673
 
169
- if (smpl) {
170
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
674
+ const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
675
+ size_t i = 1; // first token always matches
676
+
677
+ for (; i < cur_p->size; ++i) {
678
+ if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
679
+ break; // prob too small
680
+ }
681
+ }
682
+
683
+ // Resize the output vector to keep only the matching tokens
684
+ cur_p->size = i;
171
685
  }
172
686
  }
173
687
 
174
- void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, size_t min_keep, std::mt19937 & rng) {
175
- if(xtc_threshold <= 0.0f || !candidates-> size) {
688
+ static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
689
+ const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
690
+ return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
691
+ }
692
+
693
+ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
694
+ delete (llama_sampler_min_p *) smpl->ctx;
695
+ }
696
+
697
+ static struct llama_sampler_i llama_sampler_min_p_i = {
698
+ /* .name = */ llama_sampler_min_p_name,
699
+ /* .accept = */ nullptr,
700
+ /* .apply = */ llama_sampler_min_p_apply,
701
+ /* .reset = */ nullptr,
702
+ /* .clone = */ llama_sampler_min_p_clone,
703
+ /* .free = */ llama_sampler_min_p_free,
704
+ };
705
+
706
+ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
707
+ return new llama_sampler {
708
+ /* .iface = */ &llama_sampler_min_p_i,
709
+ /* .ctx = */ new llama_sampler_min_p {
710
+ /* .p = */ p,
711
+ /* .min_keep = */ min_keep,
712
+ },
713
+ };
714
+ }
715
+
716
+ // xtc
717
+
718
+ struct llama_sampler_xtc {
719
+ const uint32_t seed;
720
+ std::mt19937 rng;
721
+ const float xtc_p;
722
+ const float xtc_t;
723
+ const size_t min_keep;
724
+ };
725
+
726
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /* smpl */) {
727
+ return "xtc";
728
+ }
729
+
730
+ static void llama_sampler_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
731
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
732
+
733
+ size_t min_keep = ctx -> min_keep;
734
+ std::mt19937 rng = ctx -> rng;
735
+
736
+ float xtc_threshold = ctx -> xtc_t;
737
+ float xtc_probability = ctx -> xtc_p;
738
+
739
+
740
+ if(xtc_threshold <= 0.0f || !cur_p-> size) {
176
741
  return;
177
742
  }
178
743
 
179
744
  bool xtc_applied = false;
180
745
  const int64_t t_start_sample_us = lm_ggml_time_us();
181
- llama_sample_softmax(nullptr, candidates);
746
+ llama_sampler_softmax_impl(cur_p);
182
747
 
183
748
  // unsorted iteration
184
- if (!candidates->sorted) {
749
+ if (!cur_p->sorted) {
185
750
  std::vector<llama_token_data> top_tokens, low_tokens;
186
751
 
187
752
  // split candidates into two arrays for low and high tokens
188
- for (size_t i = 0; i < candidates->size; ++i) {
189
- if (candidates->data[i].logit >= xtc_threshold) {
190
- top_tokens.push_back(candidates->data[i]);
753
+ for (size_t i = 0; i < cur_p->size; ++i) {
754
+ if (cur_p->data[i].logit >= xtc_threshold) {
755
+ top_tokens.push_back(cur_p->data[i]);
191
756
  } else {
192
- low_tokens.push_back(candidates-> data[i]);
757
+ low_tokens.push_back(cur_p-> data[i]);
193
758
  }
194
759
  }
195
760
  // if there is only one or no top_tokens, do not truncate
@@ -212,28 +777,28 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
212
777
  }
213
778
  }
214
779
  if(low_tokens.size() >= min_keep) {
215
- memcpy(candidates->data, low_tokens.data(), low_tokens.size()*sizeof(llama_token_data));
216
- candidates->size = low_tokens.size();
780
+ memcpy(cur_p->data, low_tokens.data(), low_tokens.size()*sizeof(llama_token_data));
781
+ cur_p->size = low_tokens.size();
217
782
  xtc_applied = true;
218
783
  }
219
784
  }
220
785
  // sorted iteration
221
-
786
+
222
787
  if (!xtc_applied) {
223
788
  // Sort the logits in descending order
224
- if (!candidates->sorted) {
225
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
789
+ if (!cur_p->sorted) {
790
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
226
791
  return a.logit > b.logit;
227
792
  });
228
- candidates->sorted = true;
793
+ cur_p->sorted = true;
229
794
  }
230
795
 
231
796
  // find last token over threshold
232
797
 
233
798
  size_t last_index = 0;
234
799
 
235
- for (; last_index < candidates -> size; ++last_index) {
236
- if(candidates -> data[last_index].p < xtc_threshold) {
800
+ for (; last_index < cur_p -> size; ++last_index) {
801
+ if(cur_p -> data[last_index].p < xtc_threshold) {
237
802
  break;
238
803
  }
239
804
  }
@@ -244,102 +809,80 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
244
809
  }
245
810
  last_index--;
246
811
  // items beyond safe index will be ignored
247
- size_t safe_index = candidates -> size;
812
+ size_t safe_index = cur_p -> size;
248
813
 
249
814
  // remove tokens until last threshold item
250
815
  std::uniform_real_distribution<float> random_float(0.0 , 1.0);
251
816
  for (size_t i = 0; i < last_index; i++) {
252
817
  if(random_float(rng) < xtc_probability) {
253
- std::swap(candidates-> data[i], candidates->data[safe_index - 1]);
818
+ std::swap(cur_p-> data[i], cur_p->data[safe_index - 1]);
254
819
  safe_index--;
255
- if (candidates-> sorted) {
256
- candidates -> sorted = false;
820
+ if (cur_p-> sorted) {
821
+ cur_p -> sorted = false;
257
822
  }
258
823
  }
259
824
  }
260
- candidates -> size = safe_index;
825
+ cur_p -> size = safe_index;
261
826
  }
827
+ }
262
828
 
263
- if (smpl) {
264
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
265
- }
829
+ static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
830
+ const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
831
+ return llama_sampler_init_xtc(ctx->xtc_p, ctx->xtc_t, ctx->min_keep, ctx->seed);
266
832
  }
267
833
 
268
- void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
269
- if (p <= 0.0f || !candidates->size) {
270
- return;
271
- }
834
+ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
835
+ delete (const llama_sampler_xtc *) smpl->ctx;
836
+ }
272
837
 
273
- const int64_t t_start_sample_us = lm_ggml_time_us();
838
+ static struct llama_sampler_i llama_sampler_xtc_i = {
839
+ /* .name = */ llama_sampler_xtc_name,
840
+ /* .accept = */ nullptr,
841
+ /* .apply = */ llama_sampler_xtc_apply,
842
+ /* .reset = */ nullptr,
843
+ /* .clone = */ llama_sampler_xtc_clone,
844
+ /* .free = */ llama_sampler_xtc_free,
845
+ };
846
+
847
+ struct llama_sampler * llama_sampler_init_xtc(float xtc_p, float xtc_t, size_t min_keep, uint32_t seed) {
848
+ return new llama_sampler {
849
+ /* .iface = */ &llama_sampler_xtc_i,
850
+ /* .ctx = */ new llama_sampler_xtc {
851
+ /* .seed = */ seed,
852
+ /* .rng = */ std::mt19937(seed),
853
+ /* .xtc_p = */ xtc_p,
854
+ /* .xtc_t = */ xtc_t,
855
+ /* .min_keep = */ min_keep
856
+ },
857
+ };
858
+ }
274
859
 
275
- bool min_p_applied = false;
860
+ // tail-free
276
861
 
277
- // if the candidates aren't sorted, try the unsorted implementation first
278
- if (!candidates->sorted) {
279
- std::vector<llama_token_data> filtered_tokens;
862
+ struct llama_sampler_tail_free {
863
+ const float z;
864
+ const size_t min_keep;
865
+ };
280
866
 
281
- float max_logit = -FLT_MAX;
282
- for (size_t i = 0; i < candidates->size; ++i) {
283
- max_logit = std::max(max_logit, candidates->data[i].logit);
284
- }
285
- const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
867
+ static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
868
+ return "tail-free";
869
+ }
286
870
 
287
- for (size_t i = 0; i < candidates->size; ++i) {
288
- if (candidates->data[i].logit >= min_logit) {
289
- filtered_tokens.push_back(candidates->data[i]);
290
- }
291
- }
871
+ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
872
+ const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
292
873
 
293
- // if we have enough values the operation was a success
294
- if (filtered_tokens.size() >= min_keep) {
295
- memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
296
- candidates->size = filtered_tokens.size();
297
- min_p_applied = true;
298
- }
874
+ if (ctx->z >= 1.0f || cur_p->size <= 2) {
875
+ return;
299
876
  }
300
877
 
301
- // if the candidates are sorted or the unsorted implementation failed, use this implementation
302
- if (!min_p_applied) {
303
- // Sort the logits in descending order
304
- if (!candidates->sorted) {
305
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
306
- return a.logit > b.logit;
307
- });
308
- candidates->sorted = true;
309
- }
878
+ llama_sampler_softmax_impl(cur_p);
310
879
 
311
- const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
312
- size_t i = 1; // first token always matches
313
-
314
- for (; i < candidates->size; ++i) {
315
- if (candidates->data[i].logit < min_logit && i >= min_keep) {
316
- break; // prob too small
317
- }
318
- }
319
-
320
- // Resize the output vector to keep only the matching tokens
321
- candidates->size = i;
322
- }
323
-
324
- if (smpl) {
325
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
326
- }
327
- }
328
-
329
- void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
330
- if (z >= 1.0f || candidates->size <= 2) {
331
- return;
332
- }
333
-
334
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
335
- const int64_t t_start_sample_us = lm_ggml_time_us();
336
-
337
- // Compute the first and second derivatives
338
- std::vector<float> first_derivatives(candidates->size - 1);
339
- std::vector<float> second_derivatives(candidates->size - 2);
880
+ // Compute the first and second derivatives
881
+ std::vector<float> first_derivatives(cur_p->size - 1);
882
+ std::vector<float> second_derivatives(cur_p->size - 2);
340
883
 
341
884
  for (size_t i = 0; i < first_derivatives.size(); ++i) {
342
- first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
885
+ first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
343
886
  }
344
887
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
345
888
  second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
@@ -366,51 +909,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
366
909
  }
367
910
 
368
911
  float cum_sum = 0.0f;
369
- size_t last_idx = candidates->size;
912
+ size_t last_idx = cur_p->size;
370
913
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
371
914
  cum_sum += second_derivatives[i];
372
915
 
373
916
  // Check if the running sum is greater than z or if we have kept at least min_keep tokens
374
- if (cum_sum > z && i >= min_keep) {
917
+ if (cum_sum > ctx->z && i >= ctx->min_keep) {
375
918
  last_idx = i;
376
919
  break;
377
920
  }
378
921
  }
379
922
 
380
923
  // Resize the output vector to keep only the tokens above the tail location
381
- candidates->size = last_idx;
924
+ cur_p->size = last_idx;
925
+ }
382
926
 
383
- if (smpl) {
384
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
385
- }
927
+ static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
928
+ const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
929
+ return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
930
+ }
931
+
932
+ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
933
+ delete (llama_sampler_tail_free *) smpl->ctx;
934
+ }
935
+
936
+ static struct llama_sampler_i llama_sampler_tail_free_i = {
937
+ /* .name = */ llama_sampler_tail_free_name,
938
+ /* .accept = */ nullptr,
939
+ /* .apply = */ llama_sampler_tail_free_apply,
940
+ /* .reset = */ nullptr,
941
+ /* .clone = */ llama_sampler_tail_free_clone,
942
+ /* .free = */ llama_sampler_tail_free_free,
943
+ };
944
+
945
+ struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
946
+ return new llama_sampler {
947
+ /* .iface = */ &llama_sampler_tail_free_i,
948
+ /* .ctx = */ new llama_sampler_tail_free {
949
+ /* .z = */ z,
950
+ /*. min_keep = */ min_keep,
951
+ },
952
+ };
953
+ }
954
+
955
+ // typical
956
+
957
+ struct llama_sampler_typical {
958
+ const float p;
959
+ const size_t min_keep;
960
+ };
961
+
962
+ static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
963
+ return "typical";
386
964
  }
387
965
 
388
- void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
966
+ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
967
+ const auto * ctx = (llama_sampler_typical *) smpl->ctx;
968
+
389
969
  // Reference implementation:
390
970
  // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
391
- if (p >= 1.0f) {
971
+ if (ctx->p >= 1.0f) {
392
972
  return;
393
973
  }
394
974
 
395
975
  // Compute the softmax of logits and calculate entropy
396
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
397
-
398
- const int64_t t_start_sample_us = lm_ggml_time_us();
976
+ llama_sampler_softmax_impl(cur_p);
399
977
 
400
978
  float entropy = 0.0f;
401
- for (size_t i = 0; i < candidates->size; ++i) {
402
- entropy += -candidates->data[i].p * logf(candidates->data[i].p);
979
+ for (size_t i = 0; i < cur_p->size; ++i) {
980
+ entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
403
981
  }
404
982
 
405
983
  // Compute the absolute difference between negative log probability and entropy for each candidate
406
984
  std::vector<float> shifted_scores;
407
- for (size_t i = 0; i < candidates->size; ++i) {
408
- float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
985
+ for (size_t i = 0; i < cur_p->size; ++i) {
986
+ float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
409
987
  shifted_scores.push_back(shifted_score);
410
988
  }
411
989
 
412
990
  // Sort tokens based on the shifted_scores and their corresponding indices
413
- std::vector<size_t> indices(candidates->size);
991
+ std::vector<size_t> indices(cur_p->size);
414
992
  std::iota(indices.begin(), indices.end(), 0);
415
993
 
416
994
  std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
@@ -423,134 +1001,618 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
423
1001
 
424
1002
  for (size_t i = 0; i < indices.size(); ++i) {
425
1003
  size_t idx = indices[i];
426
- cum_sum += candidates->data[idx].p;
1004
+ cum_sum += cur_p->data[idx].p;
427
1005
 
428
1006
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
429
- if (cum_sum > p && i >= min_keep - 1) {
1007
+ if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
430
1008
  last_idx = i + 1;
431
1009
  break;
432
1010
  }
433
1011
  }
434
1012
 
435
1013
  // Resize the output vector to keep only the locally typical tokens
436
- std::vector<llama_token_data> new_candidates;
1014
+ std::vector<llama_token_data> cur_p_new;
437
1015
  for (size_t i = 0; i < last_idx; ++i) {
438
1016
  size_t idx = indices[i];
439
- new_candidates.push_back(candidates->data[idx]);
1017
+ cur_p_new.push_back(cur_p->data[idx]);
440
1018
  }
441
1019
 
442
- // Replace the data in candidates with the new_candidates data
443
- std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
444
- candidates->size = new_candidates.size();
445
- candidates->sorted = false;
1020
+ // Replace the data in cur_p with the cur_p_new data
1021
+ std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
1022
+ cur_p->size = cur_p_new.size();
1023
+ cur_p->sorted = false;
1024
+ }
446
1025
 
447
- if (smpl) {
448
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
449
- }
1026
+ static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
1027
+ const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
1028
+ return llama_sampler_init_typical(ctx->p, ctx->min_keep);
450
1029
  }
451
1030
 
452
- void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
453
- const int64_t t_start_sample_us = lm_ggml_time_us();
1031
+ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
1032
+ delete (llama_sampler_typical *) smpl->ctx;
1033
+ }
454
1034
 
455
- // no need to do anything if there is only one (or zero) candidates
456
- if(candidates->size <= 1) {
457
- return;
1035
+ static struct llama_sampler_i llama_sampler_typical_i = {
1036
+ /* .name = */ llama_sampler_typical_name,
1037
+ /* .accept = */ nullptr,
1038
+ /* .apply = */ llama_sampler_typical_apply,
1039
+ /* .reset = */ nullptr,
1040
+ /* .clone = */ llama_sampler_typical_clone,
1041
+ /* .free = */ llama_sampler_typical_free,
1042
+ };
1043
+
1044
+ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1045
+ return new llama_sampler {
1046
+ /* .iface = */ &llama_sampler_typical_i,
1047
+ /* .ctx = */ new llama_sampler_typical {
1048
+ /* .p = */ p,
1049
+ /* .min_keep = */ min_keep,
1050
+ },
1051
+ };
1052
+ }
1053
+
1054
+ // temp
1055
+
1056
+ struct llama_sampler_temp {
1057
+ const float temp;
1058
+ };
1059
+
1060
+ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
1061
+ return "temp";
1062
+ }
1063
+
1064
+ static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1065
+ const auto * ctx = (llama_sampler_temp *) smpl->ctx;
1066
+ for (size_t i = 0; i < cur_p->size; ++i) {
1067
+ cur_p->data[i].logit /= ctx->temp;
458
1068
  }
1069
+ }
459
1070
 
460
- // Calculate maximum possible entropy
461
- float max_entropy = -logf(1.0f / candidates->size);
1071
+ static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
1072
+ const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
1073
+ return llama_sampler_init_temp(ctx->temp);
1074
+ }
462
1075
 
463
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1076
+ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1077
+ delete (llama_sampler_temp *) smpl->ctx;
1078
+ }
464
1079
 
465
- // Calculate entropy of the softmax probabilities
466
- float entropy = 0.0f;
467
- for (size_t i = 0; i < candidates->size; ++i) {
468
- float prob = candidates->data[i].p;
469
- if (prob > 0.0f) { // Ensure no log(0)
470
- entropy -= prob * logf(prob);
1080
+ static struct llama_sampler_i llama_sampler_temp_i = {
1081
+ /* .name = */ llama_sampler_temp_name,
1082
+ /* .accept = */ nullptr,
1083
+ /* .apply = */ llama_sampler_temp_apply,
1084
+ /* .reset = */ nullptr,
1085
+ /* .clone = */ llama_sampler_temp_clone,
1086
+ /* .free = */ llama_sampler_temp_free,
1087
+ };
1088
+
1089
+ struct llama_sampler * llama_sampler_init_temp(float temp) {
1090
+ return new llama_sampler {
1091
+ /* .iface = */ &llama_sampler_temp_i,
1092
+ /* .ctx = */ new llama_sampler_temp {
1093
+ /*.temp = */ temp,
1094
+ },
1095
+ };
1096
+ }
1097
+
1098
+ // temp-ext
1099
+
1100
+ struct llama_sampler_temp_ext {
1101
+ const float temp;
1102
+ const float delta;
1103
+ const float exponent;
1104
+ };
1105
+
1106
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1107
+ return "temp-ext";
1108
+ }
1109
+
1110
+ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1111
+ const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1112
+ if (ctx->delta > 0) {
1113
+ const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1114
+ const float max_temp = ctx->temp + ctx->delta;
1115
+ float exponent_val = ctx->exponent;
1116
+
1117
+ // no need to do anything if there is only one (or zero) candidates
1118
+ if (cur_p->size <= 1) {
1119
+ return;
1120
+ }
1121
+
1122
+ // Calculate maximum possible entropy
1123
+ float max_entropy = -logf(1.0f / cur_p->size);
1124
+
1125
+ llama_sampler_softmax_impl(cur_p);
1126
+
1127
+ // Calculate entropy of the softmax probabilities
1128
+ float entropy = 0.0f;
1129
+ for (size_t i = 0; i < cur_p->size; ++i) {
1130
+ float prob = cur_p->data[i].p;
1131
+ if (prob > 0.0f) { // Ensure no log(0)
1132
+ entropy -= prob * logf(prob);
1133
+ }
1134
+ }
1135
+
1136
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1137
+ float normalized_entropy = entropy / max_entropy;
1138
+
1139
+ // Map the normalized entropy to the desired temperature range using the power function
1140
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1141
+
1142
+ #ifdef DEBUG
1143
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1144
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1145
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1146
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1147
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1148
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1149
+ #endif
1150
+
1151
+ // Apply the dynamically calculated temperature scaling
1152
+ for (size_t i = 0; i < cur_p->size; ++i) {
1153
+ cur_p->data[i].logit /= dyn_temp;
1154
+ }
1155
+
1156
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
1157
+ const double max_l_double = cur_p->data[0].logit;
1158
+
1159
+ double cum_sum_double = 0.0;
1160
+ for (size_t i = 0; i < cur_p->size; ++i) {
1161
+ double p = exp(cur_p->data[i].logit - max_l_double);
1162
+ cur_p->data[i].p = p; // Store the scaled probability
1163
+ cum_sum_double += p;
1164
+ }
1165
+
1166
+ for (size_t i = 0; i < cur_p->size; ++i) {
1167
+ cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1168
+ }
1169
+
1170
+ #ifdef DEBUG
1171
+ // Print the updated top 25 probabilities after temperature scaling
1172
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1173
+ for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1174
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1175
+ }
1176
+ #endif
1177
+ } else {
1178
+ for (size_t i = 0; i < cur_p->size; ++i) {
1179
+ cur_p->data[i].logit /= ctx->temp;
471
1180
  }
472
1181
  }
1182
+ }
473
1183
 
474
- // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
475
- float normalized_entropy = entropy / max_entropy;
1184
+ static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1185
+ const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1186
+ return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1187
+ }
476
1188
 
477
- // Map the normalized entropy to the desired temperature range using the power function
478
- float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1189
+ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1190
+ delete (llama_sampler_temp_ext *) smpl->ctx;
1191
+ }
479
1192
 
480
- #ifdef DEBUG
481
- LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
482
- LLAMA_LOG_INFO("Entropy: %f\n", entropy);
483
- LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
484
- LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
485
- LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
486
- LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
487
- #endif
1193
+ static struct llama_sampler_i llama_sampler_temp_ext_i = {
1194
+ /* .name = */ llama_sampler_temp_ext_name,
1195
+ /* .accept = */ nullptr,
1196
+ /* .apply = */ llama_sampler_temp_ext_apply,
1197
+ /* .reset = */ nullptr,
1198
+ /* .clone = */ llama_sampler_temp_ext_clone,
1199
+ /* .free = */ llama_sampler_temp_ext_free,
1200
+ };
1201
+
1202
+ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1203
+ return new llama_sampler {
1204
+ /* .iface = */ &llama_sampler_temp_ext_i,
1205
+ /* .ctx = */ new llama_sampler_temp_ext {
1206
+ /* .temp = */ temp,
1207
+ /* .delta = */ delta,
1208
+ /* .exponent = */ exponent,
1209
+ },
1210
+ };
1211
+ }
1212
+
1213
+ // mirostat
1214
+
1215
+ struct llama_sampler_mirostat {
1216
+ const int32_t n_vocab;
1217
+
1218
+ const uint32_t seed;
1219
+ uint32_t seed_cur;
1220
+
1221
+ const float tau;
1222
+ const float eta;
1223
+
1224
+ const int32_t m;
1225
+
1226
+ float mu;
1227
+
1228
+ std::mt19937 rng;
1229
+ };
1230
+
1231
+ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1232
+ return "mirostat";
1233
+ }
1234
+
1235
+ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1236
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1237
+
1238
+ llama_sampler_softmax_impl(cur_p);
1239
+
1240
+ // Estimate s_hat using the most probable m tokens
1241
+ float s_hat = 0.0;
1242
+ float sum_ti_bi = 0.0;
1243
+ float sum_ti_sq = 0.0;
1244
+ for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
1245
+ float t_i = logf(float(i + 2) / float(i + 1));
1246
+ float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
1247
+ sum_ti_bi += t_i * b_i;
1248
+ sum_ti_sq += t_i * t_i;
1249
+ }
1250
+ s_hat = sum_ti_bi / sum_ti_sq;
1251
+
1252
+ // Compute k from the estimated s_hat and target surprise value
1253
+ float epsilon_hat = s_hat - 1;
1254
+ float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
1255
+
1256
+ llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1257
+ llama_sampler_softmax_impl(cur_p);
1258
+
1259
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1260
+
1261
+ cur_p->selected = idx;
1262
+
1263
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1264
+ float e = observed_surprise - ctx->tau;
1265
+
1266
+ // Update mu using the learning rate and error
1267
+ ctx->mu = ctx->mu - ctx->eta * e;
1268
+ }
1269
+
1270
+ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1271
+ const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1272
+ auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
1273
+
1274
+ // copy the state
1275
+ {
1276
+ auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1277
+
1278
+ result_ctx->mu = ctx->mu;
1279
+ result_ctx->rng = ctx->rng;
1280
+ }
1281
+
1282
+ return result;
1283
+ }
1284
+
1285
+ static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1286
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1287
+ ctx->mu = 2.0f*ctx->tau;
1288
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1289
+ ctx->rng.seed(ctx->seed_cur);
1290
+ }
1291
+
1292
+ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1293
+ delete (llama_sampler_mirostat *) smpl->ctx;
1294
+ }
1295
+
1296
+ static struct llama_sampler_i llama_sampler_mirostat_i = {
1297
+ /* .name = */ llama_sampler_mirostat_name,
1298
+ /* .accept = */ nullptr,
1299
+ /* .apply = */ llama_sampler_mirostat_apply,
1300
+ /* .reset = */ llama_sampler_mirostat_reset,
1301
+ /* .clone = */ llama_sampler_mirostat_clone,
1302
+ /* .free = */ llama_sampler_mirostat_free,
1303
+ };
1304
+
1305
+ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1306
+ auto seed_cur = get_rng_seed(seed);
1307
+ return new llama_sampler {
1308
+ /* .iface = */ &llama_sampler_mirostat_i,
1309
+ /* .ctx = */ new llama_sampler_mirostat {
1310
+ /* .n_vocab = */ n_vocab,
1311
+ /* .seed = */ seed,
1312
+ /* .seed_cur = */ seed_cur,
1313
+ /* .tau = */ tau,
1314
+ /* .eta = */ eta,
1315
+ /* .m = */ m,
1316
+ /* .mu = */ 2.0f*tau,
1317
+ /* .rng = */ std::mt19937(seed_cur),
1318
+ },
1319
+ };
1320
+ }
1321
+
1322
+ // mirostat v2
1323
+
1324
+ struct llama_sampler_mirostat_v2 {
1325
+ const uint32_t seed;
1326
+ uint32_t seed_cur;
1327
+
1328
+ const float tau;
1329
+ const float eta;
1330
+
1331
+ float mu;
1332
+
1333
+ std::mt19937 rng;
1334
+ };
1335
+
1336
+ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1337
+ return "mirostat-v2";
1338
+ }
1339
+
1340
+ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1341
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1342
+
1343
+ llama_sampler_softmax_impl(cur_p);
1344
+
1345
+ // Truncate the words with surprise values greater than mu
1346
+ cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1347
+ return -log2f(candidate.p) > ctx->mu;
1348
+ }));
1349
+
1350
+ if (cur_p->size == 0) {
1351
+ cur_p->size = 1;
1352
+ }
488
1353
 
489
- // Apply the dynamically calculated temperature scaling
490
- for (size_t i = 0; i < candidates->size; ++i) {
491
- candidates->data[i].logit /= dyn_temp;
1354
+ // Normalize the probabilities of the remaining words
1355
+ llama_sampler_softmax_impl(cur_p);
1356
+
1357
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1358
+
1359
+ cur_p->selected = idx;
1360
+
1361
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1362
+ float e = observed_surprise - ctx->tau;
1363
+
1364
+ // Update mu using the learning rate and error
1365
+ ctx->mu = ctx->mu - ctx->eta * e;
1366
+ }
1367
+
1368
+ static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1369
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1370
+ ctx->mu = 2.0f*ctx->tau;
1371
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1372
+ ctx->rng.seed(ctx->seed_cur);
1373
+ }
1374
+
1375
+ static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1376
+ const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1377
+
1378
+ auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
1379
+
1380
+ // copy the state
1381
+ {
1382
+ auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1383
+
1384
+ result_ctx->mu = ctx->mu;
1385
+ result_ctx->rng = ctx->rng;
492
1386
  }
493
1387
 
494
- // Re-compute softmax probabilities after scaling logits with dynamic temperature
495
- double max_l_double = candidates->data[0].logit;
496
- double cum_sum_double = 0.0;
497
- for (size_t i = 0; i < candidates->size; ++i) {
498
- double p = exp(candidates->data[i].logit - max_l_double);
499
- candidates->data[i].p = p; // Store the scaled probability
500
- cum_sum_double += p;
1388
+ return result;
1389
+ }
1390
+
1391
+ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1392
+ delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1393
+ }
1394
+
1395
+ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1396
+ /* .name = */ llama_sampler_mirostat_v2_name,
1397
+ /* .accept = */ nullptr,
1398
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
1399
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
1400
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
1401
+ /* .free = */ llama_sampler_mirostat_v2_free,
1402
+ };
1403
+
1404
+ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1405
+ auto seed_cur = get_rng_seed(seed);
1406
+ return new llama_sampler {
1407
+ /* .iface = */ &llama_sampler_mirostat_v2_i,
1408
+ /* .ctx = */ new llama_sampler_mirostat_v2 {
1409
+ /* .seed = */ seed,
1410
+ /* .seed_cur = */ seed_cur,
1411
+ /* .tau = */ tau,
1412
+ /* .eta = */ eta,
1413
+ /* .mu = */ 2.0f*tau,
1414
+ /* .rng = */ std::mt19937(seed_cur),
1415
+ },
1416
+ };
1417
+ }
1418
+
1419
+ // grammar
1420
+
1421
+ struct llama_sampler_grammar {
1422
+ const struct llama_vocab * vocab;
1423
+
1424
+ std::string grammar_str;
1425
+ std::string grammar_root;
1426
+
1427
+ struct llama_grammar * grammar;
1428
+ };
1429
+
1430
+ static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1431
+ return "grammar";
1432
+ }
1433
+
1434
+ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1435
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1436
+ if (ctx->grammar) {
1437
+ llama_grammar_accept_impl(*ctx->grammar, token);
501
1438
  }
502
- for (size_t i = 0; i < candidates->size; ++i) {
503
- candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1439
+ }
1440
+
1441
+ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1442
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1443
+ if (ctx->grammar) {
1444
+ llama_grammar_apply_impl(*ctx->grammar, cur_p);
504
1445
  }
1446
+ }
505
1447
 
506
- #ifdef DEBUG
507
- // Print the updated top 25 probabilities after temperature scaling
508
- LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
509
- for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
510
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
1448
+ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1449
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1450
+ if (!ctx->grammar) {
1451
+ return;
511
1452
  }
512
- #endif
513
1453
 
514
- if (smpl) {
515
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1454
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
1455
+
1456
+ llama_grammar_free_impl(ctx->grammar);
1457
+ ctx->grammar = grammar_new;
1458
+ }
1459
+
1460
+ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1461
+ const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1462
+
1463
+ auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
1464
+
1465
+ // copy the state
1466
+ {
1467
+ auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1468
+
1469
+ if (ctx->grammar) {
1470
+ result_ctx->grammar_str = ctx->grammar_str;
1471
+ result_ctx->grammar_root = ctx->grammar_root;
1472
+
1473
+ result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1474
+ }
516
1475
  }
1476
+
1477
+ return result;
517
1478
  }
518
1479
 
519
- void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
520
- const int64_t t_start_sample_us = lm_ggml_time_us();
1480
+ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1481
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
521
1482
 
522
- for (size_t i = 0; i < candidates->size; ++i) {
523
- candidates->data[i].logit /= temp;
1483
+ if (ctx->grammar) {
1484
+ llama_grammar_free_impl(ctx->grammar);
524
1485
  }
525
1486
 
526
- if (smpl) {
527
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1487
+ delete ctx;
1488
+ }
1489
+
1490
+ static struct llama_sampler_i llama_sampler_grammar_i = {
1491
+ /* .name = */ llama_sampler_grammar_name,
1492
+ /* .accept = */ llama_sampler_grammar_accept_impl,
1493
+ /* .apply = */ llama_sampler_grammar_apply,
1494
+ /* .reset = */ llama_sampler_grammar_reset,
1495
+ /* .clone = */ llama_sampler_grammar_clone,
1496
+ /* .free = */ llama_sampler_grammar_free,
1497
+ };
1498
+
1499
+ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
1500
+ auto * ctx = new llama_sampler_grammar;
1501
+
1502
+ if (grammar_str != nullptr && grammar_str[0] != '\0') {
1503
+ *ctx = {
1504
+ /* .vocab = */ &vocab,
1505
+ /* .grammar_str = */ grammar_str,
1506
+ /* .grammar_root = */ grammar_root,
1507
+ /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
1508
+ };
1509
+ } else {
1510
+ *ctx = {
1511
+ /* .vocab = */ &vocab,
1512
+ /* .grammar_str = */ {},
1513
+ /* .grammar_root = */ {},
1514
+ /* .grammar = */ nullptr,
1515
+ };
528
1516
  }
1517
+
1518
+ return new llama_sampler {
1519
+ /* .iface = */ &llama_sampler_grammar_i,
1520
+ /* .ctx = */ ctx,
1521
+ };
529
1522
  }
530
1523
 
531
- void llama_sample_repetition_penalties_impl(
532
- struct llama_sampling * smpl,
533
- llama_token_data_array * candidates,
534
- const llama_token * last_tokens,
535
- size_t penalty_last_n,
536
- float penalty_repeat,
537
- float penalty_freq,
538
- float penalty_present) {
539
- if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
1524
+ // penalties
1525
+
1526
+ struct llama_sampler_penalties {
1527
+ const int32_t n_vocab;
1528
+ const llama_token special_eos_id;
1529
+ const llama_token linefeed_id;
1530
+
1531
+ const int32_t penalty_last_n;
1532
+ const float penalty_repeat;
1533
+ const float penalty_freq;
1534
+ const float penalty_present;
1535
+
1536
+ const bool penalize_nl;
1537
+ const bool ignore_eos;
1538
+
1539
+ ring_buffer<llama_token> prev;
1540
+ };
1541
+
1542
+ static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1543
+ return "penalties";
1544
+ }
1545
+
1546
+ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1547
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1548
+ if (ctx->penalty_last_n == 0) {
540
1549
  return;
541
1550
  }
542
1551
 
543
- const int64_t t_start_sample_us = lm_ggml_time_us();
1552
+ ctx->prev.push_back(token);
1553
+ }
1554
+
1555
+ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1556
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1557
+
1558
+ if (ctx->ignore_eos) {
1559
+ assert(ctx->special_eos_id >= 0);
1560
+
1561
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1562
+ if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
1563
+ cur_p->data[ctx->special_eos_id].logit = -INFINITY;
1564
+ } else {
1565
+ // else, search for the special EOS token
1566
+ for (size_t i = 0; i < cur_p->size; ++i) {
1567
+ if (cur_p->data[i].id == ctx->special_eos_id) {
1568
+ cur_p->data[i].logit = -INFINITY;
1569
+ break;
1570
+ }
1571
+ }
1572
+ }
1573
+ }
1574
+
1575
+ if ((ctx->penalty_last_n == 0) ||
1576
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1577
+ return;
1578
+ }
1579
+
1580
+ bool nl_found = false;
1581
+ size_t nl_idx = 0;
1582
+ float nl_logit = -INFINITY;
1583
+ if (!ctx->penalize_nl) {
1584
+ assert(ctx->linefeed_id >= 0);
1585
+
1586
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1587
+ if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
1588
+ nl_found = true;
1589
+ nl_idx = ctx->linefeed_id;
1590
+ nl_logit = cur_p->data[ctx->linefeed_id].logit;
1591
+ } else {
1592
+ // else, search for the linefeed token
1593
+ for (size_t i = 0; i < cur_p->size; ++i) {
1594
+ if (cur_p->data[i].id == ctx->linefeed_id) {
1595
+ nl_found = true;
1596
+ nl_idx = i;
1597
+ nl_logit = cur_p->data[i].logit;
1598
+ break;
1599
+ }
1600
+ }
1601
+ }
1602
+ }
544
1603
 
545
1604
  // Create a frequency map to count occurrences of each token in last_tokens
546
- std::unordered_map<llama_token, int> token_count;
547
- for (size_t i = 0; i < penalty_last_n; ++i) {
548
- token_count[last_tokens[i]]++;
1605
+ // TODO: optimize this by maintaining the token count in the sampler context
1606
+ using llama_token_cnt = std::unordered_map<llama_token, int>;
1607
+ llama_token_cnt token_count;
1608
+
1609
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1610
+ token_count[ctx->prev.rat(i)]++;
549
1611
  }
550
1612
 
551
- // Apply frequency and presence penalties to the candidates
552
- for (size_t i = 0; i < candidates->size; ++i) {
553
- const auto token_iter = token_count.find(candidates->data[i].id);
1613
+ // Apply frequency and presence penalties to the cur_p
1614
+ for (size_t i = 0; i < cur_p->size; ++i) {
1615
+ const auto token_iter = token_count.find(cur_p->data[i].id);
554
1616
  if (token_iter == token_count.end()) {
555
1617
  continue;
556
1618
  }
@@ -559,171 +1621,238 @@ void llama_sample_repetition_penalties_impl(
559
1621
 
560
1622
  // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
561
1623
  // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
562
- if (candidates->data[i].logit <= 0) {
563
- candidates->data[i].logit *= penalty_repeat;
1624
+ if (cur_p->data[i].logit <= 0) {
1625
+ cur_p->data[i].logit *= ctx->penalty_repeat;
564
1626
  } else {
565
- candidates->data[i].logit /= penalty_repeat;
1627
+ cur_p->data[i].logit /= ctx->penalty_repeat;
566
1628
  }
567
1629
 
568
- candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
1630
+ cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
569
1631
  }
570
1632
 
571
- candidates->sorted = false;
1633
+ cur_p->sorted = false;
572
1634
 
573
- if (smpl) {
574
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1635
+ if (!ctx->penalize_nl && nl_found) {
1636
+ // restore the logit of the newline token if it was penalized
1637
+ cur_p->data[nl_idx].logit = nl_logit;
575
1638
  }
576
1639
  }
577
1640
 
578
- void llama_sample_apply_guidance_impl(
579
- struct llama_sampling * smpl,
580
- float * logits,
581
- float * logits_guidance,
582
- float scale) {
583
- LM_GGML_ASSERT(smpl);
584
-
585
- const auto t_start_sample_us = lm_ggml_time_us();
586
- const auto n_vocab = smpl->n_vocab;
587
-
588
- llama_log_softmax(logits, n_vocab);
589
- llama_log_softmax(logits_guidance, n_vocab);
1641
+ static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1642
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1643
+ ctx->prev.clear();
1644
+ }
590
1645
 
591
- for (int i = 0; i < n_vocab; ++i) {
592
- auto & l = logits[i];
593
- const auto & g = logits_guidance[i];
1646
+ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1647
+ const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1648
+ auto * result = llama_sampler_init_penalties(
1649
+ ctx->n_vocab,
1650
+ ctx->special_eos_id,
1651
+ ctx->linefeed_id,
1652
+ ctx->penalty_last_n,
1653
+ ctx->penalty_repeat,
1654
+ ctx->penalty_freq,
1655
+ ctx->penalty_present,
1656
+ ctx->penalize_nl,
1657
+ ctx->ignore_eos);
1658
+
1659
+ // copy the state
1660
+ {
1661
+ auto * result_ctx = (llama_sampler_penalties *) result->ctx;
594
1662
 
595
- l = scale * (l - g) + g;
1663
+ result_ctx->prev = ctx->prev;
596
1664
  }
597
1665
 
598
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1666
+ return result;
599
1667
  }
600
1668
 
601
- llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
602
- LM_GGML_ASSERT(smpl);
603
-
604
- const int32_t n_vocab = float(smpl->n_vocab);
605
-
606
- int64_t t_start_sample_us = lm_ggml_time_us();
1669
+ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1670
+ delete (llama_sampler_penalties *) smpl->ctx;
1671
+ }
607
1672
 
608
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1673
+ static struct llama_sampler_i llama_sampler_penalties_i = {
1674
+ /* .name = */ llama_sampler_penalties_name,
1675
+ /* .accept = */ llama_sampler_penalties_accept,
1676
+ /* .apply = */ llama_sampler_penalties_apply,
1677
+ /* .reset = */ llama_sampler_penalties_reset,
1678
+ /* .clone = */ llama_sampler_penalties_clone,
1679
+ /* .free = */ llama_sampler_penalties_free,
1680
+ };
1681
+
1682
+ struct llama_sampler * llama_sampler_init_penalties(
1683
+ int32_t n_vocab,
1684
+ llama_token special_eos_id,
1685
+ llama_token linefeed_id,
1686
+ int32_t penalty_last_n,
1687
+ float penalty_repeat,
1688
+ float penalty_freq,
1689
+ float penalty_present,
1690
+ bool penalize_nl,
1691
+ bool ignore_eos) {
1692
+ if (linefeed_id == LLAMA_TOKEN_NULL) {
1693
+ penalize_nl = true;
1694
+ }
609
1695
 
610
- // Estimate s_hat using the most probable m tokens
611
- float s_hat = 0.0;
612
- float sum_ti_bi = 0.0;
613
- float sum_ti_sq = 0.0;
614
- for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
615
- float t_i = logf(float(i + 2) / float(i + 1));
616
- float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
617
- sum_ti_bi += t_i * b_i;
618
- sum_ti_sq += t_i * t_i;
1696
+ if (special_eos_id == LLAMA_TOKEN_NULL) {
1697
+ ignore_eos = false;
619
1698
  }
620
- s_hat = sum_ti_bi / sum_ti_sq;
621
1699
 
622
- // Compute k from the estimated s_hat and target surprise value
623
- float epsilon_hat = s_hat - 1;
624
- float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
1700
+ penalty_last_n = std::max(penalty_last_n, 0);
1701
+
1702
+ return new llama_sampler {
1703
+ /* .iface = */ &llama_sampler_penalties_i,
1704
+ /* .ctx = */ new llama_sampler_penalties {
1705
+ /* .n_vocab = */ n_vocab,
1706
+ /* .special_eos_id = */ special_eos_id,
1707
+ /* .linefeed_id = */ linefeed_id,
1708
+ /* .penalty_last_n = */ penalty_last_n,
1709
+ /* .penalty_repeat = */ penalty_repeat,
1710
+ /* .penalty_freq = */ penalty_freq,
1711
+ /* .penalty_present = */ penalty_present,
1712
+ /* .penalize_nl = */ penalize_nl,
1713
+ /* .ignore_eos = */ ignore_eos,
1714
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1715
+ },
1716
+ };
1717
+ }
625
1718
 
626
- // Sample the next word X using top-k sampling
627
- llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
628
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
629
- llama_token X = llama_sample_token_impl(smpl, candidates);
630
- t_start_sample_us = lm_ggml_time_us();
1719
+ // logit-bias
631
1720
 
632
- // Compute error as the difference between observed surprise and target surprise value
633
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
634
- return candidate.id == X;
635
- }));
636
- float observed_surprise = -log2f(candidates->data[X_idx].p);
637
- float e = observed_surprise - tau;
1721
+ struct llama_sampler_logit_bias {
1722
+ const int32_t n_vocab;
638
1723
 
639
- // Update mu using the learning rate and error
640
- *mu = *mu - eta * e;
1724
+ const std::vector<llama_logit_bias> logit_bias;
1725
+
1726
+ std::vector<llama_logit_bias> to_search;
1727
+ };
641
1728
 
642
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
643
- return X;
1729
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
1730
+ return "logit-bias";
644
1731
  }
645
1732
 
646
- llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
647
- int64_t t_start_sample_us;
648
- t_start_sample_us = lm_ggml_time_us();
1733
+ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1734
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
649
1735
 
650
- llama_sample_softmax_impl(smpl, candidates);
1736
+ if (ctx->logit_bias.empty()) {
1737
+ return;
1738
+ }
651
1739
 
652
- // Truncate the words with surprise values greater than mu
653
- candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
654
- return -log2f(candidate.p) > *mu;
655
- }));
1740
+ ctx->to_search.clear();
656
1741
 
657
- if (candidates->size == 0) {
658
- candidates->size = 1;
1742
+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
1743
+ for (const auto & lb : ctx->logit_bias) {
1744
+ if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
1745
+ cur_p->data[lb.token].logit += lb.bias;
1746
+ } else {
1747
+ ctx->to_search.push_back(lb);
1748
+ }
659
1749
  }
660
1750
 
661
- if (smpl) {
662
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1751
+ if (ctx->to_search.empty()) {
1752
+ return;
663
1753
  }
664
1754
 
665
- // Normalize the probabilities of the remaining words
666
- llama_sample_softmax_impl(smpl, candidates);
1755
+ // search for the remaining candidates that were not found in the previous step
1756
+ for (size_t i = 0; i < cur_p->size; ++i) {
1757
+ for (const auto & lb : ctx->to_search) {
1758
+ if (cur_p->data[i].id == lb.token) {
1759
+ cur_p->data[i].logit += lb.bias;
1760
+ break;
1761
+ }
1762
+ }
1763
+ }
1764
+ }
667
1765
 
668
- // Sample the next word X from the remaining words
669
- llama_token X = llama_sample_token_impl(smpl, candidates);
670
- t_start_sample_us = lm_ggml_time_us();
1766
+ static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
1767
+ const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
1768
+ return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
1769
+ }
671
1770
 
672
- // Compute error as the difference between observed surprise and target surprise value
673
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
674
- return candidate.id == X;
675
- }));
676
- float observed_surprise = -log2f(candidates->data[X_idx].p);
677
- float e = observed_surprise - tau;
1771
+ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
1772
+ delete (llama_sampler_logit_bias *) smpl->ctx;
1773
+ }
678
1774
 
679
- // Update mu using the learning rate and error
680
- *mu = *mu - eta * e;
1775
+ static struct llama_sampler_i llama_sampler_logit_bias_i = {
1776
+ /* .name = */ llama_sampler_logit_bias_name,
1777
+ /* .accept = */ nullptr,
1778
+ /* .apply = */ llama_sampler_logit_bias_apply,
1779
+ /* .reset = */ nullptr,
1780
+ /* .clone = */ llama_sampler_logit_bias_clone,
1781
+ /* .free = */ llama_sampler_logit_bias_free,
1782
+ };
1783
+
1784
+ struct llama_sampler * llama_sampler_init_logit_bias(
1785
+ int32_t n_vocab,
1786
+ int32_t n_logit_bias,
1787
+ const llama_logit_bias * logit_bias) {
1788
+ return new llama_sampler {
1789
+ /* .iface = */ &llama_sampler_logit_bias_i,
1790
+ /* .ctx = */ new llama_sampler_logit_bias {
1791
+ /* .n_vocab = */ n_vocab,
1792
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
1793
+ /* .to_search = */ {},
1794
+ },
1795
+ };
1796
+ }
1797
+
1798
+ // utils
681
1799
 
682
- if (smpl) {
683
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1800
+ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
1801
+ if (smpl->iface == &llama_sampler_dist_i) {
1802
+ return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
684
1803
  }
685
- return X;
686
- }
687
1804
 
688
- llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
689
- const int64_t t_start_sample_us = lm_ggml_time_us();
1805
+ if (smpl->iface == &llama_sampler_mirostat_i) {
1806
+ return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
1807
+ }
690
1808
 
691
- // Find max element
692
- auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
693
- return a.logit < b.logit;
694
- });
1809
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
1810
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
1811
+ }
695
1812
 
696
- llama_token result = max_iter->id;
697
- if (smpl) {
698
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
699
- smpl->n_sample++;
1813
+ if (smpl->iface == &llama_sampler_chain_i) {
1814
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
1815
+ for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
1816
+ const uint32_t seed = llama_sampler_get_seed(*it);
1817
+ if (seed != LLAMA_DEFAULT_SEED) {
1818
+ return seed;
1819
+ }
1820
+ }
700
1821
  }
701
- return result;
1822
+
1823
+ return LLAMA_DEFAULT_SEED;
702
1824
  }
703
1825
 
704
- llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
705
- LM_GGML_ASSERT(smpl);
1826
+ // perf
706
1827
 
707
- const int64_t t_start_sample_us = lm_ggml_time_us();
708
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1828
+ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
1829
+ struct llama_perf_sampler_data data = {};
709
1830
 
710
- std::vector<float> probs;
711
- probs.reserve(candidates->size);
712
- for (size_t i = 0; i < candidates->size; ++i) {
713
- probs.push_back(candidates->data[i].p);
1831
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1832
+ LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
714
1833
  }
715
1834
 
716
- std::discrete_distribution<> dist(probs.begin(), probs.end());
717
- int idx = dist(rng);
1835
+ const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
718
1836
 
719
- llama_token result = candidates->data[idx].id;
1837
+ data.t_sample_ms = 1e-3 * ctx->t_sample_us;
1838
+ data.n_sample = std::max(0, ctx->n_sample);
720
1839
 
721
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
722
- smpl->n_sample++;
1840
+ return data;
1841
+ }
723
1842
 
724
- return result;
1843
+ void llama_perf_sampler_print(const struct llama_sampler * chain) {
1844
+ const auto data = llama_perf_sampler(chain);
1845
+
1846
+ LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
1847
+ __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
725
1848
  }
726
1849
 
727
- llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
728
- return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
1850
+ void llama_perf_sampler_reset(struct llama_sampler * chain) {
1851
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1852
+ LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
1853
+ }
1854
+
1855
+ auto * ctx = (struct llama_sampler_chain *) chain->ctx;
1856
+
1857
+ ctx->t_sample_us = ctx->n_sample = 0;
729
1858
  }