@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -1,12 +1,53 @@
1
1
  #include "llama-sampling.h"
2
2
 
3
+ #include "llama-vocab.h"
4
+ #include "llama-grammar.h"
5
+
3
6
  #include <algorithm>
7
+ #include <cassert>
8
+ #include <cfloat>
9
+ #include <chrono>
10
+ #include <cmath>
11
+ #include <cstdlib>
4
12
  #include <cstring>
5
13
  #include <ctime>
6
- #include <cfloat>
7
14
  #include <numeric>
15
+ #include <random>
8
16
  #include <unordered_map>
9
17
 
18
+ static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
19
+ // iterator for the probabilities
20
+ #ifdef __GNUC__
21
+ #pragma GCC diagnostic push
22
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
23
+ #endif
24
+
25
+ struct probs_iterator {
26
+ typedef std::input_iterator_tag iterator_category;
27
+ typedef float value_type;
28
+ typedef float * pointer;
29
+ typedef float & reference;
30
+ typedef ptrdiff_t difference_type;
31
+
32
+ const llama_token_data * data;
33
+
34
+ bool operator==(const probs_iterator & other) const { return data == other.data; }
35
+ bool operator!=(const probs_iterator & other) const { return data != other.data; }
36
+ const float & operator*() const { return data->p; }
37
+ probs_iterator & operator++() { ++data; return *this; }
38
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
39
+ };
40
+
41
+ #ifdef __GNUC__
42
+ #pragma GCC diagnostic pop
43
+ #endif
44
+
45
+ std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
46
+
47
+ return dist(rng);
48
+ }
49
+
50
+ /*
10
51
  static void llama_log_softmax(float * array, size_t size) {
11
52
  float max_l = *std::max_element(array, array + size);
12
53
  float sum = 0.f;
@@ -20,79 +61,65 @@ static void llama_log_softmax(float * array, size_t size) {
20
61
  array[i] = logf(array[i] / sum);
21
62
  }
22
63
  }
64
+ */
23
65
 
24
- void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
25
- if (seed == LLAMA_DEFAULT_SEED) {
26
- seed = time(NULL);
27
- }
28
-
29
- smpl->rng.seed(seed);
30
- }
31
-
32
- void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
33
- GGML_ASSERT(candidates->size > 0);
34
-
35
- const int64_t t_start_sample_us = ggml_time_us();
66
+ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
67
+ GGML_ASSERT(cur_p->size > 0);
36
68
 
37
69
  // Sort the logits in descending order
38
- if (!candidates->sorted) {
39
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
70
+ if (!cur_p->sorted) {
71
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
40
72
  return a.logit > b.logit;
41
73
  });
42
- candidates->sorted = true;
74
+ cur_p->sorted = true;
43
75
  }
44
76
 
45
- float max_l = candidates->data[0].logit;
77
+ float max_l = cur_p->data[0].logit;
46
78
  float cum_sum = 0.0f;
