cui-llama.rn 1.1.4 → 1.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,49 +8,45 @@
8
8
  #include <cstring>
9
9
  #include <ctime>
10
10
  #include <cfloat>
11
+ #include <chrono>
12
+ #include <cmath>
11
13
  #include <numeric>
12
14
  #include <random>
13
15
  #include <unordered_map>
14
16
 
15
- static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
16
- #if 1
17
- probs.resize(cur_p->size);
18
- for (size_t i = 0; i < cur_p->size; ++i) {
19
- probs[i] = cur_p->data[i].p;
20
- }
21
-
22
- std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
23
- #else
24
- // avoid the copy with a custom iterator
17
+ static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
18
+ // iterator for the probabilities
19
+ #ifdef __GNUC__
25
20
  #pragma GCC diagnostic push
26
21
  #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
22
+ #endif
27
23
 
28
24
  struct probs_iterator {
29
25
  typedef std::input_iterator_tag iterator_category;
30
26
  typedef float value_type;
31
27
  typedef float * pointer;
32
28
  typedef float & reference;
33
- typedef size_t difference_type;
29
+ typedef ptrdiff_t difference_type;
34
30
 
35
- const llama_token_data_array * data;
36
- size_t i;
31
+ const llama_token_data * data;
37
32
 
38
- bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; }
39
- bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; }
40
- float operator*() const { return data->data[i].p; }
41
- probs_iterator & operator++() { ++i; return *this; }
42
- probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; }
33
+ bool operator==(const probs_iterator & other) const { return data == other.data; }
34
+ bool operator!=(const probs_iterator & other) const { return data != other.data; }
35
+ const float & operator*() const { return data->p; }
36
+ probs_iterator & operator++() { ++data; return *this; }
37
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
43
38
  };
44
- #pragma GCC diagnostic pop
45
-
46
- std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size});
47
39
 
48
- LM_GGML_UNUSED(probs);
40
+ #ifdef __GNUC__
41
+ #pragma GCC diagnostic pop
49
42
  #endif
50
43
 
44
+ std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
45
+
51
46
  return dist(rng);
52
47
  }
53
48
 
49
+ /*
54
50
  static void llama_log_softmax(float * array, size_t size) {
55
51
  float max_l = *std::max_element(array, array + size);
56
52
  float sum = 0.f;
@@ -64,6 +60,7 @@ static void llama_log_softmax(float * array, size_t size) {
64
60
  array[i] = logf(array[i] / sum);
65
61
  }
66
62
  }
63
+ */
67
64
 
68
65
  static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
69
66
  LM_GGML_ASSERT(cur_p->size > 0);
@@ -166,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
166
163
  cur_p->size = k;
167
164
  }
168
165
 
166
+ static uint32_t get_rng_seed(uint32_t seed) {
167
+ if (seed == LLAMA_DEFAULT_SEED) {
168
+ // use system clock if std::random_device is not a true RNG
169
+ static bool is_rd_prng = std::random_device().entropy() == 0;
170
+ if (is_rd_prng) {
171
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
172
+ }
173
+ std::random_device rd;
174
+ return rd();
175
+ }
176
+ return seed;
177
+ }
178
+
169
179
  // llama_sampler API
170
180
 
171
181
  const char * llama_sampler_name(const struct llama_sampler * smpl) {
@@ -231,67 +241,92 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
231
241
  cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
232
242
  }
233
243
 
234
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
244
+ llama_token_data_array cur_p = {
245
+ /* .data = */ cur.data(),
246
+ /* .size = */ cur.size(),
247
+ /* .selected = */ -1,
248
+ /* .sorted = */ false,
249
+ };
235
250
 
236
251
  llama_sampler_apply(smpl, &cur_p);
237
252
 
238
- return cur_p.data[cur_p.selected].id;
253
+ LM_GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
254
+
255
+ auto token = cur_p.data[cur_p.selected].id;
256
+
257
+ llama_sampler_accept(smpl, token);
258
+
259
+ return token;
239
260
  }
240
261
 
241
262
  // sampler chain
242
263
 
243
- static struct llama_sampler_i llama_sampler_chain_i = {
244
- /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
245
- /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
246
- auto * chain = (llama_sampler_chain *) smpl->ctx;
264
+ static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
265
+ return "chain";
266
+ }
247
267
 
248
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
268
+ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
269
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
249
270
 
250
- for (auto * smpl : chain->samplers) {
251
- llama_sampler_accept(smpl, token);
252
- }
271
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
253
272
 
