cui-llama.rn 1.4.4 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. package/android/src/main/CMakeLists.txt +2 -2
  2. package/android/src/main/jni.cpp +12 -10
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/chat-template.hpp +529 -529
  12. package/cpp/chat.cpp +959 -265
  13. package/cpp/chat.h +135 -0
  14. package/cpp/common.cpp +2064 -1996
  15. package/cpp/common.h +700 -744
  16. package/cpp/ggml-alloc.c +1039 -1030
  17. package/cpp/ggml-alloc.h +1 -1
  18. package/cpp/ggml-backend-impl.h +255 -255
  19. package/cpp/ggml-backend-reg.cpp +586 -582
  20. package/cpp/ggml-backend.cpp +2004 -2002
  21. package/cpp/ggml-backend.h +354 -354
  22. package/cpp/ggml-common.h +1851 -1851
  23. package/cpp/ggml-cpp.h +39 -39
  24. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  25. package/cpp/ggml-cpu-aarch64.h +8 -8
  26. package/cpp/ggml-cpu-impl.h +531 -380
  27. package/cpp/ggml-cpu-quants.c +12527 -11517
  28. package/cpp/ggml-cpu-traits.cpp +36 -36
  29. package/cpp/ggml-cpu-traits.h +38 -38
  30. package/cpp/ggml-cpu.c +15766 -14485
  31. package/cpp/ggml-cpu.cpp +655 -633
  32. package/cpp/ggml-cpu.h +138 -135
  33. package/cpp/ggml-impl.h +567 -567
  34. package/cpp/ggml-metal-impl.h +235 -0
  35. package/cpp/ggml-metal.h +66 -66
  36. package/cpp/ggml-metal.m +5146 -5002
  37. package/cpp/ggml-opt.cpp +854 -854
  38. package/cpp/ggml-opt.h +216 -216
  39. package/cpp/ggml-quants.c +5238 -5238
  40. package/cpp/ggml-threading.h +14 -14
  41. package/cpp/ggml.c +6529 -6524
  42. package/cpp/ggml.h +2198 -2194
  43. package/cpp/gguf.cpp +1329 -1329
  44. package/cpp/gguf.h +202 -202
  45. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  46. package/cpp/json-schema-to-grammar.h +21 -22
  47. package/cpp/json.hpp +24766 -24766
  48. package/cpp/llama-adapter.cpp +347 -347
  49. package/cpp/llama-adapter.h +74 -74
  50. package/cpp/llama-arch.cpp +1513 -1492
  51. package/cpp/llama-arch.h +403 -402
  52. package/cpp/llama-batch.cpp +368 -368
  53. package/cpp/llama-batch.h +88 -88
  54. package/cpp/llama-chat.cpp +588 -587
  55. package/cpp/llama-chat.h +53 -53
  56. package/cpp/llama-context.cpp +1775 -1775
  57. package/cpp/llama-context.h +128 -128
  58. package/cpp/llama-cparams.cpp +1 -1
  59. package/cpp/llama-cparams.h +37 -37
  60. package/cpp/llama-cpp.h +30 -30
  61. package/cpp/llama-grammar.cpp +1219 -1219
  62. package/cpp/llama-grammar.h +173 -164
  63. package/cpp/llama-hparams.cpp +71 -71
  64. package/cpp/llama-hparams.h +139 -139
  65. package/cpp/llama-impl.cpp +167 -167
  66. package/cpp/llama-impl.h +61 -61
  67. package/cpp/llama-kv-cache.cpp +718 -718
  68. package/cpp/llama-kv-cache.h +219 -218
  69. package/cpp/llama-mmap.cpp +600 -590
  70. package/cpp/llama-mmap.h +68 -68
  71. package/cpp/llama-model-loader.cpp +1124 -1124
  72. package/cpp/llama-model-loader.h +167 -167
  73. package/cpp/llama-model.cpp +4087 -4023
  74. package/cpp/llama-model.h +370 -370
  75. package/cpp/llama-sampling.cpp +2558 -2525
  76. package/cpp/llama-sampling.h +32 -32
  77. package/cpp/llama-vocab.cpp +3264 -3252
  78. package/cpp/llama-vocab.h +125 -125
  79. package/cpp/llama.cpp +10284 -10137
  80. package/cpp/llama.h +1354 -1340
  81. package/cpp/log.cpp +393 -423
  82. package/cpp/log.h +132 -132
  83. package/cpp/minja/chat-template.hpp +529 -0
  84. package/cpp/minja/minja.hpp +2915 -0
  85. package/cpp/minja.hpp +2915 -2883
  86. package/cpp/rn-llama.cpp +20 -37
  87. package/cpp/rn-llama.h +12 -2
  88. package/cpp/sampling.cpp +570 -532
  89. package/cpp/sgemm.cpp +2598 -2598
  90. package/cpp/sgemm.h +14 -14
  91. package/cpp/speculative.cpp +278 -277
  92. package/cpp/speculative.h +28 -28
  93. package/package.json +1 -1
  94. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  95. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  96. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  97. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  98. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  99. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  100. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  101. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  102. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  103. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  104. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  105. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  106. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  107. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  108. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  109. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  110. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  111. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  112. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  113. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  114. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  115. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  116. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  117. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  118. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  119. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  120. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  122. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  124. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  125. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  126. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  127. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  128. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  129. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  130. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  132. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  134. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  135. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  136. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  194. package/android/src/main/build-arm64/Makefile +0 -1862
  195. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  196. package/cpp/chat.hpp +0 -55
  197. package/cpp/rn-llama.hpp +0 -913