47
- for (size_t i = 0; i < candidates->size; ++i) {
48
- float p = expf(candidates->data[i].logit - max_l);
49
- candidates->data[i].p = p;
79
+
80
+ for (size_t i = 0; i < cur_p->size; ++i) {
81
+ float p = expf(cur_p->data[i].logit - max_l);
82
+ cur_p->data[i].p = p;
50
83
  cum_sum += p;
51
84
  }
52
- for (size_t i = 0; i < candidates->size; ++i) {
53
- candidates->data[i].p /= cum_sum;
54
- }
55
85
 
56
- if (smpl) {
57
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
86
+ for (size_t i = 0; i < cur_p->size; ++i) {
87
+ cur_p->data[i].p /= cum_sum;
58
88
  }
59
89
  }
60
90
 
61
- void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
91
+ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
62
92
  // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
63
- // if (k >= (int32_t)candidates->size) {
93
+ // if (k >= (int32_t)cur_p->size) {
64
94
  // return;
65
95
  // }
66
96
 
67
- const int64_t t_start_sample_us = ggml_time_us();
68
-
69
97
  if (k <= 0) {
70
- k = candidates->size;
98
+ k = cur_p->size;
71
99
  }
72
100
 
73
- k = std::max(k, (int) min_keep);
74
- k = std::min(k, (int) candidates->size);
101
+ k = std::min(k, (int) cur_p->size);
75
102
 
76
103
  // Sort scores in descending order
77
- if (!candidates->sorted) {
104
+ if (!cur_p->sorted) {
78
105
  auto comp = [](const llama_token_data & a, const llama_token_data & b) {
79
106
  return a.logit > b.logit;
80
107
  };
81
108
  if (k <= 128) {
82
- std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
109
+ std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
83
110
  } else {
84
111
  constexpr int nbuckets = 128;
85
112
  constexpr float bucket_low = -10.0f;
86
113
  constexpr float bucket_high = 10.0f;
87
114
  constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
88
- constexpr float bucker_inter = -bucket_low * bucket_scale;
115
+ constexpr float bucket_inter = -bucket_low * bucket_scale;
89
116
 
90
- std::vector<int> bucket_idx(candidates->size);
117
+ std::vector<int> bucket_idx(cur_p->size);
91
118
  std::vector<int> histo(nbuckets, 0);
92
119
 
93
- for (int i = 0; i < (int)candidates->size; ++i) {
94
- const float val = candidates->data[i].logit;
95
- int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
120
+ for (int i = 0; i < (int)cur_p->size; ++i) {
121
+ const float val = cur_p->data[i].logit;
122
+ int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
96
123
  ib = std::max(0, std::min(nbuckets-1, ib));
97
124
  bucket_idx[i] = ib;
98
125
  ++histo[ib];
@@ -101,20 +128,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
101
128
  int ib = nbuckets - 1;
102
129
  for ( ; ib >= 0; --ib) {
103
130
  nhave += histo[ib];
104
- if (nhave >= k) break;
131
+ if (nhave >= k) {
132
+ break;
133
+ }
105
134
  }
106
135
  std::vector<llama_token_data> tmp_tokens(nhave);
107
- auto ptr = tmp_tokens.data();
136
+ auto * ptr = tmp_tokens.data();
108
137
  std::vector<llama_token_data*> bucket_ptrs;
109
138
  bucket_ptrs.reserve(nbuckets - ib);
110
139
  for (int j = nbuckets - 1; j >= ib; --j) {
111
140
  bucket_ptrs.push_back(ptr);
112
141
  ptr += histo[j];
113
142
  }
114
- for (int i = 0; i < (int)candidates->size; ++i) {
143
+ for (int i = 0; i < (int)cur_p->size; ++i) {
115
144
  int j = bucket_idx[i];
116
145
  if (j >= ib) {
117
- *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
146
+ *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
118
147
  }
119
148
  }
120
149
 
@@ -127,125 +156,582 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
127
156
  }
128
157
  std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
129
158
 
130
- std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
159
+ std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
131
160
 
132
161
  }
133
- candidates->sorted = true;
162
+ cur_p->sorted = true;
134
163
  }
135
- candidates->size = k;
164
+ cur_p->size = k;
165
+ }
136
166
 
137
- if (smpl) {
138
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
167
+ static uint32_t get_rng_seed(uint32_t seed) {
168
+ if (seed == LLAMA_DEFAULT_SEED) {
169
+ // use system clock if std::random_device is not a true RNG
170
+ static bool is_rd_prng = std::random_device().entropy() == 0;
171
+ if (is_rd_prng) {
172
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
173
+ }
174
+ std::random_device rd;
175
+ return rd();
139
176
  }
177
+ return seed;
140
178
  }
141
179
 
142
- void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143
- if (p >= 1.0f) {
180
+ // llama_sampler API
181
+
182
+ const char * llama_sampler_name(const struct llama_sampler * smpl) {
183
+ if (!smpl->iface) {
184
+ return "(null)";
185
+ }
186
+
187
+ return smpl->iface->name(smpl);
188
+ }
189
+
190
+ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
191
+ if (smpl->iface->accept) {
192
+ smpl->iface->accept(smpl, token);
193
+ }
194
+ }
195
+
196
+ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
197
+ GGML_ASSERT(smpl->iface->apply);
198
+ smpl->iface->apply(smpl, cur_p);
199
+ }
200
+
201
+ void llama_sampler_reset(struct llama_sampler * smpl) {
202
+ if (smpl->iface->reset) {
203
+ smpl->iface->reset(smpl);
204
+ }
205
+ }
206
+
207
+ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
208
+ if (smpl->iface->clone) {
209
+ return smpl->iface->clone(smpl);
210
+ }
211
+
212
+ if (smpl->ctx == nullptr) {
213
+ return new llama_sampler {
214
+ /* .iface = */ smpl->iface,
215
+ /* .ctx = */ nullptr,
216
+ };
217
+ }
218
+
219
+ GGML_ABORT("the sampler does not support cloning");
220
+ }
221
+
222
+ void llama_sampler_free(struct llama_sampler * smpl) {
223
+ if (smpl == nullptr) {
144
224
  return;
145
225
  }
146
226
 
147
- llama_sample_softmax_impl(smpl, candidates);
227
+ if (smpl->iface->free) {
228
+ smpl->iface->free(smpl);
229
+ }
230
+
231
+ delete smpl;
232
+ }
233
+
234
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
235
+ const auto * logits = llama_get_logits_ith(ctx, idx);
236
+
237
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
238
+
239
+ // TODO: do not allocate each time
240
+ std::vector<llama_token_data> cur;
241
+ cur.reserve(n_vocab);
242
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
243
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
244
+ }
245
+
246
+ llama_token_data_array cur_p = {
247
+ /* .data = */ cur.data(),
248
+ /* .size = */ cur.size(),
249
+ /* .selected = */ -1,
250
+ /* .sorted = */ false,
251
+ };
252
+
253
+ llama_sampler_apply(smpl, &cur_p);
254
+
255
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
256
+
257
+ auto token = cur_p.data[cur_p.selected].id;
258
+
259
+ llama_sampler_accept(smpl, token);
260
+
261
+ return token;
262
+ }
263
+
264
+ // sampler chain
265
+
266
+ static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
267
+ return "chain";
268
+ }
269
+
270
+ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
271
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
272
+
273
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
274
+
275
+ for (auto * smpl : chain->samplers) {
276
+ llama_sampler_accept(smpl, token);
277
+ }
278
+
279
+ chain->n_sample++;
280
+ }
281
+
282
+ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
283
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
284
+
285
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
286
+
287
+ for (auto * smpl : chain->samplers) {
288
+ llama_sampler_apply(smpl, cur_p);
289
+ }
290
+ }
291
+
292
+ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
293
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
294
+
295
+ for (auto * smpl : chain->samplers) {
296
+ llama_sampler_reset(smpl);
297
+ }
298
+
299
+ chain->t_sample_us = 0;
300
+ chain->n_sample = 0;
301
+ }
302
+
303
+ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
304
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
305
+
306
+ auto * result = llama_sampler_chain_init(chain_src->params);
307
+
308
+ for (auto * smpl : chain_src->samplers) {
309
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
310
+ }
311
+
312
+ return result;
313
+ }
314
+
315
+ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
316
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
317
+
318
+ for (auto * smpl : chain->samplers) {
319
+ llama_sampler_free(smpl);
320
+ }
321
+
322
+ delete chain;
323
+ }
324
+
325
+ static struct llama_sampler_i llama_sampler_chain_i = {
326
+ /* .name = */ llama_sampler_chain_name,
327
+ /* .accept = */ llama_sampler_chain_accept,
328
+ /* .apply = */ llama_sampler_chain_apply,
329
+ /* .reset = */ llama_sampler_chain_reset,
330
+ /* .clone = */ llama_sampler_chain_clone,
331
+ /* .free = */ llama_sampler_chain_free,
332
+ };
333
+
334
+ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
335
+ return new llama_sampler {
336
+ /* .iface = */ &llama_sampler_chain_i,
337
+ /* .ctx = */ new llama_sampler_chain {
338
+ /* .params = */ params,
339
+ /* .samplers = */ {},
340
+ /* .t_sample_us = */ 0,
341
+ /* .n_sample = */ 0,
342
+ },
343
+ };
344
+ }
345
+
346
+ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
347
+ auto * p = (llama_sampler_chain *) chain->ctx;
348
+ p->samplers.push_back(smpl);
349
+ }
350
+
351
+ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
352
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
353
+
354
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
355
+ return nullptr;
356
+ }
357
+
358
+ return p->samplers[i];
359
+ }
360
+
361
+ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
362
+ auto * p = (llama_sampler_chain *) chain->ctx;
363
+
364
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
365
+ return nullptr;
366
+ }
367
+
368
+ auto * result = p->samplers[i];
369
+ p->samplers.erase(p->samplers.begin() + i);
370
+
371
+ return result;
372
+ }
373
+
374
+ int llama_sampler_chain_n(const struct llama_sampler * chain) {
375
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
376
+
377
+ return p->samplers.size();
378
+ }
379
+
380
+ //
381
+ // samplers
382
+ //
383
+
384
+ // greedy
385
+
386
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
387
+ return "greedy";
388
+ }
389
+
390
+ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
391
+ cur_p->selected = 0;
392
+ for (size_t i = 1; i < cur_p->size; ++i) {
393
+ if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
394
+ cur_p->selected = i;
395
+ }
396
+ }
397
+ }
398
+
399
+ static struct llama_sampler_i llama_sampler_greedy_i = {
400
+ /* .name = */ llama_sampler_greedy_name,
401
+ /* .accept = */ nullptr,
402
+ /* .apply = */ llama_sampler_greedy_apply,
403
+ /* .reset = */ nullptr,
404
+ /* .clone = */ nullptr,
405
+ /* .free = */ nullptr,
406
+ };
407
+
408
+ struct llama_sampler * llama_sampler_init_greedy() {
409
+ return new llama_sampler {
410
+ /* .iface = */ &llama_sampler_greedy_i,
411
+ /* .ctx = */ nullptr,
412
+ };
413
+ }
414
+
415
+ // dist
416
+
417
+ struct llama_sampler_dist {
418
+ const uint32_t seed;
419
+ uint32_t seed_cur;
420
+
421
+ std::mt19937 rng;
422
+ };
423
+
424
+ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
425
+ return "dist";
426
+ }
427
+
428
+ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
429
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
430
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
431
+ }
432
+
433
+ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
434
+ const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
435
+ auto * result = llama_sampler_init_dist(ctx->seed);
436
+
437
+ // copy the state
438
+ {
439
+ auto * result_ctx = (llama_sampler_dist *) result->ctx;
440
+
441
+ result_ctx->rng = ctx->rng;
442
+ }
443
+
444
+ return result;
445
+ }
446
+
447
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
448
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
449
+ ctx->seed_cur = get_rng_seed(ctx->seed);
450
+ ctx->rng.seed(ctx->seed_cur);
451
+ }
452
+
453
+ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
454
+ delete (llama_sampler_dist *) smpl->ctx;
455
+ }
456
+
457
+ static struct llama_sampler_i llama_sampler_dist_i = {
458
+ /* .name = */ llama_sampler_dist_name,
459
+ /* .accept = */ nullptr,
460
+ /* .apply = */ llama_sampler_dist_apply,
461
+ /* .reset = */ llama_sampler_dist_reset,
462
+ /* .clone = */ llama_sampler_dist_clone,
463
+ /* .free = */ llama_sampler_dist_free,
464
+ };
465
+
466
+ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
467
+ auto seed_cur = get_rng_seed(seed);
468
+ return new llama_sampler {
469
+ /* .iface = */ &llama_sampler_dist_i,
470
+ /* .ctx = */ new llama_sampler_dist {
471
+ /* .seed = */ seed,
472
+ /* .seed_cur = */ seed_cur,
473
+ /* .rng = */ std::mt19937(seed_cur),
474
+ },
475
+ };
476
+ }
477
+
478
+ // softmax
479
+
480
+ static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
481
+ return "softmax";
482
+ }
483
+
484
+ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
485
+ llama_sampler_softmax_impl(cur_p);
486
+ }
487
+
488
+ static struct llama_sampler_i llama_sampler_softmax_i = {
489
+ /* .name = */ llama_sampler_softmax_name,
490
+ /* .accept = */ nullptr,
491
+ /* .apply = */ llama_sampler_softmax_apply,
492
+ /* .reset = */ nullptr,
493
+ /* .clone = */ nullptr,
494
+ /* .free = */ nullptr,
495
+ };
496
+
497
+ struct llama_sampler * llama_sampler_init_softmax() {
498
+ return new llama_sampler {
499
+ /* .iface = */ &llama_sampler_softmax_i,
500
+ /* .ctx = */ nullptr,
501
+ };
502
+ }
503
+
504
+ // top-k
505
+
506
+ struct llama_sampler_top_k {
507
+ const int32_t k;
508
+ };
509
+
510
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
511
+ return "top-k";
512
+ }
513
+
514
+ static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
515
+ const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
516
+ llama_sampler_top_k_impl(cur_p, ctx->k);
517
+ }
518
+
519
+ static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
520
+ const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
521
+ return llama_sampler_init_top_k(ctx->k);
522
+ }
523
+
524
+ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
525
+ delete (llama_sampler_top_k *) smpl->ctx;
526
+ }
527
+
528
+ static struct llama_sampler_i llama_sampler_top_k_i = {
529
+ /* .name = */ llama_sampler_top_k_name,
530
+ /* .accept = */ nullptr,
531
+ /* .apply = */ llama_sampler_top_k_apply,
532
+ /* .reset = */ nullptr,
533
+ /* .clone = */ llama_sampler_top_k_clone,
534
+ /* .free = */ llama_sampler_top_k_free,
535
+ };
536
+
537
+ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
538
+ return new llama_sampler {
539
+ /* .iface = */ &llama_sampler_top_k_i,
540
+ /* .ctx = */ new llama_sampler_top_k {
541
+ /* .k = */ k,
542
+ },
543
+ };
544
+ }
545
+
546
+ // top-p
547
+
548
+ struct llama_sampler_top_p {
549
+ const float p;
550
+ const size_t min_keep;
551
+ };
552
+
553
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
554
+ return "top-p";
555
+ }
148
556
 
