cui-llama.rn 1.0.3 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/README.md +35 -39
  2. package/android/src/main/CMakeLists.txt +12 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +29 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
  5. package/android/src/main/jni.cpp +62 -8
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  8. package/cpp/common.cpp +3237 -3231
  9. package/cpp/common.h +469 -468
  10. package/cpp/ggml-aarch64.c +2193 -2193
  11. package/cpp/ggml-aarch64.h +39 -39
  12. package/cpp/ggml-alloc.c +1036 -1042
  13. package/cpp/ggml-backend-impl.h +153 -153
  14. package/cpp/ggml-backend.c +2240 -2234
  15. package/cpp/ggml-backend.h +238 -238
  16. package/cpp/ggml-common.h +1833 -1829
  17. package/cpp/ggml-impl.h +755 -655
  18. package/cpp/ggml-metal.h +65 -65
  19. package/cpp/ggml-metal.m +3269 -3269
  20. package/cpp/ggml-quants.c +14872 -14860
  21. package/cpp/ggml-quants.h +132 -132
  22. package/cpp/ggml.c +22055 -22044
  23. package/cpp/ggml.h +2453 -2447
  24. package/cpp/llama-grammar.cpp +539 -0
  25. package/cpp/llama-grammar.h +39 -0
  26. package/cpp/llama-impl.h +26 -0
  27. package/cpp/llama-sampling.cpp +635 -0
  28. package/cpp/llama-sampling.h +56 -0
  29. package/cpp/llama-vocab.cpp +1721 -0
  30. package/cpp/llama-vocab.h +130 -0
  31. package/cpp/llama.cpp +19171 -21892
  32. package/cpp/llama.h +1240 -1217
  33. package/cpp/log.h +737 -737
  34. package/cpp/rn-llama.hpp +207 -29
  35. package/cpp/sampling.cpp +460 -460
  36. package/cpp/sgemm.cpp +1027 -1027
  37. package/cpp/sgemm.h +14 -14
  38. package/cpp/unicode.cpp +6 -0
  39. package/cpp/unicode.h +3 -0
  40. package/ios/RNLlama.mm +15 -6
  41. package/ios/RNLlamaContext.h +2 -8
  42. package/ios/RNLlamaContext.mm +41 -34
  43. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  44. package/lib/commonjs/chat.js +37 -0
  45. package/lib/commonjs/chat.js.map +1 -0
  46. package/lib/commonjs/index.js +14 -1
  47. package/lib/commonjs/index.js.map +1 -1
  48. package/lib/module/NativeRNLlama.js.map +1 -1
  49. package/lib/module/chat.js +31 -0
  50. package/lib/module/chat.js.map +1 -0
  51. package/lib/module/index.js +14 -1
  52. package/lib/module/index.js.map +1 -1
  53. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  54. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  55. package/lib/typescript/chat.d.ts +10 -0
  56. package/lib/typescript/chat.d.ts.map +1 -0
  57. package/lib/typescript/index.d.ts +9 -2
  58. package/lib/typescript/index.d.ts.map +1 -1
  59. package/package.json +1 -1
  60. package/src/NativeRNLlama.ts +10 -1
  61. package/src/chat.ts +44 -0
  62. package/src/index.ts +31 -4