@@ -1,2525 +1,2558 @@
1
- #include "llama-sampling.h"
2
-
3
- #include "llama-impl.h"
4
- #include "llama-vocab.h"
5
- #include "llama-grammar.h"
6
-
7
- #include <algorithm>
8
- #include <cassert>
9
- #include <cfloat>
10
- #include <chrono>
11
- #include <cmath>
12
- #include <cstdlib>
13
- #include <cstring>
14
- #include <ctime>
15
- #include <numeric>
16
- #include <random>
17
- #include <unordered_map>
18
- #include <stdexcept>
19
-
20
- // the ring buffer works similarly to std::deque, but with a fixed capacity
21
- template<typename T>
22
- struct ring_buffer {
23
- ring_buffer(size_t cap) : capacity(cap), data(cap) {}
24
-
25
- T & front() {
26
- if (sz == 0) {
27
- throw std::runtime_error("ring buffer is empty");
28
- }
29
- return data[first];
30
- }
31
-
32
- const T & front() const {
33
- if (sz == 0) {
34
- throw std::runtime_error("ring buffer is empty");
35
- }
36
- return data[first];
37
- }
38
-
39
- T & back() {
40
- if (sz == 0) {
41
- throw std::runtime_error("ring buffer is empty");
42
- }
43
- return data[pos];
44
- }
45
-
46
- const T & back() const {
47
- if (sz == 0) {
48
- throw std::runtime_error("ring buffer is empty");
49
- }
50
- return data[pos];
51
- }
52
-
53
- void push_back(const T & value) {
54
- if (capacity == 0) {
55
- throw std::runtime_error("ring buffer: capacity is zero");
56
- }
57
-
58
- if (sz == capacity) {
59
- // advance the start when buffer is full
60
- first = (first + 1) % capacity;
61
- } else {
62
- sz++;
63
- }
64
- data[pos] = value;
65
- pos = (pos + 1) % capacity;
66
- }
67
-
68
- T pop_front() {
69
- if (sz == 0) {
70
- throw std::runtime_error("ring buffer is empty");
71
- }
72
- T value = data[first];
73
- first = (first + 1) % capacity;
74
- sz--;
75
- return value;
76
- }
77
-
78
- //T & operator[](size_t i) {
79
- // if (i >= sz) {
80
- // throw std::runtime_error("ring buffer: index out of bounds");
81
- // }
82
- // return data[(first + i) % capacity];
83
- //}
84
-
85
- //const T & at(size_t i) const {
86
- // if (i >= sz) {
87
- // throw std::runtime_error("ring buffer: index out of bounds");
88
- // }
89
- // return data[(first + i) % capacity];
90
- //}
91
-
92
- const T & rat(size_t i) const {
93
- if (i >= sz) {
94
- throw std::runtime_error("ring buffer: index out of bounds");
95
- }
96
- return data[(first + sz - i - 1) % capacity];
97
- }
98
-
99
- std::vector<T> to_vector() const {
100
- std::vector<T> result;
101
- result.reserve(sz);
102
- for (size_t i = 0; i < sz; i++) {
103
- result.push_back(data[(first + i) % capacity]);
104
- }
105
- return result;
106
- }
107
-
108
- void clear() {
109
- // here only reset the status of the buffer
110
- sz = 0;
111
- first = 0;
112
- pos = 0;
113
- }
114
-
115
- bool empty() const {
116
- return sz == 0;
117
- }
118
-
119
- size_t size() const {
120
- return sz;
121
- }
122
-
123
- size_t capacity = 0;
124
- size_t sz = 0;
125
- size_t first = 0;
126
- size_t pos = 0;
127
-
128
- std::vector<T> data;
129
- };
130
-
131
- static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
132
- // iterator for the probabilities
133
- #ifdef __GNUC__
134
- #pragma GCC diagnostic push
135
- #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
136
- #endif
137
-
138
- struct probs_iterator {
139
- typedef std::input_iterator_tag iterator_category;
140
- typedef float value_type;
141
- typedef float * pointer;
142
- typedef float & reference;
143
- typedef ptrdiff_t difference_type;
144
-
145
- const llama_token_data * data;
146
-
147
- bool operator==(const probs_iterator & other) const { return data == other.data; }
148
- bool operator!=(const probs_iterator & other) const { return data != other.data; }
149
- const float & operator*() const { return data->p; }
150
- probs_iterator & operator++() { ++data; return *this; }
151
- probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
152
- };
153
-
154
- #ifdef __GNUC__
155
- #pragma GCC diagnostic pop
156
- #endif
157
-
158
- std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
159
-
160
- return dist(rng);
161
- }
162
-
163
- /*
164
- static void llama_log_softmax(float * array, size_t size) {
165
- float max_l = *std::max_element(array, array + size);
166
- float sum = 0.f;
167
- for (size_t i = 0; i < size; ++i) {
168
- float p = expf(array[i] - max_l);
169
- sum += p;
170
- array[i] = p;
171
- }
172
-
173
- for (size_t i = 0; i < size; ++i) {
174
- array[i] = logf(array[i] / sum);
175
- }
176
- }
177
- */
178
-
179
- static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
180
- if (temp <= 0.0f) {
181
- // find the token with the highest logit and set the rest to -inf
182
- size_t max_i = 0;
183
- float max_l = cur_p->data[0].logit;
184
-
185
- for (size_t i = 1; i < cur_p->size; ++i) {
186
- if (cur_p->data[i ].logit > max_l) {
187
- cur_p->data[max_i].logit = -INFINITY;
188
- max_i = i;
189
- max_l = cur_p->data[i].logit;
190
- } else {
191
- cur_p->data[i].logit = -INFINITY;
192
- }
193
- }
194
-
195
- return;
196
- }
197
-
198
- for (size_t i = 0; i < cur_p->size; ++i) {
199
- cur_p->data[i].logit /= temp;
200
- }
201
- }
202
-
203
- static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
204
- LM_GGML_ASSERT(cur_p->size > 0);
205
-
206
- // Sort the logits in descending order
207
- if (!cur_p->sorted) {
208
- std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
209
- return a.logit > b.logit;
210
- });
211
- cur_p->sorted = true;
212
- }
213
-
214
- float max_l = cur_p->data[0].logit;
215
- float cum_sum = 0.0f;
216
-
217
- for (size_t i = 0; i < cur_p->size; ++i) {
218
- float p = expf(cur_p->data[i].logit - max_l);
219
- cur_p->data[i].p = p;
220
- cum_sum += p;
221
- }
222
-
223
- for (size_t i = 0; i < cur_p->size; ++i) {
224
- cur_p->data[i].p /= cum_sum;
225
- }
226
- }
227
-
228
- static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
229
- // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
230
- // if (k >= (int32_t)cur_p->size) {
231
- // return;
232
- // }
233
-
234
- if (k <= 0) {
235
- k = cur_p->size;
236
- }
237
-
238
- k = std::min(k, (int) cur_p->size);
239
-
240
- // Sort scores in descending order
241
- if (!cur_p->sorted) {
242
- auto comp = [](const llama_token_data & a, const llama_token_data & b) {
243
- return a.logit > b.logit;
244
- };
245
- if (k <= 128) {
246
- std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
247
- } else {
248
- constexpr int nbuckets = 128;
249
- constexpr float bucket_low = -10.0f;
250
- constexpr float bucket_high = 10.0f;
251
- constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
252
- constexpr float bucket_inter = -bucket_low * bucket_scale;
253
-
254
- std::vector<int> bucket_idx(cur_p->size);
255
- std::vector<int> histo(nbuckets, 0);
256
-
257
- for (int i = 0; i < (int)cur_p->size; ++i) {
258
- const float val = cur_p->data[i].logit;
259
- int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
260
- ib = std::max(0, std::min(nbuckets - 1, ib));
261
- bucket_idx[i] = ib;
262
- ++histo[ib];
263
- }
264
- int nhave = 0;
265
- int ib = nbuckets - 1;
266
- for ( ; ib >= 0; --ib) {
267
- nhave += histo[ib];
268
- if (nhave >= k) {
269
- break;
270
- }
271
- }
272
- std::vector<llama_token_data> tmp_tokens(nhave);
273
- auto * ptr = tmp_tokens.data();
274
- std::vector<llama_token_data*> bucket_ptrs;
275
- bucket_ptrs.reserve(nbuckets - ib);
276
- for (int j = nbuckets - 1; j >= ib; --j) {
277
- bucket_ptrs.push_back(ptr);
278
- ptr += histo[j];
279
- }
280
- for (int i = 0; i < (int)cur_p->size; ++i) {
281
- int j = bucket_idx[i];
282
- if (j >= ib) {
283
- *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
284
- }
285
- }
286
-
287
- ptr = tmp_tokens.data();
288
- int ndone = 0;
289
- for (int j = nbuckets - 1; j > ib; --j) {
290
- std::sort(ptr, ptr + histo[j], comp);
291
- ptr += histo[j];
292
- ndone += histo[j];
293
- }
294
- std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
295
-
296
- std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
297
-
298
- }
299
- cur_p->sorted = true;
300
- }
301
- cur_p->size = k;
302
- }
303
-
304
- static uint32_t get_rng_seed(uint32_t seed) {
305
- if (seed == LLAMA_DEFAULT_SEED) {
306
- // use system clock if std::random_device is not a true RNG
307
- static bool is_rd_prng = std::random_device().entropy() == 0;
308
- if (is_rd_prng) {
309
- return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
310
- }
311
- std::random_device rd;
312
- return rd();
313
- }
314
- return seed;
315
- }
316
-
317
- // llama_sampler API
318
-
319
- struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
320
- return new llama_sampler {
321
- /* .iface = */ iface,
322
- /* .ctx = */ ctx,
323
- };
324
- }
325
-
326
- const char * llama_sampler_name(const struct llama_sampler * smpl) {
327
- if (!smpl->iface) {
328
- return "(null)";
329
- }
330
-
331
- return smpl->iface->name(smpl);
332
- }
333
-
334
- void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
335
- if (smpl->iface->accept) {
336
- smpl->iface->accept(smpl, token);
337
- }
338
- }
339
-
340
- void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
341
- LM_GGML_ASSERT(smpl->iface->apply);
342
- smpl->iface->apply(smpl, cur_p);
343
- }
344
-
345
- void llama_sampler_reset(struct llama_sampler * smpl) {
346
- if (smpl->iface->reset) {
347
- smpl->iface->reset(smpl);
348
- }
349
- }
350
-
351
- struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
352
- if (smpl->iface->clone) {
353
- return smpl->iface->clone(smpl);
354
- }
355
-
356
- if (smpl->ctx == nullptr) {
357
- return llama_sampler_init(
358
- /* .iface = */ smpl->iface,
359
- /* .ctx = */ nullptr
360
- );
361
- }
362
-
363
- LM_GGML_ABORT("the sampler does not support cloning");
364
- }
365
-
366
- void llama_sampler_free(struct llama_sampler * smpl) {
367
- if (smpl == nullptr) {
368
- return;
369
- }
370
-
371
- if (smpl->iface->free) {
372
- smpl->iface->free(smpl);
373
- }
374
-
375
- delete smpl;
376
- }
377
-
378
- llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
379
- const auto * logits = llama_get_logits_ith(ctx, idx);
380
-
381
- const llama_model * model = llama_get_model(ctx);
382
- const llama_vocab * vocab = llama_model_get_vocab(model);
383
-
384
- const int n_vocab = llama_vocab_n_tokens(vocab);
385
-
386
- // TODO: do not allocate each time
387
- std::vector<llama_token_data> cur;
388
- cur.reserve(n_vocab);
389
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
390
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
391
- }
392
-
393
- llama_token_data_array cur_p = {
394
- /* .data = */ cur.data(),
395
- /* .size = */ cur.size(),
396
- /* .selected = */ -1,
397
- /* .sorted = */ false,
398
- };
399
-
400
- llama_sampler_apply(smpl, &cur_p);
401
-
402
- LM_GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
403
-
404
- auto token = cur_p.data[cur_p.selected].id;
405
-
406
- llama_sampler_accept(smpl, token);
407
-
408
- return token;
409
- }
410
-
411
- // sampler chain
412
-
413
- static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
414
- return "chain";
415
- }
416
-
417
- static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
418
- auto * chain = (llama_sampler_chain *) smpl->ctx;
419
-
420
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
421
-
422
- for (auto * smpl : chain->samplers) {
423
- llama_sampler_accept(smpl, token);
424
- }
425
-
426
- chain->n_sample++;
427
- }
428
-
429
- static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
430
- auto * chain = (llama_sampler_chain *) smpl->ctx;
431
-
432
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
433
-
434
- for (auto * smpl : chain->samplers) {
435
- llama_sampler_apply(smpl, cur_p);
436
- }
437
- }
438
-
439
- static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
440
- auto * chain = (llama_sampler_chain *) smpl->ctx;
441
-
442
- for (auto * smpl : chain->samplers) {
443
- llama_sampler_reset(smpl);
444
- }
445
-
446
- chain->t_sample_us = 0;
447
- chain->n_sample = 0;
448
- }
449
-
450
- static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
451
- const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
452
-
453
- auto * result = llama_sampler_chain_init(chain_src->params);
454
-
455
- for (auto * smpl : chain_src->samplers) {
456
- llama_sampler_chain_add(result, llama_sampler_clone(smpl));
457
- }
458
-
459
- return result;
460
- }
461
-
462
- static void llama_sampler_chain_free(struct llama_sampler * smpl) {
463
- auto * chain = (llama_sampler_chain *) smpl->ctx;
464
-
465
- for (auto * smpl : chain->samplers) {
466
- llama_sampler_free(smpl);
467
- }
468
-
469
- delete chain;
470
- }
471
-
472
- static struct llama_sampler_i llama_sampler_chain_i = {
473
- /* .name = */ llama_sampler_chain_name,
474
- /* .accept = */ llama_sampler_chain_accept,
475
- /* .apply = */ llama_sampler_chain_apply,
476
- /* .reset = */ llama_sampler_chain_reset,
477
- /* .clone = */ llama_sampler_chain_clone,
478
- /* .free = */ llama_sampler_chain_free,
479
- };
480
-
481
- struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
482
- return llama_sampler_init(
483
- /* .iface = */ &llama_sampler_chain_i,
484
- /* .ctx = */ new llama_sampler_chain {
485
- /* .params = */ params,
486
- /* .samplers = */ {},
487
- /* .t_sample_us = */ 0,
488
- /* .n_sample = */ 0,
489
- }
490
- );
491
- }
492
-
493
- void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
494
- auto * p = (llama_sampler_chain *) chain->ctx;
495
- p->samplers.push_back(smpl);
496
- }
497
-
498
-
499
- struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
500
- const auto * p = (const llama_sampler_chain *) chain->ctx;
501
-
502
- if (i < 0 || (size_t) i >= p->samplers.size()) {
503
- return nullptr;
504
- }
505
-
506
- return p->samplers[i];
507
- }
508
-
509
- struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
510
- auto * p = (llama_sampler_chain *) chain->ctx;
511
-
512
- if (i < 0 || (size_t) i >= p->samplers.size()) {
513
- return nullptr;
514
- }
515
-
516
- auto * result = p->samplers[i];
517
- p->samplers.erase(p->samplers.begin() + i);
518
-
519
- return result;
520
- }
521
-
522
- int llama_sampler_chain_n(const struct llama_sampler * chain) {
523
- const auto * p = (const llama_sampler_chain *) chain->ctx;
524
-
525
- return p->samplers.size();
526
- }
527
-
528
- //
529
- // samplers
530
- //
531
-
532
- // greedy
533
-
534
- static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
535
- return "greedy";
536
- }
537
-
538
- static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
539
- cur_p->selected = 0;
540
- for (size_t i = 1; i < cur_p->size; ++i) {
541
- if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
542
- cur_p->selected = i;
543
- }
544
- }
545
- }
546
-
547
- static struct llama_sampler_i llama_sampler_greedy_i = {
548
- /* .name = */ llama_sampler_greedy_name,
549
- /* .accept = */ nullptr,
550
- /* .apply = */ llama_sampler_greedy_apply,
551
- /* .reset = */ nullptr,
552
- /* .clone = */ nullptr,
553
- /* .free = */ nullptr,
554
- };
555
-
556
- struct llama_sampler * llama_sampler_init_greedy() {
557
- return llama_sampler_init(
558
- /* .iface = */ &llama_sampler_greedy_i,
559
- /* .ctx = */ nullptr
560
- );
561
- }
562
-
563
- // dist
564
-
565
- struct llama_sampler_dist {
566
- const uint32_t seed;
567
- uint32_t seed_cur;
568
-
569
- std::mt19937 rng;
570
- };
571
-
572
- static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
573
- return "dist";
574
- }
575
-
576
- static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
577
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
578
-
579
- llama_sampler_softmax_impl(cur_p);
580
-
581
- cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
582
- }
583
-
584
- static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
585
- const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
586
- auto * result = llama_sampler_init_dist(ctx->seed);
587
-
588
- // copy the state
589
- {
590
- auto * result_ctx = (llama_sampler_dist *) result->ctx;
591
-
592
- result_ctx->rng = ctx->rng;
593
- }
594
-
595
- return result;
596
- }
597
-
598
- static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
599
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
600
- ctx->seed_cur = get_rng_seed(ctx->seed);
601
- ctx->rng.seed(ctx->seed_cur);
602
- }
603
-
604
- static void llama_sampler_dist_free(struct llama_sampler * smpl) {
605
- delete (llama_sampler_dist *) smpl->ctx;
606
- }
607
-
608
- static struct llama_sampler_i llama_sampler_dist_i = {
609
- /* .name = */ llama_sampler_dist_name,
610
- /* .accept = */ nullptr,
611
- /* .apply = */ llama_sampler_dist_apply,
612
- /* .reset = */ llama_sampler_dist_reset,
613
- /* .clone = */ llama_sampler_dist_clone,
614
- /* .free = */ llama_sampler_dist_free,
615
- };
616
-
617
- struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
618
- auto seed_cur = get_rng_seed(seed);
619
- return llama_sampler_init(
620
- /* .iface = */ &llama_sampler_dist_i,
621
- /* .ctx = */ new llama_sampler_dist {
622
- /* .seed = */ seed,
623
- /* .seed_cur = */ seed_cur,
624
- /* .rng = */ std::mt19937(seed_cur),
625
- }
626
- );
627
- }
628
-
629
- // softmax
630
-
631
- static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
632
- return "softmax";
633
- }
634
-
635
- static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
636
- llama_sampler_softmax_impl(cur_p);
637
- }
638
-
639
- static struct llama_sampler_i llama_sampler_softmax_i = {
640
- /* .name = */ llama_sampler_softmax_name,
641
- /* .accept = */ nullptr,
642
- /* .apply = */ llama_sampler_softmax_apply,
643
- /* .reset = */ nullptr,
644
- /* .clone = */ nullptr,
645
- /* .free = */ nullptr,
646
- };
647
-
648
- struct llama_sampler * llama_sampler_init_softmax() {
649
- return llama_sampler_init(
650
- /* .iface = */ &llama_sampler_softmax_i,
651
- /* .ctx = */ nullptr
652
- );
653
- }
654
-
655
- // top-k
656
-
657
- struct llama_sampler_top_k {
658
- const int32_t k;
659
- };
660
-
661
- static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
662
- return "top-k";
663
- }
664
-
665
- static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
666
- const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
667
- llama_sampler_top_k_impl(cur_p, ctx->k);
668
- }
669
-
670
- static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
671
- const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
672
- return llama_sampler_init_top_k(ctx->k);
673
- }
674
-
675
- static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
676
- delete (llama_sampler_top_k *) smpl->ctx;
677
- }
678
-
679
- static struct llama_sampler_i llama_sampler_top_k_i = {
680
- /* .name = */ llama_sampler_top_k_name,
681
- /* .accept = */ nullptr,
682
- /* .apply = */ llama_sampler_top_k_apply,
683
- /* .reset = */ nullptr,
684
- /* .clone = */ llama_sampler_top_k_clone,
685
- /* .free = */ llama_sampler_top_k_free,
686
- };
687
-
688
- struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
689
- return llama_sampler_init(
690
- /* .iface = */ &llama_sampler_top_k_i,
691
- /* .ctx = */ new llama_sampler_top_k {
692
- /* .k = */ k,
693
- }
694
- );
695
- }
696
-
697
- // top-p
698
-
699
- struct llama_sampler_top_p {
700
- const float p;
701
- const size_t min_keep;
702
- };
703
-
704
- static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
705
- return "top-p";
706
- }
707
-
708
- static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
709
- const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
710
-
711
- if (ctx->p >= 1.0f) {
712
- return;
713
- }
714
-
715
- llama_sampler_softmax_impl(cur_p);
716
-
717
- // Compute the cumulative probabilities
718
- float cum_sum = 0.0f;
719
- size_t last_idx = cur_p->size;
720
-
721
- for (size_t i = 0; i < cur_p->size; ++i) {
722
- cum_sum += cur_p->data[i].p;
723
-
724
- // Check if the running sum is at least p or if we have kept at least min_keep tokens
725
- // we set the last index to i+1 to indicate that the current iterate should be included in the set
726
- if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
727
- last_idx = i + 1;
728
- break;
729
- }
730
- }
731
-
732
- // Resize the output vector to keep only the top-p tokens
733
- cur_p->size = last_idx;
734
- }
735
-
736
- static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
737
- const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
738
- return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
739
- }
740
-
741
- static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
742
- delete (llama_sampler_top_p *) smpl->ctx;
743
- }
744
-
745
- static struct llama_sampler_i llama_sampler_top_p_i = {
746
- /* .name = */ llama_sampler_top_p_name,
747
- /* .accept = */ nullptr,
748
- /* .apply = */ llama_sampler_top_p_apply,
749
- /* .reset = */ nullptr,
750
- /* .clone = */ llama_sampler_top_p_clone,
751
- /* .free = */ llama_sampler_top_p_free,
752
- };
753
-
754
- struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
755
- return llama_sampler_init(
756
- /* .iface = */ &llama_sampler_top_p_i,
757
- /* .ctx = */ new llama_sampler_top_p {
758
- /* .p = */ p,
759
- /* .min_keep = */ min_keep,
760
- }
761
- );
762
- }
763
-
764
- // min-p
765
-
766
- struct llama_sampler_min_p {
767
- const float p;
768
- const size_t min_keep;
769
- };
770
-
771
- static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
772
- return "min-p";
773
- }
774
-
775
- static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
776
- const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
777
-
778
- if (ctx->p <= 0.0f || !cur_p->size) {
779
- return;
780
- }
781
-
782
- bool min_p_applied = false;
783
-
784
- // if the cur_p aren't sorted, try the unsorted implementation first
785
- if (!cur_p->sorted) {
786
- std::vector<llama_token_data> filtered_tokens;
787
-
788
- float max_logit = -FLT_MAX;
789
- for (size_t i = 0; i < cur_p->size; ++i) {
790
- max_logit = std::max(max_logit, cur_p->data[i].logit);
791
- }
792
- const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
793
-
794
- for (size_t i = 0; i < cur_p->size; ++i) {
795
- if (cur_p->data[i].logit >= min_logit) {
796
- filtered_tokens.push_back(cur_p->data[i]);
797
- }
798
- }
799
-
800
- // if we have enough values the operation was a success
801
- if (filtered_tokens.size() >= ctx->min_keep) {
802
- memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
803
- cur_p->size = filtered_tokens.size();
804
- min_p_applied = true;
805
- }
806
- }
807
-
808
- // if the cur_p are sorted or the unsorted implementation failed, use this implementation
809
- if (!min_p_applied) {
810
- // Sort the logits in descending order
811
- if (!cur_p->sorted) {
812
- std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
813
- return a.logit > b.logit;
814
- });
815
- cur_p->sorted = true;
816
- }
817
-
818
- const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
819
- size_t i = 1; // first token always matches
820
-
821
- for (; i < cur_p->size; ++i) {
822
- if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
823
- break; // prob too small
824
- }
825
- }
826
-
827
- // Resize the output vector to keep only the matching tokens
828
- cur_p->size = i;
829
- }
830
- }
831
-
832
- static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
833
- const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
834
- return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
835
- }
836
-
837
- static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
838
- delete (llama_sampler_min_p *) smpl->ctx;
839
- }
840
-
841
- static struct llama_sampler_i llama_sampler_min_p_i = {
842
- /* .name = */ llama_sampler_min_p_name,
843
- /* .accept = */ nullptr,
844
- /* .apply = */ llama_sampler_min_p_apply,
845
- /* .reset = */ nullptr,
846
- /* .clone = */ llama_sampler_min_p_clone,
847
- /* .free = */ llama_sampler_min_p_free,
848
- };
849
-
850
- struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
851
- return llama_sampler_init(
852
- /* .iface = */ &llama_sampler_min_p_i,
853
- /* .ctx = */ new llama_sampler_min_p {
854
- /* .p = */ p,
855
- /* .min_keep = */ min_keep,
856
- }
857
- );
858
- }
859
-
860
- // typical
861
-
862
- struct llama_sampler_typical {
863
- const float p;
864
- const size_t min_keep;
865
- };
866
-
867
- static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
868
- return "typical";
869
- }
870
-
871
- static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
872
- const auto * ctx = (llama_sampler_typical *) smpl->ctx;
873
-
874
- // Reference implementation:
875
- // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
876
- if (ctx->p >= 1.0f) {
877
- return;
878
- }
879
-
880
- // Compute the softmax of logits and calculate entropy
881
- llama_sampler_softmax_impl(cur_p);
882
-
883
- float entropy = 0.0f;
884
- for (size_t i = 0; i < cur_p->size; ++i) {
885
- entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
886
- }
887
-
888
- // Compute the absolute difference between negative log probability and entropy for each candidate
889
- std::vector<float> shifted_scores;
890
- for (size_t i = 0; i < cur_p->size; ++i) {
891
- float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
892
- shifted_scores.push_back(shifted_score);
893
- }
894
-
895
- // Sort tokens based on the shifted_scores and their corresponding indices
896
- std::vector<size_t> indices(cur_p->size);
897
- std::iota(indices.begin(), indices.end(), 0);
898
-
899
- std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
900
- return shifted_scores[a] < shifted_scores[b];
901
- });
902
-
903
- // Compute the cumulative probabilities
904
- float cum_sum = 0.0f;
905
- size_t last_idx = indices.size();
906
-
907
- for (size_t i = 0; i < indices.size(); ++i) {
908
- size_t idx = indices[i];
909
- cum_sum += cur_p->data[idx].p;
910
-
911
- // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
912
- if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
913
- last_idx = i + 1;
914
- break;
915
- }
916
- }
917
-
918
- // Resize the output vector to keep only the locally typical tokens
919
- std::vector<llama_token_data> cur_p_new;
920
- for (size_t i = 0; i < last_idx; ++i) {
921
- size_t idx = indices[i];
922
- cur_p_new.push_back(cur_p->data[idx]);
923
- }
924
-
925
- // Replace the data in cur_p with the cur_p_new data
926
- std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
927
- cur_p->size = cur_p_new.size();
928
- cur_p->sorted = false;
929
- }
930
-
931
- static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
932
- const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
933
- return llama_sampler_init_typical(ctx->p, ctx->min_keep);
934
- }
935
-
936
- static void llama_sampler_typical_free(struct llama_sampler * smpl) {
937
- delete (llama_sampler_typical *) smpl->ctx;
938
- }
939
-
940
- static struct llama_sampler_i llama_sampler_typical_i = {
941
- /* .name = */ llama_sampler_typical_name,
942
- /* .accept = */ nullptr,
943
- /* .apply = */ llama_sampler_typical_apply,
944
- /* .reset = */ nullptr,
945
- /* .clone = */ llama_sampler_typical_clone,
946
- /* .free = */ llama_sampler_typical_free,
947
- };
948
-
949
- struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
950
- return llama_sampler_init(
951
- /* .iface = */ &llama_sampler_typical_i,
952
- /* .ctx = */ new llama_sampler_typical {
953
- /* .p = */ p,
954
- /* .min_keep = */ min_keep,
955
- }
956
- );
957
- }
958
-
959
- // temp
960
-
961
- struct llama_sampler_temp {
962
- const float temp;
963
- };
964
-
965
- static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
966
- return "temp";
967
- }
968
-
969
- static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
970
- const auto * ctx = (llama_sampler_temp *) smpl->ctx;
971
-
972
- llama_sampler_temp_impl(cur_p, ctx->temp);
973
- }
974
-
975
- static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
976
- const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
977
- return llama_sampler_init_temp(ctx->temp);
978
- }
979
-
980
- static void llama_sampler_temp_free(struct llama_sampler * smpl) {
981
- delete (llama_sampler_temp *) smpl->ctx;
982
- }
983
-
984
- static struct llama_sampler_i llama_sampler_temp_i = {
985
- /* .name = */ llama_sampler_temp_name,
986
- /* .accept = */ nullptr,
987
- /* .apply = */ llama_sampler_temp_apply,
988
- /* .reset = */ nullptr,
989
- /* .clone = */ llama_sampler_temp_clone,
990
- /* .free = */ llama_sampler_temp_free,
991
- };
992
-
993
- struct llama_sampler * llama_sampler_init_temp(float temp) {
994
- return llama_sampler_init(
995
- /* .iface = */ &llama_sampler_temp_i,
996
- /* .ctx = */ new llama_sampler_temp {
997
- /*.temp = */ temp,
998
- }
999
- );
1000
- }
1001
-
1002
- // temp-ext
1003
-
1004
- struct llama_sampler_temp_ext {
1005
- const float temp;
1006
- const float delta;
1007
- const float exponent;
1008
- };
1009
-
1010
- static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1011
- return "temp-ext";
1012
- }
1013
-
1014
- static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1015
- const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1016
- if (ctx->delta > 0) {
1017
- const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1018
- const float max_temp = ctx->temp + ctx->delta;
1019
-
1020
- float exponent_val = ctx->exponent;
1021
-
1022
- // no need to do anything if there is only one (or zero) candidates
1023
- if (cur_p->size <= 1) {
1024
- return;
1025
- }
1026
-
1027
- // Calculate maximum possible entropy
1028
- float max_entropy = -logf(1.0f / cur_p->size);
1029
-
1030
- llama_sampler_softmax_impl(cur_p);
1031
-
1032
- // Calculate entropy of the softmax probabilities
1033
- float entropy = 0.0f;
1034
- for (size_t i = 0; i < cur_p->size; ++i) {
1035
- float prob = cur_p->data[i].p;
1036
- if (prob > 0.0f) { // Ensure no log(0)
1037
- entropy -= prob * logf(prob);
1038
- }
1039
- }
1040
-
1041
- // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1042
- float normalized_entropy = entropy / max_entropy;
1043
-
1044
- // Map the normalized entropy to the desired temperature range using the power function
1045
- float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1046
-
1047
- #ifdef DEBUG
1048
- LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1049
- LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1050
- LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1051
- LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1052
- LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1053
- LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1054
- #endif
1055
-
1056
- // Apply the dynamically calculated temperature scaling
1057
- llama_sampler_temp_impl(cur_p, dyn_temp);
1058
-
1059
- // Re-compute softmax probabilities after scaling logits with dynamic temperature
1060
- const double max_l_double = cur_p->data[0].logit;
1061
-
1062
- double cum_sum_double = 0.0;
1063
- for (size_t i = 0; i < cur_p->size; ++i) {
1064
- double p = exp(cur_p->data[i].logit - max_l_double);
1065
- cur_p->data[i].p = p; // Store the scaled probability
1066
- cum_sum_double += p;
1067
- }
1068
-
1069
- for (size_t i = 0; i < cur_p->size; ++i) {
1070
- cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1071
- }
1072
-
1073
- #ifdef DEBUG
1074
- // Print the updated top 25 probabilities after temperature scaling
1075
- LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1076
- for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1077
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1078
- }
1079
- #endif
1080
- } else {
1081
- llama_sampler_temp_impl(cur_p, ctx->temp);
1082
- }
1083
- }
1084
-
1085
- static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1086
- const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1087
- return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1088
- }
1089
-
1090
- static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1091
- delete (llama_sampler_temp_ext *) smpl->ctx;
1092
- }
1093
-
1094
- static struct llama_sampler_i llama_sampler_temp_ext_i = {
1095
- /* .name = */ llama_sampler_temp_ext_name,
1096
- /* .accept = */ nullptr,
1097
- /* .apply = */ llama_sampler_temp_ext_apply,
1098
- /* .reset = */ nullptr,
1099
- /* .clone = */ llama_sampler_temp_ext_clone,
1100
- /* .free = */ llama_sampler_temp_ext_free,
1101
- };
1102
-
1103
- struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1104
- return llama_sampler_init(
1105
- /* .iface = */ &llama_sampler_temp_ext_i,
1106
- /* .ctx = */ new llama_sampler_temp_ext {
1107
- /* .temp = */ temp,
1108
- /* .delta = */ delta,
1109
- /* .exponent = */ exponent,
1110
- }
1111
- );
1112
- }
1113
-
1114
- // xtc
1115
-
1116
- struct llama_sampler_xtc {
1117
- const float probability;
1118
- const float threshold;
1119
- const size_t min_keep;
1120
-
1121
- const uint32_t seed;
1122
- uint32_t seed_cur;
1123
-
1124
- std::mt19937 rng;
1125
- };
1126
-
1127
- static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1128
- return "xtc";
1129
- }
1130
-
1131
- static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1132
- auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1133
-
1134
- if (ctx->probability <= 0.0f
1135
- || ctx->threshold > 0.5f
1136
- || cur_p->size < 2) {
1137
- return;
1138
- }
1139
-
1140
- std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1141
- float chance = distribution(ctx->rng);
1142
- if (chance > ctx->probability) return;
1143
-
1144
- // in case it's not sorted/recalculated yet
1145
- llama_sampler_softmax_impl(cur_p);
1146
-
1147
- int pos_last = 0;
1148
-
1149
- for (size_t i = 0; i < cur_p->size; ++i) {
1150
- if (cur_p->data[i].p >= ctx->threshold) {
1151
- pos_last = i;
1152
- } else break;
1153
- }
1154
-
1155
- if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1156
- cur_p->data += pos_last;
1157
- cur_p->size -= pos_last;
1158
- }
1159
- }
1160
-
1161
- static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1162
- const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1163
- auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1164
-
1165
- // copy the state
1166
- {
1167
- auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1168
-
1169
- result_ctx->rng = ctx->rng;
1170
- }
1171
-
1172
- return result;
1173
- }
1174
-
1175
- static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1176
- delete (llama_sampler_xtc *) smpl->ctx;
1177
- }
1178
-
1179
- static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1180
- auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1181
- ctx->seed_cur = get_rng_seed(ctx->seed);
1182
- ctx->rng.seed(ctx->seed_cur);
1183
- }
1184
-
1185
- static struct llama_sampler_i llama_sampler_xtc_i = {
1186
- /* .name = */ llama_sampler_xtc_name,
1187
- /* .accept = */ nullptr,
1188
- /* .apply = */ llama_sample_xtc_apply,
1189
- /* .reset = */ llama_sampler_xtc_reset,
1190
- /* .clone = */ llama_sampler_xtc_clone,
1191
- /* .free = */ llama_sampler_xtc_free,
1192
- };
1193
-
1194
- struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1195
- auto seed_cur = get_rng_seed(seed);
1196
- return llama_sampler_init(
1197
- /* .iface = */ &llama_sampler_xtc_i,
1198
- /* .ctx = */ new llama_sampler_xtc {
1199
- /* .probability = */ p,
1200
- /* .threshold = */ t,
1201
- /* .min_keep = */ min_keep,
1202
- /* .seed = */ seed,
1203
- /* .seed_cur = */ seed_cur,
1204
- /* .rng = */ std::mt19937(seed_cur),
1205
- }
1206
- );
1207
- }
1208
-
1209
- // mirostat
1210
-
1211
- struct llama_sampler_mirostat {
1212
- const int32_t n_vocab;
1213
-
1214
- const uint32_t seed;
1215
- uint32_t seed_cur;
1216
-
1217
- const float tau;
1218
- const float eta;
1219
-
1220
- const int32_t m;
1221
-
1222
- float mu;
1223
-
1224
- std::mt19937 rng;
1225
- };
1226
-
1227
- static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1228
- return "mirostat";
1229
- }
1230
-
1231
- static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1232
- auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1233
-
1234
- llama_sampler_softmax_impl(cur_p);
1235
-
1236
- // Estimate s_hat using the most probable m tokens
1237
- float s_hat = 0.0;
1238
- float sum_ti_bi = 0.0;
1239
- float sum_ti_sq = 0.0;
1240
- for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
1241
- float t_i = logf(float(i + 2) / float(i + 1));
1242
- float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
1243
- sum_ti_bi += t_i * b_i;
1244
- sum_ti_sq += t_i * t_i;
1245
- }
1246
- s_hat = sum_ti_bi / sum_ti_sq;
1247
-
1248
- // Compute k from the estimated s_hat and target surprise value
1249
- float epsilon_hat = s_hat - 1;
1250
- float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
1251
-
1252
- llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1253
- llama_sampler_softmax_impl(cur_p);
1254
-
1255
- const int idx = llama_sample_dist(cur_p, ctx->rng);
1256
-
1257
- cur_p->selected = idx;
1258
-
1259
- float observed_surprise = -log2f(cur_p->data[idx].p);
1260
- float e = observed_surprise - ctx->tau;
1261
-
1262
- // Update mu using the learning rate and error
1263
- ctx->mu = ctx->mu - ctx->eta * e;
1264
- }
1265
-
1266
- static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1267
- const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1268
- auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
1269
-
1270
- // copy the state
1271
- {
1272
- auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1273
-
1274
- result_ctx->mu = ctx->mu;
1275
- result_ctx->rng = ctx->rng;
1276
- }
1277
-
1278
- return result;
1279
- }
1280
-
1281
- static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1282
- auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1283
- ctx->mu = 2.0f*ctx->tau;
1284
- ctx->seed_cur = get_rng_seed(ctx->seed);
1285
- ctx->rng.seed(ctx->seed_cur);
1286
- }
1287
-
1288
- static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1289
- delete (llama_sampler_mirostat *) smpl->ctx;
1290
- }
1291
-
1292
- static struct llama_sampler_i llama_sampler_mirostat_i = {
1293
- /* .name = */ llama_sampler_mirostat_name,
1294
- /* .accept = */ nullptr,
1295
- /* .apply = */ llama_sampler_mirostat_apply,
1296
- /* .reset = */ llama_sampler_mirostat_reset,
1297
- /* .clone = */ llama_sampler_mirostat_clone,
1298
- /* .free = */ llama_sampler_mirostat_free,
1299
- };
1300
-
1301
- struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1302
- auto seed_cur = get_rng_seed(seed);
1303
- return llama_sampler_init(
1304
- /* .iface = */ &llama_sampler_mirostat_i,
1305
- /* .ctx = */ new llama_sampler_mirostat {
1306
- /* .n_vocab = */ n_vocab,
1307
- /* .seed = */ seed,
1308
- /* .seed_cur = */ seed_cur,
1309
- /* .tau = */ tau,
1310
- /* .eta = */ eta,
1311
- /* .m = */ m,
1312
- /* .mu = */ 2.0f*tau,
1313
- /* .rng = */ std::mt19937(seed_cur),
1314
- }
1315
- );
1316
- }
1317
-
1318
- // mirostat v2
1319
-
1320
- struct llama_sampler_mirostat_v2 {
1321
- const uint32_t seed;
1322
- uint32_t seed_cur;
1323
-
1324
- const float tau;
1325
- const float eta;
1326
-
1327
- float mu;
1328
-
1329
- std::mt19937 rng;
1330
- };
1331
-
1332
- static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1333
- return "mirostat-v2";
1334
- }
1335
-
1336
- static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1337
- auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1338
-
1339
- llama_sampler_softmax_impl(cur_p);
1340
-
1341
- // Truncate the words with surprise values greater than mu
1342
- cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1343
- return -log2f(candidate.p) > ctx->mu;
1344
- }));
1345
-
1346
- if (cur_p->size == 0) {
1347
- cur_p->size = 1;
1348
- }
1349
-
1350
- // Normalize the probabilities of the remaining words
1351
- llama_sampler_softmax_impl(cur_p);
1352
-
1353
- const int idx = llama_sample_dist(cur_p, ctx->rng);
1354
-
1355
- cur_p->selected = idx;
1356
-
1357
- float observed_surprise = -log2f(cur_p->data[idx].p);
1358
- float e = observed_surprise - ctx->tau;
1359
-
1360
- // Update mu using the learning rate and error
1361
- ctx->mu = ctx->mu - ctx->eta * e;
1362
- }
1363
-
1364
- static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1365
- auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1366
- ctx->mu = 2.0f*ctx->tau;
1367
- ctx->seed_cur = get_rng_seed(ctx->seed);
1368
- ctx->rng.seed(ctx->seed_cur);
1369
- }
1370
-
1371
- static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1372
- const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1373
-
1374
- auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
1375
-
1376
- // copy the state
1377
- {
1378
- auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1379
-
1380
- result_ctx->mu = ctx->mu;
1381
- result_ctx->rng = ctx->rng;
1382
- }
1383
-
1384
- return result;
1385
- }
1386
-
1387
- static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1388
- delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1389
- }
1390
-
1391
- static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1392
- /* .name = */ llama_sampler_mirostat_v2_name,
1393
- /* .accept = */ nullptr,
1394
- /* .apply = */ llama_sampler_mirostat_v2_apply,
1395
- /* .reset = */ llama_sampler_mirostat_v2_reset,
1396
- /* .clone = */ llama_sampler_mirostat_v2_clone,
1397
- /* .free = */ llama_sampler_mirostat_v2_free,
1398
- };
1399
-
1400
- struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1401
- auto seed_cur = get_rng_seed(seed);
1402
- return llama_sampler_init(
1403
- /* .iface = */ &llama_sampler_mirostat_v2_i,
1404
- /* .ctx = */ new llama_sampler_mirostat_v2 {
1405
- /* .seed = */ seed,
1406
- /* .seed_cur = */ seed_cur,
1407
- /* .tau = */ tau,
1408
- /* .eta = */ eta,
1409
- /* .mu = */ 2.0f*tau,
1410
- /* .rng = */ std::mt19937(seed_cur),
1411
- }
1412
- );
1413
- }
1414
-
1415
- // grammar
1416
-
1417
- struct llama_sampler_grammar {
1418
- const struct llama_vocab * vocab;
1419
-
1420
- std::string grammar_str;
1421
- std::string grammar_root;
1422
-
1423
- struct llama_grammar * grammar;
1424
- };
1425
-
1426
- static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1427
- return "grammar";
1428
- }
1429
-
1430
- static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1431
- auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1432
- if (ctx->grammar) {
1433
- llama_grammar_accept_impl(*ctx->grammar, token);
1434
- }
1435
- }
1436
-
1437
- static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1438
- auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1439
- if (ctx->grammar) {
1440
- llama_grammar_apply_impl(*ctx->grammar, cur_p);
1441
- }
1442
- }
1443
-
1444
- // Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
1445
- static struct llama_sampler * llama_sampler_init_grammar_impl(
1446
- const struct llama_vocab * vocab,
1447
- const char * grammar_str,
1448
- const char * grammar_root,
1449
- bool lazy,
1450
- const char ** trigger_words,
1451
- size_t num_trigger_words,
1452
- const llama_token * trigger_tokens,
1453
- size_t num_trigger_tokens);
1454
-
1455
- static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1456
- auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1457
- if (!ctx->grammar) {
1458
- return;
1459
- }
1460
-
1461
- std::vector<const char *> trigger_words;
1462
- for (auto & word : ctx->grammar->trigger_words) {
1463
- trigger_words.push_back(word.c_str());
1464
- }
1465
- auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
1466
- ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
1467
- ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
1468
-
1469
- llama_grammar_free_impl(ctx->grammar);
1470
- ctx->grammar = grammar_new;
1471
- }
1472
-
1473
- static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1474
- const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1475
-
1476
- auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
1477
-
1478
- // copy the state
1479
- {
1480
- auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1481
-
1482
- if (ctx->grammar) {
1483
- result_ctx->grammar_str = ctx->grammar_str;
1484
- result_ctx->grammar_root = ctx->grammar_root;
1485
-
1486
- result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1487
- }
1488
- }
1489
-
1490
- return result;
1491
- }
1492
-
1493
- static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1494
- const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1495
-
1496
- if (ctx->grammar) {
1497
- llama_grammar_free_impl(ctx->grammar);
1498
- }
1499
-
1500
- delete ctx;
1501
- }
1502
-
1503
- static struct llama_sampler_i llama_sampler_grammar_i = {
1504
- /* .name = */ llama_sampler_grammar_name,
1505
- /* .accept = */ llama_sampler_grammar_accept_impl,
1506
- /* .apply = */ llama_sampler_grammar_apply,
1507
- /* .reset = */ llama_sampler_grammar_reset,
1508
- /* .clone = */ llama_sampler_grammar_clone,
1509
- /* .free = */ llama_sampler_grammar_free,
1510
- };
1511
-
1512
- static struct llama_sampler * llama_sampler_init_grammar_impl(
1513
- const struct llama_vocab * vocab,
1514
- const char * grammar_str,
1515
- const char * grammar_root,
1516
- bool lazy,
1517
- const char ** trigger_words,
1518
- size_t num_trigger_words,
1519
- const llama_token * trigger_tokens,
1520
- size_t num_trigger_tokens) {
1521
- auto * ctx = new llama_sampler_grammar;
1522
-
1523
- if (grammar_str != nullptr && grammar_str[0] != '\0') {
1524
- *ctx = {
1525
- /* .vocab = */ vocab,
1526
- /* .grammar_str = */ grammar_str,
1527
- /* .grammar_root = */ grammar_root,
1528
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
1529
- };
1530
- } else {
1531
- *ctx = {
1532
- /* .vocab = */ vocab,
1533
- /* .grammar_str = */ {},
1534
- /* .grammar_root = */ {},
1535
- /* .grammar = */ nullptr,
1536
- };
1537
- }
1538
-
1539
- return llama_sampler_init(
1540
- /* .iface = */ &llama_sampler_grammar_i,
1541
- /* .ctx = */ ctx
1542
- );
1543
- }
1544
-
1545
- struct llama_sampler * llama_sampler_init_grammar(
1546
- const struct llama_vocab * vocab,
1547
- const char * grammar_str,
1548
- const char * grammar_root) {
1549
- return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
1550
- }
1551
-
1552
- struct llama_sampler * llama_sampler_init_grammar_lazy(
1553
- const struct llama_vocab * vocab,
1554
- const char * grammar_str,
1555
- const char * grammar_root,
1556
- const char ** trigger_words,
1557
- size_t num_trigger_words,
1558
- const llama_token * trigger_tokens,
1559
- size_t num_trigger_tokens) {
1560
- return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
1561
- }
1562
-
1563
- // penalties
1564
-
1565
- struct llama_sampler_penalties {
1566
- const int32_t penalty_last_n;
1567
- const float penalty_repeat;
1568
- const float penalty_freq;
1569
- const float penalty_present;
1570
-
1571
- ring_buffer<llama_token> prev;
1572
-
1573
- // a frequency map to count token occurrences
1574
- std::unordered_map<llama_token, int> token_count;
1575
- };
1576
-
1577
- static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1578
- return "penalties";
1579
- }
1580
-
1581
- static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1582
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1583
- if (ctx->penalty_last_n == 0) {
1584
- return;
1585
- }
1586
-
1587
- ctx->token_count[token]++;
1588
-
1589
- // if the ring buffer is full, remove the oldest token
1590
- if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
1591
- const auto old = ctx->prev.front();
1592
-
1593
- ctx->token_count[old]--;
1594
- if (ctx->token_count[old] == 0) {
1595
- ctx->token_count.erase(old);
1596
- }
1597
- }
1598
-
1599
- ctx->prev.push_back(token);
1600
-
1601
- #if 0
1602
- // sanity check
1603
- std::unordered_map<llama_token, int> tmp;
1604
- for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1605
- tmp[ctx->prev.rat(i)]++;
1606
- }
1607
-
1608
- assert(ctx->token_count == tmp);
1609
- #endif
1610
- }
1611
-
1612
- static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1613
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1614
-
1615
- if ((ctx->penalty_last_n == 0) ||
1616
- (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1617
- return;
1618
- }
1619
-
1620
- // Apply frequency and presence penalties to the cur_p
1621
- for (size_t i = 0; i < cur_p->size; ++i) {
1622
- const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
1623
- if (token_iter == ctx->token_count.end()) {
1624
- continue;
1625
- }
1626
-
1627
- const int count = token_iter->second;
1628
-
1629
- assert(count > 0 && count <= ctx->penalty_last_n);
1630
-
1631
- // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1632
- // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1633
- if (cur_p->data[i].logit <= 0) {
1634
- cur_p->data[i].logit *= ctx->penalty_repeat;
1635
- } else {
1636
- cur_p->data[i].logit /= ctx->penalty_repeat;
1637
- }
1638
-
1639
- cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
1640
- }
1641
-
1642
- cur_p->sorted = false;
1643
- }
1644
-
1645
- static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1646
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1647
- ctx->prev.clear();
1648
- ctx->token_count.clear();
1649
- }
1650
-
1651
- static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1652
- const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1653
- auto * result = llama_sampler_init_penalties(
1654
- ctx->penalty_last_n,
1655
- ctx->penalty_repeat,
1656
- ctx->penalty_freq,
1657
- ctx->penalty_present);
1658
-
1659
- // copy the state
1660
- {
1661
- auto * result_ctx = (llama_sampler_penalties *) result->ctx;
1662
-
1663
- result_ctx->prev = ctx->prev;
1664
- }
1665
-
1666
- return result;
1667
- }
1668
-
1669
- static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1670
- delete (llama_sampler_penalties *) smpl->ctx;
1671
- }
1672
-
1673
- static struct llama_sampler_i llama_sampler_penalties_i = {
1674
- /* .name = */ llama_sampler_penalties_name,
1675
- /* .accept = */ llama_sampler_penalties_accept,
1676
- /* .apply = */ llama_sampler_penalties_apply,
1677
- /* .reset = */ llama_sampler_penalties_reset,
1678
- /* .clone = */ llama_sampler_penalties_clone,
1679
- /* .free = */ llama_sampler_penalties_free,
1680
- };
1681
-
1682
- struct llama_sampler * llama_sampler_init_penalties(
1683
- int32_t penalty_last_n,
1684
- float penalty_repeat,
1685
- float penalty_freq,
1686
- float penalty_present) {
1687
- penalty_last_n = std::max(penalty_last_n, 0);
1688
-
1689
- return llama_sampler_init(
1690
- /* .iface = */ &llama_sampler_penalties_i,
1691
- /* .ctx = */ new llama_sampler_penalties {
1692
- /* .penalty_last_n = */ penalty_last_n,
1693
- /* .penalty_repeat = */ penalty_repeat,
1694
- /* .penalty_freq = */ penalty_freq,
1695
- /* .penalty_present = */ penalty_present,
1696
- /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1697
- /* .token_count = */ {},
1698
- }
1699
- );
1700
- }
1701
-
1702
- // top-n-sigma
1703
-
1704
- struct llama_sampler_top_n_sigma {
1705
- const float n;
1706
- };
1707
-
1708
- static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
1709
- return "top-n-sigma";
1710
- }
1711
-
1712
- static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1713
- const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1714
-
1715
- // find max logit and calculate mean
1716
- float max = cur_p->data[0].logit;
1717
- float logits_sum = 0;
1718
- for (size_t i = 0; i < cur_p->size; ++i) {
1719
- if (cur_p->data[i].logit > max) {
1720
- max = cur_p->data[i].logit;
1721
- }
1722
- logits_sum += cur_p->data[i].logit;
1723
- }
1724
- float mean = logits_sum/cur_p->size;
1725
-
1726
- // calculate standard deviation
1727
- float acc = 0;
1728
- for (size_t i = 0; i < cur_p->size; ++i) {
1729
- acc += pow(cur_p->data[i].logit - mean, 2);
1730
- }
1731
- float std = sqrt(acc/cur_p->size);
1732
-
1733
- //apply mask
1734
- for (size_t i = 0; i < cur_p->size; ++i) {
1735
- if (cur_p->data[i].logit < max - (ctx->n * std)) {
1736
- cur_p->data[i].logit = -INFINITY;
1737
- }
1738
- }
1739
- llama_sampler_softmax_impl(cur_p);
1740
- }
1741
-
1742
- static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
1743
- const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
1744
- return llama_sampler_init_top_n_sigma(ctx->n);
1745
- }
1746
-
1747
- static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1748
- delete (llama_sampler_top_n_sigma *) smpl->ctx;
1749
- }
1750
-
1751
- static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1752
- /* .name = */ llama_sampler_top_n_sigma_name,
1753
- /* .accept = */ nullptr,
1754
- /* .apply = */ llama_sampler_top_n_sigma_apply,
1755
- /* .reset = */ nullptr,
1756
- /* .clone = */ llama_sampler_top_n_sigma_clone,
1757
- /* .free = */ llama_sampler_top_n_sigma_free,
1758
- };
1759
-
1760
- struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
1761
- return llama_sampler_init(
1762
- /* .iface = */ &llama_sampler_top_n_sigma_i,
1763
- /* .ctx = */ new llama_sampler_top_n_sigma {
1764
- /* .n = */ n,
1765
- }
1766
- );
1767
- }
1768
-
1769
- // DRY
1770
-
1771
- struct llama_sampler_dry {
1772
- int32_t total_context_size;
1773
-
1774
- const float dry_multiplier;
1775
- const float dry_base;
1776
- const int32_t dry_allowed_length;
1777
- const int32_t dry_penalty_last_n;
1778
-
1779
- std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
1780
- std::vector<int> dry_repeat_count;
1781
- std::unordered_map<llama_token, int> dry_max_token_repeat;
1782
- ring_buffer<llama_token> last_tokens;
1783
- };
1784
-
1785
- // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1786
- static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1787
- for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
1788
- std::string word = vocab.detokenize({token_id}, true);
1789
- if (word.find(str) != std::string::npos) {
1790
- token_sequences.emplace(token_id, std::vector<llama_token>());
1791
- } else {
1792
- size_t word_len = word.size();
1793
- size_t str_len = str.size();
1794
- size_t pos = -1;
1795
- while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
1796
- bool match = true;
1797
- size_t i;
1798
- for (i = 1; i < str_len && i + pos < word_len; ++i) {
1799
- if (word[pos + i] != str[i]) {
1800
- match = false;
1801
- break;
1802
- }
1803
- }
1804
- if (match) {
1805
- std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
1806
- if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1807
- tokenization.resize(max_tail_len);
1808
- }
1809
-
1810
- // Ensure we don't already have a duplicate matching tokenization
1811
- auto its = token_sequences.equal_range(token_id);
1812
- bool found = false;
1813
- for (auto it = its.first; it != its.second; ++it) {
1814
- if (tokenization == it->second) {
1815
- found = true;
1816
- break;
1817
- }
1818
- }
1819
- if (!found) {
1820
- token_sequences.emplace(token_id, tokenization);
1821
- }
1822
- }
1823
- }
1824
- }
1825
- }
1826
- }
1827
-
1828
- static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
1829
- return "dry";
1830
- }
1831
-
1832
- static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
1833
- auto * ctx = (llama_sampler_dry *) smpl->ctx;
1834
- if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1835
- return;
1836
- }
1837
-
1838
- ctx->last_tokens.push_back(token);
1839
- }
1840
-
1841
- // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1842
- static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1843
- auto * ctx = (llama_sampler_dry *) smpl->ctx;
1844
-
1845
- if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1846
- return;
1847
- }
1848
-
1849
- int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
1850
- int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
1851
-
1852
- if (last_n_repeat <= ctx->dry_allowed_length) {
1853
- return;
1854
- }
1855
-
1856
- ctx->dry_repeat_count.assign(last_n_repeat, 0);
1857
- ctx->dry_max_token_repeat.clear();
1858
-
1859
- // Step 1: Look for restart sequences to limit the maximum repetition length.
1860
- // Work backwards through the context looking for any token that begins a restart sequence.
1861
- //
1862
- // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
1863
- // sequences that together comprise a restart sequence. This allows us to quickly check
1864
- // whether each token is the head of a complete sequence. Most restart sequences are actually
1865
- // a single token, and for these the "tail" is an empty vector.
1866
- //
1867
- // If the token is a "head", test all restart sequences that begin with this token
1868
- // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
1869
- // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
1870
- // longest matching sequence (if any) is used to limit the maximum repetition length.
1871
- //
1872
- // Note that in the case case of a short sequence contained in a longer one, this might fail to
1873
- // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
1874
- // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
1875
- // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
1876
- //
1877
- // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
1878
- // have already clamped the maximum tail sequence length when generating `restart_sequences`.
1879
- // With clamping, this scan is O(N) in the context length.
1880
-
1881
- int rep_limit = last_n_repeat;
1882
- for (int i = 0; i < last_n_repeat; ++i) {
1883
- llama_token token = ctx->last_tokens.rat(i);
1884
- auto its = ctx->dry_processed_breakers.equal_range(token);
1885
- if (its.first == ctx->dry_processed_breakers.end()) {
1886
- continue;
1887
- }
1888
- int longest_match = -1;
1889
- for (auto it = its.first; it != its.second; ++it) {
1890
- // Note that (*it) does not contain the head character, so seq_len will be
1891
- // the restart sequence length minus 1.
1892
- // In the common case of a single-token restart sequence, (*it) will be empty
1893
- // and we will trivially match.
1894
- int seq_len = (int)it->second.size();
1895
- if (seq_len > longest_match && seq_len <= (int)i) {
1896
- bool match = true;
1897
- for (int offset = 0; offset < seq_len; ++offset) {
1898
- // The -1 when indexing `last_tokens` is because we already matched the head.
1899
- if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
1900
- match = false;
1901
- break;
1902
- }
1903
- }
1904
- if (match) {
1905
- longest_match = seq_len;
1906
- }
1907
- }
1908
- }
1909
- if (longest_match >= 0) {
1910
- // We found a restart sequence starting `i` tokens from the end and continuing for
1911
- // `longest_match` tokens.
1912
- rep_limit = i - longest_match;
1913
- break;
1914
- }
1915
- }
1916
- if (rep_limit < ctx->dry_allowed_length) {
1917
- return;
1918
- }
1919
-
1920
- // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
1921
- // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
1922
- // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
1923
- //
1924
- // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
1925
- // https://ivanyu.me/blog/2014/10/15/z-algorithm/
1926
- //
1927
- // The code below is adapted from the public domain implementation by the same author here:
1928
- // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
1929
- //
1930
- // Example:
1931
- // Last N tokens: a b c c b c y a b c
1932
- // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1933
- // ^
1934
- // This `3` means that the last three tokens of the context (a b c) also appear here.
1935
- //
1936
- // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
1937
- // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
1938
- // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
1939
- // ensure that the inner while loops only examine each token in the context once as the outer
1940
- // for loop iterates over the context.
1941
-
1942
- {
1943
- const int last = last_n_repeat - 1;
1944
- int rt = 0, lt = 0;
1945
-
1946
- for (int k = 1; k < last_n_repeat; ++k) {
1947
- if (k > rt) {
1948
- // If k is outside the current Z-box, do naive computation.
1949
- int n = 0;
1950
- while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
1951
- ++n;
1952
- }
1953
- ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1954
- if (n > 0) {
1955
- lt = k;
1956
- rt = k + n - 1;
1957
- }
1958
- } else {
1959
- // If k is inside the current Z-box, consider two cases.
1960
-
1961
- int p = k - lt; // Pair index.
1962
- int right_part_len = rt - k + 1;
1963
-
1964
- if (ctx->dry_repeat_count[last - p] < right_part_len) {
1965
- int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
1966
- ctx->dry_repeat_count[last - k] = n;
1967
- } else {
1968
- int i = rt + 1;
1969
- while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
1970
- i += 1;
1971
- }
1972
-
1973
- int n = std::min(i - k, rep_limit);
1974
- ctx->dry_repeat_count[last - k] = n;
1975
- lt = k;
1976
- rt = i - 1;
1977
- }
1978
- }
1979
- }
1980
- }
1981
-
1982
- // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
1983
- // that would be generated by emitting each new token that would extend a sequence.
1984
- //
1985
- // Following the same example as above:
1986
- // Last N tokens: a b c c b c y a b c
1987
- // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1988
- //
1989
- // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
1990
- // c: 3 -> 4 (from `a b c` to `a b c c`)
1991
- // b: 1 -> 2 (from `c` to `c b`)
1992
- // y: 2 -> 3 (from `b c` to `b c y`)
1993
-
1994
- for (int i = 0; i < last_n_repeat - 1; ++i) {
1995
- int repeat_len = ctx->dry_repeat_count[i];
1996
- if (repeat_len >= ctx->dry_allowed_length) {
1997
- // This token ends a repeat, so the next token would continue one.
1998
- // By convention, the value of `repeat_len` only includes the tokens currently
1999
- // in the context, not the new token that would be added.
2000
- llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
2001
- // Track the maximum sequence ending in this token.
2002
- const auto& it = ctx->dry_max_token_repeat.find(token);
2003
- if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
2004
- ctx->dry_max_token_repeat[token] = repeat_len;
2005
- }
2006
- }
2007
- }
2008
-
2009
- // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
2010
-
2011
- // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
2012
- // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
2013
- const float FLOAT_MAX_LOG = 88.7228391f;
2014
- int max_exponent = 0;
2015
- if (ctx->dry_base > 1.000001f) {
2016
- max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
2017
- }
2018
-
2019
- for (size_t i = 0; i < cur_p->size; ++i) {
2020
- const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
2021
- if (af_kvp != ctx->dry_max_token_repeat.end()) {
2022
- // Check all sequence breakers starting with this token
2023
- auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
2024
- bool is_single_token_breaker = false;
2025
-
2026
- for (auto it = range.first; it != range.second; ++it) {
2027
- if (it->second.empty()) {
2028
- is_single_token_breaker = true;
2029
- break;
2030
- }
2031
- }
2032
-
2033
- // Apply penalty only if it's not a single-token sequence breaker
2034
- if (!is_single_token_breaker) {
2035
- int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
2036
- if (max_exponent > 0 && repeat_exp > max_exponent) {
2037
- repeat_exp = max_exponent;
2038
- }
2039
- float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
2040
- cur_p->data[i].logit -= penalty;
2041
- }
2042
- }
2043
- }
2044
-
2045
- cur_p->sorted = false;
2046
- }
2047
-
2048
- static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
2049
- auto * ctx = (llama_sampler_dry *) smpl->ctx;
2050
- ctx->last_tokens.clear();
2051
- ctx->dry_repeat_count.clear();
2052
- ctx->dry_max_token_repeat.clear();
2053
- }
2054
-
2055
- static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
2056
- const auto * ctx = (llama_sampler_dry *) smpl->ctx;
2057
-
2058
- llama_vocab dummy_vocab;
2059
-
2060
- // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
2061
- auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
2062
-
2063
- // Copy the state, including the processed breakers
2064
- {
2065
- auto * result_ctx = (llama_sampler_dry *) result->ctx;
2066
- result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
2067
- result_ctx->dry_repeat_count = ctx->dry_repeat_count;
2068
- result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
2069
- result_ctx->last_tokens = ctx->last_tokens;
2070
- }
2071
-
2072
- return result;
2073
- }
2074
-
2075
- static void llama_sampler_dry_free(struct llama_sampler * smpl) {
2076
- delete (llama_sampler_dry *) smpl->ctx;
2077
- }
2078
-
2079
- static struct llama_sampler_i llama_sampler_dry_i = {
2080
- /* .name = */ llama_sampler_dry_name,
2081
- /* .accept = */ llama_sampler_dry_accept,
2082
- /* .apply = */ llama_sampler_dry_apply,
2083
- /* .reset = */ llama_sampler_dry_reset,
2084
- /* .clone = */ llama_sampler_dry_clone,
2085
- /* .free = */ llama_sampler_dry_free,
2086
- };
2087
-
2088
- struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
2089
- int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
2090
- std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
2091
- const int MAX_CHAR_LEN = 40;
2092
- const int MAX_SEQ_LEN = 20;
2093
-
2094
- const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2095
-
2096
- if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2097
- // Process sequence breakers
2098
- for (size_t i = 0; i < num_breakers; ++i) {
2099
- if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
2100
- LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
2101
- continue;
2102
- }
2103
-
2104
- std::string sequence_break(seq_breakers[i]);
2105
- if (sequence_break.empty()) {
2106
- LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
2107
- continue;
2108
- }
2109
-
2110
- if (sequence_break.size() > MAX_CHAR_LEN) {
2111
- LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
2112
- sequence_break.resize(MAX_CHAR_LEN);
2113
- }
2114
-
2115
- get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
2116
- }
2117
- }
2118
-
2119
- return llama_sampler_init(
2120
- /* .iface = */ &llama_sampler_dry_i,
2121
- /* .ctx = */ new llama_sampler_dry {
2122
- /* .total_context_size = */ context_size,
2123
- /* .dry_multiplier = */ dry_multiplier,
2124
- /* .dry_base = */ dry_base,
2125
- /* .dry_allowed_length = */ dry_allowed_length,
2126
- /* .dry_penalty_last_n = */ dry_penalty_last_n,
2127
- /* .dry_processed_breakers = */ std::move(processed_breakers),
2128
- /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
2129
- /* .dry_max_token_repeat = */ {},
2130
- /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2131
- }
2132
- );
2133
- }
2134
-
2135
- // wrapper for test-sampling.cpp
2136
- struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
2137
- llama_vocab dummy_vocab;
2138
- auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
2139
- auto * ctx = (llama_sampler_dry *) result->ctx;
2140
-
2141
- // Process the token-based sequence breakers
2142
- ctx->dry_processed_breakers.clear();
2143
- if (seq_breakers.empty()) {
2144
- LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
2145
- } else {
2146
- for (const auto& breaker : seq_breakers) {
2147
- if (breaker.empty()) {
2148
- LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
2149
- continue;
2150
- }
2151
- llama_token head_token = breaker[0];
2152
- std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
2153
- ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
2154
- }
2155
-
2156
- if (ctx->dry_processed_breakers.empty()) {
2157
- LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
2158
- }
2159
- }
2160
-
2161
- return result;
2162
- }
2163
-
2164
- // logit-bias
2165
-
2166
- struct llama_sampler_logit_bias {
2167
- const int32_t n_vocab;
2168
-
2169
- const std::vector<llama_logit_bias> logit_bias;
2170
-
2171
- std::vector<llama_logit_bias> to_search;
2172
- };
2173
-
2174
- static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
2175
- return "logit-bias";
2176
- }
2177
-
2178
- static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2179
- auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
2180
-
2181
- if (ctx->logit_bias.empty()) {
2182
- return;
2183
- }
2184
-
2185
- ctx->to_search.clear();
2186
-
2187
- // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
2188
- for (const auto & lb : ctx->logit_bias) {
2189
- if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
2190
- cur_p->data[lb.token].logit += lb.bias;
2191
- } else {
2192
- ctx->to_search.push_back(lb);
2193
- }
2194
- }
2195
-
2196
- if (ctx->to_search.empty()) {
2197
- return;
2198
- }
2199
-
2200
- // search for the remaining candidates that were not found in the previous step
2201
- for (size_t i = 0; i < cur_p->size; ++i) {
2202
- for (const auto & lb : ctx->to_search) {
2203
- if (cur_p->data[i].id == lb.token) {
2204
- cur_p->data[i].logit += lb.bias;
2205
- break;
2206
- }
2207
- }
2208
- }
2209
- }
2210
-
2211
- static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
2212
- const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
2213
- return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
2214
- }
2215
-
2216
- static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2217
- delete (llama_sampler_logit_bias *) smpl->ctx;
2218
- }
2219
-
2220
- static struct llama_sampler_i llama_sampler_logit_bias_i = {
2221
- /* .name = */ llama_sampler_logit_bias_name,
2222
- /* .accept = */ nullptr,
2223
- /* .apply = */ llama_sampler_logit_bias_apply,
2224
- /* .reset = */ nullptr,
2225
- /* .clone = */ llama_sampler_logit_bias_clone,
2226
- /* .free = */ llama_sampler_logit_bias_free,
2227
- };
2228
-
2229
- struct llama_sampler * llama_sampler_init_logit_bias(
2230
- int32_t n_vocab,
2231
- int32_t n_logit_bias,
2232
- const llama_logit_bias * logit_bias) {
2233
- return llama_sampler_init(
2234
- /* .iface = */ &llama_sampler_logit_bias_i,
2235
- /* .ctx = */ new llama_sampler_logit_bias {
2236
- /* .n_vocab = */ n_vocab,
2237
- /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2238
- /* .to_search = */ {},
2239
- }
2240
- );
2241
- }
2242
-
2243
- // infill
2244
-
2245
- //#define LM_GGML_DEBUG_SAMPLER_INFILL
2246
-
2247
- struct llama_sampler_infill {
2248
- const struct llama_vocab * vocab;
2249
-
2250
- std::vector<char> buf0;
2251
- std::vector<char> buf1;
2252
- };
2253
-
2254
- static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
2255
- return "infill";
2256
- }
2257
-
2258
- static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2259
- auto * ctx = (llama_sampler_infill *) smpl->ctx;
2260
-
2261
- llama_sampler_softmax_impl(cur_p);
2262
-
2263
- #if defined(LM_GGML_DEBUG_SAMPLER_INFILL)
2264
- #define LOG_DBG_CUR LLAMA_LOG_DEBUG
2265
- #else
2266
- #define LOG_DBG_CUR(...)
2267
- #endif
2268
-
2269
- for (size_t i = 0; i < cur_p->size; ++i) {
2270
- LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2271
- }
2272
-
2273
- float p_txt_sum = 0.0f;
2274
- float p_eog_sum = 0.0f;
2275
-
2276
- for (size_t i = 0; i < cur_p->size; ++i) {
2277
- if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2278
- p_eog_sum += cur_p->data[i].p;
2279
- } else {
2280
- p_txt_sum += cur_p->data[i].p;
2281
- }
2282
- }
2283
-
2284
- const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; LM_GGML_UNUSED(rat);
2285
-
2286
- LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
2287
-
2288
- if (3*p_eog_sum*cur_p->size > p_txt_sum) {
2289
- LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
2290
-
2291
- // keep just the EOG tokens
2292
- const auto size_org = cur_p->size;
2293
-
2294
- cur_p->size = 0;
2295
-
2296
- float p_sum = 0.0f;
2297
-
2298
- for (size_t i = 0; i < size_org; ++i) {
2299
- if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2300
- p_sum += cur_p->data[i].p;
2301
-
2302
- cur_p->data[cur_p->size++] = cur_p->data[i];
2303
- }
2304
- }
2305
-
2306
- // normalize probs
2307
- for (size_t i = 0; i < cur_p->size; ++i) {
2308
- cur_p->data[i].p /= p_sum;
2309
- }
2310
-
2311
- return;
2312
- }
2313
-
2314
- size_t n_combined = 0; LM_GGML_UNUSED(n_combined);
2315
-
2316
- // combine tokens with common prefix
2317
- for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
2318
- for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
2319
- if (cur_p->data[i0].logit == -INFINITY) {
2320
- break;
2321
- }
2322
-
2323
- if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
2324
- continue;
2325
- }
2326
-
2327
- int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2328
- if (len0 < 0) {
2329
- ctx->buf0.resize(len0);
2330
- len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2331
- assert(len0 > 0);
2332
- }
2333
-
2334
- int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2335
- if (len1 < 0) {
2336
- ctx->buf1.resize(len1);
2337
- len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2338
- assert(len1 > 0);
2339
- }
2340
-
2341
- // token i0 is a prefix of token i1
2342
- if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
2343
- int dst = i0;
2344
- int src = i1;
2345
-
2346
- // merge into the token with higher probability
2347
- if (cur_p->data[i1].p > cur_p->data[i0].p) {
2348
- std::swap(dst, src);
2349
- }
2350
-
2351
- cur_p->data[dst].p += cur_p->data[src].p;
2352
- cur_p->data[src].logit = -INFINITY;
2353
- cur_p->data[src].p = 0.0f;
2354
-
2355
- n_combined++;
2356
- }
2357
- }
2358
- }
2359
-
2360
- size_t n_non_eog = 0;
2361
-
2362
- size_t size_org = cur_p->size;
2363
-
2364
- float p_sum = 0.0f;
2365
- float thold = 0.2f;
2366
-
2367
- cur_p->size = 0;
2368
-
2369
- LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2370
-
2371
- for (size_t i = 0; i < size_org; ++i) {
2372
- const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2373
-
2374
- if (cur_p->data[i].p < thold && !is_eog) {
2375
- continue;
2376
- }
2377
-
2378
- if (!is_eog) {
2379
- ++n_non_eog;
2380
- }
2381
-
2382
- p_sum += cur_p->data[i].p;
2383
-
2384
- // keep this token
2385
- cur_p->data[cur_p->size++] = cur_p->data[i];
2386
- }
2387
-
2388
- LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2389
-
2390
- // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2391
- if (n_non_eog == 0) {
2392
- cur_p->size = 1;
2393
- cur_p->data[0].id = ctx->vocab->token_eot();
2394
- cur_p->data[0].logit = 1.0f;
2395
-
2396
- return;
2397
- }
2398
-
2399
- // normalize probs
2400
- for (size_t i = 0; i < cur_p->size; ++i) {
2401
- cur_p->data[i].p /= p_sum;
2402
-
2403
- LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2404
- }
2405
-
2406
- size_org = cur_p->size;
2407
- p_sum = 0.0f;
2408
- thold = 1.0/(n_non_eog + 1);
2409
-
2410
- cur_p->size = 0;
2411
-
2412
- LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2413
-
2414
- for (size_t i = 0; i < size_org; ++i) {
2415
- const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2416
-
2417
- if (cur_p->data[i].p < thold && !is_eog) {
2418
- continue;
2419
- }
2420
-
2421
- p_sum += cur_p->data[i].p;
2422
-
2423
- cur_p->data[cur_p->size++] = cur_p->data[i];
2424
- }
2425
-
2426
- // normalize probs
2427
- for (size_t i = 0; i < cur_p->size; ++i) {
2428
- cur_p->data[i].p /= p_sum;
2429
-
2430
- LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2431
- }
2432
-
2433
- #undef LOG_DBG_CUR
2434
- }
2435
-
2436
- static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2437
- const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2438
- return llama_sampler_init_infill(ctx->vocab);
2439
- }
2440
-
2441
- static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2442
- delete (llama_sampler_infill *) smpl->ctx;
2443
- }
2444
-
2445
- static struct llama_sampler_i llama_sampler_infill_i = {
2446
- /* .name = */ llama_sampler_infill_name,
2447
- /* .accept = */ nullptr,
2448
- /* .apply = */ llama_sampler_infill_apply,
2449
- /* .reset = */ nullptr,
2450
- /* .clone = */ llama_sampler_infill_clone,
2451
- /* .free = */ llama_sampler_infill_free,
2452
- };
2453
-
2454
- struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2455
- return llama_sampler_init(
2456
- /* .iface = */ &llama_sampler_infill_i,
2457
- /* .ctx = */ new llama_sampler_infill {
2458
- /* .vocab = */ vocab,
2459
- /* .buf0 = */ std::vector<char>(512),
2460
- /* .buf1 = */ std::vector<char>(512),
2461
- }
2462
- );
2463
- }
2464
-
2465
- // utils
2466
-
2467
- uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2468
- if (smpl->iface == &llama_sampler_dist_i) {
2469
- return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
2470
- }
2471
-
2472
- if (smpl->iface == &llama_sampler_mirostat_i) {
2473
- return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
2474
- }
2475
-
2476
- if (smpl->iface == &llama_sampler_mirostat_v2_i) {
2477
- return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
2478
- }
2479
-
2480
- if (smpl->iface == &llama_sampler_chain_i) {
2481
- const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2482
- for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2483
- const uint32_t seed = llama_sampler_get_seed(*it);
2484
- if (seed != LLAMA_DEFAULT_SEED) {
2485
- return seed;
2486
- }
2487
- }
2488
- }
2489
-
2490
- return LLAMA_DEFAULT_SEED;
2491
- }
2492
-
2493
- // perf
2494
-
2495
- struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
2496
- struct llama_perf_sampler_data data = {};
2497
-
2498
- if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2499
- LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2500
- }
2501
-
2502
- const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
2503
-
2504
- data.t_sample_ms = 1e-3 * ctx->t_sample_us;
2505
- data.n_sample = std::max(0, ctx->n_sample);
2506
-
2507
- return data;
2508
- }
2509
-
2510
- void llama_perf_sampler_print(const struct llama_sampler * chain) {
2511
- const auto data = llama_perf_sampler(chain);
2512
-
2513
- LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2514
- __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
2515
- }
2516
-
2517
- void llama_perf_sampler_reset(struct llama_sampler * chain) {
2518
- if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2519
- LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2520
- }
2521
-
2522
- auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2523
-
2524
- ctx->t_sample_us = ctx->n_sample = 0;
2525
- }
1
+ #include "llama-sampling.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-vocab.h"
5
+ #include "llama-grammar.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+ #include <cfloat>
10
+ #include <chrono>
11
+ #include <cmath>
12
+ #include <cstdlib>
13
+ #include <cstring>
14
+ #include <ctime>
15
+ #include <numeric>
16
+ #include <random>
17
+ #include <unordered_map>
18
+ #include <stdexcept>
19
+
20
+ // the ring buffer works similarly to std::deque, but with a fixed capacity
21
+ template<typename T>
22
+ struct ring_buffer {
23
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
24
+
25
+ T & front() {
26
+ if (sz == 0) {
27
+ throw std::runtime_error("ring buffer is empty");
28
+ }
29
+ return data[first];
30
+ }
31
+
32
+ const T & front() const {
33
+ if (sz == 0) {
34
+ throw std::runtime_error("ring buffer is empty");
35
+ }
36
+ return data[first];
37
+ }
38
+
39
+ T & back() {
40
+ if (sz == 0) {
41
+ throw std::runtime_error("ring buffer is empty");
42
+ }
43
+ return data[pos];
44
+ }
45
+
46
+ const T & back() const {
47
+ if (sz == 0) {
48
+ throw std::runtime_error("ring buffer is empty");
49
+ }
50
+ return data[pos];
51
+ }
52
+
53
+ void push_back(const T & value) {
54
+ if (capacity == 0) {
55
+ throw std::runtime_error("ring buffer: capacity is zero");
56
+ }
57
+
58
+ if (sz == capacity) {
59
+ // advance the start when buffer is full
60
+ first = (first + 1) % capacity;
61
+ } else {
62
+ sz++;
63
+ }
64
+ data[pos] = value;
65
+ pos = (pos + 1) % capacity;
66
+ }
67
+
68
+ T pop_front() {
69
+ if (sz == 0) {
70
+ throw std::runtime_error("ring buffer is empty");
71
+ }
72
+ T value = data[first];
73
+ first = (first + 1) % capacity;
74
+ sz--;
75
+ return value;
76
+ }
77
+
78
+ //T & operator[](size_t i) {
79
+ // if (i >= sz) {
80
+ // throw std::runtime_error("ring buffer: index out of bounds");
81
+ // }
82
+ // return data[(first + i) % capacity];
83
+ //}
84
+
85
+ //const T & at(size_t i) const {
86
+ // if (i >= sz) {
87
+ // throw std::runtime_error("ring buffer: index out of bounds");
88
+ // }
89
+ // return data[(first + i) % capacity];
90
+ //}
91
+
92
+ const T & rat(size_t i) const {
93
+ if (i >= sz) {
94
+ throw std::runtime_error("ring buffer: index out of bounds");
95
+ }
96
+ return data[(first + sz - i - 1) % capacity];
97
+ }
98
+
99
+ std::vector<T> to_vector() const {
100
+ std::vector<T> result;
101
+ result.reserve(sz);
102
+ for (size_t i = 0; i < sz; i++) {
103
+ result.push_back(data[(first + i) % capacity]);
104
+ }
105
+ return result;
106
+ }
107
+
108
+ void clear() {
109
+ // here only reset the status of the buffer
110
+ sz = 0;
111
+ first = 0;
112
+ pos = 0;
113
+ }
114
+
115
+ bool empty() const {
116
+ return sz == 0;
117
+ }
118
+
119
+ size_t size() const {
120
+ return sz;
121
+ }
122
+
123
+ size_t capacity = 0;
124
+ size_t sz = 0;
125
+ size_t first = 0;
126
+ size_t pos = 0;
127
+
128
+ std::vector<T> data;
129
+ };
130
+
131
+ static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
132
+ // iterator for the probabilities
133
+ #ifdef __GNUC__
134
+ #pragma GCC diagnostic push
135
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
136
+ #endif
137
+
138
+ struct probs_iterator {
139
+ typedef std::input_iterator_tag iterator_category;
140
+ typedef float value_type;
141
+ typedef float * pointer;
142
+ typedef float & reference;
143
+ typedef ptrdiff_t difference_type;
144
+
145
+ const llama_token_data * data;
146
+
147
+ bool operator==(const probs_iterator & other) const { return data == other.data; }
148
+ bool operator!=(const probs_iterator & other) const { return data != other.data; }
149
+ const float & operator*() const { return data->p; }
150
+ probs_iterator & operator++() { ++data; return *this; }
151
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
152
+ };
153
+
154
+ #ifdef __GNUC__
155
+ #pragma GCC diagnostic pop
156
+ #endif
157
+
158
+ std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
159
+
160
+ return dist(rng);
161
+ }
162
+
163
+ /*
164
+ static void llama_log_softmax(float * array, size_t size) {
165
+ float max_l = *std::max_element(array, array + size);
166
+ float sum = 0.f;
167
+ for (size_t i = 0; i < size; ++i) {
168
+ float p = expf(array[i] - max_l);
169
+ sum += p;
170
+ array[i] = p;
171
+ }
172
+
173
+ for (size_t i = 0; i < size; ++i) {
174
+ array[i] = logf(array[i] / sum);
175
+ }
176
+ }
177
+ */
178
+
179
+ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
180
+ if (temp <= 0.0f) {
181
+ // find the token with the highest logit and set the rest to -inf
182
+ size_t max_i = 0;
183
+ float max_l = cur_p->data[0].logit;
184
+
185
+ for (size_t i = 1; i < cur_p->size; ++i) {
186
+ if (cur_p->data[i ].logit > max_l) {
187
+ cur_p->data[max_i].logit = -INFINITY;
188
+ max_i = i;
189
+ max_l = cur_p->data[i].logit;
190
+ } else {
191
+ cur_p->data[i].logit = -INFINITY;
192
+ }
193
+ }
194
+
195
+ return;
196
+ }
197
+
198
+ for (size_t i = 0; i < cur_p->size; ++i) {
199
+ cur_p->data[i].logit /= temp;
200
+ }
201
+ }
202
+
203
+ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
204
+ LM_GGML_ASSERT(cur_p->size > 0);
205
+
206
+ // Sort the logits in descending order
207
+ if (!cur_p->sorted) {
208
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
209
+ return a.logit > b.logit;
210
+ });
211
+ cur_p->sorted = true;
212
+ }
213
+
214
+ float max_l = cur_p->data[0].logit;
215
+ float cum_sum = 0.0f;
216
+
217
+ for (size_t i = 0; i < cur_p->size; ++i) {
218
+ float p = expf(cur_p->data[i].logit - max_l);
219
+ cur_p->data[i].p = p;
220
+ cum_sum += p;
221
+ }
222
+
223
+ for (size_t i = 0; i < cur_p->size; ++i) {
224
+ cur_p->data[i].p /= cum_sum;
225
+ }
226
+ }
227
+
228
+ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
229
+ // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
230
+ // if (k >= (int32_t)cur_p->size) {
231
+ // return;
232
+ // }
233
+
234
+ if (k <= 0) {
235
+ k = cur_p->size;
236
+ }
237
+
238
+ k = std::min(k, (int) cur_p->size);
239
+
240
+ // Sort scores in descending order
241
+ if (!cur_p->sorted) {
242
+ auto comp = [](const llama_token_data & a, const llama_token_data & b) {
243
+ return a.logit > b.logit;
244
+ };
245
+ if (k <= 128) {
246
+ std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
247
+ } else {
248
+ constexpr int nbuckets = 128;
249
+ constexpr float bucket_low = -10.0f;
250
+ constexpr float bucket_high = 10.0f;
251
+ constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
252
+ constexpr float bucket_inter = -bucket_low * bucket_scale;
253
+
254
+ std::vector<int> bucket_idx(cur_p->size);
255
+ std::vector<int> histo(nbuckets, 0);
256
+
257
+ for (int i = 0; i < (int)cur_p->size; ++i) {
258
+ const float val = cur_p->data[i].logit;
259
+ int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
260
+ ib = std::max(0, std::min(nbuckets - 1, ib));
261
+ bucket_idx[i] = ib;
262
+ ++histo[ib];
263
+ }
264
+ int nhave = 0;
265
+ int ib = nbuckets - 1;
266
+ for ( ; ib >= 0; --ib) {
267
+ nhave += histo[ib];
268
+ if (nhave >= k) {
269
+ break;
270
+ }
271
+ }
272
+ std::vector<llama_token_data> tmp_tokens(nhave);
273
+ auto * ptr = tmp_tokens.data();
274
+ std::vector<llama_token_data*> bucket_ptrs;
275
+ bucket_ptrs.reserve(nbuckets - ib);
276
+ for (int j = nbuckets - 1; j >= ib; --j) {
277
+ bucket_ptrs.push_back(ptr);
278
+ ptr += histo[j];
279
+ }
280
+ for (int i = 0; i < (int)cur_p->size; ++i) {
281
+ int j = bucket_idx[i];
282
+ if (j >= ib) {
283
+ *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
284
+ }
285
+ }
286
+
287
+ ptr = tmp_tokens.data();
288
+ int ndone = 0;
289
+ for (int j = nbuckets - 1; j > ib; --j) {
290
+ std::sort(ptr, ptr + histo[j], comp);
291
+ ptr += histo[j];
292
+ ndone += histo[j];
293
+ }
294
+ std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
295
+
296
+ std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
297
+
298
+ }
299
+ cur_p->sorted = true;
300
+ }
301
+ cur_p->size = k;
302
+ }
303
+
304
+ static uint32_t get_rng_seed(uint32_t seed) {
305
+ if (seed == LLAMA_DEFAULT_SEED) {
306
+ // use system clock if std::random_device is not a true RNG
307
+ static bool is_rd_prng = std::random_device().entropy() == 0;
308
+ if (is_rd_prng) {
309
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
310
+ }
311
+ std::random_device rd;
312
+ return rd();
313
+ }
314
+ return seed;
315
+ }
316
+
317
+ // llama_sampler API
318
+
319
+ struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
320
+ return new llama_sampler {
321
+ /* .iface = */ iface,
322
+ /* .ctx = */ ctx,
323
+ };
324
+ }
325
+
326
+ const char * llama_sampler_name(const struct llama_sampler * smpl) {
327
+ if (!smpl->iface) {
328
+ return "(null)";
329
+ }
330
+
331
+ return smpl->iface->name(smpl);
332
+ }
333
+
334
+ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
335
+ if (smpl->iface->accept) {
336
+ smpl->iface->accept(smpl, token);
337
+ }
338
+ }
339
+
340
+ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
341
+ LM_GGML_ASSERT(smpl->iface->apply);
342
+ smpl->iface->apply(smpl, cur_p);
343
+ }
344
+
345
+ void llama_sampler_reset(struct llama_sampler * smpl) {
346
+ if (smpl->iface->reset) {
347
+ smpl->iface->reset(smpl);
348
+ }
349
+ }
350
+
351
+ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
352
+ if (smpl->iface->clone) {
353
+ return smpl->iface->clone(smpl);
354
+ }
355
+
356
+ if (smpl->ctx == nullptr) {
357
+ return llama_sampler_init(
358
+ /* .iface = */ smpl->iface,
359
+ /* .ctx = */ nullptr
360
+ );
361
+ }
362
+
363
+ LM_GGML_ABORT("the sampler does not support cloning");
364
+ }
365
+
366
+ void llama_sampler_free(struct llama_sampler * smpl) {
367
+ if (smpl == nullptr) {
368
+ return;
369
+ }
370
+
371
+ if (smpl->iface->free) {
372
+ smpl->iface->free(smpl);
373
+ }
374
+
375
+ delete smpl;
376
+ }
377
+
378
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
379
+ const auto * logits = llama_get_logits_ith(ctx, idx);
380
+
381
+ const llama_model * model = llama_get_model(ctx);
382
+ const llama_vocab * vocab = llama_model_get_vocab(model);
383
+
384
+ const int n_vocab = llama_vocab_n_tokens(vocab);
385
+
386
+ // TODO: do not allocate each time
387
+ std::vector<llama_token_data> cur;
388
+ cur.reserve(n_vocab);
389
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
390
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
391
+ }
392
+
393
+ llama_token_data_array cur_p = {
394
+ /* .data = */ cur.data(),
395
+ /* .size = */ cur.size(),
396
+ /* .selected = */ -1,
397
+ /* .sorted = */ false,
398
+ };
399
+
400
+ llama_sampler_apply(smpl, &cur_p);
401
+
402
+ LM_GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
403
+
404
+ auto token = cur_p.data[cur_p.selected].id;
405
+
406
+ llama_sampler_accept(smpl, token);
407
+
408
+ return token;
409
+ }
410
+
411
+ // sampler chain
412
+
413
+ static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
414
+ return "chain";
415
+ }
416
+
417
+ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
418
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
419
+
420
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
421
+
422
+ for (auto * smpl : chain->samplers) {
423
+ llama_sampler_accept(smpl, token);
424
+ }
425
+
426
+ chain->n_sample++;
427
+ }
428
+
429
+ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
430
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
431
+
432
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
433
+
434
+ for (auto * smpl : chain->samplers) {
435
+ llama_sampler_apply(smpl, cur_p);
436
+ }
437
+ }
438
+
439
+ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
440
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
441
+
442
+ for (auto * smpl : chain->samplers) {
443
+ llama_sampler_reset(smpl);
444
+ }
445
+
446
+ chain->t_sample_us = 0;
447
+ chain->n_sample = 0;
448
+ }
449
+
450
+ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
451
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
452
+
453
+ auto * result = llama_sampler_chain_init(chain_src->params);
454
+
455
+ for (auto * smpl : chain_src->samplers) {
456
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
457
+ }
458
+
459
+ return result;
460
+ }
461
+
462
+ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
463
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
464
+
465
+ for (auto * smpl : chain->samplers) {
466
+ llama_sampler_free(smpl);
467
+ }
468
+
469
+ delete chain;
470
+ }
471
+
472
+ static struct llama_sampler_i llama_sampler_chain_i = {
473
+ /* .name = */ llama_sampler_chain_name,
474
+ /* .accept = */ llama_sampler_chain_accept,
475
+ /* .apply = */ llama_sampler_chain_apply,
476
+ /* .reset = */ llama_sampler_chain_reset,
477
+ /* .clone = */ llama_sampler_chain_clone,
478
+ /* .free = */ llama_sampler_chain_free,
479
+ };
480
+
481
+ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
482
+ return llama_sampler_init(
483
+ /* .iface = */ &llama_sampler_chain_i,
484
+ /* .ctx = */ new llama_sampler_chain {
485
+ /* .params = */ params,
486
+ /* .samplers = */ {},
487
+ /* .t_sample_us = */ 0,
488
+ /* .n_sample = */ 0,
489
+ }
490
+ );
491
+ }
492
+
493
+ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
494
+ auto * p = (llama_sampler_chain *) chain->ctx;
495
+ p->samplers.push_back(smpl);
496
+ }
497
+
498
+
499
+ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
500
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
501
+
502
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
503
+ return nullptr;
504
+ }
505
+
506
+ return p->samplers[i];
507
+ }
508
+
509
+ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
510
+ auto * p = (llama_sampler_chain *) chain->ctx;
511
+
512
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
513
+ return nullptr;
514
+ }
515
+
516
+ auto * result = p->samplers[i];
517
+ p->samplers.erase(p->samplers.begin() + i);
518
+
519
+ return result;
520
+ }
521
+
522
+ int llama_sampler_chain_n(const struct llama_sampler * chain) {
523
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
524
+
525
+ return p->samplers.size();
526
+ }
527
+
528
+ //
529
+ // samplers
530
+ //
531
+
532
+ // greedy
533
+
534
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
535
+ return "greedy";
536
+ }
537
+
538
+ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
539
+ cur_p->selected = 0;
540
+ for (size_t i = 1; i < cur_p->size; ++i) {
541
+ if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
542
+ cur_p->selected = i;
543
+ }
544
+ }
545
+ }
546
+
547
+ static struct llama_sampler_i llama_sampler_greedy_i = {
548
+ /* .name = */ llama_sampler_greedy_name,
549
+ /* .accept = */ nullptr,
550
+ /* .apply = */ llama_sampler_greedy_apply,
551
+ /* .reset = */ nullptr,
552
+ /* .clone = */ nullptr,
553
+ /* .free = */ nullptr,
554
+ };
555
+
556
+ struct llama_sampler * llama_sampler_init_greedy() {
557
+ return llama_sampler_init(
558
+ /* .iface = */ &llama_sampler_greedy_i,
559
+ /* .ctx = */ nullptr
560
+ );
561
+ }
562
+
563
+ // dist
564
+
565
+ struct llama_sampler_dist {
566
+ const uint32_t seed;
567
+ uint32_t seed_cur;
568
+
569
+ std::mt19937 rng;
570
+ };
571
+
572
+ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
573
+ return "dist";
574
+ }
575
+
576
+ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
577
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
578
+
579
+ llama_sampler_softmax_impl(cur_p);
580
+
581
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
582
+ }
583
+
584
+ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
585
+ const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
586
+ auto * result = llama_sampler_init_dist(ctx->seed);
587
+
588
+ // copy the state
589
+ {
590
+ auto * result_ctx = (llama_sampler_dist *) result->ctx;
591
+
592
+ result_ctx->rng = ctx->rng;
593
+ }
594
+
595
+ return result;
596
+ }
597
+
598
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
599
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
600
+ ctx->seed_cur = get_rng_seed(ctx->seed);
601
+ ctx->rng.seed(ctx->seed_cur);
602
+ }
603
+
604
+ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
605
+ delete (llama_sampler_dist *) smpl->ctx;
606
+ }
607
+
608
+ static struct llama_sampler_i llama_sampler_dist_i = {
609
+ /* .name = */ llama_sampler_dist_name,
610
+ /* .accept = */ nullptr,
611
+ /* .apply = */ llama_sampler_dist_apply,
612
+ /* .reset = */ llama_sampler_dist_reset,
613
+ /* .clone = */ llama_sampler_dist_clone,
614
+ /* .free = */ llama_sampler_dist_free,
615
+ };
616
+
617
+ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
618
+ auto seed_cur = get_rng_seed(seed);
619
+ return llama_sampler_init(
620
+ /* .iface = */ &llama_sampler_dist_i,
621
+ /* .ctx = */ new llama_sampler_dist {
622
+ /* .seed = */ seed,
623
+ /* .seed_cur = */ seed_cur,
624
+ /* .rng = */ std::mt19937(seed_cur),
625
+ }
626
+ );
627
+ }
628
+
629
+ // softmax
630
+
631
+ static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
632
+ return "softmax";
633
+ }
634
+
635
+ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
636
+ llama_sampler_softmax_impl(cur_p);
637
+ }
638
+
639
+ static struct llama_sampler_i llama_sampler_softmax_i = {
640
+ /* .name = */ llama_sampler_softmax_name,
641
+ /* .accept = */ nullptr,
642
+ /* .apply = */ llama_sampler_softmax_apply,
643
+ /* .reset = */ nullptr,
644
+ /* .clone = */ nullptr,
645
+ /* .free = */ nullptr,
646
+ };
647
+
648
+ struct llama_sampler * llama_sampler_init_softmax() {
649
+ return llama_sampler_init(
650
+ /* .iface = */ &llama_sampler_softmax_i,
651
+ /* .ctx = */ nullptr
652
+ );
653
+ }
654
+
655
+ // top-k
656
+
657
+ struct llama_sampler_top_k {
658
+ const int32_t k;
659
+ };
660
+
661
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
662
+ return "top-k";
663
+ }
664
+
665
+ static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
666
+ const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
667
+ llama_sampler_top_k_impl(cur_p, ctx->k);
668
+ }
669
+
670
+ static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
671
+ const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
672
+ return llama_sampler_init_top_k(ctx->k);
673
+ }
674
+
675
+ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
676
+ delete (llama_sampler_top_k *) smpl->ctx;
677
+ }
678
+
679
+ static struct llama_sampler_i llama_sampler_top_k_i = {
680
+ /* .name = */ llama_sampler_top_k_name,
681
+ /* .accept = */ nullptr,
682
+ /* .apply = */ llama_sampler_top_k_apply,
683
+ /* .reset = */ nullptr,
684
+ /* .clone = */ llama_sampler_top_k_clone,
685
+ /* .free = */ llama_sampler_top_k_free,
686
+ };
687
+
688
+ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
689
+ return llama_sampler_init(
690
+ /* .iface = */ &llama_sampler_top_k_i,
691
+ /* .ctx = */ new llama_sampler_top_k {
692
+ /* .k = */ k,
693
+ }
694
+ );
695
+ }
696
+
697
+ // top-p
698
+
699
+ struct llama_sampler_top_p {
700
+ const float p;
701
+ const size_t min_keep;
702
+ };
703
+
704
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
705
+ return "top-p";
706
+ }
707
+
708
+ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
709
+ const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
710
+
711
+ if (ctx->p >= 1.0f) {
712
+ return;
713
+ }
714
+
715
+ llama_sampler_softmax_impl(cur_p);
716
+
717
+ // Compute the cumulative probabilities
718
+ float cum_sum = 0.0f;
719
+ size_t last_idx = cur_p->size;
720
+
721
+ for (size_t i = 0; i < cur_p->size; ++i) {
722
+ cum_sum += cur_p->data[i].p;
723
+
724
+ // Check if the running sum is at least p or if we have kept at least min_keep tokens
725
+ // we set the last index to i+1 to indicate that the current iterate should be included in the set
726
+ if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
727
+ last_idx = i + 1;
728
+ break;
729
+ }
730
+ }
731
+
732
+ // Resize the output vector to keep only the top-p tokens
733
+ cur_p->size = last_idx;
734
+ }
735
+
736
+ static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
737
+ const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
738
+ return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
739
+ }
740
+
741
+ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
742
+ delete (llama_sampler_top_p *) smpl->ctx;
743
+ }
744
+
745
+ static struct llama_sampler_i llama_sampler_top_p_i = {
746
+ /* .name = */ llama_sampler_top_p_name,
747
+ /* .accept = */ nullptr,
748
+ /* .apply = */ llama_sampler_top_p_apply,
749
+ /* .reset = */ nullptr,
750
+ /* .clone = */ llama_sampler_top_p_clone,
751
+ /* .free = */ llama_sampler_top_p_free,
752
+ };
753
+
754
+ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
755
+ return llama_sampler_init(
756
+ /* .iface = */ &llama_sampler_top_p_i,
757
+ /* .ctx = */ new llama_sampler_top_p {
758
+ /* .p = */ p,
759
+ /* .min_keep = */ min_keep,
760
+ }
761
+ );
762
+ }
763
+
764
+ // min-p
765
+
766
+ struct llama_sampler_min_p {
767
+ const float p;
768
+ const size_t min_keep;
769
+ };
770
+
771
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
772
+ return "min-p";
773
+ }
774
+
775
+ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
776
+ const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
777
+
778
+ if (ctx->p <= 0.0f || !cur_p->size) {
779
+ return;
780
+ }
781
+
782
+ bool min_p_applied = false;
783
+
784
+ // if the cur_p aren't sorted, try the unsorted implementation first
785
+ if (!cur_p->sorted) {
786
+ std::vector<llama_token_data> filtered_tokens;
787
+
788
+ float max_logit = -FLT_MAX;
789
+ for (size_t i = 0; i < cur_p->size; ++i) {
790
+ max_logit = std::max(max_logit, cur_p->data[i].logit);
791
+ }
792
+ const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
793
+
794
+ for (size_t i = 0; i < cur_p->size; ++i) {
795
+ if (cur_p->data[i].logit >= min_logit) {
796
+ filtered_tokens.push_back(cur_p->data[i]);
797
+ }
798
+ }
799
+
800
+ // if we have enough values the operation was a success
801
+ if (filtered_tokens.size() >= ctx->min_keep) {
802
+ memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
803
+ cur_p->size = filtered_tokens.size();
804
+ min_p_applied = true;
805
+ }
806
+ }
807
+
808
+ // if the cur_p are sorted or the unsorted implementation failed, use this implementation
809
+ if (!min_p_applied) {
810
+ // Sort the logits in descending order
811
+ if (!cur_p->sorted) {
812
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
813
+ return a.logit > b.logit;
814
+ });
815
+ cur_p->sorted = true;
816
+ }
817
+
818
+ const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
819
+ size_t i = 1; // first token always matches
820
+
821
+ for (; i < cur_p->size; ++i) {
822
+ if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
823
+ break; // prob too small
824
+ }
825
+ }
826
+
827
+ // Resize the output vector to keep only the matching tokens
828
+ cur_p->size = i;
829
+ }
830
+ }
831
+
832
+ static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
833
+ const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
834
+ return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
835
+ }
836
+
837
+ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
838
+ delete (llama_sampler_min_p *) smpl->ctx;
839
+ }
840
+
841
+ static struct llama_sampler_i llama_sampler_min_p_i = {
842
+ /* .name = */ llama_sampler_min_p_name,
843
+ /* .accept = */ nullptr,
844
+ /* .apply = */ llama_sampler_min_p_apply,
845
+ /* .reset = */ nullptr,
846
+ /* .clone = */ llama_sampler_min_p_clone,
847
+ /* .free = */ llama_sampler_min_p_free,
848
+ };
849
+
850
+ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
851
+ return llama_sampler_init(
852
+ /* .iface = */ &llama_sampler_min_p_i,
853
+ /* .ctx = */ new llama_sampler_min_p {
854
+ /* .p = */ p,
855
+ /* .min_keep = */ min_keep,
856
+ }
857
+ );
858
+ }
859
+
860
+ // typical
861
+
862
+ struct llama_sampler_typical {
863
+ const float p;
864
+ const size_t min_keep;
865
+ };
866
+
867
+ static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
868
+ return "typical";
869
+ }
870
+
871
+ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
872
+ const auto * ctx = (llama_sampler_typical *) smpl->ctx;
873
+
874
+ // Reference implementation:
875
+ // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
876
+ if (ctx->p >= 1.0f) {
877
+ return;
878
+ }
879
+
880
+ // Compute the softmax of logits and calculate entropy
881
+ llama_sampler_softmax_impl(cur_p);
882
+
883
+ float entropy = 0.0f;
884
+ for (size_t i = 0; i < cur_p->size; ++i) {
885
+ entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
886
+ }
887
+
888
+ // Compute the absolute difference between negative log probability and entropy for each candidate
889
+ std::vector<float> shifted_scores;
890
+ for (size_t i = 0; i < cur_p->size; ++i) {
891
+ float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
892
+ shifted_scores.push_back(shifted_score);
893
+ }
894
+
895
+ // Sort tokens based on the shifted_scores and their corresponding indices
896
+ std::vector<size_t> indices(cur_p->size);
897
+ std::iota(indices.begin(), indices.end(), 0);
898
+
899
+ std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
900
+ return shifted_scores[a] < shifted_scores[b];
901
+ });
902
+
903
+ // Compute the cumulative probabilities
904
+ float cum_sum = 0.0f;
905
+ size_t last_idx = indices.size();
906
+
907
+ for (size_t i = 0; i < indices.size(); ++i) {
908
+ size_t idx = indices[i];
909
+ cum_sum += cur_p->data[idx].p;
910
+
911
+ // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
912
+ if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
913
+ last_idx = i + 1;
914
+ break;
915
+ }
916
+ }
917
+
918
+ // Resize the output vector to keep only the locally typical tokens
919
+ std::vector<llama_token_data> cur_p_new;
920
+ for (size_t i = 0; i < last_idx; ++i) {
921
+ size_t idx = indices[i];
922
+ cur_p_new.push_back(cur_p->data[idx]);
923
+ }
924
+
925
+ // Replace the data in cur_p with the cur_p_new data
926
+ std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
927
+ cur_p->size = cur_p_new.size();
928
+ cur_p->sorted = false;
929
+ }
930
+
931
+ static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
932
+ const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
933
+ return llama_sampler_init_typical(ctx->p, ctx->min_keep);
934
+ }
935
+
936
+ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
937
+ delete (llama_sampler_typical *) smpl->ctx;
938
+ }
939
+
940
+ static struct llama_sampler_i llama_sampler_typical_i = {
941
+ /* .name = */ llama_sampler_typical_name,
942
+ /* .accept = */ nullptr,
943
+ /* .apply = */ llama_sampler_typical_apply,
944
+ /* .reset = */ nullptr,
945
+ /* .clone = */ llama_sampler_typical_clone,
946
+ /* .free = */ llama_sampler_typical_free,
947
+ };
948
+
949
+ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
950
+ return llama_sampler_init(
951
+ /* .iface = */ &llama_sampler_typical_i,
952
+ /* .ctx = */ new llama_sampler_typical {
953
+ /* .p = */ p,
954
+ /* .min_keep = */ min_keep,
955
+ }
956
+ );
957
+ }
958
+
959
+ // temp
960
+
961
+ struct llama_sampler_temp {
962
+ const float temp;
963
+ };
964
+
965
+ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
966
+ return "temp";
967
+ }
968
+
969
+ static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
970
+ const auto * ctx = (llama_sampler_temp *) smpl->ctx;
971
+
972
+ llama_sampler_temp_impl(cur_p, ctx->temp);
973
+ }
974
+
975
+ static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
976
+ const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
977
+ return llama_sampler_init_temp(ctx->temp);
978
+ }
979
+
980
+ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
981
+ delete (llama_sampler_temp *) smpl->ctx;
982
+ }
983
+
984
+ static struct llama_sampler_i llama_sampler_temp_i = {
985
+ /* .name = */ llama_sampler_temp_name,
986
+ /* .accept = */ nullptr,
987
+ /* .apply = */ llama_sampler_temp_apply,
988
+ /* .reset = */ nullptr,
989
+ /* .clone = */ llama_sampler_temp_clone,
990
+ /* .free = */ llama_sampler_temp_free,
991
+ };
992
+
993
+ struct llama_sampler * llama_sampler_init_temp(float temp) {
994
+ return llama_sampler_init(
995
+ /* .iface = */ &llama_sampler_temp_i,
996
+ /* .ctx = */ new llama_sampler_temp {
997
+ /*.temp = */ temp,
998
+ }
999
+ );
1000
+ }
1001
+
1002
+ // temp-ext
1003
+
1004
+ struct llama_sampler_temp_ext {
1005
+ const float temp;
1006
+ const float delta;
1007
+ const float exponent;
1008
+ };
1009
+
1010
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1011
+ return "temp-ext";
1012
+ }
1013
+
1014
+ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1015
+ const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1016
+ if (ctx->delta > 0) {
1017
+ const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1018
+ const float max_temp = ctx->temp + ctx->delta;
1019
+
1020
+ float exponent_val = ctx->exponent;
1021
+
1022
+ // no need to do anything if there is only one (or zero) candidates
1023
+ if (cur_p->size <= 1) {
1024
+ return;
1025
+ }
1026
+
1027
+ // Calculate maximum possible entropy
1028
+ float max_entropy = -logf(1.0f / cur_p->size);
1029
+
1030
+ llama_sampler_softmax_impl(cur_p);
1031
+
1032
+ // Calculate entropy of the softmax probabilities
1033
+ float entropy = 0.0f;
1034
+ for (size_t i = 0; i < cur_p->size; ++i) {
1035
+ float prob = cur_p->data[i].p;
1036
+ if (prob > 0.0f) { // Ensure no log(0)
1037
+ entropy -= prob * logf(prob);
1038
+ }
1039
+ }
1040
+
1041
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1042
+ float normalized_entropy = entropy / max_entropy;
1043
+
1044
+ // Map the normalized entropy to the desired temperature range using the power function
1045
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1046
+
1047
+ #ifdef DEBUG
1048
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1049
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1050
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1051
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1052
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1053
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1054
+ #endif
1055
+
1056
+ // Apply the dynamically calculated temperature scaling
1057
+ llama_sampler_temp_impl(cur_p, dyn_temp);
1058
+
1059
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
1060
+ const double max_l_double = cur_p->data[0].logit;
1061
+
1062
+ double cum_sum_double = 0.0;
1063
+ for (size_t i = 0; i < cur_p->size; ++i) {
1064
+ double p = exp(cur_p->data[i].logit - max_l_double);
1065
+ cur_p->data[i].p = p; // Store the scaled probability
1066
+ cum_sum_double += p;
1067
+ }
1068
+
1069
+ for (size_t i = 0; i < cur_p->size; ++i) {
1070
+ cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1071
+ }
1072
+
1073
+ #ifdef DEBUG
1074
+ // Print the updated top 25 probabilities after temperature scaling
1075
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1076
+ for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1077
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1078
+ }
1079
+ #endif
1080
+ } else {
1081
+ llama_sampler_temp_impl(cur_p, ctx->temp);
1082
+ }
1083
+ }
1084
+
1085
+ static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1086
+ const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1087
+ return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1088
+ }
1089
+
1090
+ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1091
+ delete (llama_sampler_temp_ext *) smpl->ctx;
1092
+ }
1093
+
1094
+ static struct llama_sampler_i llama_sampler_temp_ext_i = {
1095
+ /* .name = */ llama_sampler_temp_ext_name,
1096
+ /* .accept = */ nullptr,
1097
+ /* .apply = */ llama_sampler_temp_ext_apply,
1098
+ /* .reset = */ nullptr,
1099
+ /* .clone = */ llama_sampler_temp_ext_clone,
1100
+ /* .free = */ llama_sampler_temp_ext_free,
1101
+ };
1102
+
1103
+ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1104
+ return llama_sampler_init(
1105
+ /* .iface = */ &llama_sampler_temp_ext_i,
1106
+ /* .ctx = */ new llama_sampler_temp_ext {
1107
+ /* .temp = */ temp,
1108
+ /* .delta = */ delta,
1109
+ /* .exponent = */ exponent,
1110
+ }
1111
+ );
1112
+ }
1113
+
1114
+ // xtc
1115
+
1116
+ struct llama_sampler_xtc {
1117
+ const float probability;
1118
+ const float threshold;
1119
+ const size_t min_keep;
1120
+
1121
+ const uint32_t seed;
1122
+ uint32_t seed_cur;
1123
+
1124
+ std::mt19937 rng;
1125
+ };
1126
+
1127
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1128
+ return "xtc";
1129
+ }
1130
+
1131
+ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1132
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1133
+
1134
+ if (ctx->probability <= 0.0f
1135
+ || ctx->threshold > 0.5f
1136
+ || cur_p->size < 2) {
1137
+ return;
1138
+ }
1139
+
1140
+ std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1141
+ float chance = distribution(ctx->rng);
1142
+ if (chance > ctx->probability) return;
1143
+
1144
+ // in case it's not sorted/recalculated yet
1145
+ llama_sampler_softmax_impl(cur_p);
1146
+
1147
+ int pos_last = 0;
1148
+
1149
+ for (size_t i = 0; i < cur_p->size; ++i) {
1150
+ if (cur_p->data[i].p >= ctx->threshold) {
1151
+ pos_last = i;
1152
+ } else break;
1153
+ }
1154
+
1155
+ if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1156
+ cur_p->data += pos_last;
1157
+ cur_p->size -= pos_last;
1158
+ }
1159
+ }
1160
+
1161
+ static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1162
+ const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1163
+ auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1164
+
1165
+ // copy the state
1166
+ {
1167
+ auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1168
+
1169
+ result_ctx->rng = ctx->rng;
1170
+ }
1171
+
1172
+ return result;
1173
+ }
1174
+
1175
+ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1176
+ delete (llama_sampler_xtc *) smpl->ctx;
1177
+ }
1178
+
1179
+ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1180
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1181
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1182
+ ctx->rng.seed(ctx->seed_cur);
1183
+ }
1184
+
1185
+ static struct llama_sampler_i llama_sampler_xtc_i = {
1186
+ /* .name = */ llama_sampler_xtc_name,
1187
+ /* .accept = */ nullptr,
1188
+ /* .apply = */ llama_sample_xtc_apply,
1189
+ /* .reset = */ llama_sampler_xtc_reset,
1190
+ /* .clone = */ llama_sampler_xtc_clone,
1191
+ /* .free = */ llama_sampler_xtc_free,
1192
+ };
1193
+
1194
+ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1195
+ auto seed_cur = get_rng_seed(seed);
1196
+ return llama_sampler_init(
1197
+ /* .iface = */ &llama_sampler_xtc_i,
1198
+ /* .ctx = */ new llama_sampler_xtc {
1199
+ /* .probability = */ p,
1200
+ /* .threshold = */ t,
1201
+ /* .min_keep = */ min_keep,
1202
+ /* .seed = */ seed,
1203
+ /* .seed_cur = */ seed_cur,
1204
+ /* .rng = */ std::mt19937(seed_cur),
1205
+ }
1206
+ );
1207
+ }
1208
+
1209
+ // mirostat
1210
+
1211
+ struct llama_sampler_mirostat {
1212
+ const int32_t n_vocab;
1213
+
1214
+ const uint32_t seed;
1215
+ uint32_t seed_cur;
1216
+
1217
+ const float tau;
1218
+ const float eta;
1219
+
1220
+ const int32_t m;
1221
+
1222
+ float mu;
1223
+
1224
+ std::mt19937 rng;
1225
+ };
1226
+
1227
+ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1228
+ return "mirostat";
1229
+ }
1230
+
1231
+ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1232
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1233
+
1234
+ llama_sampler_softmax_impl(cur_p);
1235
+
1236
+ // Estimate s_hat using the most probable m tokens
1237
+ float s_hat = 0.0;
1238
+ float sum_ti_bi = 0.0;
1239
+ float sum_ti_sq = 0.0;
1240
+ for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
1241
+ float t_i = logf(float(i + 2) / float(i + 1));
1242
+ float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
1243
+ sum_ti_bi += t_i * b_i;
1244
+ sum_ti_sq += t_i * t_i;
1245
+ }
1246
+ s_hat = sum_ti_bi / sum_ti_sq;
1247
+
1248
+ // Compute k from the estimated s_hat and target surprise value
1249
+ float epsilon_hat = s_hat - 1;
1250
+ float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
1251
+
1252
+ llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1253
+ llama_sampler_softmax_impl(cur_p);
1254
+
1255
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1256
+
1257
+ cur_p->selected = idx;
1258
+
1259
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1260
+ float e = observed_surprise - ctx->tau;
1261
+
1262
+ // Update mu using the learning rate and error
1263
+ ctx->mu = ctx->mu - ctx->eta * e;
1264
+ }
1265
+
1266
+ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1267
+ const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1268
+ auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
1269
+
1270
+ // copy the state
1271
+ {
1272
+ auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1273
+
1274
+ result_ctx->mu = ctx->mu;
1275
+ result_ctx->rng = ctx->rng;
1276
+ }
1277
+
1278
+ return result;
1279
+ }
1280
+
1281
+ static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1282
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1283
+ ctx->mu = 2.0f*ctx->tau;
1284
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1285
+ ctx->rng.seed(ctx->seed_cur);
1286
+ }
1287
+
1288
+ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1289
+ delete (llama_sampler_mirostat *) smpl->ctx;
1290
+ }
1291
+
1292
+ static struct llama_sampler_i llama_sampler_mirostat_i = {
1293
+ /* .name = */ llama_sampler_mirostat_name,
1294
+ /* .accept = */ nullptr,
1295
+ /* .apply = */ llama_sampler_mirostat_apply,
1296
+ /* .reset = */ llama_sampler_mirostat_reset,
1297
+ /* .clone = */ llama_sampler_mirostat_clone,
1298
+ /* .free = */ llama_sampler_mirostat_free,
1299
+ };
1300
+
1301
+ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1302
+ auto seed_cur = get_rng_seed(seed);
1303
+ return llama_sampler_init(
1304
+ /* .iface = */ &llama_sampler_mirostat_i,
1305
+ /* .ctx = */ new llama_sampler_mirostat {
1306
+ /* .n_vocab = */ n_vocab,
1307
+ /* .seed = */ seed,
1308
+ /* .seed_cur = */ seed_cur,
1309
+ /* .tau = */ tau,
1310
+ /* .eta = */ eta,
1311
+ /* .m = */ m,
1312
+ /* .mu = */ 2.0f*tau,
1313
+ /* .rng = */ std::mt19937(seed_cur),
1314
+ }
1315
+ );
1316
+ }
1317
+
1318
+ // mirostat v2
1319
+
1320
+ struct llama_sampler_mirostat_v2 {
1321
+ const uint32_t seed;
1322
+ uint32_t seed_cur;
1323
+
1324
+ const float tau;
1325
+ const float eta;
1326
+
1327
+ float mu;
1328
+
1329
+ std::mt19937 rng;
1330
+ };
1331
+
1332
+ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1333
+ return "mirostat-v2";
1334
+ }
1335
+
1336
+ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1337
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1338
+
1339
+ llama_sampler_softmax_impl(cur_p);
1340
+
1341
+ // Truncate the words with surprise values greater than mu
1342
+ cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1343
+ return -log2f(candidate.p) > ctx->mu;
1344
+ }));
1345
+
1346
+ if (cur_p->size == 0) {
1347
+ cur_p->size = 1;
1348
+ }
1349
+
1350
+ // Normalize the probabilities of the remaining words
1351
+ llama_sampler_softmax_impl(cur_p);
1352
+
1353
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1354
+
1355
+ cur_p->selected = idx;
1356
+
1357
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1358
+ float e = observed_surprise - ctx->tau;
1359
+
1360
+ // Update mu using the learning rate and error
1361
+ ctx->mu = ctx->mu - ctx->eta * e;
1362
+ }
1363
+
1364
+ static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1365
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1366
+ ctx->mu = 2.0f*ctx->tau;
1367
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1368
+ ctx->rng.seed(ctx->seed_cur);
1369
+ }
1370
+
1371
+ static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1372
+ const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1373
+
1374
+ auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
1375
+
1376
+ // copy the state
1377
+ {
1378
+ auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1379
+
1380
+ result_ctx->mu = ctx->mu;
1381
+ result_ctx->rng = ctx->rng;
1382
+ }
1383
+
1384
+ return result;
1385
+ }
1386
+
1387
+ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1388
+ delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1389
+ }
1390
+
1391
+ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1392
+ /* .name = */ llama_sampler_mirostat_v2_name,
1393
+ /* .accept = */ nullptr,
1394
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
1395
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
1396
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
1397
+ /* .free = */ llama_sampler_mirostat_v2_free,
1398
+ };
1399
+
1400
+ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1401
+ auto seed_cur = get_rng_seed(seed);
1402
+ return llama_sampler_init(
1403
+ /* .iface = */ &llama_sampler_mirostat_v2_i,
1404
+ /* .ctx = */ new llama_sampler_mirostat_v2 {
1405
+ /* .seed = */ seed,
1406
+ /* .seed_cur = */ seed_cur,
1407
+ /* .tau = */ tau,
1408
+ /* .eta = */ eta,
1409
+ /* .mu = */ 2.0f*tau,
1410
+ /* .rng = */ std::mt19937(seed_cur),
1411
+ }
1412
+ );
1413
+ }
1414
+
1415
+ // grammar
1416
+
1417
+ struct llama_sampler_grammar {
1418
+ const struct llama_vocab * vocab;
1419
+
1420
+ std::string grammar_str;
1421
+ std::string grammar_root;
1422
+
1423
+ struct llama_grammar * grammar;
1424
+ };
1425
+
1426
+ static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1427
+ return "grammar";
1428
+ }
1429
+
1430
+ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1431
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1432
+ if (ctx->grammar) {
1433
+ llama_grammar_accept_impl(*ctx->grammar, token);
1434
+ }
1435
+ }
1436
+
1437
+ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1438
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1439
+ if (ctx->grammar) {
1440
+ llama_grammar_apply_impl(*ctx->grammar, cur_p);
1441
+ }
1442
+ }
1443
+
1444
+ // Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
1445
+ static struct llama_sampler * llama_sampler_init_grammar_impl(
1446
+ const struct llama_vocab * vocab,
1447
+ const char * grammar_str,
1448
+ const char * grammar_root,
1449
+ bool lazy,
1450
+ const char ** trigger_words,
1451
+ size_t num_trigger_words,
1452
+ const llama_token * trigger_tokens,
1453
+ size_t num_trigger_tokens,
1454
+ const char ** trigger_patterns,
1455
+ size_t num_trigger_patterns);
1456
+
1457
+ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1458
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1459
+ if (!ctx->grammar) {
1460
+ return;
1461
+ }
1462
+
1463
+ std::vector<const char *> trigger_patterns_c;
1464
+ trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
1465
+ for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
1466
+ trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
1467
+ }
1468
+
1469
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
1470
+ ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
1471
+ ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
1472
+
1473
+ llama_grammar_free_impl(ctx->grammar);
1474
+ ctx->grammar = grammar_new;
1475
+ }
1476
+
1477
+ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1478
+ const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1479
+
1480
+ auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
1481
+
1482
+ // copy the state
1483
+ {
1484
+ auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1485
+
1486
+ if (ctx->grammar) {
1487
+ result_ctx->grammar_str = ctx->grammar_str;
1488
+ result_ctx->grammar_root = ctx->grammar_root;
1489
+
1490
+ result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1491
+ }
1492
+ }
1493
+
1494
+ return result;
1495
+ }
1496
+
1497
+ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1498
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1499
+
1500
+ if (ctx->grammar) {
1501
+ llama_grammar_free_impl(ctx->grammar);
1502
+ }
1503
+
1504
+ delete ctx;
1505
+ }
1506
+
1507
+ static struct llama_sampler_i llama_sampler_grammar_i = {
1508
+ /* .name = */ llama_sampler_grammar_name,
1509
+ /* .accept = */ llama_sampler_grammar_accept_impl,
1510
+ /* .apply = */ llama_sampler_grammar_apply,
1511
+ /* .reset = */ llama_sampler_grammar_reset,
1512
+ /* .clone = */ llama_sampler_grammar_clone,
1513
+ /* .free = */ llama_sampler_grammar_free,
1514
+ };
1515
+
1516
+ static struct llama_sampler * llama_sampler_init_grammar_impl(
1517
+ const struct llama_vocab * vocab,
1518
+ const char * grammar_str,
1519
+ const char * grammar_root,
1520
+ bool lazy,
1521
+ const char ** trigger_words,
1522
+ size_t num_trigger_words,
1523
+ const llama_token * trigger_tokens,
1524
+ size_t num_trigger_tokens,
1525
+ const char ** trigger_patterns,
1526
+ size_t num_trigger_patterns) {
1527
+ auto * ctx = new llama_sampler_grammar;
1528
+
1529
+ if (grammar_str != nullptr && grammar_str[0] != '\0') {
1530
+ // TODO: remove trigger_words support.
1531
+ if (trigger_words != nullptr && num_trigger_words > 0) {
1532
+ LM_GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
1533
+ std::string trigger_pattern("[\\s\\S]*?(");
1534
+ for (size_t i = 0; i < num_trigger_words; ++i) {
1535
+ static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
1536
+ if (i > 0) {
1537
+ trigger_pattern += "|";
1538
+ }
1539
+ trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
1540
+ }
1541
+ trigger_pattern += ")[\\s\\S]*";
1542
+ auto trigger_pattern_c = trigger_pattern.c_str();
1543
+ trigger_patterns = &trigger_pattern_c;
1544
+ num_trigger_patterns = 1;
1545
+ }
1546
+ *ctx = {
1547
+ /* .vocab = */ vocab,
1548
+ /* .grammar_str = */ grammar_str,
1549
+ /* .grammar_root = */ grammar_root,
1550
+ /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
1551
+ };
1552
+ } else {
1553
+ *ctx = {
1554
+ /* .vocab = */ vocab,
1555
+ /* .grammar_str = */ {},
1556
+ /* .grammar_root = */ {},
1557
+ /* .grammar = */ nullptr,
1558
+ };
1559
+ }
1560
+
1561
+ return llama_sampler_init(
1562
+ /* .iface = */ &llama_sampler_grammar_i,
1563
+ /* .ctx = */ ctx
1564
+ );
1565
+ }
1566
+
1567
+ struct llama_sampler * llama_sampler_init_grammar(
1568
+ const struct llama_vocab * vocab,
1569
+ const char * grammar_str,
1570
+ const char * grammar_root) {
1571
+ return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
1572
+ }
1573
+
1574
+ struct llama_sampler * llama_sampler_init_grammar_lazy(
1575
+ const struct llama_vocab * vocab,
1576
+ const char * grammar_str,
1577
+ const char * grammar_root,
1578
+ const char ** trigger_words,
1579
+ size_t num_trigger_words,
1580
+ const llama_token * trigger_tokens,
1581
+ size_t num_trigger_tokens) {
1582
+ return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
1583
+ }
1584
+
1585
+ struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
1586
+ const struct llama_vocab * vocab,
1587
+ const char * grammar_str,
1588
+ const char * grammar_root,
1589
+ const char ** trigger_patterns,
1590
+ size_t num_trigger_patterns,
1591
+ const llama_token * trigger_tokens,
1592
+ size_t num_trigger_tokens) {
1593
+ return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
1594
+ }
1595
+
1596
+ // penalties
1597
+
1598
+ struct llama_sampler_penalties {
1599
+ const int32_t penalty_last_n;
1600
+ const float penalty_repeat;
1601
+ const float penalty_freq;
1602
+ const float penalty_present;
1603
+
1604
+ ring_buffer<llama_token> prev;
1605
+
1606
+ // a frequency map to count token occurrences
1607
+ std::unordered_map<llama_token, int> token_count;
1608
+ };
1609
+
1610
+ static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1611
+ return "penalties";
1612
+ }
1613
+
1614
+ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1615
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1616
+ if (ctx->penalty_last_n == 0) {
1617
+ return;
1618
+ }
1619
+
1620
+ ctx->token_count[token]++;
1621
+
1622
+ // if the ring buffer is full, remove the oldest token
1623
+ if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
1624
+ const auto old = ctx->prev.front();
1625
+
1626
+ ctx->token_count[old]--;
1627
+ if (ctx->token_count[old] == 0) {
1628
+ ctx->token_count.erase(old);
1629
+ }
1630
+ }
1631
+
1632
+ ctx->prev.push_back(token);
1633
+
1634
+ #if 0
1635
+ // sanity check
1636
+ std::unordered_map<llama_token, int> tmp;
1637
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1638
+ tmp[ctx->prev.rat(i)]++;
1639
+ }
1640
+
1641
+ assert(ctx->token_count == tmp);
1642
+ #endif
1643
+ }
1644
+
1645
+ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1646
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1647
+
1648
+ if ((ctx->penalty_last_n == 0) ||
1649
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1650
+ return;
1651
+ }
1652
+
1653
+ // Apply frequency and presence penalties to the cur_p
1654
+ for (size_t i = 0; i < cur_p->size; ++i) {
1655
+ const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
1656
+ if (token_iter == ctx->token_count.end()) {
1657
+ continue;
1658
+ }
1659
+
1660
+ const int count = token_iter->second;
1661
+
1662
+ assert(count > 0 && count <= ctx->penalty_last_n);
1663
+
1664
+ // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1665
+ // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1666
+ if (cur_p->data[i].logit <= 0) {
1667
+ cur_p->data[i].logit *= ctx->penalty_repeat;
1668
+ } else {
1669
+ cur_p->data[i].logit /= ctx->penalty_repeat;
1670
+ }
1671
+
1672
+ cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
1673
+ }
1674
+
1675
+ cur_p->sorted = false;
1676
+ }
1677
+
1678
+ static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1679
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1680
+ ctx->prev.clear();
1681
+ ctx->token_count.clear();
1682
+ }
1683
+
1684
+ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1685
+ const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1686
+ auto * result = llama_sampler_init_penalties(
1687
+ ctx->penalty_last_n,
1688
+ ctx->penalty_repeat,
1689
+ ctx->penalty_freq,
1690
+ ctx->penalty_present);
1691
+
1692
+ // copy the state
1693
+ {
1694
+ auto * result_ctx = (llama_sampler_penalties *) result->ctx;
1695
+
1696
+ result_ctx->prev = ctx->prev;
1697
+ }
1698
+
1699
+ return result;
1700
+ }
1701
+
1702
+ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1703
+ delete (llama_sampler_penalties *) smpl->ctx;
1704
+ }
1705
+
1706
+ static struct llama_sampler_i llama_sampler_penalties_i = {
1707
+ /* .name = */ llama_sampler_penalties_name,
1708
+ /* .accept = */ llama_sampler_penalties_accept,
1709
+ /* .apply = */ llama_sampler_penalties_apply,
1710
+ /* .reset = */ llama_sampler_penalties_reset,
1711
+ /* .clone = */ llama_sampler_penalties_clone,
1712
+ /* .free = */ llama_sampler_penalties_free,
1713
+ };
1714
+
1715
+ struct llama_sampler * llama_sampler_init_penalties(
1716
+ int32_t penalty_last_n,
1717
+ float penalty_repeat,
1718
+ float penalty_freq,
1719
+ float penalty_present) {
1720
+ penalty_last_n = std::max(penalty_last_n, 0);
1721
+
1722
+ return llama_sampler_init(
1723
+ /* .iface = */ &llama_sampler_penalties_i,
1724
+ /* .ctx = */ new llama_sampler_penalties {
1725
+ /* .penalty_last_n = */ penalty_last_n,
1726
+ /* .penalty_repeat = */ penalty_repeat,
1727
+ /* .penalty_freq = */ penalty_freq,
1728
+ /* .penalty_present = */ penalty_present,
1729
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1730
+ /* .token_count = */ {},
1731
+ }
1732
+ );
1733
+ }
1734
+
1735
+ // top-n-sigma
1736
+
1737
+ struct llama_sampler_top_n_sigma {
1738
+ const float n;
1739
+ };
1740
+
1741
+ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
1742
+ return "top-n-sigma";
1743
+ }
1744
+
1745
+ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1746
+ const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1747
+
1748
+ // find max logit and calculate mean
1749
+ float max = cur_p->data[0].logit;
1750
+ float logits_sum = 0;
1751
+ for (size_t i = 0; i < cur_p->size; ++i) {
1752
+ if (cur_p->data[i].logit > max) {
1753
+ max = cur_p->data[i].logit;
1754
+ }
1755
+ logits_sum += cur_p->data[i].logit;
1756
+ }
1757
+ float mean = logits_sum/cur_p->size;
1758
+
1759
+ // calculate standard deviation
1760
+ float acc = 0;
1761
+ for (size_t i = 0; i < cur_p->size; ++i) {
1762
+ acc += pow(cur_p->data[i].logit - mean, 2);
1763
+ }
1764
+ float std = sqrt(acc/cur_p->size);
1765
+
1766
+ //apply mask
1767
+ for (size_t i = 0; i < cur_p->size; ++i) {
1768
+ if (cur_p->data[i].logit < max - (ctx->n * std)) {
1769
+ cur_p->data[i].logit = -INFINITY;
1770
+ }
1771
+ }
1772
+ llama_sampler_softmax_impl(cur_p);
1773
+ }
1774
+
1775
+ static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
1776
+ const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
1777
+ return llama_sampler_init_top_n_sigma(ctx->n);
1778
+ }
1779
+
1780
+ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1781
+ delete (llama_sampler_top_n_sigma *) smpl->ctx;
1782
+ }
1783
+
1784
+ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1785
+ /* .name = */ llama_sampler_top_n_sigma_name,
1786
+ /* .accept = */ nullptr,
1787
+ /* .apply = */ llama_sampler_top_n_sigma_apply,
1788
+ /* .reset = */ nullptr,
1789
+ /* .clone = */ llama_sampler_top_n_sigma_clone,
1790
+ /* .free = */ llama_sampler_top_n_sigma_free,
1791
+ };
1792
+
1793
+ struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
1794
+ return llama_sampler_init(
1795
+ /* .iface = */ &llama_sampler_top_n_sigma_i,
1796
+ /* .ctx = */ new llama_sampler_top_n_sigma {
1797
+ /* .n = */ n,
1798
+ }
1799
+ );
1800
+ }
1801
+
1802
+ // DRY
1803
+
1804
+ struct llama_sampler_dry {
1805
+ int32_t total_context_size;
1806
+
1807
+ const float dry_multiplier;
1808
+ const float dry_base;
1809
+ const int32_t dry_allowed_length;
1810
+ const int32_t dry_penalty_last_n;
1811
+
1812
+ std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
1813
+ std::vector<int> dry_repeat_count;
1814
+ std::unordered_map<llama_token, int> dry_max_token_repeat;
1815
+ ring_buffer<llama_token> last_tokens;
1816
+ };
1817
+
1818
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1819
+ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1820
+ for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
1821
+ std::string word = vocab.detokenize({token_id}, true);
1822
+ if (word.find(str) != std::string::npos) {
1823
+ token_sequences.emplace(token_id, std::vector<llama_token>());
1824
+ } else {
1825
+ size_t word_len = word.size();
1826
+ size_t str_len = str.size();
1827
+ size_t pos = -1;
1828
+ while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
1829
+ bool match = true;
1830
+ size_t i;
1831
+ for (i = 1; i < str_len && i + pos < word_len; ++i) {
1832
+ if (word[pos + i] != str[i]) {
1833
+ match = false;
1834
+ break;
1835
+ }
1836
+ }
1837
+ if (match) {
1838
+ std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
1839
+ if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1840
+ tokenization.resize(max_tail_len);
1841
+ }
1842
+
1843
+ // Ensure we don't already have a duplicate matching tokenization
1844
+ auto its = token_sequences.equal_range(token_id);
1845
+ bool found = false;
1846
+ for (auto it = its.first; it != its.second; ++it) {
1847
+ if (tokenization == it->second) {
1848
+ found = true;
1849
+ break;
1850
+ }
1851
+ }
1852
+ if (!found) {
1853
+ token_sequences.emplace(token_id, tokenization);
1854
+ }
1855
+ }
1856
+ }
1857
+ }
1858
+ }
1859
+ }
1860
+
1861
+ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
1862
+ return "dry";
1863
+ }
1864
+
1865
+ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
1866
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1867
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1868
+ return;
1869
+ }
1870
+
1871
+ ctx->last_tokens.push_back(token);
1872
+ }
1873
+
1874
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1875
+ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1876
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1877
+
1878
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1879
+ return;
1880
+ }
1881
+
1882
+ int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
1883
+ int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
1884
+
1885
+ if (last_n_repeat <= ctx->dry_allowed_length) {
1886
+ return;
1887
+ }
1888
+
1889
+ ctx->dry_repeat_count.assign(last_n_repeat, 0);
1890
+ ctx->dry_max_token_repeat.clear();
1891
+
1892
+ // Step 1: Look for restart sequences to limit the maximum repetition length.
1893
+ // Work backwards through the context looking for any token that begins a restart sequence.
1894
+ //
1895
+ // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
1896
+ // sequences that together comprise a restart sequence. This allows us to quickly check
1897
+ // whether each token is the head of a complete sequence. Most restart sequences are actually
1898
+ // a single token, and for these the "tail" is an empty vector.
1899
+ //
1900
+ // If the token is a "head", test all restart sequences that begin with this token
1901
+ // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
1902
+ // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
1903
+ // longest matching sequence (if any) is used to limit the maximum repetition length.
1904
+ //
1905
+ // Note that in the case case of a short sequence contained in a longer one, this might fail to
1906
+ // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
1907
+ // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
1908
+ // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
1909
+ //
1910
+ // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
1911
+ // have already clamped the maximum tail sequence length when generating `restart_sequences`.
1912
+ // With clamping, this scan is O(N) in the context length.
1913
+
1914
+ int rep_limit = last_n_repeat;
1915
+ for (int i = 0; i < last_n_repeat; ++i) {
1916
+ llama_token token = ctx->last_tokens.rat(i);
1917
+ auto its = ctx->dry_processed_breakers.equal_range(token);
1918
+ if (its.first == ctx->dry_processed_breakers.end()) {
1919
+ continue;
1920
+ }
1921
+ int longest_match = -1;
1922
+ for (auto it = its.first; it != its.second; ++it) {
1923
+ // Note that (*it) does not contain the head character, so seq_len will be
1924
+ // the restart sequence length minus 1.
1925
+ // In the common case of a single-token restart sequence, (*it) will be empty
1926
+ // and we will trivially match.
1927
+ int seq_len = (int)it->second.size();
1928
+ if (seq_len > longest_match && seq_len <= (int)i) {
1929
+ bool match = true;
1930
+ for (int offset = 0; offset < seq_len; ++offset) {
1931
+ // The -1 when indexing `last_tokens` is because we already matched the head.
1932
+ if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
1933
+ match = false;
1934
+ break;
1935
+ }
1936
+ }
1937
+ if (match) {
1938
+ longest_match = seq_len;
1939
+ }
1940
+ }
1941
+ }
1942
+ if (longest_match >= 0) {
1943
+ // We found a restart sequence starting `i` tokens from the end and continuing for
1944
+ // `longest_match` tokens.
1945
+ rep_limit = i - longest_match;
1946
+ break;
1947
+ }
1948
+ }
1949
+ if (rep_limit < ctx->dry_allowed_length) {
1950
+ return;
1951
+ }
1952
+
1953
+ // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
1954
+ // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
1955
+ // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
1956
+ //
1957
+ // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
1958
+ // https://ivanyu.me/blog/2014/10/15/z-algorithm/
1959
+ //
1960
+ // The code below is adapted from the public domain implementation by the same author here:
1961
+ // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
1962
+ //
1963
+ // Example:
1964
+ // Last N tokens: a b c c b c y a b c
1965
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1966
+ // ^
1967
+ // This `3` means that the last three tokens of the context (a b c) also appear here.
1968
+ //
1969
+ // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
1970
+ // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
1971
+ // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
1972
+ // ensure that the inner while loops only examine each token in the context once as the outer
1973
+ // for loop iterates over the context.
1974
+
1975
+ {
1976
+ const int last = last_n_repeat - 1;
1977
+ int rt = 0, lt = 0;
1978
+
1979
+ for (int k = 1; k < last_n_repeat; ++k) {
1980
+ if (k > rt) {
1981
+ // If k is outside the current Z-box, do naive computation.
1982
+ int n = 0;
1983
+ while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
1984
+ ++n;
1985
+ }
1986
+ ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1987
+ if (n > 0) {
1988
+ lt = k;
1989
+ rt = k + n - 1;
1990
+ }
1991
+ } else {
1992
+ // If k is inside the current Z-box, consider two cases.
1993
+
1994
+ int p = k - lt; // Pair index.
1995
+ int right_part_len = rt - k + 1;
1996
+
1997
+ if (ctx->dry_repeat_count[last - p] < right_part_len) {
1998
+ int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
1999
+ ctx->dry_repeat_count[last - k] = n;
2000
+ } else {
2001
+ int i = rt + 1;
2002
+ while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
2003
+ i += 1;
2004
+ }
2005
+
2006
+ int n = std::min(i - k, rep_limit);
2007
+ ctx->dry_repeat_count[last - k] = n;
2008
+ lt = k;
2009
+ rt = i - 1;
2010
+ }
2011
+ }
2012
+ }
2013
+ }
2014
+
2015
+ // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
2016
+ // that would be generated by emitting each new token that would extend a sequence.
2017
+ //
2018
+ // Following the same example as above:
2019
+ // Last N tokens: a b c c b c y a b c
2020
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
2021
+ //
2022
+ // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
2023
+ // c: 3 -> 4 (from `a b c` to `a b c c`)
2024
+ // b: 1 -> 2 (from `c` to `c b`)
2025
+ // y: 2 -> 3 (from `b c` to `b c y`)
2026
+
2027
+ for (int i = 0; i < last_n_repeat - 1; ++i) {
2028
+ int repeat_len = ctx->dry_repeat_count[i];
2029
+ if (repeat_len >= ctx->dry_allowed_length) {
2030
+ // This token ends a repeat, so the next token would continue one.
2031
+ // By convention, the value of `repeat_len` only includes the tokens currently
2032
+ // in the context, not the new token that would be added.
2033
+ llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
2034
+ // Track the maximum sequence ending in this token.
2035
+ const auto& it = ctx->dry_max_token_repeat.find(token);
2036
+ if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
2037
+ ctx->dry_max_token_repeat[token] = repeat_len;
2038
+ }
2039
+ }
2040
+ }
2041
+
2042
+ // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
2043
+
2044
+ // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
2045
+ // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
2046
+ const float FLOAT_MAX_LOG = 88.7228391f;
2047
+ int max_exponent = 0;
2048
+ if (ctx->dry_base > 1.000001f) {
2049
+ max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
2050
+ }
2051
+
2052
+ for (size_t i = 0; i < cur_p->size; ++i) {
2053
+ const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
2054
+ if (af_kvp != ctx->dry_max_token_repeat.end()) {
2055
+ // Check all sequence breakers starting with this token
2056
+ auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
2057
+ bool is_single_token_breaker = false;
2058
+
2059
+ for (auto it = range.first; it != range.second; ++it) {
2060
+ if (it->second.empty()) {
2061
+ is_single_token_breaker = true;
2062
+ break;
2063
+ }
2064
+ }
2065
+
2066
+ // Apply penalty only if it's not a single-token sequence breaker
2067
+ if (!is_single_token_breaker) {
2068
+ int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
2069
+ if (max_exponent > 0 && repeat_exp > max_exponent) {
2070
+ repeat_exp = max_exponent;
2071
+ }
2072
+ float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
2073
+ cur_p->data[i].logit -= penalty;
2074
+ }
2075
+ }
2076
+ }
2077
+
2078
+ cur_p->sorted = false;
2079
+ }
2080
+
2081
+ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
2082
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
2083
+ ctx->last_tokens.clear();
2084
+ ctx->dry_repeat_count.clear();
2085
+ ctx->dry_max_token_repeat.clear();
2086
+ }
2087
+
2088
+ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
2089
+ const auto * ctx = (llama_sampler_dry *) smpl->ctx;
2090
+
2091
+ llama_vocab dummy_vocab;
2092
+
2093
+ // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
2094
+ auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
2095
+
2096
+ // Copy the state, including the processed breakers
2097
+ {
2098
+ auto * result_ctx = (llama_sampler_dry *) result->ctx;
2099
+ result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
2100
+ result_ctx->dry_repeat_count = ctx->dry_repeat_count;
2101
+ result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
2102
+ result_ctx->last_tokens = ctx->last_tokens;
2103
+ }
2104
+
2105
+ return result;
2106
+ }
2107
+
2108
+ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
2109
+ delete (llama_sampler_dry *) smpl->ctx;
2110
+ }
2111
+
2112
+ static struct llama_sampler_i llama_sampler_dry_i = {
2113
+ /* .name = */ llama_sampler_dry_name,
2114
+ /* .accept = */ llama_sampler_dry_accept,
2115
+ /* .apply = */ llama_sampler_dry_apply,
2116
+ /* .reset = */ llama_sampler_dry_reset,
2117
+ /* .clone = */ llama_sampler_dry_clone,
2118
+ /* .free = */ llama_sampler_dry_free,
2119
+ };
2120
+
2121
+ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
2122
+ int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
2123
+ std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
2124
+ const int MAX_CHAR_LEN = 40;
2125
+ const int MAX_SEQ_LEN = 20;
2126
+
2127
+ const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2128
+
2129
+ if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2130
+ // Process sequence breakers
2131
+ for (size_t i = 0; i < num_breakers; ++i) {
2132
+ if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
2133
+ LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
2134
+ continue;
2135
+ }
2136
+
2137
+ std::string sequence_break(seq_breakers[i]);
2138
+ if (sequence_break.empty()) {
2139
+ LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
2140
+ continue;
2141
+ }
2142
+
2143
+ if (sequence_break.size() > MAX_CHAR_LEN) {
2144
+ LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
2145
+ sequence_break.resize(MAX_CHAR_LEN);
2146
+ }
2147
+
2148
+ get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
2149
+ }
2150
+ }
2151
+
2152
+ return llama_sampler_init(
2153
+ /* .iface = */ &llama_sampler_dry_i,
2154
+ /* .ctx = */ new llama_sampler_dry {
2155
+ /* .total_context_size = */ context_size,
2156
+ /* .dry_multiplier = */ dry_multiplier,
2157
+ /* .dry_base = */ dry_base,
2158
+ /* .dry_allowed_length = */ dry_allowed_length,
2159
+ /* .dry_penalty_last_n = */ dry_penalty_last_n,
2160
+ /* .dry_processed_breakers = */ std::move(processed_breakers),
2161
+ /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
2162
+ /* .dry_max_token_repeat = */ {},
2163
+ /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2164
+ }
2165
+ );
2166
+ }
2167
+
2168
+ // wrapper for test-sampling.cpp
2169
+ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
2170
+ llama_vocab dummy_vocab;
2171
+ auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
2172
+ auto * ctx = (llama_sampler_dry *) result->ctx;
2173
+
2174
+ // Process the token-based sequence breakers
2175
+ ctx->dry_processed_breakers.clear();
2176
+ if (seq_breakers.empty()) {
2177
+ LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
2178
+ } else {
2179
+ for (const auto& breaker : seq_breakers) {
2180
+ if (breaker.empty()) {
2181
+ LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
2182
+ continue;
2183
+ }
2184
+ llama_token head_token = breaker[0];
2185
+ std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
2186
+ ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
2187
+ }
2188
+
2189
+ if (ctx->dry_processed_breakers.empty()) {
2190
+ LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
2191
+ }
2192
+ }
2193
+
2194
+ return result;
2195
+ }
2196
+
2197
+ // logit-bias
2198
+
2199
+ struct llama_sampler_logit_bias {
2200
+ const int32_t n_vocab;
2201
+
2202
+ const std::vector<llama_logit_bias> logit_bias;
2203
+
2204
+ std::vector<llama_logit_bias> to_search;
2205
+ };
2206
+
2207
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
2208
+ return "logit-bias";
2209
+ }
2210
+
2211
+ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2212
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
2213
+
2214
+ if (ctx->logit_bias.empty()) {
2215
+ return;
2216
+ }
2217
+
2218
+ ctx->to_search.clear();
2219
+
2220
+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
2221
+ for (const auto & lb : ctx->logit_bias) {
2222
+ if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
2223
+ cur_p->data[lb.token].logit += lb.bias;
2224
+ } else {
2225
+ ctx->to_search.push_back(lb);
2226
+ }
2227
+ }
2228
+
2229
+ if (ctx->to_search.empty()) {
2230
+ return;
2231
+ }
2232
+
2233
+ // search for the remaining candidates that were not found in the previous step
2234
+ for (size_t i = 0; i < cur_p->size; ++i) {
2235
+ for (const auto & lb : ctx->to_search) {
2236
+ if (cur_p->data[i].id == lb.token) {
2237
+ cur_p->data[i].logit += lb.bias;
2238
+ break;
2239
+ }
2240
+ }
2241
+ }
2242
+ }
2243
+
2244
+ static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
2245
+ const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
2246
+ return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
2247
+ }
2248
+
2249
+ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2250
+ delete (llama_sampler_logit_bias *) smpl->ctx;
2251
+ }
2252
+
2253
+ static struct llama_sampler_i llama_sampler_logit_bias_i = {
2254
+ /* .name = */ llama_sampler_logit_bias_name,
2255
+ /* .accept = */ nullptr,
2256
+ /* .apply = */ llama_sampler_logit_bias_apply,
2257
+ /* .reset = */ nullptr,
2258
+ /* .clone = */ llama_sampler_logit_bias_clone,
2259
+ /* .free = */ llama_sampler_logit_bias_free,
2260
+ };
2261
+
2262
+ struct llama_sampler * llama_sampler_init_logit_bias(
2263
+ int32_t n_vocab,
2264
+ int32_t n_logit_bias,
2265
+ const llama_logit_bias * logit_bias) {
2266
+ return llama_sampler_init(
2267
+ /* .iface = */ &llama_sampler_logit_bias_i,
2268
+ /* .ctx = */ new llama_sampler_logit_bias {
2269
+ /* .n_vocab = */ n_vocab,
2270
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2271
+ /* .to_search = */ {},
2272
+ }
2273
+ );
2274
+ }
2275
+
2276
+ // infill
2277
+
2278
+ //#define LM_GGML_DEBUG_SAMPLER_INFILL
2279
+
2280
+ struct llama_sampler_infill {
2281
+ const struct llama_vocab * vocab;
2282
+
2283
+ std::vector<char> buf0;
2284
+ std::vector<char> buf1;
2285
+ };
2286
+
2287
+ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
2288
+ return "infill";
2289
+ }
2290
+
2291
+ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2292
+ auto * ctx = (llama_sampler_infill *) smpl->ctx;
2293
+
2294
+ llama_sampler_softmax_impl(cur_p);
2295
+
2296
+ #if defined(LM_GGML_DEBUG_SAMPLER_INFILL)
2297
+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
2298
+ #else
2299
+ #define LOG_DBG_CUR(...)
2300
+ #endif
2301
+
2302
+ for (size_t i = 0; i < cur_p->size; ++i) {
2303
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2304
+ }
2305
+
2306
+ float p_txt_sum = 0.0f;
2307
+ float p_eog_sum = 0.0f;
2308
+
2309
+ for (size_t i = 0; i < cur_p->size; ++i) {
2310
+ if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2311
+ p_eog_sum += cur_p->data[i].p;
2312
+ } else {
2313
+ p_txt_sum += cur_p->data[i].p;
2314
+ }
2315
+ }
2316
+
2317
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; LM_GGML_UNUSED(rat);
2318
+
2319
+ LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
2320
+
2321
+ if (3*p_eog_sum*cur_p->size > p_txt_sum) {
2322
+ LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
2323
+
2324
+ // keep just the EOG tokens
2325
+ const auto size_org = cur_p->size;
2326
+
2327
+ cur_p->size = 0;
2328
+
2329
+ float p_sum = 0.0f;
2330
+
2331
+ for (size_t i = 0; i < size_org; ++i) {
2332
+ if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2333
+ p_sum += cur_p->data[i].p;
2334
+
2335
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2336
+ }
2337
+ }
2338
+
2339
+ // normalize probs
2340
+ for (size_t i = 0; i < cur_p->size; ++i) {
2341
+ cur_p->data[i].p /= p_sum;
2342
+ }
2343
+
2344
+ return;
2345
+ }
2346
+
2347
+ size_t n_combined = 0; LM_GGML_UNUSED(n_combined);
2348
+
2349
+ // combine tokens with common prefix
2350
+ for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
2351
+ for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
2352
+ if (cur_p->data[i0].logit == -INFINITY) {
2353
+ break;
2354
+ }
2355
+
2356
+ if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
2357
+ continue;
2358
+ }
2359
+
2360
+ int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2361
+ if (len0 < 0) {
2362
+ ctx->buf0.resize(len0);
2363
+ len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2364
+ assert(len0 > 0);
2365
+ }
2366
+
2367
+ int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2368
+ if (len1 < 0) {
2369
+ ctx->buf1.resize(len1);
2370
+ len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2371
+ assert(len1 > 0);
2372
+ }
2373
+
2374
+ // token i0 is a prefix of token i1
2375
+ if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
2376
+ int dst = i0;
2377
+ int src = i1;
2378
+
2379
+ // merge into the token with higher probability
2380
+ if (cur_p->data[i1].p > cur_p->data[i0].p) {
2381
+ std::swap(dst, src);
2382
+ }
2383
+
2384
+ cur_p->data[dst].p += cur_p->data[src].p;
2385
+ cur_p->data[src].logit = -INFINITY;
2386
+ cur_p->data[src].p = 0.0f;
2387
+
2388
+ n_combined++;
2389
+ }
2390
+ }
2391
+ }
2392
+
2393
+ size_t n_non_eog = 0;
2394
+
2395
+ size_t size_org = cur_p->size;
2396
+
2397
+ float p_sum = 0.0f;
2398
+ float thold = 0.2f;
2399
+
2400
+ cur_p->size = 0;
2401
+
2402
+ LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2403
+
2404
+ for (size_t i = 0; i < size_org; ++i) {
2405
+ const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2406
+
2407
+ if (cur_p->data[i].p < thold && !is_eog) {
2408
+ continue;
2409
+ }
2410
+
2411
+ if (!is_eog) {
2412
+ ++n_non_eog;
2413
+ }
2414
+
2415
+ p_sum += cur_p->data[i].p;
2416
+
2417
+ // keep this token
2418
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2419
+ }
2420
+
2421
+ LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2422
+
2423
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2424
+ if (n_non_eog == 0) {
2425
+ cur_p->size = 1;
2426
+ cur_p->data[0].id = ctx->vocab->token_eot();
2427
+ cur_p->data[0].logit = 1.0f;
2428
+
2429
+ return;
2430
+ }
2431
+
2432
+ // normalize probs
2433
+ for (size_t i = 0; i < cur_p->size; ++i) {
2434
+ cur_p->data[i].p /= p_sum;
2435
+
2436
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2437
+ }
2438
+
2439
+ size_org = cur_p->size;
2440
+ p_sum = 0.0f;
2441
+ thold = 1.0/(n_non_eog + 1);
2442
+
2443
+ cur_p->size = 0;
2444
+
2445
+ LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2446
+
2447
+ for (size_t i = 0; i < size_org; ++i) {
2448
+ const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2449
+
2450
+ if (cur_p->data[i].p < thold && !is_eog) {
2451
+ continue;
2452
+ }
2453
+
2454
+ p_sum += cur_p->data[i].p;
2455
+
2456
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2457
+ }
2458
+
2459
+ // normalize probs
2460
+ for (size_t i = 0; i < cur_p->size; ++i) {
2461
+ cur_p->data[i].p /= p_sum;
2462
+
2463
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2464
+ }
2465
+
2466
+ #undef LOG_DBG_CUR
2467
+ }
2468
+
2469
+ static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2470
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2471
+ return llama_sampler_init_infill(ctx->vocab);
2472
+ }
2473
+
2474
+ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2475
+ delete (llama_sampler_infill *) smpl->ctx;
2476
+ }
2477
+
2478
+ static struct llama_sampler_i llama_sampler_infill_i = {
2479
+ /* .name = */ llama_sampler_infill_name,
2480
+ /* .accept = */ nullptr,
2481
+ /* .apply = */ llama_sampler_infill_apply,
2482
+ /* .reset = */ nullptr,
2483
+ /* .clone = */ llama_sampler_infill_clone,
2484
+ /* .free = */ llama_sampler_infill_free,
2485
+ };
2486
+
2487
+ struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2488
+ return llama_sampler_init(
2489
+ /* .iface = */ &llama_sampler_infill_i,
2490
+ /* .ctx = */ new llama_sampler_infill {
2491
+ /* .vocab = */ vocab,
2492
+ /* .buf0 = */ std::vector<char>(512),
2493
+ /* .buf1 = */ std::vector<char>(512),
2494
+ }
2495
+ );
2496
+ }
2497
+
2498
+ // utils
2499
+
2500
+ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2501
+ if (smpl->iface == &llama_sampler_dist_i) {
2502
+ return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
2503
+ }
2504
+
2505
+ if (smpl->iface == &llama_sampler_mirostat_i) {
2506
+ return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
2507
+ }
2508
+
2509
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
2510
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
2511
+ }
2512
+
2513
+ if (smpl->iface == &llama_sampler_chain_i) {
2514
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2515
+ for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2516
+ const uint32_t seed = llama_sampler_get_seed(*it);
2517
+ if (seed != LLAMA_DEFAULT_SEED) {
2518
+ return seed;
2519
+ }
2520
+ }
2521
+ }
2522
+
2523
+ return LLAMA_DEFAULT_SEED;
2524
+ }
2525
+
2526
+ // perf
2527
+
2528
+ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
2529
+ struct llama_perf_sampler_data data = {};
2530
+
2531
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2532
+ LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2533
+ }
2534
+
2535
+ const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
2536
+
2537
+ data.t_sample_ms = 1e-3 * ctx->t_sample_us;
2538
+ data.n_sample = std::max(0, ctx->n_sample);
2539
+
2540
+ return data;
2541
+ }
2542
+
2543
+ void llama_perf_sampler_print(const struct llama_sampler * chain) {
2544
+ const auto data = llama_perf_sampler(chain);
2545
+
2546
+ LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2547
+ __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
2548
+ }
2549
+
2550
+ void llama_perf_sampler_reset(struct llama_sampler * chain) {
2551
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2552
+ LM_GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2553
+ }
2554
+
2555
+ auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2556
+
2557
+ ctx->t_sample_us = ctx->n_sample = 0;
2558
+ }