149
- const int64_t t_start_sample_us = ggml_time_us();
557
+ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
558
+ const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
559
+
560
+ if (ctx->p >= 1.0f) {
561
+ return;
562
+ }
563
+
564
+ llama_sampler_softmax_impl(cur_p);
150
565
 
151
566
  // Compute the cumulative probabilities
152
567
  float cum_sum = 0.0f;
153
- size_t last_idx = candidates->size;
568
+ size_t last_idx = cur_p->size;
154
569
 
155
- for (size_t i = 0; i < candidates->size; ++i) {
156
- cum_sum += candidates->data[i].p;
570
+ for (size_t i = 0; i < cur_p->size; ++i) {
571
+ cum_sum += cur_p->data[i].p;
157
572
 
158
573
  // Check if the running sum is at least p or if we have kept at least min_keep tokens
159
574
  // we set the last index to i+1 to indicate that the current iterate should be included in the set
160
- if (cum_sum >= p && i + 1 >= min_keep) {
575
+ if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
161
576
  last_idx = i + 1;
162
577
  break;
163
578
  }
164
579
  }
165
580
 
166
581
  // Resize the output vector to keep only the top-p tokens
167
- candidates->size = last_idx;
582
+ cur_p->size = last_idx;
583
+ }
168
584
 
169
- if (smpl) {
170
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
171
- }
585
+ static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
586
+ const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
587
+ return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
588
+ }
589
+
590
+ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
591
+ delete (llama_sampler_top_p *) smpl->ctx;
592
+ }
593
+
594
+ static struct llama_sampler_i llama_sampler_top_p_i = {
595
+ /* .name = */ llama_sampler_top_p_name,
596
+ /* .accept = */ nullptr,
597
+ /* .apply = */ llama_sampler_top_p_apply,
598
+ /* .reset = */ nullptr,
599
+ /* .clone = */ llama_sampler_top_p_clone,
600
+ /* .free = */ llama_sampler_top_p_free,
601
+ };
602
+
603
+ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
604
+ return new llama_sampler {
605
+ /* .iface = */ &llama_sampler_top_p_i,
606
+ /* .ctx = */ new llama_sampler_top_p {
607
+ /* .p = */ p,
608
+ /* .min_keep = */ min_keep,
609
+ },
610
+ };
611
+ }
612
+
613
+ // min-p
614
+
615
+ struct llama_sampler_min_p {
616
+ const float p;
617
+ const size_t min_keep;
618
+ };
619
+
620
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
621
+ return "min-p";
172
622
  }
173
623
 
174
- void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
175
- if (p <= 0.0f || !candidates->size) {
624
+ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
625
+ const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
626
+
627
+ if (ctx->p <= 0.0f || !cur_p->size) {
176
628
  return;
177
629
  }
178
630
 
179
- const int64_t t_start_sample_us = ggml_time_us();
180
-
181
631
  bool min_p_applied = false;
182
632
 
183
- // if the candidates aren't sorted, try the unsorted implementation first
184
- if (!candidates->sorted) {
633
+ // if the cur_p aren't sorted, try the unsorted implementation first
634
+ if (!cur_p->sorted) {
185
635
  std::vector<llama_token_data> filtered_tokens;
186
636
 
187
637
  float max_logit = -FLT_MAX;
188
- for (size_t i = 0; i < candidates->size; ++i) {
189
- max_logit = std::max(max_logit, candidates->data[i].logit);
638
+ for (size_t i = 0; i < cur_p->size; ++i) {
639
+ max_logit = std::max(max_logit, cur_p->data[i].logit);
190
640
  }
191
- const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
641
+ const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
192
642
 
193
- for (size_t i = 0; i < candidates->size; ++i) {
194
- if (candidates->data[i].logit >= min_logit) {
195
- filtered_tokens.push_back(candidates->data[i]);
643
+ for (size_t i = 0; i < cur_p->size; ++i) {
644
+ if (cur_p->data[i].logit >= min_logit) {
645
+ filtered_tokens.push_back(cur_p->data[i]);
196
646
  }
197
647
  }
198
648
 
199
649
  // if we have enough values the operation was a success
200
- if (filtered_tokens.size() >= min_keep) {
201
- memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
202
- candidates->size = filtered_tokens.size();
650
+ if (filtered_tokens.size() >= ctx->min_keep) {
651
+ memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
652
+ cur_p->size = filtered_tokens.size();
203
653
  min_p_applied = true;
204
654
  }
205
655
  }
206
656
 
207
- // if the candidates are sorted or the unsorted implementation failed, use this implementation
657
+ // if the cur_p are sorted or the unsorted implementation failed, use this implementation
208
658
  if (!min_p_applied) {
209
659
  // Sort the logits in descending order
210
- if (!candidates->sorted) {
211
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
660
+ if (!cur_p->sorted) {
661
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
212
662
  return a.logit > b.logit;
213
663
  });
214
- candidates->sorted = true;
664
+ cur_p->sorted = true;
215
665
  }
216
666
 
217
- const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
667
+ const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
218
668
  size_t i = 1; // first token always matches
219
669
 
220
- for (; i < candidates->size; ++i) {
221
- if (candidates->data[i].logit < min_logit && i >= min_keep) {
670
+ for (; i < cur_p->size; ++i) {
671
+ if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
222
672
  break; // prob too small
223
673
  }
224
674
  }
225
675
 
226
676
  // Resize the output vector to keep only the matching tokens
227
- candidates->size = i;
677
+ cur_p->size = i;
228
678
  }
679
+ }
229
680
 
230
- if (smpl) {
231
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
232
- }
681
+ static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
682
+ const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
683
+ return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
233
684
  }
234
685
 
235
- void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
236
- if (z >= 1.0f || candidates->size <= 2) {
686
+ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
687
+ delete (llama_sampler_min_p *) smpl->ctx;
688
+ }
689
+
690
+ static struct llama_sampler_i llama_sampler_min_p_i = {
691
+ /* .name = */ llama_sampler_min_p_name,
692
+ /* .accept = */ nullptr,
693
+ /* .apply = */ llama_sampler_min_p_apply,
694
+ /* .reset = */ nullptr,
695
+ /* .clone = */ llama_sampler_min_p_clone,
696
+ /* .free = */ llama_sampler_min_p_free,
697
+ };
698
+
699
+ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
700
+ return new llama_sampler {
701
+ /* .iface = */ &llama_sampler_min_p_i,
702
+ /* .ctx = */ new llama_sampler_min_p {
703
+ /* .p = */ p,
704
+ /* .min_keep = */ min_keep,
705
+ },
706
+ };
707
+ }
708
+
709
+ // tail-free
710
+
711
+ struct llama_sampler_tail_free {
712
+ const float z;
713
+ const size_t min_keep;
714
+ };
715
+
716
+ static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
717
+ return "tail-free";
718
+ }
719
+
720
+ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
721
+ const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
722
+
723
+ if (ctx->z >= 1.0f || cur_p->size <= 2) {
237
724
  return;
238
725
  }
239
726
 
240
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
241
- const int64_t t_start_sample_us = ggml_time_us();
727
+ llama_sampler_softmax_impl(cur_p);
242
728
 
243
729
  // Compute the first and second derivatives
244
- std::vector<float> first_derivatives(candidates->size - 1);
245
- std::vector<float> second_derivatives(candidates->size - 2);
730
+ std::vector<float> first_derivatives(cur_p->size - 1);
731
+ std::vector<float> second_derivatives(cur_p->size - 2);
246
732
 
247
733
  for (size_t i = 0; i < first_derivatives.size(); ++i) {
248
- first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
734
+ first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
249
735
  }
250
736
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
251
737
  second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
@@ -272,51 +758,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
272
758
  }
273
759
 
274
760
  float cum_sum = 0.0f;
275
- size_t last_idx = candidates->size;
761
+ size_t last_idx = cur_p->size;
276
762
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
277
763
  cum_sum += second_derivatives[i];
278
764
 
279
765
  // Check if the running sum is greater than z or if we have kept at least min_keep tokens
280
- if (cum_sum > z && i >= min_keep) {
766
+ if (cum_sum > ctx->z && i >= ctx->min_keep) {
281
767
  last_idx = i;
282
768
  break;
283
769
  }
284
770
  }
285
771
 
286
772
  // Resize the output vector to keep only the tokens above the tail location
287
- candidates->size = last_idx;
773
+ cur_p->size = last_idx;
774
+ }
288
775
 