254
- chain->n_sample++;
255
- },
256
- /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
257
- auto * chain = (llama_sampler_chain *) smpl->ctx;
273
+ for (auto * smpl : chain->samplers) {
274
+ llama_sampler_accept(smpl, token);
275
+ }
258
276
 
259
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
277
+ chain->n_sample++;
278
+ }
260
279
 
261
- for (auto * smpl : chain->samplers) {
262
- llama_sampler_apply(smpl, cur_p);
263
- }
264
- },
265
- /* .reset = */ [](struct llama_sampler * smpl) {
266
- auto * chain = (llama_sampler_chain *) smpl->ctx;
280
+ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
281
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
267
282
 
268
- for (auto * smpl : chain->samplers) {
269
- llama_sampler_reset(smpl);
270
- }
283
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
271
284
 
272
- chain->t_sample_us = 0;
273
- chain->n_sample = 0;
274
- },
275
- /* .clone = */ [](const struct llama_sampler * smpl) {
276
- const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
285
+ for (auto * smpl : chain->samplers) {
286
+ llama_sampler_apply(smpl, cur_p);
287
+ }
288
+ }
277
289
 
278
- auto * result = llama_sampler_chain_init(chain_src->params);
290
+ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
291
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
279
292
 
280
- for (auto * smpl : chain_src->samplers) {
281
- llama_sampler_chain_add(result, llama_sampler_clone(smpl));
282
- }
293
+ for (auto * smpl : chain->samplers) {
294
+ llama_sampler_reset(smpl);
295
+ }
283
296
 
284
- return result;
285
- },
286
- /* .free = */ [](struct llama_sampler * smpl) {
287
- auto * chain = (llama_sampler_chain *) smpl->ctx;
297
+ chain->t_sample_us = 0;
298
+ chain->n_sample = 0;
299
+ }
288
300
 
289
- for (auto * smpl : chain->samplers) {
290
- llama_sampler_free(smpl);
291
- }
301
+ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
302
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
303
+
304
+ auto * result = llama_sampler_chain_init(chain_src->params);
305
+
306
+ for (auto * smpl : chain_src->samplers) {
307
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
308
+ }
309
+
310
+ return result;
311
+ }
312
+
313
+ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
314
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
315
+
316
+ for (auto * smpl : chain->samplers) {
317
+ llama_sampler_free(smpl);
318
+ }
319
+
320
+ delete chain;
321
+ }
292
322
 
293
- delete chain;
294
- },
323
+ static struct llama_sampler_i llama_sampler_chain_i = {
324
+ /* .name = */ llama_sampler_chain_name,
325
+ /* .accept = */ llama_sampler_chain_accept,
326
+ /* .apply = */ llama_sampler_chain_apply,
327
+ /* .reset = */ llama_sampler_chain_reset,
328
+ /* .clone = */ llama_sampler_chain_clone,
329
+ /* .free = */ llama_sampler_chain_free,
295
330
  };
296
331
 
297
332
  struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@@ -323,13 +358,26 @@ llama_sampler_timings llama_sampler_chain_timings(struct llama_sampler * chain)
323
358
  struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
324
359
  const auto * p = (const llama_sampler_chain *) chain->ctx;
325
360
 
326
- if (i < 0 || i >= (int32_t) p->samplers.size()) {
361
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
327
362
  return nullptr;
328
363
  }
329
364
 
330
365
  return p->samplers[i];
331
366
  }
332
367
 
368
+ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
369
+ auto * p = (llama_sampler_chain *) chain->ctx;
370
+
371
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
372
+ return nullptr;
373
+ }
374
+
375
+ auto * result = p->samplers[i];
376
+ p->samplers.erase(p->samplers.begin() + i);
377
+
378
+ return result;
379
+ }
380
+
333
381
  int llama_sampler_chain_n(const struct llama_sampler * chain) {
334
382
  const auto * p = (const llama_sampler_chain *) chain->ctx;
335
383
 
@@ -375,10 +423,9 @@ struct llama_sampler * llama_sampler_init_greedy() {
375
423
 
376
424
  struct llama_sampler_dist {
377
425
  const uint32_t seed;
426
+ uint32_t seed_cur;
378
427
 
379
428
  std::mt19937 rng;
380
-
381
- std::vector<float> probs; // work array
382
429
  };
383
430
 
384
431
  static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
@@ -387,7 +434,7 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
387
434
 
388
435
  static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
389
436
  auto * ctx = (llama_sampler_dist *) smpl->ctx;
390
- cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
437
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
391
438
  }
392
439
 
393
440
  static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
@@ -406,7 +453,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
406
453
 
407
454
  static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
408
455
  auto * ctx = (llama_sampler_dist *) smpl->ctx;
409
- ctx->rng = std::mt19937(ctx->seed);
456
+ ctx->seed_cur = get_rng_seed(ctx->seed);
457
+ ctx->rng.seed(ctx->seed_cur);
410
458
  }