@@ -0,0 +1,635 @@
1
+ #include "llama-sampling.h"
2
+
3
+ #include <algorithm>
4
+ #include <cstring>
5
+ #include <ctime>
6
+ #include <cfloat>
7
+ #include <numeric>
8
+ #include <unordered_map>
9
+
10
+ static void llama_log_softmax(float * array, size_t size) {
11
+ float max_l = *std::max_element(array, array + size);
12
+ float sum = 0.f;
13
+ for (size_t i = 0; i < size; ++i) {
14
+ float p = expf(array[i] - max_l);
15
+ sum += p;
16
+ array[i] = p;
17
+ }
18
+
19
+ for (size_t i = 0; i < size; ++i) {
20
+ array[i] = logf(array[i] / sum);
21
+ }
22
+ }
23
+
24
+ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
25
+ if (seed == LLAMA_DEFAULT_SEED) {
26
+ seed = time(NULL);
27
+ }
28
+
29
+ smpl->rng.seed(seed);
30
+ }
31
+
32
+ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
33
+ LM_GGML_ASSERT(candidates->size > 0);
34
+
35
+ const int64_t t_start_sample_us = lm_ggml_time_us();
36
+
37
+ // Sort the logits in descending order
38
+ if (!candidates->sorted) {
39
+ std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
40
+ return a.logit > b.logit;
41
+ });
42
+ candidates->sorted = true;
43
+ }
44
+
45
+ float max_l = candidates->data[0].logit;
46
+ float cum_sum = 0.0f;
47
+ for (size_t i = 0; i < candidates->size; ++i) {
48
+ float p = expf(candidates->data[i].logit - max_l);
49
+ candidates->data[i].p = p;
50
+ cum_sum += p;
51
+ }
52
+ for (size_t i = 0; i < candidates->size; ++i) {
53
+ candidates->data[i].p /= cum_sum;
54
+ }
55
+
56
+ if (smpl) {
57
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
58
+ }
59
+ }
60
+
61
+ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
62
+ // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
63
+ // if (k >= (int32_t)candidates->size) {
64
+ // return;
65
+ // }
66
+
67
+ const int64_t t_start_sample_us = lm_ggml_time_us();
68
+
69
+ if (k <= 0) {
70
+ k = candidates->size;
71
+ }
72
+
73
+ k = std::max(k, (int) min_keep);
74
+ k = std::min(k, (int) candidates->size);
75
+
76
+ // Sort scores in descending order
77
+ if (!candidates->sorted) {
78
+ auto comp = [](const llama_token_data & a, const llama_token_data & b) {
79
+ return a.logit > b.logit;
80
+ };
81
+ if (k <= 128) {
82
+ std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
83
+ } else {
84
+ constexpr int nbuckets = 128;
85
+ constexpr float bucket_low = -10.0f;
86
+ constexpr float bucket_high = 10.0f;
87
+ constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
88
+ constexpr float bucker_inter = -bucket_low * bucket_scale;
89
+
90
+ std::vector<int> bucket_idx(candidates->size);
91
+ std::vector<int> histo(nbuckets, 0);
92
+
93
+ for (int i = 0; i < (int)candidates->size; ++i) {
94
+ const float val = candidates->data[i].logit;
95
+ int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
96
+ ib = std::max(0, std::min(nbuckets-1, ib));
97
+ bucket_idx[i] = ib;
98
+ ++histo[ib];
99
+ }
100
+ int nhave = 0;
101
+ int ib = nbuckets - 1;
102
+ for ( ; ib >= 0; --ib) {
103
+ nhave += histo[ib];
104
+ if (nhave >= k) break;
105
+ }
106
+ std::vector<llama_token_data> tmp_tokens(nhave);
107
+ auto ptr = tmp_tokens.data();
108
+ std::vector<llama_token_data*> bucket_ptrs;
109
+ bucket_ptrs.reserve(nbuckets - ib);
110
+ for (int j = nbuckets - 1; j >= ib; --j) {
111
+ bucket_ptrs.push_back(ptr);
112
+ ptr += histo[j];
113
+ }
114
+ for (int i = 0; i < (int)candidates->size; ++i) {
115
+ int j = bucket_idx[i];
116
+ if (j >= ib) {
117
+ *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
118
+ }
119
+ }
120
+
121
+ ptr = tmp_tokens.data();
122
+ int ndone = 0;
123
+ for (int j = nbuckets-1; j > ib; --j) {
124
+ std::sort(ptr, ptr + histo[j], comp);
125
+ ptr += histo[j];
126
+ ndone += histo[j];
127
+ }
128
+ std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
129
+
130
+ std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
131
+
132
+ }
133
+ candidates->sorted = true;
134
+ }
135
+ candidates->size = k;
136
+
137
+ if (smpl) {
138
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
139
+ }
140
+ }
141
+
142
+ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143
+ if (p >= 1.0f) {
144
+ return;
145
+ }
146
+
147
+ llama_sample_softmax_impl(smpl, candidates);
148
+
149
+ const int64_t t_start_sample_us = lm_ggml_time_us();
150
+
151
+ // Compute the cumulative probabilities
152
+ float cum_sum = 0.0f;
153
+ size_t last_idx = candidates->size;
154
+
155
+ for (size_t i = 0; i < candidates->size; ++i) {
156
+ cum_sum += candidates->data[i].p;
157
+
158
+ // Check if the running sum is at least p or if we have kept at least min_keep tokens
159
+ // we set the last index to i+1 to indicate that the current iterate should be included in the set
160
+ if (cum_sum >= p && i + 1 >= min_keep) {
161
+ last_idx = i + 1;
162
+ break;
163
+ }
164
+ }
165
+
166
+ // Resize the output vector to keep only the top-p tokens
167
+ candidates->size = last_idx;
168
+
169
+ if (smpl) {
170
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
171
+ }
172
+ }
173
+
174
+ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
175
+ if (p <= 0.0f || !candidates->size) {
176
+ return;
177
+ }
178
+
179
+ const int64_t t_start_sample_us = lm_ggml_time_us();
180
+
181
+ bool min_p_applied = false;
182
+
183
+ // if the candidates aren't sorted, try the unsorted implementation first
184
+ if (!candidates->sorted) {
185
+ std::vector<llama_token_data> filtered_tokens;
186
+
187
+ float max_logit = -FLT_MAX;
188
+ for (size_t i = 0; i < candidates->size; ++i) {
189
+ max_logit = std::max(max_logit, candidates->data[i].logit);
190
+ }
191
+ const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
192
+
193
+ for (size_t i = 0; i < candidates->size; ++i) {
194
+ if (candidates->data[i].logit >= min_logit) {
195
+ filtered_tokens.push_back(candidates->data[i]);
196
+ }
197
+ }
198
+
199
+ // if we have enough values the operation was a success
200
+ if (filtered_tokens.size() >= min_keep) {
201
+ memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
202
+ candidates->size = filtered_tokens.size();
203
+ min_p_applied = true;
204
+ }
205
+ }
206
+
207
+ // if the candidates are sorted or the unsorted implementation failed, use this implementation
208
+ if (!min_p_applied) {
209
+ // Sort the logits in descending order
210
+ if (!candidates->sorted) {
211
+ std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
212
+ return a.logit > b.logit;
213
+ });
214
+ candidates->sorted = true;
215
+ }
216
+
217
+ const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
218
+ size_t i = 1; // first token always matches
219
+
220
+ for (; i < candidates->size; ++i) {
221
+ if (candidates->data[i].logit < min_logit && i >= min_keep) {
222
+ break; // prob too small
223
+ }
224
+ }
225
+
226
+ // Resize the output vector to keep only the matching tokens
227
+ candidates->size = i;
228
+ }
229
+
230
+ if (smpl) {
231
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
232
+ }
233
+ }
234
+
235
+ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
236
+ if (z >= 1.0f || candidates->size <= 2) {
237
+ return;
238
+ }
239
+
240
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
241
+ const int64_t t_start_sample_us = lm_ggml_time_us();
242
+
243
+ // Compute the first and second derivatives
244
+ std::vector<float> first_derivatives(candidates->size - 1);
245
+ std::vector<float> second_derivatives(candidates->size - 2);
246
+
247
+ for (size_t i = 0; i < first_derivatives.size(); ++i) {
248
+ first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
249
+ }
250
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
251
+ second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
252
+ }
253
+
254
+ // Calculate absolute value of second derivatives
255
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
256
+ second_derivatives[i] = std::abs(second_derivatives[i]);
257
+ }
258
+
259
+ // Normalize the second derivatives
260
+ {
261
+ const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
262
+
263
+ if (second_derivatives_sum > 1e-6f) {
264
+ for (float & value : second_derivatives) {
265
+ value /= second_derivatives_sum;
266
+ }
267
+ } else {
268
+ for (float & value : second_derivatives) {
269
+ value = 1.0f / second_derivatives.size();
270
+ }
271
+ }
272
+ }
273
+
274
+ float cum_sum = 0.0f;
275
+ size_t last_idx = candidates->size;
276
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
277
+ cum_sum += second_derivatives[i];
278
+
279
+ // Check if the running sum is greater than z or if we have kept at least min_keep tokens
280
+ if (cum_sum > z && i >= min_keep) {
281
+ last_idx = i;
282
+ break;
283
+ }
284
+ }
285
+
286
+ // Resize the output vector to keep only the tokens above the tail location
287
+ candidates->size = last_idx;
288
+
289
+ if (smpl) {
290
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
291
+ }
292
+ }
293
+
294
+ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
295
+ // Reference implementation:
296
+ // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
297
+ if (p >= 1.0f) {
298
+ return;
299
+ }
300
+
301
+ // Compute the softmax of logits and calculate entropy
302
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
303
+
304
+ const int64_t t_start_sample_us = lm_ggml_time_us();
305
+
306
+ float entropy = 0.0f;
307
+ for (size_t i = 0; i < candidates->size; ++i) {
308
+ entropy += -candidates->data[i].p * logf(candidates->data[i].p);
309
+ }
310
+
311
+ // Compute the absolute difference between negative log probability and entropy for each candidate
312
+ std::vector<float> shifted_scores;
313
+ for (size_t i = 0; i < candidates->size; ++i) {
314
+ float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
315
+ shifted_scores.push_back(shifted_score);
316
+ }
317
+
318
+ // Sort tokens based on the shifted_scores and their corresponding indices
319
+ std::vector<size_t> indices(candidates->size);
320
+ std::iota(indices.begin(), indices.end(), 0);
321
+
322
+ std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
323
+ return shifted_scores[a] < shifted_scores[b];
324
+ });
325
+
326
+ // Compute the cumulative probabilities
327
+ float cum_sum = 0.0f;
328
+ size_t last_idx = indices.size();
329
+
330
+ for (size_t i = 0; i < indices.size(); ++i) {
331
+ size_t idx = indices[i];
332
+ cum_sum += candidates->data[idx].p;
333
+
334
+ // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
335
+ if (cum_sum > p && i >= min_keep - 1) {
336
+ last_idx = i + 1;
337
+ break;
338
+ }
339
+ }
340
+
341
+ // Resize the output vector to keep only the locally typical tokens
342
+ std::vector<llama_token_data> new_candidates;
343
+ for (size_t i = 0; i < last_idx; ++i) {
344
+ size_t idx = indices[i];
345
+ new_candidates.push_back(candidates->data[idx]);
346
+ }
347
+
348
+ // Replace the data in candidates with the new_candidates data
349
+ std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
350
+ candidates->size = new_candidates.size();
351
+ candidates->sorted = false;
352
+
353
+ if (smpl) {
354
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
355
+ }
356
+ }
357
+
358
+ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
359
+ const int64_t t_start_sample_us = lm_ggml_time_us();
360
+
361
+ // no need to do anything if there is only one (or zero) candidates
362
+ if(candidates->size <= 1) {
363
+ return;
364
+ }
365
+
366
+ // Calculate maximum possible entropy
367
+ float max_entropy = -logf(1.0f / candidates->size);
368
+
369
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
370
+
371
+ // Calculate entropy of the softmax probabilities
372
+ float entropy = 0.0f;
373
+ for (size_t i = 0; i < candidates->size; ++i) {
374
+ float prob = candidates->data[i].p;
375
+ if (prob > 0.0f) { // Ensure no log(0)
376
+ entropy -= prob * logf(prob);
377
+ }
378
+ }
379
+
380
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
381
+ float normalized_entropy = entropy / max_entropy;
382
+
383
+ // Map the normalized entropy to the desired temperature range using the power function
384
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
385
+
386
+ #ifdef DEBUG
387
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
388
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
389
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
390
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
391
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
392
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
393
+ #endif
394
+
395
+ // Apply the dynamically calculated temperature scaling
396
+ for (size_t i = 0; i < candidates->size; ++i) {
397
+ candidates->data[i].logit /= dyn_temp;
398
+ }
399
+
400
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
401
+ double max_l_double = candidates->data[0].logit;
402
+ double cum_sum_double = 0.0;
403
+ for (size_t i = 0; i < candidates->size; ++i) {
404
+ double p = exp(candidates->data[i].logit - max_l_double);
405
+ candidates->data[i].p = p; // Store the scaled probability
406
+ cum_sum_double += p;
407
+ }
408
+ for (size_t i = 0; i < candidates->size; ++i) {
409
+ candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
410
+ }
411
+
412
+ #ifdef DEBUG
413
+ // Print the updated top 25 probabilities after temperature scaling
414
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
415
+ for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
416
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
417
+ }
418
+ #endif
419
+
420
+ if (smpl) {
421
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
422
+ }
423
+ }
424
+
425
+ void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
426
+ const int64_t t_start_sample_us = lm_ggml_time_us();
427
+
428
+ for (size_t i = 0; i < candidates->size; ++i) {
429
+ candidates->data[i].logit /= temp;
430
+ }
431
+
432
+ if (smpl) {
433
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
434
+ }
435
+ }
436
+
437
+ void llama_sample_repetition_penalties_impl(
438
+ struct llama_sampling * smpl,
439
+ llama_token_data_array * candidates,
440
+ const llama_token * last_tokens,
441
+ size_t penalty_last_n,
442
+ float penalty_repeat,
443
+ float penalty_freq,
444
+ float penalty_present) {
445
+ if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
446
+ return;
447
+ }
448
+
449
+ const int64_t t_start_sample_us = lm_ggml_time_us();
450
+
451
+ // Create a frequency map to count occurrences of each token in last_tokens
452
+ std::unordered_map<llama_token, int> token_count;
453
+ for (size_t i = 0; i < penalty_last_n; ++i) {
454
+ token_count[last_tokens[i]]++;
455
+ }
456
+
457
+ // Apply frequency and presence penalties to the candidates
458
+ for (size_t i = 0; i < candidates->size; ++i) {
459
+ const auto token_iter = token_count.find(candidates->data[i].id);
460
+ if (token_iter == token_count.end()) {
461
+ continue;
462
+ }
463
+
464
+ const int count = token_iter->second;
465
+
466
+ // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
467
+ // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
468
+ if (candidates->data[i].logit <= 0) {
469
+ candidates->data[i].logit *= penalty_repeat;
470
+ } else {
471
+ candidates->data[i].logit /= penalty_repeat;
472
+ }
473
+
474
+ candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
475
+ }
476
+
477
+ candidates->sorted = false;
478
+
479
+ if (smpl) {
480
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
481
+ }
482
+ }
483
+
484
+ void llama_sample_apply_guidance_impl(
485
+ struct llama_sampling * smpl,
486
+ float * logits,
487
+ float * logits_guidance,
488
+ float scale) {
489
+ LM_GGML_ASSERT(smpl);
490
+
491
+ const auto t_start_sample_us = lm_ggml_time_us();
492
+ const auto n_vocab = smpl->n_vocab;
493
+
494
+ llama_log_softmax(logits, n_vocab);
495
+ llama_log_softmax(logits_guidance, n_vocab);
496
+
497
+ for (int i = 0; i < n_vocab; ++i) {
498
+ auto & l = logits[i];
499
+ const auto & g = logits_guidance[i];
500
+
501
+ l = scale * (l - g) + g;
502
+ }
503
+
504
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
505
+ }
506
+
507
+ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
508
+ LM_GGML_ASSERT(smpl);
509
+
510
+ const int32_t n_vocab = float(smpl->n_vocab);
511
+
512
+ int64_t t_start_sample_us = lm_ggml_time_us();
513
+
514
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
515
+
516
+ // Estimate s_hat using the most probable m tokens
517
+ float s_hat = 0.0;
518
+ float sum_ti_bi = 0.0;
519
+ float sum_ti_sq = 0.0;
520
+ for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
521
+ float t_i = logf(float(i + 2) / float(i + 1));
522
+ float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
523
+ sum_ti_bi += t_i * b_i;
524
+ sum_ti_sq += t_i * t_i;
525
+ }
526
+ s_hat = sum_ti_bi / sum_ti_sq;
527
+
528
+ // Compute k from the estimated s_hat and target surprise value
529
+ float epsilon_hat = s_hat - 1;
530
+ float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
531
+
532
+ // Sample the next word X using top-k sampling
533
+ llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
534
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
535
+ llama_token X = llama_sample_token_impl(smpl, candidates);
536
+ t_start_sample_us = lm_ggml_time_us();
537
+
538
+ // Compute error as the difference between observed surprise and target surprise value
539
+ size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
540
+ return candidate.id == X;
541
+ }));
542
+ float observed_surprise = -log2f(candidates->data[X_idx].p);
543
+ float e = observed_surprise - tau;
544
+
545
+ // Update mu using the learning rate and error
546
+ *mu = *mu - eta * e;
547
+
548
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
549
+ return X;
550
+ }
551
+
552
+ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
553
+ int64_t t_start_sample_us;
554
+ t_start_sample_us = lm_ggml_time_us();
555
+
556
+ llama_sample_softmax_impl(smpl, candidates);
557
+
558
+ // Truncate the words with surprise values greater than mu
559
+ candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
560
+ return -log2f(candidate.p) > *mu;
561
+ }));
562
+
563
+ if (candidates->size == 0) {
564
+ candidates->size = 1;
565
+ }
566
+
567
+ if (smpl) {
568
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
569
+ }
570
+
571
+ // Normalize the probabilities of the remaining words
572
+ llama_sample_softmax_impl(smpl, candidates);
573
+
574
+ // Sample the next word X from the remaining words
575
+ llama_token X = llama_sample_token_impl(smpl, candidates);
576
+ t_start_sample_us = lm_ggml_time_us();
577
+
578
+ // Compute error as the difference between observed surprise and target surprise value
579
+ size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
580
+ return candidate.id == X;
581
+ }));
582
+ float observed_surprise = -log2f(candidates->data[X_idx].p);
583
+ float e = observed_surprise - tau;
584
+
585
+ // Update mu using the learning rate and error
586
+ *mu = *mu - eta * e;
587
+
588
+ if (smpl) {
589
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
590
+ }
591
+ return X;
592
+ }
593
+
594
+ llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
595
+ const int64_t t_start_sample_us = lm_ggml_time_us();
596
+
597
+ // Find max element
598
+ auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
599
+ return a.logit < b.logit;
600
+ });
601
+
602
+ llama_token result = max_iter->id;
603
+ if (smpl) {
604
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
605
+ smpl->n_sample++;
606
+ }
607
+ return result;
608
+ }
609
+
610
+ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
611
+ LM_GGML_ASSERT(smpl);
612
+
613
+ const int64_t t_start_sample_us = lm_ggml_time_us();
614
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
615
+
616
+ std::vector<float> probs;
617
+ probs.reserve(candidates->size);
618
+ for (size_t i = 0; i < candidates->size; ++i) {
619
+ probs.push_back(candidates->data[i].p);
620
+ }
621
+
622
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
623
+ int idx = dist(rng);
624
+
625
+ llama_token result = candidates->data[idx].id;
626
+
627
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
628
+ smpl->n_sample++;
629
+
630
+ return result;
631
+ }
632
+
633
+ llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
634
+ return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
635
+ }
@@ -0,0 +1,56 @@
1
+ #pragma once
2
+
3
+ #include "llama-impl.h"
4
+
5
+ struct llama_sampling {
6
+ llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
7
+
8
+ std::mt19937 rng;
9
+
10
+ int32_t n_vocab = 0;
11
+
12
+ mutable int64_t t_sample_us = 0;
13
+ mutable int32_t n_sample = 0;
14
+
15
+ void reset_timings() const {
16
+ t_sample_us = 0;
17
+ n_sample = 0;
18
+ }
19
+ };
20
+
21
+ //
22
+ // internal API
23
+ //
24
+
25
+ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
26
+
27
+ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
28
+ void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
29
+ void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
30
+ void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
31
+ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
32
+ void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
33
+ void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
34
+ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
35
+
36
+ void llama_sample_repetition_penalties_impl(
37
+ struct llama_sampling * smpl,
38
+ llama_token_data_array * candidates,
39
+ const llama_token * last_tokens,
40
+ size_t penalty_last_n,
41
+ float penalty_repeat,
42
+ float penalty_freq,
43
+ float penalty_present);
44
+
45
+ void llama_sample_apply_guidance_impl(
46
+ struct llama_sampling * smpl,
47
+ float * logits,
48
+ float * logits_guidance,
49
+ float scale);
50
+
51
+ llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
52
+ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
53
+ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
54
+ llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
55
+ llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
56
+