289
- if (smpl) {
290
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
291
- }
776
+ static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
777
+ const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
778
+ return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
779
+ }
780
+
781
+ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
782
+ delete (llama_sampler_tail_free *) smpl->ctx;
783
+ }
784
+
785
+ static struct llama_sampler_i llama_sampler_tail_free_i = {
786
+ /* .name = */ llama_sampler_tail_free_name,
787
+ /* .accept = */ nullptr,
788
+ /* .apply = */ llama_sampler_tail_free_apply,
789
+ /* .reset = */ nullptr,
790
+ /* .clone = */ llama_sampler_tail_free_clone,
791
+ /* .free = */ llama_sampler_tail_free_free,
792
+ };
793
+
794
+ struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
795
+ return new llama_sampler {
796
+ /* .iface = */ &llama_sampler_tail_free_i,
797
+ /* .ctx = */ new llama_sampler_tail_free {
798
+ /* .z = */ z,
799
+ /*. min_keep = */ min_keep,
800
+ },
801
+ };
802
+ }
803
+
804
+ // typical
805
+
806
+ struct llama_sampler_typical {
807
+ const float p;
808
+ const size_t min_keep;
809
+ };
810
+
811
+ static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
812
+ return "typical";
292
813
  }
293
814
 
294
- void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
815
+ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
816
+ const auto * ctx = (llama_sampler_typical *) smpl->ctx;
817
+
295
818
  // Reference implementation:
296
819
  // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
297
- if (p >= 1.0f) {
820
+ if (ctx->p >= 1.0f) {
298
821
  return;
299
822
  }
300
823
 
301
824
  // Compute the softmax of logits and calculate entropy
302
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
303
-
304
- const int64_t t_start_sample_us = ggml_time_us();
825
+ llama_sampler_softmax_impl(cur_p);
305
826
 
306
827
  float entropy = 0.0f;
307
- for (size_t i = 0; i < candidates->size; ++i) {
308
- entropy += -candidates->data[i].p * logf(candidates->data[i].p);
828
+ for (size_t i = 0; i < cur_p->size; ++i) {
829
+ entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
309
830
  }
310
831
 
311
832
  // Compute the absolute difference between negative log probability and entropy for each candidate
312
833
  std::vector<float> shifted_scores;
313
- for (size_t i = 0; i < candidates->size; ++i) {
314
- float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
834
+ for (size_t i = 0; i < cur_p->size; ++i) {
835
+ float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
315
836
  shifted_scores.push_back(shifted_score);
316
837
  }
317
838
 
318
839
  // Sort tokens based on the shifted_scores and their corresponding indices
319
- std::vector<size_t> indices(candidates->size);
840
+ std::vector<size_t> indices(cur_p->size);
320
841
  std::iota(indices.begin(), indices.end(), 0);
321
842
 
322
843
  std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
@@ -329,134 +850,618 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
329
850
 
330
851
  for (size_t i = 0; i < indices.size(); ++i) {
331
852
  size_t idx = indices[i];
332
- cum_sum += candidates->data[idx].p;
853
+ cum_sum += cur_p->data[idx].p;
333
854
 
334
855
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
335
- if (cum_sum > p && i >= min_keep - 1) {
856
+ if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
336
857
  last_idx = i + 1;
337
858
  break;
338
859
  }
339
860
  }
340
861
 
341
862
  // Resize the output vector to keep only the locally typical tokens
342
- std::vector<llama_token_data> new_candidates;
863
+ std::vector<llama_token_data> cur_p_new;
343
864
  for (size_t i = 0; i < last_idx; ++i) {
344
865
  size_t idx = indices[i];
345
- new_candidates.push_back(candidates->data[idx]);
866
+ cur_p_new.push_back(cur_p->data[idx]);
346
867
  }
347
868
 
348
- // Replace the data in candidates with the new_candidates data
349
- std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
350
- candidates->size = new_candidates.size();
351
- candidates->sorted = false;
869
+ // Replace the data in cur_p with the cur_p_new data
870
+ std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
871
+ cur_p->size = cur_p_new.size();
872
+ cur_p->sorted = false;
873
+ }
352
874
 
353
- if (smpl) {
354
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
355
- }
875
+ static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
876
+ const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
877
+ return llama_sampler_init_typical(ctx->p, ctx->min_keep);
356
878
  }
357
879
 
358
- void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
359
- const int64_t t_start_sample_us = ggml_time_us();
880
+ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
881
+ delete (llama_sampler_typical *) smpl->ctx;
882
+ }
360
883
 
361
- // no need to do anything if there is only one (or zero) candidates
362
- if(candidates->size <= 1) {
363
- return;
884
+ static struct llama_sampler_i llama_sampler_typical_i = {
885
+ /* .name = */ llama_sampler_typical_name,
886
+ /* .accept = */ nullptr,
887
+ /* .apply = */ llama_sampler_typical_apply,
888
+ /* .reset = */ nullptr,
889
+ /* .clone = */ llama_sampler_typical_clone,
890
+ /* .free = */ llama_sampler_typical_free,
891
+ };
892
+
893
+ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
894
+ return new llama_sampler {
895
+ /* .iface = */ &llama_sampler_typical_i,
896
+ /* .ctx = */ new llama_sampler_typical {
897
+ /* .p = */ p,
898
+ /* .min_keep = */ min_keep,
899
+ },
900
+ };
901
+ }
902
+
903
+ // temp
904
+
905
+ struct llama_sampler_temp {
906
+ const float temp;
907
+ };
908
+
909
+ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
910
+ return "temp";
911
+ }
912
+
913
+ static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
914
+ const auto * ctx = (llama_sampler_temp *) smpl->ctx;
915
+ for (size_t i = 0; i < cur_p->size; ++i) {
916
+ cur_p->data[i].logit /= ctx->temp;
364
917
  }
918
+ }
365
919
 
366
- // Calculate maximum possible entropy
367
- float max_entropy = -logf(1.0f / candidates->size);
920
+ static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
921
+ const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
922
+ return llama_sampler_init_temp(ctx->temp);
923
+ }
368
924
 
369
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
925
+ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
926
+ delete (llama_sampler_temp *) smpl->ctx;
927
+ }
370
928
 