411
459
 
412
460
  static void llama_sampler_dist_free(struct llama_sampler * smpl) {
@@ -423,12 +471,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
423
471
  };
424
472
 
425
473
  struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
474
+ auto seed_cur = get_rng_seed(seed);
426
475
  return new llama_sampler {
427
476
  /* .iface = */ &llama_sampler_dist_i,
428
477
  /* .ctx = */ new llama_sampler_dist {
429
- /* .seed = */ seed,
430
- /* .rng = */ std::mt19937(seed),
431
- /* .probs = */ {},
478
+ /* .seed = */ seed,
479
+ /* .seed_cur = */ seed_cur,
480
+ /* .rng = */ std::mt19937(seed_cur),
432
481
  },
433
482
  };
434
483
  }
@@ -1167,6 +1216,7 @@ struct llama_sampler_mirostat {
1167
1216
  const int32_t n_vocab;
1168
1217
 
1169
1218
  const uint32_t seed;
1219
+ uint32_t seed_cur;
1170
1220
 
1171
1221
  const float tau;
1172
1222
  const float eta;
@@ -1176,8 +1226,6 @@ struct llama_sampler_mirostat {
1176
1226
  float mu;
1177
1227
 
1178
1228
  std::mt19937 rng;
1179
-
1180
- std::vector<float> probs;
1181
1229
  };
1182
1230
 
1183
1231
  static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
@@ -1208,7 +1256,7 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke
1208
1256
  llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1209
1257
  llama_sampler_softmax_impl(cur_p);
1210
1258
 
1211
- const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
1259
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1212
1260
 
1213
1261
  cur_p->selected = idx;
1214
1262
 
@@ -1237,7 +1285,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
1237
1285
  static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1238
1286
  auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1239
1287
  ctx->mu = 2.0f*ctx->tau;
1240
- ctx->rng = std::mt19937(ctx->seed);
1288
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1289
+ ctx->rng.seed(ctx->seed_cur);
1241
1290
  }
1242
1291
 
1243
1292
  static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
@@ -1254,17 +1303,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
1254
1303
  };
1255
1304
 
1256
1305
  struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1306
+ auto seed_cur = get_rng_seed(seed);
1257
1307
  return new llama_sampler {
1258
1308
  /* .iface = */ &llama_sampler_mirostat_i,
1259
1309
  /* .ctx = */ new llama_sampler_mirostat {
1260
- /* .n_vocab = */ n_vocab,
1261
- /* .seed = */ seed,
1262
- /* .tau = */ tau,
1263
- /* .eta = */ eta,
1264
- /* .m = */ m,
1265
- /* .mu = */ 2.0f*tau,
1266
- /* .rng = */ std::mt19937(seed),
1267
- /* .probs = */ {},
1310
+ /* .n_vocab = */ n_vocab,
1311
+ /* .seed = */ seed,
1312
+ /* .seed_cur = */ seed_cur,
1313
+ /* .tau = */ tau,
1314
+ /* .eta = */ eta,
1315
+ /* .m = */ m,
1316
+ /* .mu = */ 2.0f*tau,
1317
+ /* .rng = */ std::mt19937(seed_cur),
1268
1318
  },
1269
1319
  };
1270
1320
  }
@@ -1273,6 +1323,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
1273
1323
 
1274
1324
  struct llama_sampler_mirostat_v2 {
1275
1325
  const uint32_t seed;
1326
+ uint32_t seed_cur;
1276
1327
 
1277
1328
  const float tau;
1278
1329
  const float eta;
@@ -1280,8 +1331,6 @@ struct llama_sampler_mirostat_v2 {
1280
1331
  float mu;
1281
1332
 
1282
1333
  std::mt19937 rng;
1283
-
1284
- std::vector<float> probs;
1285
1334
  };
1286
1335
 
1287
1336
  static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
@@ -1305,7 +1354,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
1305
1354
  // Normalize the probabilities of the remaining words
1306
1355
  llama_sampler_softmax_impl(cur_p);
1307
1356
 
1308
- const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
1357
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1309
1358
 
1310
1359
  cur_p->selected = idx;
1311
1360
 
@@ -1319,7 +1368,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
1319
1368
  static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1320
1369
  auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1321
1370
  ctx->mu = 2.0f*ctx->tau;
1322
- ctx->rng = std::mt19937(ctx->seed);
1371
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1372
+ ctx->rng.seed(ctx->seed_cur);
1323
1373
  }