371
- // Calculate entropy of the softmax probabilities
372
- float entropy = 0.0f;
373
- for (size_t i = 0; i < candidates->size; ++i) {
374
- float prob = candidates->data[i].p;
375
- if (prob > 0.0f) { // Ensure no log(0)
376
- entropy -= prob * logf(prob);
929
+ static struct llama_sampler_i llama_sampler_temp_i = {
930
+ /* .name = */ llama_sampler_temp_name,
931
+ /* .accept = */ nullptr,
932
+ /* .apply = */ llama_sampler_temp_apply,
933
+ /* .reset = */ nullptr,
934
+ /* .clone = */ llama_sampler_temp_clone,
935
+ /* .free = */ llama_sampler_temp_free,
936
+ };
937
+
938
+ struct llama_sampler * llama_sampler_init_temp(float temp) {
939
+ return new llama_sampler {
940
+ /* .iface = */ &llama_sampler_temp_i,
941
+ /* .ctx = */ new llama_sampler_temp {
942
+ /*.temp = */ temp,
943
+ },
944
+ };
945
+ }
946
+
947
+ // temp-ext
948
+
949
+ struct llama_sampler_temp_ext {
950
+ const float temp;
951
+ const float delta;
952
+ const float exponent;
953
+ };
954
+
955
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
956
+ return "temp-ext";
957
+ }
958
+
959
+ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
960
+ const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
961
+ if (ctx->delta > 0) {
962
+ const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
963
+ const float max_temp = ctx->temp + ctx->delta;
964
+ float exponent_val = ctx->exponent;
965
+
966
+ // no need to do anything if there is only one (or zero) candidates
967
+ if (cur_p->size <= 1) {
968
+ return;
969
+ }
970
+
971
+ // Calculate maximum possible entropy
972
+ float max_entropy = -logf(1.0f / cur_p->size);
973
+
974
+ llama_sampler_softmax_impl(cur_p);
975
+
976
+ // Calculate entropy of the softmax probabilities
977
+ float entropy = 0.0f;
978
+ for (size_t i = 0; i < cur_p->size; ++i) {
979
+ float prob = cur_p->data[i].p;
980
+ if (prob > 0.0f) { // Ensure no log(0)
981
+ entropy -= prob * logf(prob);
982
+ }
983
+ }
984
+
985
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
986
+ float normalized_entropy = entropy / max_entropy;
987
+
988
+ // Map the normalized entropy to the desired temperature range using the power function
989
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
990
+
991
+ #ifdef DEBUG
992
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
993
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
994
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
995
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
996
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
997
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
998
+ #endif
999
+
1000
+ // Apply the dynamically calculated temperature scaling
1001
+ for (size_t i = 0; i < cur_p->size; ++i) {
1002
+ cur_p->data[i].logit /= dyn_temp;
1003
+ }
1004
+
1005
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
1006
+ const double max_l_double = cur_p->data[0].logit;
1007
+
1008
+ double cum_sum_double = 0.0;
1009
+ for (size_t i = 0; i < cur_p->size; ++i) {
1010
+ double p = exp(cur_p->data[i].logit - max_l_double);
1011
+ cur_p->data[i].p = p; // Store the scaled probability
1012
+ cum_sum_double += p;
1013
+ }
1014
+
1015
+ for (size_t i = 0; i < cur_p->size; ++i) {
1016
+ cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1017
+ }
1018
+
1019
+ #ifdef DEBUG
1020
+ // Print the updated top 25 probabilities after temperature scaling
1021
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1022
+ for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1023
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1024
+ }
1025
+ #endif
1026
+ } else {
1027
+ for (size_t i = 0; i < cur_p->size; ++i) {
1028
+ cur_p->data[i].logit /= ctx->temp;
377
1029
  }
378
1030
  }
1031
+ }
379
1032
 
380
- // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
381
- float normalized_entropy = entropy / max_entropy;
1033
+ static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1034
+ const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1035
+ return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1036
+ }
382
1037
 
383
- // Map the normalized entropy to the desired temperature range using the power function
384
- float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1038
+ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1039
+ delete (llama_sampler_temp_ext *) smpl->ctx;
1040
+ }
385
1041
 
386
- #ifdef DEBUG
387
- LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
388
- LLAMA_LOG_INFO("Entropy: %f\n", entropy);
389
- LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
390
- LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
391
- LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
392
- LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
393
- #endif
1042
+ static struct llama_sampler_i llama_sampler_temp_ext_i = {
1043
+ /* .name = */ llama_sampler_temp_ext_name,
1044
+ /* .accept = */ nullptr,
1045
+ /* .apply = */ llama_sampler_temp_ext_apply,
1046
+ /* .reset = */ nullptr,
1047
+ /* .clone = */ llama_sampler_temp_ext_clone,
1048
+ /* .free = */ llama_sampler_temp_ext_free,
1049
+ };
1050
+
1051
+ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1052
+ return new llama_sampler {
1053
+ /* .iface = */ &llama_sampler_temp_ext_i,
1054
+ /* .ctx = */ new llama_sampler_temp_ext {
1055
+ /* .temp = */ temp,
1056
+ /* .delta = */ delta,
1057
+ /* .exponent = */ exponent,
1058
+ },
1059
+ };
1060
+ }
1061
+
1062
+ // mirostat
1063
+
1064
+ struct llama_sampler_mirostat {
1065
+ const int32_t n_vocab;
1066
+
1067
+ const uint32_t seed;
1068
+ uint32_t seed_cur;
1069
+
1070
+ const float tau;
1071
+ const float eta;
1072
+
1073
+ const int32_t m;
1074
+
1075
+ float mu;
394
1076
 
395
- // Apply the dynamically calculated temperature scaling
396
- for (size_t i = 0; i < candidates->size; ++i) {
397
- candidates->data[i].logit /= dyn_temp;
1077
+ std::mt19937 rng;
1078
+ };
1079
+
1080
+ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1081
+ return "mirostat";
1082
+ }
1083
+
1084
+ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1085
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1086
+
1087
+ llama_sampler_softmax_impl(cur_p);
1088
+
1089
+ // Estimate s_hat using the most probable m tokens
1090
+ float s_hat = 0.0;
1091
+ float sum_ti_bi = 0.0;
1092
+ float sum_ti_sq = 0.0;
1093
+ for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
1094
+ float t_i = logf(float(i + 2) / float(i + 1));
1095
+ float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
1096
+ sum_ti_bi += t_i * b_i;
1097
+ sum_ti_sq += t_i * t_i;
398
1098
  }
1099
+ s_hat = sum_ti_bi / sum_ti_sq;
1100
+
1101
+ // Compute k from the estimated s_hat and target surprise value
1102
+ float epsilon_hat = s_hat - 1;
1103
+ float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
1104
+
1105
+ llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1106
+ llama_sampler_softmax_impl(cur_p);
1107
+
1108
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1109
+
1110
+ cur_p->selected = idx;
1111
+
1112
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1113
+ float e = observed_surprise - ctx->tau;
399
1114
 
400
- // Re-compute softmax probabilities after scaling logits with dynamic temperature
401
- double max_l_double = candidates->data[0].logit;
402
- double cum_sum_double = 0.0;
403
- for (size_t i = 0; i < candidates->size; ++i) {
404
- double p = exp(candidates->data[i].logit - max_l_double);
405
- candidates->data[i].p = p; // Store the scaled probability
406
- cum_sum_double += p;
1115
+ // Update mu using the learning rate and error
1116
+ ctx->mu = ctx->mu - ctx->eta * e;
1117
+ }
1118
+
1119
+ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1120
+ const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1121
+ auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
1122
+
1123
+ // copy the state
1124
+ {
1125
+ auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1126
+
1127
+ result_ctx->mu = ctx->mu;
1128
+ result_ctx->rng = ctx->rng;
407
1129
  }
408
- for (size_t i = 0; i < candidates->size; ++i) {
409
- candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1130
+
1131
+ return result;
1132
+ }
1133
+
1134
+ static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1135
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1136
+ ctx->mu = 2.0f*ctx->tau;
1137
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1138
+ ctx->rng.seed(ctx->seed_cur);
1139
+ }
1140
+
1141
+ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1142
+ delete (llama_sampler_mirostat *) smpl->ctx;
1143
+ }
1144
+
1145
+ static struct llama_sampler_i llama_sampler_mirostat_i = {
1146
+ /* .name = */ llama_sampler_mirostat_name,
1147
+ /* .accept = */ nullptr,
1148
+ /* .apply = */ llama_sampler_mirostat_apply,
1149
+ /* .reset = */ llama_sampler_mirostat_reset,
1150
+ /* .clone = */ llama_sampler_mirostat_clone,
1151
+ /* .free = */ llama_sampler_mirostat_free,
1152
+ };
1153
+
1154
+ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1155
+ auto seed_cur = get_rng_seed(seed);
1156
+ return new llama_sampler {
1157
+ /* .iface = */ &llama_sampler_mirostat_i,
1158
+ /* .ctx = */ new llama_sampler_mirostat {
1159
+ /* .n_vocab = */ n_vocab,
1160
+ /* .seed = */ seed,
1161
+ /* .seed_cur = */ seed_cur,
1162
+ /* .tau = */ tau,
1163
+ /* .eta = */ eta,
1164
+ /* .m = */ m,
1165
+ /* .mu = */ 2.0f*tau,
1166
+ /* .rng = */ std::mt19937(seed_cur),
1167
+ },
1168
+ };
1169
+ }
1170
+
1171
+ // mirostat v2
1172
+
1173
+ struct llama_sampler_mirostat_v2 {
1174
+ const uint32_t seed;
1175
+ uint32_t seed_cur;
1176
+
1177
+ const float tau;
1178
+ const float eta;
1179
+
1180
+ float mu;
1181
+
1182
+ std::mt19937 rng;
1183
+ };
1184
+
1185
+ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1186
+ return "mirostat-v2";
1187
+ }
1188
+
1189
+ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1190
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1191
+
1192
+ llama_sampler_softmax_impl(cur_p);
1193
+
1194
+ // Truncate the words with surprise values greater than mu
1195
+ cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1196
+ return -log2f(candidate.p) > ctx->mu;
1197
+ }));
1198
+
1199
+ if (cur_p->size == 0) {
1200
+ cur_p->size = 1;
410
1201
  }
411
1202
 
412
- #ifdef DEBUG
413
- // Print the updated top 25 probabilities after temperature scaling
414
- LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
415
- for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
416
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
1203
+ // Normalize the probabilities of the remaining words
1204
+ llama_sampler_softmax_impl(cur_p);
1205
+
1206
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1207
+
1208
+ cur_p->selected = idx;
1209
+
1210
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1211
+ float e = observed_surprise - ctx->tau;
1212
+
1213
+ // Update mu using the learning rate and error
1214
+ ctx->mu = ctx->mu - ctx->eta * e;
1215
+ }
1216
+
1217
+ static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1218
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1219
+ ctx->mu = 2.0f*ctx->tau;
1220
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1221
+ ctx->rng.seed(ctx->seed_cur);
1222
+ }
1223
+
1224
+ static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1225
+ const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1226
+
1227
+ auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
1228
+
1229
+ // copy the state
1230
+ {
1231
+ auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1232
+
1233
+ result_ctx->mu = ctx->mu;
1234
+ result_ctx->rng = ctx->rng;
417
1235
  }
418
- #endif
419
1236
 
420
- if (smpl) {
421
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1237
+ return result;
1238
+ }
1239
+
1240
+ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1241
+ delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1242
+ }
1243
+
1244
+ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1245
+ /* .name = */ llama_sampler_mirostat_v2_name,
1246
+ /* .accept = */ nullptr,
1247
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
1248
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
1249
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
1250
+ /* .free = */ llama_sampler_mirostat_v2_free,
1251
+ };
1252
+
1253
+ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1254
+ auto seed_cur = get_rng_seed(seed);
1255
+ return new llama_sampler {
1256
+ /* .iface = */ &llama_sampler_mirostat_v2_i,
1257
+ /* .ctx = */ new llama_sampler_mirostat_v2 {
1258
+ /* .seed = */ seed,
1259
+ /* .seed_cur = */ seed_cur,
1260
+ /* .tau = */ tau,
1261
+ /* .eta = */ eta,
1262
+ /* .mu = */ 2.0f*tau,
1263
+ /* .rng = */ std::mt19937(seed_cur),
1264
+ },
1265
+ };
1266
+ }
1267
+
1268
+ // grammar
1269
+
1270
+ struct llama_sampler_grammar {
1271
+ const struct llama_vocab * vocab;
1272
+
1273
+ std::string grammar_str;
1274
+ std::string grammar_root;
1275
+
1276
+ struct llama_grammar * grammar;
1277
+ };
1278
+
1279
+ static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1280
+ return "grammar";
1281
+ }
1282
+
1283
+ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1284
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1285
+ if (ctx->grammar) {
1286
+ llama_grammar_accept_impl(*ctx->grammar, token);
1287
+ }
1288
+ }
1289
+
1290
+ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1291
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1292
+ if (ctx->grammar) {
1293
+ llama_grammar_apply_impl(*ctx->grammar, cur_p);
1294
+ }
1295
+ }
1296
+
1297
+ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1298
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1299
+ if (!ctx->grammar) {
1300
+ return;
422
1301
  }
1302
+
1303
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
1304
+
1305
+ llama_grammar_free_impl(ctx->grammar);
1306
+ ctx->grammar = grammar_new;
423
1307
  }
424
1308
 
425
- void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
426
- const int64_t t_start_sample_us = ggml_time_us();
1309
+ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1310
+ const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1311
+
1312
+ auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
1313
+
1314
+ // copy the state
1315
+ {
1316
+ auto * result_ctx = (llama_sampler_grammar *) result->ctx;
427
1317
 
428
- for (size_t i = 0; i < candidates->size; ++i) {
429
- candidates->data[i].logit /= temp;
1318
+ if (ctx->grammar) {
1319
+ result_ctx->grammar_str = ctx->grammar_str;
1320
+ result_ctx->grammar_root = ctx->grammar_root;
1321
+
1322
+ result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1323
+ }
1324
+ }
1325
+
1326
+ return result;
1327
+ }
1328
+
1329
+ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1330
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1331
+
1332
+ if (ctx->grammar) {
1333
+ llama_grammar_free_impl(ctx->grammar);
1334
+ }
1335
+
1336
+ delete ctx;
1337
+ }
1338
+
1339
+ static struct llama_sampler_i llama_sampler_grammar_i = {
1340
+ /* .name = */ llama_sampler_grammar_name,
1341
+ /* .accept = */ llama_sampler_grammar_accept_impl,
1342
+ /* .apply = */ llama_sampler_grammar_apply,
1343
+ /* .reset = */ llama_sampler_grammar_reset,
1344
+ /* .clone = */ llama_sampler_grammar_clone,
1345
+ /* .free = */ llama_sampler_grammar_free,
1346
+ };
1347
+
1348
+ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
1349
+ auto * ctx = new llama_sampler_grammar;
1350
+
1351
+ if (grammar_str != nullptr && grammar_str[0] != '\0') {
1352
+ *ctx = {
1353
+ /* .vocab = */ &vocab,
1354
+ /* .grammar_str = */ grammar_str,
1355
+ /* .grammar_root = */ grammar_root,
1356
+ /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
1357
+ };
1358
+ } else {
1359
+ *ctx = {
1360
+ /* .vocab = */ &vocab,
1361
+ /* .grammar_str = */ {},
1362
+ /* .grammar_root = */ {},
1363
+ /* .grammar = */ nullptr,
1364
+ };
430
1365
  }
431
1366
 
432
- if (smpl) {
433
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1367
+ return new llama_sampler {
1368
+ /* .iface = */ &llama_sampler_grammar_i,
1369
+ /* .ctx = */ ctx,
1370
+ };
1371
+ }
1372
+
1373
+ // penalties
1374
+
1375
+ struct llama_sampler_penalties {
1376
+ const int32_t n_vocab;
1377
+ const llama_token special_eos_id;
1378
+ const llama_token linefeed_id;
1379
+
1380
+ const int32_t penalty_last_n;
1381
+ const float penalty_repeat;
1382
+ const float penalty_freq;
1383
+ const float penalty_present;
1384
+
1385
+ const bool penalize_nl;
1386
+ const bool ignore_eos;
1387
+
1388
+ ring_buffer<llama_token> prev;
1389
+ };
1390
+
1391
+ static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1392
+ return "penalties";
1393
+ }
1394
+
1395
+ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1396
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1397
+ if (ctx->penalty_last_n == 0) {
1398
+ return;
434
1399
  }