1324
1374
 
1325
1375
  static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
@@ -1352,15 +1402,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1352
1402
  };
1353
1403
 
1354
1404
  struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1405
+ auto seed_cur = get_rng_seed(seed);
1355
1406
  return new llama_sampler {
1356
1407
  /* .iface = */ &llama_sampler_mirostat_v2_i,
1357
1408
  /* .ctx = */ new llama_sampler_mirostat_v2 {
1358
- /* .seed = */ seed,
1359
- /* .tau = */ tau,
1360
- /* .eta = */ eta,
1361
- /* .mu = */ 2.0f*tau,
1362
- /* .rng = */ std::mt19937(seed),
1363
- /* .probs = */ {},
1409
+ /* .seed = */ seed,
1410
+ /* .seed_cur = */ seed_cur,
1411
+ /* .tau = */ tau,
1412
+ /* .eta = */ eta,
1413
+ /* .mu = */ 2.0f*tau,
1414
+ /* .rng = */ std::mt19937(seed_cur),
1364
1415
  },
1365
1416
  };
1366
1417
  }
@@ -1646,6 +1697,8 @@ struct llama_sampler * llama_sampler_init_penalties(
1646
1697
  ignore_eos = false;
1647
1698
  }
1648
1699
 
1700
+ penalty_last_n = std::max(penalty_last_n, 0);
1701
+
1649
1702
  return new llama_sampler {
1650
1703
  /* .iface = */ &llama_sampler_penalties_i,
1651
1704
  /* .ctx = */ new llama_sampler_penalties {
@@ -1680,6 +1733,10 @@ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /
1680
1733
  static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1681
1734
  auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
1682
1735
 
1736
+ if (ctx->logit_bias.empty()) {
1737
+ return;
1738
+ }
1739
+
1683
1740
  ctx->to_search.clear();
1684
1741
 
1685
1742
  // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
@@ -1691,6 +1748,10 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
1691
1748
  }
1692
1749
  }
1693
1750
 
1751
+ if (ctx->to_search.empty()) {
1752
+ return;
1753
+ }
1754
+
1694
1755
  // search for the remaining candidates that were not found in the previous step
1695
1756
  for (size_t i = 0; i < cur_p->size; ++i) {
1696
1757
  for (const auto & lb : ctx->to_search) {
@@ -1701,6 +1762,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
1701
1762
  }
1702
1763
  }
1703
1764
  }
1765
+
1704
1766
  static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
1705
1767
  const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
1706
1768
  return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
@@ -1732,3 +1794,65 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1732
1794
  },
1733
1795
  };
1734
1796
  }
1797
+
1798
+ // utils
1799
+
1800
+ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
1801
+ if (smpl->iface == &llama_sampler_dist_i) {
1802
+ return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
1803
+ }
1804
+
1805
+ if (smpl->iface == &llama_sampler_mirostat_i) {
1806
+ return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
1807
+ }
1808
+
1809
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
1810
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
1811
+ }
1812
+
1813
+ if (smpl->iface == &llama_sampler_chain_i) {
1814
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
1815
+ for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
1816
+ const uint32_t seed = llama_sampler_get_seed(*it);
1817
+ if (seed != LLAMA_DEFAULT_SEED) {
1818
+ return seed;
1819
+ }
1820
+ }
1821
+ }
1822
+
1823
+ return LLAMA_DEFAULT_SEED;
1824
+ }
1825
+
1826
+ // perf
1827
+
1828
+ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
1829
+ struct llama_perf_sampler_data data = {};
1830
+
1831
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1832
+ LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
1833
+ }
1834
+
1835
+ const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
1836
+
1837
+ data.t_sample_ms = 1e-3 * ctx->t_sample_us;
1838
+ data.n_sample = std::max(0, ctx->n_sample);
1839
+
1840
+ return data;
1841
+ }
1842
+
1843
+ void llama_perf_sampler_print(const struct llama_sampler * chain) {
1844
+ const auto data = llama_perf_sampler(chain);
1845
+
1846
+ LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
1847
+ __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
1848
+ }
1849
+
1850
+ void llama_perf_sampler_reset(struct llama_sampler * chain) {
1851
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1852
+ LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
1853
+ }
1854
+
1855
+ auto * ctx = (struct llama_sampler_chain *) chain->ctx;
1856
+
1857
+ ctx->t_sample_us = ctx->n_sample = 0;
1858
+ }