1400
+
1401
+ ctx->prev.push_back(token);
435
1402
  }
436
1403
 
437
- void llama_sample_repetition_penalties_impl(
438
- struct llama_sampling * smpl,
439
- llama_token_data_array * candidates,
440
- const llama_token * last_tokens,
441
- size_t penalty_last_n,
442
- float penalty_repeat,
443
- float penalty_freq,
444
- float penalty_present) {
445
- if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
1404
+ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1405
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1406
+
1407
+ if (ctx->ignore_eos) {
1408
+ assert(ctx->special_eos_id >= 0);
1409
+
1410
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1411
+ if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
1412
+ cur_p->data[ctx->special_eos_id].logit = -INFINITY;
1413
+ } else {
1414
+ // else, search for the special EOS token
1415
+ for (size_t i = 0; i < cur_p->size; ++i) {
1416
+ if (cur_p->data[i].id == ctx->special_eos_id) {
1417
+ cur_p->data[i].logit = -INFINITY;
1418
+ break;
1419
+ }
1420
+ }
1421
+ }
1422
+ }
1423
+
1424
+ if ((ctx->penalty_last_n == 0) ||
1425
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
446
1426
  return;
447
1427
  }
448
1428
 
449
- const int64_t t_start_sample_us = ggml_time_us();
1429
+ bool nl_found = false;
1430
+ size_t nl_idx = 0;
1431
+ float nl_logit = -INFINITY;
1432
+ if (!ctx->penalize_nl) {
1433
+ assert(ctx->linefeed_id >= 0);
1434
+
1435
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1436
+ if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
1437
+ nl_found = true;
1438
+ nl_idx = ctx->linefeed_id;
1439
+ nl_logit = cur_p->data[ctx->linefeed_id].logit;
1440
+ } else {
1441
+ // else, search for the linefeed token
1442
+ for (size_t i = 0; i < cur_p->size; ++i) {
1443
+ if (cur_p->data[i].id == ctx->linefeed_id) {
1444
+ nl_found = true;
1445
+ nl_idx = i;
1446
+ nl_logit = cur_p->data[i].logit;
1447
+ break;
1448
+ }
1449
+ }
1450
+ }
1451
+ }
450
1452
 
451
1453
  // Create a frequency map to count occurrences of each token in last_tokens
452
- std::unordered_map<llama_token, int> token_count;
453
- for (size_t i = 0; i < penalty_last_n; ++i) {
454
- token_count[last_tokens[i]]++;
1454
+ // TODO: optimize this by maintaining the token count in the sampler context
1455
+ using llama_token_cnt = std::unordered_map<llama_token, int>;
1456
+ llama_token_cnt token_count;
1457
+
1458
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1459
+ token_count[ctx->prev.rat(i)]++;
455
1460
  }
456
1461
 
457
- // Apply frequency and presence penalties to the candidates
458
- for (size_t i = 0; i < candidates->size; ++i) {
459
- const auto token_iter = token_count.find(candidates->data[i].id);
1462
+ // Apply frequency and presence penalties to the cur_p
1463
+ for (size_t i = 0; i < cur_p->size; ++i) {
1464
+ const auto token_iter = token_count.find(cur_p->data[i].id);
460
1465
  if (token_iter == token_count.end()) {
461
1466
  continue;
462
1467
  }
@@ -465,171 +1470,238 @@ void llama_sample_repetition_penalties_impl(
465
1470
 
466
1471
  // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
467
1472
  // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
468
- if (candidates->data[i].logit <= 0) {
469
- candidates->data[i].logit *= penalty_repeat;
1473
+ if (cur_p->data[i].logit <= 0) {
1474
+ cur_p->data[i].logit *= ctx->penalty_repeat;
470
1475
  } else {
471
- candidates->data[i].logit /= penalty_repeat;
1476
+ cur_p->data[i].logit /= ctx->penalty_repeat;
472
1477
  }
473
1478
 
474
- candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
1479
+ cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
475
1480
  }
476
1481
 
477
- candidates->sorted = false;
1482
+ cur_p->sorted = false;
478
1483
 
479
- if (smpl) {
480
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1484
+ if (!ctx->penalize_nl && nl_found) {
1485
+ // restore the logit of the newline token if it was penalized
1486
+ cur_p->data[nl_idx].logit = nl_logit;
481
1487
  }
482
1488
  }
483
1489
 
484
- void llama_sample_apply_guidance_impl(
485
- struct llama_sampling * smpl,
486
- float * logits,
487
- float * logits_guidance,
488
- float scale) {
489
- GGML_ASSERT(smpl);
490
-
491
- const auto t_start_sample_us = ggml_time_us();
492
- const auto n_vocab = smpl->n_vocab;
493
-
494
- llama_log_softmax(logits, n_vocab);
495
- llama_log_softmax(logits_guidance, n_vocab);
1490
+ static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1491
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1492
+ ctx->prev.clear();
1493
+ }
496
1494
 
497
- for (int i = 0; i < n_vocab; ++i) {
498
- auto & l = logits[i];
499
- const auto & g = logits_guidance[i];
1495
+ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1496
+ const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1497
+ auto * result = llama_sampler_init_penalties(
1498
+ ctx->n_vocab,
1499
+ ctx->special_eos_id,
1500
+ ctx->linefeed_id,
1501
+ ctx->penalty_last_n,
1502
+ ctx->penalty_repeat,
1503
+ ctx->penalty_freq,
1504
+ ctx->penalty_present,
1505
+ ctx->penalize_nl,
1506
+ ctx->ignore_eos);
1507
+
1508
+ // copy the state
1509
+ {
1510
+ auto * result_ctx = (llama_sampler_penalties *) result->ctx;
500
1511
 
501
- l = scale * (l - g) + g;
1512
+ result_ctx->prev = ctx->prev;
502
1513
  }
503
1514
 
504
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1515
+ return result;
505
1516
  }
506
1517
 
507
- llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
508
- GGML_ASSERT(smpl);
509
-
510
- const int32_t n_vocab = float(smpl->n_vocab);
511
-
512
- int64_t t_start_sample_us = ggml_time_us();
1518
+ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1519
+ delete (llama_sampler_penalties *) smpl->ctx;
1520
+ }
513
1521
 
514
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1522
+ static struct llama_sampler_i llama_sampler_penalties_i = {
1523
+ /* .name = */ llama_sampler_penalties_name,
1524
+ /* .accept = */ llama_sampler_penalties_accept,
1525
+ /* .apply = */ llama_sampler_penalties_apply,
1526
+ /* .reset = */ llama_sampler_penalties_reset,
1527
+ /* .clone = */ llama_sampler_penalties_clone,
1528
+ /* .free = */ llama_sampler_penalties_free,
1529
+ };
1530
+
1531
+ struct llama_sampler * llama_sampler_init_penalties(
1532
+ int32_t n_vocab,
1533
+ llama_token special_eos_id,
1534
+ llama_token linefeed_id,
1535
+ int32_t penalty_last_n,
1536
+ float penalty_repeat,
1537
+ float penalty_freq,
1538
+ float penalty_present,
1539
+ bool penalize_nl,
1540
+ bool ignore_eos) {
1541
+ if (linefeed_id == LLAMA_TOKEN_NULL) {
1542
+ penalize_nl = true;
1543
+ }
515
1544
 
516
- // Estimate s_hat using the most probable m tokens
517
- float s_hat = 0.0;
518
- float sum_ti_bi = 0.0;
519
- float sum_ti_sq = 0.0;
520
- for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
521
- float t_i = logf(float(i + 2) / float(i + 1));
522
- float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
523
- sum_ti_bi += t_i * b_i;
524
- sum_ti_sq += t_i * t_i;
1545
+ if (special_eos_id == LLAMA_TOKEN_NULL) {
1546
+ ignore_eos = false;
525
1547
  }
526
- s_hat = sum_ti_bi / sum_ti_sq;
527
1548
 
528
- // Compute k from the estimated s_hat and target surprise value
529
- float epsilon_hat = s_hat - 1;
530
- float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
1549
+ penalty_last_n = std::max(penalty_last_n, 0);
1550
+
1551
+ return new llama_sampler {
1552
+ /* .iface = */ &llama_sampler_penalties_i,
1553
+ /* .ctx = */ new llama_sampler_penalties {
1554
+ /* .n_vocab = */ n_vocab,
1555
+ /* .special_eos_id = */ special_eos_id,
1556
+ /* .linefeed_id = */ linefeed_id,
1557
+ /* .penalty_last_n = */ penalty_last_n,
1558
+ /* .penalty_repeat = */ penalty_repeat,
1559
+ /* .penalty_freq = */ penalty_freq,
1560
+ /* .penalty_present = */ penalty_present,
1561
+ /* .penalize_nl = */ penalize_nl,
1562
+ /* .ignore_eos = */ ignore_eos,
1563
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1564
+ },
1565
+ };
1566
+ }
531
1567
 
532
- // Sample the next word X using top-k sampling
533
- llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
534
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
535
- llama_token X = llama_sample_token_impl(smpl, candidates);
536
- t_start_sample_us = ggml_time_us();
1568
+ // logit-bias
537
1569
 
538
- // Compute error as the difference between observed surprise and target surprise value
539
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
540
- return candidate.id == X;
541
- }));
542
- float observed_surprise = -log2f(candidates->data[X_idx].p);
543
- float e = observed_surprise - tau;
1570
+ struct llama_sampler_logit_bias {
1571
+ const int32_t n_vocab;
544
1572
 
545
- // Update mu using the learning rate and error
546
- *mu = *mu - eta * e;
1573
+ const std::vector<llama_logit_bias> logit_bias;
547
1574
 
548
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
549
- return X;
1575
+ std::vector<llama_logit_bias> to_search;
1576
+ };
1577
+
1578
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
1579
+ return "logit-bias";
550
1580
  }
551
1581
 
552
- llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
553
- int64_t t_start_sample_us;
554
- t_start_sample_us = ggml_time_us();
1582
+ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1583
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
1584
+
1585
+ if (ctx->logit_bias.empty()) {
1586
+ return;
1587
+ }
555
1588
 
556
- llama_sample_softmax_impl(smpl, candidates);
1589
+ ctx->to_search.clear();
557
1590
 
558
- // Truncate the words with surprise values greater than mu
559
- candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
560
- return -log2f(candidate.p) > *mu;
561
- }));
1591
+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
1592
+ for (const auto & lb : ctx->logit_bias) {
1593
+ if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
1594
+ cur_p->data[lb.token].logit += lb.bias;
1595
+ } else {
1596
+ ctx->to_search.push_back(lb);
1597
+ }
1598
+ }
562
1599
 
563
- if (candidates->size == 0) {
564
- candidates->size = 1;
1600
+ if (ctx->to_search.empty()) {
1601
+ return;
565
1602
  }
566
1603
 
567
- if (smpl) {
568
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1604
+ // search for the remaining candidates that were not found in the previous step
1605
+ for (size_t i = 0; i < cur_p->size; ++i) {
1606
+ for (const auto & lb : ctx->to_search) {
1607
+ if (cur_p->data[i].id == lb.token) {
1608
+ cur_p->data[i].logit += lb.bias;
1609
+ break;
1610
+ }
1611
+ }
569
1612
  }
1613
+ }
570
1614
 
571
- // Normalize the probabilities of the remaining words
572
- llama_sample_softmax_impl(smpl, candidates);
1615
+ static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
1616
+ const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
1617
+ return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
1618
+ }
573
1619
 
574
- // Sample the next word X from the remaining words
575
- llama_token X = llama_sample_token_impl(smpl, candidates);
576
- t_start_sample_us = ggml_time_us();
1620
+ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
1621
+ delete (llama_sampler_logit_bias *) smpl->ctx;
1622
+ }
577
1623
 
578
- // Compute error as the difference between observed surprise and target surprise value
579
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
580
- return candidate.id == X;
581
- }));
582
- float observed_surprise = -log2f(candidates->data[X_idx].p);
583
- float e = observed_surprise - tau;
1624
+ static struct llama_sampler_i llama_sampler_logit_bias_i = {
1625
+ /* .name = */ llama_sampler_logit_bias_name,
1626
+ /* .accept = */ nullptr,
1627
+ /* .apply = */ llama_sampler_logit_bias_apply,
1628
+ /* .reset = */ nullptr,
1629
+ /* .clone = */ llama_sampler_logit_bias_clone,
1630
+ /* .free = */ llama_sampler_logit_bias_free,
1631
+ };
1632
+
1633
+ struct llama_sampler * llama_sampler_init_logit_bias(
1634
+ int32_t n_vocab,
1635
+ int32_t n_logit_bias,
1636
+ const llama_logit_bias * logit_bias) {
1637
+ return new llama_sampler {
1638
+ /* .iface = */ &llama_sampler_logit_bias_i,
1639
+ /* .ctx = */ new llama_sampler_logit_bias {
1640
+ /* .n_vocab = */ n_vocab,
1641
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
1642
+ /* .to_search = */ {},
1643
+ },
1644
+ };
1645
+ }
584
1646
 
585
- // Update mu using the learning rate and error
586
- *mu = *mu - eta * e;
1647
+ // utils
587
1648
 
588
- if (smpl) {
589
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1649
+ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
1650
+ if (smpl->iface == &llama_sampler_dist_i) {
1651
+ return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
590
1652
  }
591
- return X;
592
- }
593
1653
 
594
- llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
595
- const int64_t t_start_sample_us = ggml_time_us();
1654
+ if (smpl->iface == &llama_sampler_mirostat_i) {
1655
+ return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
1656
+ }
596
1657
 
597
- // Find max element
598
- auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
599
- return a.logit < b.logit;
600
- });
1658
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
1659
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
1660
+ }
601
1661
 
602
- llama_token result = max_iter->id;
603
- if (smpl) {
604
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
605
- smpl->n_sample++;
1662
+ if (smpl->iface == &llama_sampler_chain_i) {
1663
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
1664
+ for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
1665
+ const uint32_t seed = llama_sampler_get_seed(*it);
1666
+ if (seed != LLAMA_DEFAULT_SEED) {
1667
+ return seed;
1668
+ }
1669
+ }
606
1670
  }
607
- return result;
1671
+
1672
+ return LLAMA_DEFAULT_SEED;
608
1673
  }
609
1674
 
610
- llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
611
- GGML_ASSERT(smpl);
1675
+ // perf
612
1676
 
613
- const int64_t t_start_sample_us = ggml_time_us();
614
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1677
+ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
1678
+ struct llama_perf_sampler_data data = {};
615
1679
 
616
- std::vector<float> probs;
617
- probs.reserve(candidates->size);
618
- for (size_t i = 0; i < candidates->size; ++i) {
619
- probs.push_back(candidates->data[i].p);
1680
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1681
+ GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
620
1682
  }
621
1683
 
622
- std::discrete_distribution<> dist(probs.begin(), probs.end());
623
- int idx = dist(rng);
1684
+ const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
624
1685
 
625
- llama_token result = candidates->data[idx].id;
1686
+ data.t_sample_ms = 1e-3 * ctx->t_sample_us;
1687
+ data.n_sample = std::max(0, ctx->n_sample);
626
1688
 
627
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
628
- smpl->n_sample++;
1689
+ return data;
1690
+ }
629
1691
 
630
- return result;
1692
+ void llama_perf_sampler_print(const struct llama_sampler * chain) {
1693
+ const auto data = llama_perf_sampler(chain);
1694
+
1695
+ LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
1696
+ __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
631
1697
  }
632
1698
 
633
- llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
634
- return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
1699
+ void llama_perf_sampler_reset(struct llama_sampler * chain) {
1700
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1701
+ GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
1702
+ }
1703
+
1704
+ auto * ctx = (struct llama_sampler_chain *) chain->ctx;
1705
+
1706
+ ctx->t_sample_us = ctx->n_sample = 0;
635
1707
  }