@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,12 +1,53 @@
1
1
  #include "llama-sampling.h"
2
2
 
3
+ #include "llama-vocab.h"
4
+ #include "llama-grammar.h"
5
+
3
6
  #include <algorithm>
7
+ #include <cassert>
8
+ #include <cfloat>
9
+ #include <chrono>
10
+ #include <cmath>
11
+ #include <cstdlib>
4
12
  #include <cstring>
5
13
  #include <ctime>
6
- #include <cfloat>
7
14
  #include <numeric>
15
+ #include <random>
8
16
  #include <unordered_map>
9
17
 
18
+ static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
19
+ // iterator for the probabilities
20
+ #ifdef __GNUC__
21
+ #pragma GCC diagnostic push
22
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
23
+ #endif
24
+
25
+ struct probs_iterator {
26
+ typedef std::input_iterator_tag iterator_category;
27
+ typedef float value_type;
28
+ typedef float * pointer;
29
+ typedef float & reference;
30
+ typedef ptrdiff_t difference_type;
31
+
32
+ const llama_token_data * data;
33
+
34
+ bool operator==(const probs_iterator & other) const { return data == other.data; }
35
+ bool operator!=(const probs_iterator & other) const { return data != other.data; }
36
+ const float & operator*() const { return data->p; }
37
+ probs_iterator & operator++() { ++data; return *this; }
38
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
39
+ };
40
+
41
+ #ifdef __GNUC__
42
+ #pragma GCC diagnostic pop
43
+ #endif
44
+
45
+ std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
46
+
47
+ return dist(rng);
48
+ }
49
+
50
+ /*
10
51
  static void llama_log_softmax(float * array, size_t size) {
11
52
  float max_l = *std::max_element(array, array + size);
12
53
  float sum = 0.f;
@@ -20,79 +61,89 @@ static void llama_log_softmax(float * array, size_t size) {
20
61
  array[i] = logf(array[i] / sum);
21
62
  }
22
63
  }
64
+ */
65
+
66
+ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
67
+ if (temp <= 0.0f) {
68
+ // find the token with the highest logit and set the rest to -inf
69
+ size_t max_i = 0;
70
+ float max_l = cur_p->data[0].logit;
71
+
72
+ for (size_t i = 1; i < cur_p->size; ++i) {
73
+ if (cur_p->data[i ].logit > max_l) {
74
+ cur_p->data[max_i].logit = -INFINITY;
75
+ max_i = i;
76
+ max_l = cur_p->data[i].logit;
77
+ } else {
78
+ cur_p->data[i].logit = -INFINITY;
79
+ }
80
+ }
23
81
 
24
- void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
25
- if (seed == LLAMA_DEFAULT_SEED) {
26
- seed = time(NULL);
82
+ return;
27
83
  }
28
84
 
29
- smpl->rng.seed(seed);
85
+ for (size_t i = 0; i < cur_p->size; ++i) {
86
+ cur_p->data[i].logit /= temp;
87
+ }
30
88
  }
31
89
 
32
- void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
33
- GGML_ASSERT(candidates->size > 0);
34
-
35
- const int64_t t_start_sample_us = ggml_time_us();
90
+ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
91
+ GGML_ASSERT(cur_p->size > 0);
36
92
 
37
93
  // Sort the logits in descending order
38
- if (!candidates->sorted) {
39
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
94
+ if (!cur_p->sorted) {
95
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
40
96
  return a.logit > b.logit;
41
97
  });
42
- candidates->sorted = true;
98
+ cur_p->sorted = true;
43
99
  }
44
100
 
45
- float max_l = candidates->data[0].logit;
101
+ float max_l = cur_p->data[0].logit;
46
102
  float cum_sum = 0.0f;
47
- for (size_t i = 0; i < candidates->size; ++i) {
48
- float p = expf(candidates->data[i].logit - max_l);
49
- candidates->data[i].p = p;
103
+
104
+ for (size_t i = 0; i < cur_p->size; ++i) {
105
+ float p = expf(cur_p->data[i].logit - max_l);
106
+ cur_p->data[i].p = p;
50
107
  cum_sum += p;
51
108
  }
52
- for (size_t i = 0; i < candidates->size; ++i) {
53
- candidates->data[i].p /= cum_sum;
54
- }
55
109
 
56
- if (smpl) {
57
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
110
+ for (size_t i = 0; i < cur_p->size; ++i) {
111
+ cur_p->data[i].p /= cum_sum;
58
112
  }
59
113
  }
60
114
 
61
- void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
62
- // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
63
- // if (k >= (int32_t)candidates->size) {
115
+ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
116
+ // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
117
+ // if (k >= (int32_t)cur_p->size) {
64
118
  // return;
65
119
  // }
66
120
 
67
- const int64_t t_start_sample_us = ggml_time_us();
68
-
69
121
  if (k <= 0) {
70
- k = candidates->size;
122
+ k = cur_p->size;
71
123
  }
72
124
 
73
- k = std::max(k, (int) min_keep);
74
- k = std::min(k, (int) candidates->size);
125
+ k = std::min(k, (int) cur_p->size);
75
126
 
76
127
  // Sort scores in descending order
77
- if (!candidates->sorted) {
128
+ if (!cur_p->sorted) {
78
129
  auto comp = [](const llama_token_data & a, const llama_token_data & b) {
79
130
  return a.logit > b.logit;
80
131
  };
81
132
  if (k <= 128) {
82
- std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
133
+ std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
83
134
  } else {
84
135
  constexpr int nbuckets = 128;
85
136
  constexpr float bucket_low = -10.0f;
86
137
  constexpr float bucket_high = 10.0f;
87
138
  constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
88
- constexpr float bucker_inter = -bucket_low * bucket_scale;
139
+ constexpr float bucket_inter = -bucket_low * bucket_scale;
89
140
 
90
- std::vector<int> bucket_idx(candidates->size);
141
+ std::vector<int> bucket_idx(cur_p->size);
91
142
  std::vector<int> histo(nbuckets, 0);
92
143
 
93
- for (int i = 0; i < (int)candidates->size; ++i) {
94
- const float val = candidates->data[i].logit;
95
- int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
144
+ for (int i = 0; i < (int)cur_p->size; ++i) {
145
+ const float val = cur_p->data[i].logit;
146
+ int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
96
147
  ib = std::max(0, std::min(nbuckets-1, ib));
97
148
  bucket_idx[i] = ib;
98
149
  ++histo[ib];
@@ -101,20 +152,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
101
152
  int ib = nbuckets - 1;
102
153
  for ( ; ib >= 0; --ib) {
103
154
  nhave += histo[ib];
104
- if (nhave >= k) break;
155
+ if (nhave >= k) {
156
+ break;
157
+ }
105
158
  }
106
159
  std::vector<llama_token_data> tmp_tokens(nhave);
107
- auto ptr = tmp_tokens.data();
160
+ auto * ptr = tmp_tokens.data();
108
161
  std::vector<llama_token_data*> bucket_ptrs;
109
162
  bucket_ptrs.reserve(nbuckets - ib);
110
163
  for (int j = nbuckets - 1; j >= ib; --j) {
111
164
  bucket_ptrs.push_back(ptr);
112
165
  ptr += histo[j];
113
166
  }
114
- for (int i = 0; i < (int)candidates->size; ++i) {
167
+ for (int i = 0; i < (int)cur_p->size; ++i) {
115
168
  int j = bucket_idx[i];
116
169
  if (j >= ib) {
117
- *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
170
+ *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
118
171
  }
119
172
  }
120
173
 
@@ -127,196 +180,596 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
127
180
  }
128
181
  std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
129
182
 
130
- std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
183
+ std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
184
+
185
+ }
186
+ cur_p->sorted = true;
187
+ }
188
+ cur_p->size = k;
189
+ }
131
190
 
191
+ static uint32_t get_rng_seed(uint32_t seed) {
192
+ if (seed == LLAMA_DEFAULT_SEED) {
193
+ // use system clock if std::random_device is not a true RNG
194
+ static bool is_rd_prng = std::random_device().entropy() == 0;
195
+ if (is_rd_prng) {
196
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
132
197
  }
133
- candidates->sorted = true;
198
+ std::random_device rd;
199
+ return rd();
200
+ }
201
+ return seed;
202
+ }
203
+
204
+ // llama_sampler API
205
+
206
+ const char * llama_sampler_name(const struct llama_sampler * smpl) {
207
+ if (!smpl->iface) {
208
+ return "(null)";
134
209
  }
135
- candidates->size = k;
136
210
 
137
- if (smpl) {
138
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
211
+ return smpl->iface->name(smpl);
212
+ }
213
+
214
+ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
215
+ if (smpl->iface->accept) {
216
+ smpl->iface->accept(smpl, token);
217
+ }
218
+ }
219
+
220
+ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
221
+ GGML_ASSERT(smpl->iface->apply);
222
+ smpl->iface->apply(smpl, cur_p);
223
+ }
224
+
225
+ void llama_sampler_reset(struct llama_sampler * smpl) {
226
+ if (smpl->iface->reset) {
227
+ smpl->iface->reset(smpl);
228
+ }
229
+ }
230
+
231
+ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
232
+ if (smpl->iface->clone) {
233
+ return smpl->iface->clone(smpl);
234
+ }
235
+
236
+ if (smpl->ctx == nullptr) {
237
+ return new llama_sampler {
238
+ /* .iface = */ smpl->iface,
239
+ /* .ctx = */ nullptr,
240
+ };
139
241
  }
242
+
243
+ GGML_ABORT("the sampler does not support cloning");
140
244
  }
141
245
 
142
- void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143
- if (p >= 1.0f) {
246
+ void llama_sampler_free(struct llama_sampler * smpl) {
247
+ if (smpl == nullptr) {
144
248
  return;
145
249
  }
146
250
 
147
- llama_sample_softmax_impl(smpl, candidates);
251
+ if (smpl->iface->free) {
252
+ smpl->iface->free(smpl);
253
+ }
254
+
255
+ delete smpl;
256
+ }
257
+
258
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
259
+ const auto * logits = llama_get_logits_ith(ctx, idx);
260
+
261
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
262
+
263
+ // TODO: do not allocate each time
264
+ std::vector<llama_token_data> cur;
265
+ cur.reserve(n_vocab);
266
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
267
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
268
+ }
269
+
270
+ llama_token_data_array cur_p = {
271
+ /* .data = */ cur.data(),
272
+ /* .size = */ cur.size(),
273
+ /* .selected = */ -1,
274
+ /* .sorted = */ false,
275
+ };
276
+
277
+ llama_sampler_apply(smpl, &cur_p);
278
+
279
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
280
+
281
+ auto token = cur_p.data[cur_p.selected].id;
282
+
283
+ llama_sampler_accept(smpl, token);
284
+
285
+ return token;
286
+ }
287
+
288
+ // sampler chain
289
+
290
+ static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
291
+ return "chain";
292
+ }
293
+
294
+ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
295
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
296
+
297
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
298
+
299
+ for (auto * smpl : chain->samplers) {
300
+ llama_sampler_accept(smpl, token);
301
+ }
302
+
303
+ chain->n_sample++;
304
+ }
305
+
306
+ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
307
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
308
+
309
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
310
+
311
+ for (auto * smpl : chain->samplers) {
312
+ llama_sampler_apply(smpl, cur_p);
313
+ }
314
+ }
315
+
316
+ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
317
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
318
+
319
+ for (auto * smpl : chain->samplers) {
320
+ llama_sampler_reset(smpl);
321
+ }
322
+
323
+ chain->t_sample_us = 0;
324
+ chain->n_sample = 0;
325
+ }
326
+
327
+ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
328
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
329
+
330
+ auto * result = llama_sampler_chain_init(chain_src->params);
331
+
332
+ for (auto * smpl : chain_src->samplers) {
333
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
334
+ }
335
+
336
+ return result;
337
+ }
338
+
339
+ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
340
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
341
+
342
+ for (auto * smpl : chain->samplers) {
343
+ llama_sampler_free(smpl);
344
+ }
345
+
346
+ delete chain;
347
+ }
348
+
349
+ static struct llama_sampler_i llama_sampler_chain_i = {
350
+ /* .name = */ llama_sampler_chain_name,
351
+ /* .accept = */ llama_sampler_chain_accept,
352
+ /* .apply = */ llama_sampler_chain_apply,
353
+ /* .reset = */ llama_sampler_chain_reset,
354
+ /* .clone = */ llama_sampler_chain_clone,
355
+ /* .free = */ llama_sampler_chain_free,
356
+ };
357
+
358
+ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
359
+ return new llama_sampler {
360
+ /* .iface = */ &llama_sampler_chain_i,
361
+ /* .ctx = */ new llama_sampler_chain {
362
+ /* .params = */ params,
363
+ /* .samplers = */ {},
364
+ /* .t_sample_us = */ 0,
365
+ /* .n_sample = */ 0,
366
+ },
367
+ };
368
+ }
369
+
370
+ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
371
+ auto * p = (llama_sampler_chain *) chain->ctx;
372
+ p->samplers.push_back(smpl);
373
+ }
374
+
375
+ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
376
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
377
+
378
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
379
+ return nullptr;
380
+ }
381
+
382
+ return p->samplers[i];
383
+ }
384
+
385
+ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
386
+ auto * p = (llama_sampler_chain *) chain->ctx;
387
+
388
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
389
+ return nullptr;
390
+ }
391
+
392
+ auto * result = p->samplers[i];
393
+ p->samplers.erase(p->samplers.begin() + i);
394
+
395
+ return result;
396
+ }
397
+
398
+ int llama_sampler_chain_n(const struct llama_sampler * chain) {
399
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
400
+
401
+ return p->samplers.size();
402
+ }
403
+
404
+ //
405
+ // samplers
406
+ //
407
+
408
+ // greedy
409
+
410
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
411
+ return "greedy";
412
+ }
413
+
414
+ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
415
+ cur_p->selected = 0;
416
+ for (size_t i = 1; i < cur_p->size; ++i) {
417
+ if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
418
+ cur_p->selected = i;
419
+ }
420
+ }
421
+ }
422
+
423
+ static struct llama_sampler_i llama_sampler_greedy_i = {
424
+ /* .name = */ llama_sampler_greedy_name,
425
+ /* .accept = */ nullptr,
426
+ /* .apply = */ llama_sampler_greedy_apply,
427
+ /* .reset = */ nullptr,
428
+ /* .clone = */ nullptr,
429
+ /* .free = */ nullptr,
430
+ };
431
+
432
+ struct llama_sampler * llama_sampler_init_greedy() {
433
+ return new llama_sampler {
434
+ /* .iface = */ &llama_sampler_greedy_i,
435
+ /* .ctx = */ nullptr,
436
+ };
437
+ }
438
+
439
+ // dist
440
+
441
+ struct llama_sampler_dist {
442
+ const uint32_t seed;
443
+ uint32_t seed_cur;
444
+
445
+ std::mt19937 rng;
446
+ };
447
+
448
+ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
449
+ return "dist";
450
+ }
451
+
452
+ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
453
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
454
+
455
+ llama_sampler_softmax_impl(cur_p);
456
+
457
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
458
+ }
459
+
460
+ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
461
+ const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
462
+ auto * result = llama_sampler_init_dist(ctx->seed);
463
+
464
+ // copy the state
465
+ {
466
+ auto * result_ctx = (llama_sampler_dist *) result->ctx;
467
+
468
+ result_ctx->rng = ctx->rng;
469
+ }
470
+
471
+ return result;
472
+ }
473
+
474
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
475
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
476
+ ctx->seed_cur = get_rng_seed(ctx->seed);
477
+ ctx->rng.seed(ctx->seed_cur);
478
+ }
479
+
480
+ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
481
+ delete (llama_sampler_dist *) smpl->ctx;
482
+ }
483
+
484
+ static struct llama_sampler_i llama_sampler_dist_i = {
485
+ /* .name = */ llama_sampler_dist_name,
486
+ /* .accept = */ nullptr,
487
+ /* .apply = */ llama_sampler_dist_apply,
488
+ /* .reset = */ llama_sampler_dist_reset,
489
+ /* .clone = */ llama_sampler_dist_clone,
490
+ /* .free = */ llama_sampler_dist_free,
491
+ };
492
+
493
+ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
494
+ auto seed_cur = get_rng_seed(seed);
495
+ return new llama_sampler {
496
+ /* .iface = */ &llama_sampler_dist_i,
497
+ /* .ctx = */ new llama_sampler_dist {
498
+ /* .seed = */ seed,
499
+ /* .seed_cur = */ seed_cur,
500
+ /* .rng = */ std::mt19937(seed_cur),
501
+ },
502
+ };
503
+ }
504
+
505
+ // softmax
506
+
507
+ static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
508
+ return "softmax";
509
+ }
510
+
511
+ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
512
+ llama_sampler_softmax_impl(cur_p);
513
+ }
514
+
515
+ static struct llama_sampler_i llama_sampler_softmax_i = {
516
+ /* .name = */ llama_sampler_softmax_name,
517
+ /* .accept = */ nullptr,
518
+ /* .apply = */ llama_sampler_softmax_apply,
519
+ /* .reset = */ nullptr,
520
+ /* .clone = */ nullptr,
521
+ /* .free = */ nullptr,
522
+ };
523
+
524
+ struct llama_sampler * llama_sampler_init_softmax() {
525
+ return new llama_sampler {
526
+ /* .iface = */ &llama_sampler_softmax_i,
527
+ /* .ctx = */ nullptr,
528
+ };
529
+ }
530
+
531
+ // top-k
532
+
533
+ struct llama_sampler_top_k {
534
+ const int32_t k;
535
+ };
536
+
537
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
538
+ return "top-k";
539
+ }
540
+
541
+ static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
542
+ const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
543
+ llama_sampler_top_k_impl(cur_p, ctx->k);
544
+ }
545
+
546
+ static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
547
+ const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
548
+ return llama_sampler_init_top_k(ctx->k);
549
+ }
550
+
551
+ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
552
+ delete (llama_sampler_top_k *) smpl->ctx;
553
+ }
554
+
555
+ static struct llama_sampler_i llama_sampler_top_k_i = {
556
+ /* .name = */ llama_sampler_top_k_name,
557
+ /* .accept = */ nullptr,
558
+ /* .apply = */ llama_sampler_top_k_apply,
559
+ /* .reset = */ nullptr,
560
+ /* .clone = */ llama_sampler_top_k_clone,
561
+ /* .free = */ llama_sampler_top_k_free,
562
+ };
563
+
564
+ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
565
+ return new llama_sampler {
566
+ /* .iface = */ &llama_sampler_top_k_i,
567
+ /* .ctx = */ new llama_sampler_top_k {
568
+ /* .k = */ k,
569
+ },
570
+ };
571
+ }
572
+
573
+ // top-p
574
+
575
+ struct llama_sampler_top_p {
576
+ const float p;
577
+ const size_t min_keep;
578
+ };
579
+
580
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
581
+ return "top-p";
582
+ }
583
+
584
+ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
585
+ const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
586
+
587
+ if (ctx->p >= 1.0f) {
588
+ return;
589
+ }
148
590
 
149
- const int64_t t_start_sample_us = ggml_time_us();
591
+ llama_sampler_softmax_impl(cur_p);
150
592
 
151
593
  // Compute the cumulative probabilities
152
594
  float cum_sum = 0.0f;
153
- size_t last_idx = candidates->size;
595
+ size_t last_idx = cur_p->size;
154
596
 
155
- for (size_t i = 0; i < candidates->size; ++i) {
156
- cum_sum += candidates->data[i].p;
597
+ for (size_t i = 0; i < cur_p->size; ++i) {
598
+ cum_sum += cur_p->data[i].p;
157
599
 
158
600
  // Check if the running sum is at least p or if we have kept at least min_keep tokens
159
601
  // we set the last index to i+1 to indicate that the current iterate should be included in the set
160
- if (cum_sum >= p && i + 1 >= min_keep) {
602
+ if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
161
603
  last_idx = i + 1;
162
604
  break;
163
605
  }
164
606
  }
165
607
 
166
608
  // Resize the output vector to keep only the top-p tokens
167
- candidates->size = last_idx;
609
+ cur_p->size = last_idx;
610
+ }
168
611
 
169
- if (smpl) {
170
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
171
- }
612
+ static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
613
+ const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
614
+ return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
615
+ }
616
+
617
+ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
618
+ delete (llama_sampler_top_p *) smpl->ctx;
619
+ }
620
+
621
+ static struct llama_sampler_i llama_sampler_top_p_i = {
622
+ /* .name = */ llama_sampler_top_p_name,
623
+ /* .accept = */ nullptr,
624
+ /* .apply = */ llama_sampler_top_p_apply,
625
+ /* .reset = */ nullptr,
626
+ /* .clone = */ llama_sampler_top_p_clone,
627
+ /* .free = */ llama_sampler_top_p_free,
628
+ };
629
+
630
+ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
631
+ return new llama_sampler {
632
+ /* .iface = */ &llama_sampler_top_p_i,
633
+ /* .ctx = */ new llama_sampler_top_p {
634
+ /* .p = */ p,
635
+ /* .min_keep = */ min_keep,
636
+ },
637
+ };
172
638
  }
173
639
 
174
- void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
175
- if (p <= 0.0f || !candidates->size) {
640
+ // min-p
641
+
642
+ struct llama_sampler_min_p {
643
+ const float p;
644
+ const size_t min_keep;
645
+ };
646
+
647
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
648
+ return "min-p";
649
+ }
650
+
651
+ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
652
+ const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
653
+
654
+ if (ctx->p <= 0.0f || !cur_p->size) {
176
655
  return;
177
656
  }
178
657
 
179
- const int64_t t_start_sample_us = ggml_time_us();
180
-
181
658
  bool min_p_applied = false;
182
659
 
183
- // if the candidates aren't sorted, try the unsorted implementation first
184
- if (!candidates->sorted) {
660
+ // if the cur_p aren't sorted, try the unsorted implementation first
661
+ if (!cur_p->sorted) {
185
662
  std::vector<llama_token_data> filtered_tokens;
186
663
 
187
664
  float max_logit = -FLT_MAX;
188
- for (size_t i = 0; i < candidates->size; ++i) {
189
- max_logit = std::max(max_logit, candidates->data[i].logit);
665
+ for (size_t i = 0; i < cur_p->size; ++i) {
666
+ max_logit = std::max(max_logit, cur_p->data[i].logit);
190
667
  }
191
- const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
668
+ const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
192
669
 
193
- for (size_t i = 0; i < candidates->size; ++i) {
194
- if (candidates->data[i].logit >= min_logit) {
195
- filtered_tokens.push_back(candidates->data[i]);
670
+ for (size_t i = 0; i < cur_p->size; ++i) {
671
+ if (cur_p->data[i].logit >= min_logit) {
672
+ filtered_tokens.push_back(cur_p->data[i]);
196
673
  }
197
674
  }
198
675
 
199
676
  // if we have enough values the operation was a success
200
- if (filtered_tokens.size() >= min_keep) {
201
- memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
202
- candidates->size = filtered_tokens.size();
677
+ if (filtered_tokens.size() >= ctx->min_keep) {
678
+ memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
679
+ cur_p->size = filtered_tokens.size();
203
680
  min_p_applied = true;
204
681
  }
205
682
  }
206
683
 
207
- // if the candidates are sorted or the unsorted implementation failed, use this implementation
684
+ // if the cur_p are sorted or the unsorted implementation failed, use this implementation
208
685
  if (!min_p_applied) {
209
686
  // Sort the logits in descending order
210
- if (!candidates->sorted) {
211
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
687
+ if (!cur_p->sorted) {
688
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
212
689
  return a.logit > b.logit;
213
690
  });
214
- candidates->sorted = true;
691
+ cur_p->sorted = true;
215
692
  }
216
693
 
217
- const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
694
+ const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
218
695
  size_t i = 1; // first token always matches
219
696
 
220
- for (; i < candidates->size; ++i) {
221
- if (candidates->data[i].logit < min_logit && i >= min_keep) {
697
+ for (; i < cur_p->size; ++i) {
698
+ if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
222
699
  break; // prob too small
223
700
  }
224
701
  }
225
702
 
226
703
  // Resize the output vector to keep only the matching tokens
227
- candidates->size = i;
228
- }
229
-
230
- if (smpl) {
231
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
704
+ cur_p->size = i;
232
705
  }
233
706
  }
234
707
 
235
- void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
236
- if (z >= 1.0f || candidates->size <= 2) {
237
- return;
238
- }
239
-
240
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
241
- const int64_t t_start_sample_us = ggml_time_us();
242
-
243
- // Compute the first and second derivatives
244
- std::vector<float> first_derivatives(candidates->size - 1);
245
- std::vector<float> second_derivatives(candidates->size - 2);
246
-
247
- for (size_t i = 0; i < first_derivatives.size(); ++i) {
248
- first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
249
- }
250
- for (size_t i = 0; i < second_derivatives.size(); ++i) {
251
- second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
252
- }
253
-
254
- // Calculate absolute value of second derivatives
255
- for (size_t i = 0; i < second_derivatives.size(); ++i) {
256
- second_derivatives[i] = std::abs(second_derivatives[i]);
257
- }
258
-
259
- // Normalize the second derivatives
260
- {
261
- const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
708
+ static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
709
+ const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
710
+ return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
711
+ }
262
712
 
263
- if (second_derivatives_sum > 1e-6f) {
264
- for (float & value : second_derivatives) {
265
- value /= second_derivatives_sum;
266
- }
267
- } else {
268
- for (float & value : second_derivatives) {
269
- value = 1.0f / second_derivatives.size();
270
- }
271
- }
272
- }
713
+ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
714
+ delete (llama_sampler_min_p *) smpl->ctx;
715
+ }
273
716
 
274
- float cum_sum = 0.0f;
275
- size_t last_idx = candidates->size;
276
- for (size_t i = 0; i < second_derivatives.size(); ++i) {
277
- cum_sum += second_derivatives[i];
717
+ static struct llama_sampler_i llama_sampler_min_p_i = {
718
+ /* .name = */ llama_sampler_min_p_name,
719
+ /* .accept = */ nullptr,
720
+ /* .apply = */ llama_sampler_min_p_apply,
721
+ /* .reset = */ nullptr,
722
+ /* .clone = */ llama_sampler_min_p_clone,
723
+ /* .free = */ llama_sampler_min_p_free,
724
+ };
725
+
726
+ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
727
+ return new llama_sampler {
728
+ /* .iface = */ &llama_sampler_min_p_i,
729
+ /* .ctx = */ new llama_sampler_min_p {
730
+ /* .p = */ p,
731
+ /* .min_keep = */ min_keep,
732
+ },
733
+ };
734
+ }
278
735
 
279
- // Check if the running sum is greater than z or if we have kept at least min_keep tokens
280
- if (cum_sum > z && i >= min_keep) {
281
- last_idx = i;
282
- break;
283
- }
284
- }
736
+ // typical
285
737
 
286
- // Resize the output vector to keep only the tokens above the tail location
287
- candidates->size = last_idx;
738
+ struct llama_sampler_typical {
739
+ const float p;
740
+ const size_t min_keep;
741
+ };
288
742
 
289
- if (smpl) {
290
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
291
- }
743
+ static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
744
+ return "typical";
292
745
  }
293
746
 
294
- void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
747
+ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
748
+ const auto * ctx = (llama_sampler_typical *) smpl->ctx;
749
+
295
750
  // Reference implementation:
296
751
  // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
297
- if (p >= 1.0f) {
752
+ if (ctx->p >= 1.0f) {
298
753
  return;
299
754
  }
300
755
 
301
756
  // Compute the softmax of logits and calculate entropy
302
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
303
-
304
- const int64_t t_start_sample_us = ggml_time_us();
757
+ llama_sampler_softmax_impl(cur_p);
305
758
 
306
759
  float entropy = 0.0f;
307
- for (size_t i = 0; i < candidates->size; ++i) {
308
- entropy += -candidates->data[i].p * logf(candidates->data[i].p);
760
+ for (size_t i = 0; i < cur_p->size; ++i) {
761
+ entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
309
762
  }
310
763
 
311
764
  // Compute the absolute difference between negative log probability and entropy for each candidate
312
765
  std::vector<float> shifted_scores;
313
- for (size_t i = 0; i < candidates->size; ++i) {
314
- float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
766
+ for (size_t i = 0; i < cur_p->size; ++i) {
767
+ float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
315
768
  shifted_scores.push_back(shifted_score);
316
769
  }
317
770
 
318
771
  // Sort tokens based on the shifted_scores and their corresponding indices
319
- std::vector<size_t> indices(candidates->size);
772
+ std::vector<size_t> indices(cur_p->size);
320
773
  std::iota(indices.begin(), indices.end(), 0);
321
774
 
322
775
  std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
@@ -329,197 +782,340 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
329
782
 
330
783
  for (size_t i = 0; i < indices.size(); ++i) {
331
784
  size_t idx = indices[i];
332
- cum_sum += candidates->data[idx].p;
785
+ cum_sum += cur_p->data[idx].p;
333
786
 
334
787
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
335
- if (cum_sum > p && i >= min_keep - 1) {
788
+ if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
336
789
  last_idx = i + 1;
337
790
  break;
338
791
  }
339
792
  }
340
793
 
341
794
  // Resize the output vector to keep only the locally typical tokens
342
- std::vector<llama_token_data> new_candidates;
795
+ std::vector<llama_token_data> cur_p_new;
343
796
  for (size_t i = 0; i < last_idx; ++i) {
344
797
  size_t idx = indices[i];
345
- new_candidates.push_back(candidates->data[idx]);
798
+ cur_p_new.push_back(cur_p->data[idx]);
346
799
  }
347
800
 
348
- // Replace the data in candidates with the new_candidates data
349
- std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
350
- candidates->size = new_candidates.size();
351
- candidates->sorted = false;
352
-
353
- if (smpl) {
354
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
355
- }
801
+ // Replace the data in cur_p with the cur_p_new data
802
+ std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
803
+ cur_p->size = cur_p_new.size();
804
+ cur_p->sorted = false;
356
805
  }
357
806
 
358
- void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
359
- const int64_t t_start_sample_us = ggml_time_us();
360
-
361
- // no need to do anything if there is only one (or zero) candidates
362
- if(candidates->size <= 1) {
363
- return;
364
- }
807
+ static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
808
+ const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
809
+ return llama_sampler_init_typical(ctx->p, ctx->min_keep);
810
+ }
365
811
 
366
- // Calculate maximum possible entropy
367
- float max_entropy = -logf(1.0f / candidates->size);
812
+ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
813
+ delete (llama_sampler_typical *) smpl->ctx;
814
+ }
368
815
 
369
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
816
+ static struct llama_sampler_i llama_sampler_typical_i = {
817
+ /* .name = */ llama_sampler_typical_name,
818
+ /* .accept = */ nullptr,
819
+ /* .apply = */ llama_sampler_typical_apply,
820
+ /* .reset = */ nullptr,
821
+ /* .clone = */ llama_sampler_typical_clone,
822
+ /* .free = */ llama_sampler_typical_free,
823
+ };
824
+
825
+ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
826
+ return new llama_sampler {
827
+ /* .iface = */ &llama_sampler_typical_i,
828
+ /* .ctx = */ new llama_sampler_typical {
829
+ /* .p = */ p,
830
+ /* .min_keep = */ min_keep,
831
+ },
832
+ };
833
+ }
370
834
 
371
- // Calculate entropy of the softmax probabilities
372
- float entropy = 0.0f;
373
- for (size_t i = 0; i < candidates->size; ++i) {
374
- float prob = candidates->data[i].p;
375
- if (prob > 0.0f) { // Ensure no log(0)
376
- entropy -= prob * logf(prob);
377
- }
378
- }
835
+ // temp
379
836
 
380
- // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
381
- float normalized_entropy = entropy / max_entropy;
837
+ struct llama_sampler_temp {
838
+ const float temp;
839
+ };
382
840
 
383
- // Map the normalized entropy to the desired temperature range using the power function
384
- float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
841
+ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
842
+ return "temp";
843
+ }
385
844
 
386
- #ifdef DEBUG
387
- LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
388
- LLAMA_LOG_INFO("Entropy: %f\n", entropy);
389
- LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
390
- LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
391
- LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
392
- LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
393
- #endif
845
+ static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
846
+ const auto * ctx = (llama_sampler_temp *) smpl->ctx;
394
847
 
395
- // Apply the dynamically calculated temperature scaling
396
- for (size_t i = 0; i < candidates->size; ++i) {
397
- candidates->data[i].logit /= dyn_temp;
398
- }
848
+ llama_sampler_temp_impl(cur_p, ctx->temp);
849
+ }
399
850
 
400
- // Re-compute softmax probabilities after scaling logits with dynamic temperature
401
- double max_l_double = candidates->data[0].logit;
402
- double cum_sum_double = 0.0;
403
- for (size_t i = 0; i < candidates->size; ++i) {
404
- double p = exp(candidates->data[i].logit - max_l_double);
405
- candidates->data[i].p = p; // Store the scaled probability
406
- cum_sum_double += p;
407
- }
408
- for (size_t i = 0; i < candidates->size; ++i) {
409
- candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
410
- }
851
+ static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
852
+ const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
853
+ return llama_sampler_init_temp(ctx->temp);
854
+ }
411
855
 
412
- #ifdef DEBUG
413
- // Print the updated top 25 probabilities after temperature scaling
414
- LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
415
- for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
416
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
417
- }
418
- #endif
856
+ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
857
+ delete (llama_sampler_temp *) smpl->ctx;
858
+ }
419
859
 
420
- if (smpl) {
421
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
422
- }
860
+ static struct llama_sampler_i llama_sampler_temp_i = {
861
+ /* .name = */ llama_sampler_temp_name,
862
+ /* .accept = */ nullptr,
863
+ /* .apply = */ llama_sampler_temp_apply,
864
+ /* .reset = */ nullptr,
865
+ /* .clone = */ llama_sampler_temp_clone,
866
+ /* .free = */ llama_sampler_temp_free,
867
+ };
868
+
869
+ struct llama_sampler * llama_sampler_init_temp(float temp) {
870
+ return new llama_sampler {
871
+ /* .iface = */ &llama_sampler_temp_i,
872
+ /* .ctx = */ new llama_sampler_temp {
873
+ /*.temp = */ temp,
874
+ },
875
+ };
423
876
  }
424
877
 
425
- void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
426
- const int64_t t_start_sample_us = ggml_time_us();
878
+ // temp-ext
427
879
 
428
- for (size_t i = 0; i < candidates->size; ++i) {
429
- candidates->data[i].logit /= temp;
430
- }
880
+ struct llama_sampler_temp_ext {
881
+ const float temp;
882
+ const float delta;
883
+ const float exponent;
884
+ };
431
885
 
432
- if (smpl) {
433
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
434
- }
886
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
887
+ return "temp-ext";
435
888
  }
436
889
 
437
- void llama_sample_repetition_penalties_impl(
438
- struct llama_sampling * smpl,
439
- llama_token_data_array * candidates,
440
- const llama_token * last_tokens,
441
- size_t penalty_last_n,
442
- float penalty_repeat,
443
- float penalty_freq,
444
- float penalty_present) {
445
- if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
446
- return;
447
- }
448
-
449
- const int64_t t_start_sample_us = ggml_time_us();
890
+ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
891
+ const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
892
+ if (ctx->delta > 0) {
893
+ const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
894
+ const float max_temp = ctx->temp + ctx->delta;
450
895
 
451
- // Create a frequency map to count occurrences of each token in last_tokens
452
- std::unordered_map<llama_token, int> token_count;
453
- for (size_t i = 0; i < penalty_last_n; ++i) {
454
- token_count[last_tokens[i]]++;
455
- }
896
+ float exponent_val = ctx->exponent;
456
897
 
457
- // Apply frequency and presence penalties to the candidates
458
- for (size_t i = 0; i < candidates->size; ++i) {
459
- const auto token_iter = token_count.find(candidates->data[i].id);
460
- if (token_iter == token_count.end()) {
461
- continue;
898
+ // no need to do anything if there is only one (or zero) candidates
899
+ if (cur_p->size <= 1) {
900
+ return;
462
901
  }
463
902
 
464
- const int count = token_iter->second;
903
+ // Calculate maximum possible entropy
904
+ float max_entropy = -logf(1.0f / cur_p->size);
465
905
 
466
- // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
467
- // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
468
- if (candidates->data[i].logit <= 0) {
469
- candidates->data[i].logit *= penalty_repeat;
470
- } else {
471
- candidates->data[i].logit /= penalty_repeat;
906
+ llama_sampler_softmax_impl(cur_p);
907
+
908
+ // Calculate entropy of the softmax probabilities
909
+ float entropy = 0.0f;
910
+ for (size_t i = 0; i < cur_p->size; ++i) {
911
+ float prob = cur_p->data[i].p;
912
+ if (prob > 0.0f) { // Ensure no log(0)
913
+ entropy -= prob * logf(prob);
914
+ }
472
915
  }
473
916
 
474
- candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
475
- }
917
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
918
+ float normalized_entropy = entropy / max_entropy;
476
919
 
477
- candidates->sorted = false;
920
+ // Map the normalized entropy to the desired temperature range using the power function
921
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
478
922
 
479
- if (smpl) {
480
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
481
- }
482
- }
923
+ #ifdef DEBUG
924
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
925
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
926
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
927
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
928
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
929
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
930
+ #endif
483
931
 
484
- void llama_sample_apply_guidance_impl(
485
- struct llama_sampling * smpl,
486
- float * logits,
487
- float * logits_guidance,
488
- float scale) {
489
- GGML_ASSERT(smpl);
932
+ // Apply the dynamically calculated temperature scaling
933
+ llama_sampler_temp_impl(cur_p, dyn_temp);
490
934
 
491
- const auto t_start_sample_us = ggml_time_us();
492
- const auto n_vocab = smpl->n_vocab;
935
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
936
+ const double max_l_double = cur_p->data[0].logit;
493
937
 
494
- llama_log_softmax(logits, n_vocab);
495
- llama_log_softmax(logits_guidance, n_vocab);
938
+ double cum_sum_double = 0.0;
939
+ for (size_t i = 0; i < cur_p->size; ++i) {
940
+ double p = exp(cur_p->data[i].logit - max_l_double);
941
+ cur_p->data[i].p = p; // Store the scaled probability
942
+ cum_sum_double += p;
943
+ }
496
944
 
497
- for (int i = 0; i < n_vocab; ++i) {
498
- auto & l = logits[i];
499
- const auto & g = logits_guidance[i];
945
+ for (size_t i = 0; i < cur_p->size; ++i) {
946
+ cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
947
+ }
500
948
 
501
- l = scale * (l - g) + g;
949
+ #ifdef DEBUG
950
+ // Print the updated top 25 probabilities after temperature scaling
951
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
952
+ for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
953
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
954
+ }
955
+ #endif
956
+ } else {
957
+ llama_sampler_temp_impl(cur_p, ctx->temp);
502
958
  }
959
+ }
503
960
 
504
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
961
+ static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
962
+ const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
963
+ return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
505
964
  }
506
965
 
507
- llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
508
- GGML_ASSERT(smpl);
966
+ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
967
+ delete (llama_sampler_temp_ext *) smpl->ctx;
968
+ }
509
969
 
510
- const int32_t n_vocab = float(smpl->n_vocab);
970
+ static struct llama_sampler_i llama_sampler_temp_ext_i = {
971
+ /* .name = */ llama_sampler_temp_ext_name,
972
+ /* .accept = */ nullptr,
973
+ /* .apply = */ llama_sampler_temp_ext_apply,
974
+ /* .reset = */ nullptr,
975
+ /* .clone = */ llama_sampler_temp_ext_clone,
976
+ /* .free = */ llama_sampler_temp_ext_free,
977
+ };
978
+
979
+ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
980
+ return new llama_sampler {
981
+ /* .iface = */ &llama_sampler_temp_ext_i,
982
+ /* .ctx = */ new llama_sampler_temp_ext {
983
+ /* .temp = */ temp,
984
+ /* .delta = */ delta,
985
+ /* .exponent = */ exponent,
986
+ },
987
+ };
988
+ }
511
989
 
512
- int64_t t_start_sample_us = ggml_time_us();
990
+ // xtc
513
991
 
514
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
992
+ struct llama_sampler_xtc {
993
+ const float probability;
994
+ const float threshold;
995
+ const size_t min_keep;
996
+
997
+ const uint32_t seed;
998
+ uint32_t seed_cur;
999
+
1000
+ std::mt19937 rng;
1001
+ };
1002
+
1003
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1004
+ return "xtc";
1005
+ }
1006
+
1007
+ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1008
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1009
+
1010
+ if (ctx->probability <= 0.0f
1011
+ || ctx->threshold > 0.5f
1012
+ || cur_p->size < 2) {
1013
+ return;
1014
+ }
1015
+
1016
+ std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1017
+ float chance = distribution(ctx->rng);
1018
+ if (chance > ctx->probability) return;
1019
+
1020
+ // in case it's not sorted/recalculated yet
1021
+ llama_sampler_softmax_impl(cur_p);
1022
+
1023
+ int pos_last = 0;
1024
+
1025
+ for (size_t i = 0; i < cur_p->size; ++i) {
1026
+ if (cur_p->data[i].p >= ctx->threshold) {
1027
+ pos_last = i;
1028
+ } else break;
1029
+ }
1030
+
1031
+ if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1032
+ cur_p->data += pos_last;
1033
+ cur_p->size -= pos_last;
1034
+ }
1035
+ }
1036
+
1037
+ static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1038
+ const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1039
+ auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1040
+
1041
+ // copy the state
1042
+ {
1043
+ auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1044
+
1045
+ result_ctx->rng = ctx->rng;
1046
+ }
1047
+
1048
+ return result;
1049
+ }
1050
+
1051
+ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1052
+ delete (llama_sampler_xtc *) smpl->ctx;
1053
+ }
1054
+
1055
+ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1056
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1057
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1058
+ ctx->rng.seed(ctx->seed_cur);
1059
+ }
1060
+
1061
+ static struct llama_sampler_i llama_sampler_xtc_i = {
1062
+ /* .name = */ llama_sampler_xtc_name,
1063
+ /* .accept = */ nullptr,
1064
+ /* .apply = */ llama_sample_xtc_apply,
1065
+ /* .reset = */ llama_sampler_xtc_reset,
1066
+ /* .clone = */ llama_sampler_xtc_clone,
1067
+ /* .free = */ llama_sampler_xtc_free,
1068
+ };
1069
+
1070
+ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1071
+ auto seed_cur = get_rng_seed(seed);
1072
+ return new llama_sampler {
1073
+ /* .iface = */ &llama_sampler_xtc_i,
1074
+ /* .ctx = */ new llama_sampler_xtc {
1075
+ /* .probability = */ p,
1076
+ /* .threshold = */ t,
1077
+ /* .min_keep = */ min_keep,
1078
+ /* .seed = */ seed,
1079
+ /* .seed_cur = */ seed_cur,
1080
+ /* .rng = */ std::mt19937(seed_cur),
1081
+ },
1082
+ };
1083
+ }
1084
+
1085
+ // mirostat
1086
+
1087
+ struct llama_sampler_mirostat {
1088
+ const int32_t n_vocab;
1089
+
1090
+ const uint32_t seed;
1091
+ uint32_t seed_cur;
1092
+
1093
+ const float tau;
1094
+ const float eta;
1095
+
1096
+ const int32_t m;
1097
+
1098
+ float mu;
1099
+
1100
+ std::mt19937 rng;
1101
+ };
1102
+
1103
+ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1104
+ return "mirostat";
1105
+ }
1106
+
1107
+ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1108
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1109
+
1110
+ llama_sampler_softmax_impl(cur_p);
515
1111
 
516
1112
  // Estimate s_hat using the most probable m tokens
517
1113
  float s_hat = 0.0;
518
1114
  float sum_ti_bi = 0.0;
519
1115
  float sum_ti_sq = 0.0;
520
- for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
1116
+ for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
521
1117
  float t_i = logf(float(i + 2) / float(i + 1));
522
- float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
1118
+ float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
523
1119
  sum_ti_bi += t_i * b_i;
524
1120
  sum_ti_sq += t_i * t_i;
525
1121
  }
@@ -527,109 +1123,1225 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
527
1123
 
528
1124
  // Compute k from the estimated s_hat and target surprise value
529
1125
  float epsilon_hat = s_hat - 1;
530
- float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
1126
+ float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
531
1127
 
532
- // Sample the next word X using top-k sampling
533
- llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
534
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
535
- llama_token X = llama_sample_token_impl(smpl, candidates);
536
- t_start_sample_us = ggml_time_us();
1128
+ llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1129
+ llama_sampler_softmax_impl(cur_p);
537
1130
 
538
- // Compute error as the difference between observed surprise and target surprise value
539
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
540
- return candidate.id == X;
541
- }));
542
- float observed_surprise = -log2f(candidates->data[X_idx].p);
543
- float e = observed_surprise - tau;
1131
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1132
+
1133
+ cur_p->selected = idx;
1134
+
1135
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1136
+ float e = observed_surprise - ctx->tau;
544
1137
 
545
1138
  // Update mu using the learning rate and error
546
- *mu = *mu - eta * e;
1139
+ ctx->mu = ctx->mu - ctx->eta * e;
1140
+ }
1141
+
1142
+ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1143
+ const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1144
+ auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
547
1145
 
548
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
549
- return X;
1146
+ // copy the state
1147
+ {
1148
+ auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1149
+
1150
+ result_ctx->mu = ctx->mu;
1151
+ result_ctx->rng = ctx->rng;
1152
+ }
1153
+
1154
+ return result;
550
1155
  }
551
1156
 
552
- llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
553
- int64_t t_start_sample_us;
554
- t_start_sample_us = ggml_time_us();
1157
+ static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1158
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1159
+ ctx->mu = 2.0f*ctx->tau;
1160
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1161
+ ctx->rng.seed(ctx->seed_cur);
1162
+ }
555
1163
 
556
- llama_sample_softmax_impl(smpl, candidates);
1164
+ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1165
+ delete (llama_sampler_mirostat *) smpl->ctx;
1166
+ }
1167
+
1168
+ static struct llama_sampler_i llama_sampler_mirostat_i = {
1169
+ /* .name = */ llama_sampler_mirostat_name,
1170
+ /* .accept = */ nullptr,
1171
+ /* .apply = */ llama_sampler_mirostat_apply,
1172
+ /* .reset = */ llama_sampler_mirostat_reset,
1173
+ /* .clone = */ llama_sampler_mirostat_clone,
1174
+ /* .free = */ llama_sampler_mirostat_free,
1175
+ };
1176
+
1177
+ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1178
+ auto seed_cur = get_rng_seed(seed);
1179
+ return new llama_sampler {
1180
+ /* .iface = */ &llama_sampler_mirostat_i,
1181
+ /* .ctx = */ new llama_sampler_mirostat {
1182
+ /* .n_vocab = */ n_vocab,
1183
+ /* .seed = */ seed,
1184
+ /* .seed_cur = */ seed_cur,
1185
+ /* .tau = */ tau,
1186
+ /* .eta = */ eta,
1187
+ /* .m = */ m,
1188
+ /* .mu = */ 2.0f*tau,
1189
+ /* .rng = */ std::mt19937(seed_cur),
1190
+ },
1191
+ };
1192
+ }
1193
+
1194
+ // mirostat v2
1195
+
1196
+ struct llama_sampler_mirostat_v2 {
1197
+ const uint32_t seed;
1198
+ uint32_t seed_cur;
1199
+
1200
+ const float tau;
1201
+ const float eta;
1202
+
1203
+ float mu;
1204
+
1205
+ std::mt19937 rng;
1206
+ };
1207
+
1208
+ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1209
+ return "mirostat-v2";
1210
+ }
1211
+
1212
+ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1213
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1214
+
1215
+ llama_sampler_softmax_impl(cur_p);
557
1216
 
558
1217
  // Truncate the words with surprise values greater than mu
559
- candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
560
- return -log2f(candidate.p) > *mu;
1218
+ cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1219
+ return -log2f(candidate.p) > ctx->mu;
561
1220
  }));
562
1221
 
563
- if (candidates->size == 0) {
564
- candidates->size = 1;
565
- }
566
-
567
- if (smpl) {
568
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1222
+ if (cur_p->size == 0) {
1223
+ cur_p->size = 1;
569
1224
  }
570
1225
 
571
1226
  // Normalize the probabilities of the remaining words
572
- llama_sample_softmax_impl(smpl, candidates);
1227
+ llama_sampler_softmax_impl(cur_p);
573
1228
 
574
- // Sample the next word X from the remaining words
575
- llama_token X = llama_sample_token_impl(smpl, candidates);
576
- t_start_sample_us = ggml_time_us();
1229
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
577
1230
 
578
- // Compute error as the difference between observed surprise and target surprise value
579
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
580
- return candidate.id == X;
581
- }));
582
- float observed_surprise = -log2f(candidates->data[X_idx].p);
583
- float e = observed_surprise - tau;
1231
+ cur_p->selected = idx;
1232
+
1233
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1234
+ float e = observed_surprise - ctx->tau;
584
1235
 
585
1236
  // Update mu using the learning rate and error
586
- *mu = *mu - eta * e;
1237
+ ctx->mu = ctx->mu - ctx->eta * e;
1238
+ }
1239
+
1240
+ static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1241
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1242
+ ctx->mu = 2.0f*ctx->tau;
1243
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1244
+ ctx->rng.seed(ctx->seed_cur);
1245
+ }
1246
+
1247
+ static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1248
+ const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1249
+
1250
+ auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
587
1251
 
588
- if (smpl) {
589
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
1252
+ // copy the state
1253
+ {
1254
+ auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1255
+
1256
+ result_ctx->mu = ctx->mu;
1257
+ result_ctx->rng = ctx->rng;
590
1258
  }
591
- return X;
1259
+
1260
+ return result;
592
1261
  }
593
1262
 
594
- llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
595
- const int64_t t_start_sample_us = ggml_time_us();
1263
+ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1264
+ delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1265
+ }
596
1266
 
597
- // Find max element
598
- auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
599
- return a.logit < b.logit;
600
- });
1267
+ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1268
+ /* .name = */ llama_sampler_mirostat_v2_name,
1269
+ /* .accept = */ nullptr,
1270
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
1271
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
1272
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
1273
+ /* .free = */ llama_sampler_mirostat_v2_free,
1274
+ };
1275
+
1276
+ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1277
+ auto seed_cur = get_rng_seed(seed);
1278
+ return new llama_sampler {
1279
+ /* .iface = */ &llama_sampler_mirostat_v2_i,
1280
+ /* .ctx = */ new llama_sampler_mirostat_v2 {
1281
+ /* .seed = */ seed,
1282
+ /* .seed_cur = */ seed_cur,
1283
+ /* .tau = */ tau,
1284
+ /* .eta = */ eta,
1285
+ /* .mu = */ 2.0f*tau,
1286
+ /* .rng = */ std::mt19937(seed_cur),
1287
+ },
1288
+ };
1289
+ }
1290
+
1291
+ // grammar
1292
+
1293
+ struct llama_sampler_grammar {
1294
+ const struct llama_vocab * vocab;
1295
+
1296
+ std::string grammar_str;
1297
+ std::string grammar_root;
1298
+
1299
+ struct llama_grammar * grammar;
1300
+ };
1301
+
1302
+ static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1303
+ return "grammar";
1304
+ }
1305
+
1306
+ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1307
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1308
+ if (ctx->grammar) {
1309
+ llama_grammar_accept_impl(*ctx->grammar, token);
1310
+ }
1311
+ }
1312
+
1313
+ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1314
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1315
+ if (ctx->grammar) {
1316
+ llama_grammar_apply_impl(*ctx->grammar, cur_p);
1317
+ }
1318
+ }
601
1319
 
602
- llama_token result = max_iter->id;
603
- if (smpl) {
604
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
605
- smpl->n_sample++;
1320
+ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1321
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1322
+ if (!ctx->grammar) {
1323
+ return;
606
1324
  }
1325
+
1326
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
1327
+
1328
+ llama_grammar_free_impl(ctx->grammar);
1329
+ ctx->grammar = grammar_new;
1330
+ }
1331
+
1332
+ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1333
+ const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1334
+
1335
+ auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
1336
+
1337
+ // copy the state
1338
+ {
1339
+ auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1340
+
1341
+ if (ctx->grammar) {
1342
+ result_ctx->grammar_str = ctx->grammar_str;
1343
+ result_ctx->grammar_root = ctx->grammar_root;
1344
+
1345
+ result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1346
+ }
1347
+ }
1348
+
607
1349
  return result;
608
1350
  }
609
1351
 
610
- llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
611
- GGML_ASSERT(smpl);
1352
+ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1353
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
612
1354
 
613
- const int64_t t_start_sample_us = ggml_time_us();
614
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
1355
+ if (ctx->grammar) {
1356
+ llama_grammar_free_impl(ctx->grammar);
1357
+ }
1358
+
1359
+ delete ctx;
1360
+ }
615
1361
 
616
- std::vector<float> probs;
617
- probs.reserve(candidates->size);
618
- for (size_t i = 0; i < candidates->size; ++i) {
619
- probs.push_back(candidates->data[i].p);
1362
+ static struct llama_sampler_i llama_sampler_grammar_i = {
1363
+ /* .name = */ llama_sampler_grammar_name,
1364
+ /* .accept = */ llama_sampler_grammar_accept_impl,
1365
+ /* .apply = */ llama_sampler_grammar_apply,
1366
+ /* .reset = */ llama_sampler_grammar_reset,
1367
+ /* .clone = */ llama_sampler_grammar_clone,
1368
+ /* .free = */ llama_sampler_grammar_free,
1369
+ };
1370
+
1371
+ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
1372
+ auto * ctx = new llama_sampler_grammar;
1373
+
1374
+ if (grammar_str != nullptr && grammar_str[0] != '\0') {
1375
+ *ctx = {
1376
+ /* .vocab = */ &vocab,
1377
+ /* .grammar_str = */ grammar_str,
1378
+ /* .grammar_root = */ grammar_root,
1379
+ /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
1380
+ };
1381
+ } else {
1382
+ *ctx = {
1383
+ /* .vocab = */ &vocab,
1384
+ /* .grammar_str = */ {},
1385
+ /* .grammar_root = */ {},
1386
+ /* .grammar = */ nullptr,
1387
+ };
620
1388
  }
621
1389
 
622
- std::discrete_distribution<> dist(probs.begin(), probs.end());
623
- int idx = dist(rng);
1390
+ return new llama_sampler {
1391
+ /* .iface = */ &llama_sampler_grammar_i,
1392
+ /* .ctx = */ ctx,
1393
+ };
1394
+ }
1395
+
1396
+ // penalties
624
1397
 
625
- llama_token result = candidates->data[idx].id;
1398
+ struct llama_sampler_penalties {
1399
+ const int32_t n_vocab;
1400
+ const llama_token special_eos_id;
1401
+ const llama_token linefeed_id;
626
1402
 
627
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
628
- smpl->n_sample++;
1403
+ const int32_t penalty_last_n;
1404
+ const float penalty_repeat;
1405
+ const float penalty_freq;
1406
+ const float penalty_present;
1407
+
1408
+ const bool penalize_nl;
1409
+ const bool ignore_eos;
1410
+
1411
+ ring_buffer<llama_token> prev;
1412
+ };
1413
+
1414
+ static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1415
+ return "penalties";
1416
+ }
1417
+
1418
+ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1419
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1420
+ if (ctx->penalty_last_n == 0) {
1421
+ return;
1422
+ }
1423
+
1424
+ ctx->prev.push_back(token);
1425
+ }
1426
+
1427
+ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1428
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1429
+
1430
+ if (ctx->ignore_eos) {
1431
+ assert(ctx->special_eos_id >= 0);
1432
+
1433
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1434
+ if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
1435
+ cur_p->data[ctx->special_eos_id].logit = -INFINITY;
1436
+ } else {
1437
+ // else, search for the special EOS token
1438
+ for (size_t i = 0; i < cur_p->size; ++i) {
1439
+ if (cur_p->data[i].id == ctx->special_eos_id) {
1440
+ cur_p->data[i].logit = -INFINITY;
1441
+ break;
1442
+ }
1443
+ }
1444
+ }
1445
+ }
1446
+
1447
+ if ((ctx->penalty_last_n == 0) ||
1448
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1449
+ return;
1450
+ }
1451
+
1452
+ bool nl_found = false;
1453
+ size_t nl_idx = 0;
1454
+ float nl_logit = -INFINITY;
1455
+ if (!ctx->penalize_nl) {
1456
+ assert(ctx->linefeed_id >= 0);
1457
+
1458
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1459
+ if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
1460
+ nl_found = true;
1461
+ nl_idx = ctx->linefeed_id;
1462
+ nl_logit = cur_p->data[ctx->linefeed_id].logit;
1463
+ } else {
1464
+ // else, search for the linefeed token
1465
+ for (size_t i = 0; i < cur_p->size; ++i) {
1466
+ if (cur_p->data[i].id == ctx->linefeed_id) {
1467
+ nl_found = true;
1468
+ nl_idx = i;
1469
+ nl_logit = cur_p->data[i].logit;
1470
+ break;
1471
+ }
1472
+ }
1473
+ }
1474
+ }
1475
+
1476
+ // Create a frequency map to count occurrences of each token in last_tokens
1477
+ // TODO: optimize this by maintaining the token count in the sampler context
1478
+ using llama_token_cnt = std::unordered_map<llama_token, int>;
1479
+ llama_token_cnt token_count;
1480
+
1481
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1482
+ token_count[ctx->prev.rat(i)]++;
1483
+ }
1484
+
1485
+ // Apply frequency and presence penalties to the cur_p
1486
+ for (size_t i = 0; i < cur_p->size; ++i) {
1487
+ const auto token_iter = token_count.find(cur_p->data[i].id);
1488
+ if (token_iter == token_count.end()) {
1489
+ continue;
1490
+ }
1491
+
1492
+ const int count = token_iter->second;
1493
+
1494
+ // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1495
+ // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1496
+ if (cur_p->data[i].logit <= 0) {
1497
+ cur_p->data[i].logit *= ctx->penalty_repeat;
1498
+ } else {
1499
+ cur_p->data[i].logit /= ctx->penalty_repeat;
1500
+ }
1501
+
1502
+ cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
1503
+ }
1504
+
1505
+ cur_p->sorted = false;
1506
+
1507
+ if (!ctx->penalize_nl && nl_found) {
1508
+ // restore the logit of the newline token if it was penalized
1509
+ cur_p->data[nl_idx].logit = nl_logit;
1510
+ }
1511
+ }
1512
+
1513
+ static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1514
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1515
+ ctx->prev.clear();
1516
+ }
1517
+
1518
+ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1519
+ const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1520
+ auto * result = llama_sampler_init_penalties(
1521
+ ctx->n_vocab,
1522
+ ctx->special_eos_id,
1523
+ ctx->linefeed_id,
1524
+ ctx->penalty_last_n,
1525
+ ctx->penalty_repeat,
1526
+ ctx->penalty_freq,
1527
+ ctx->penalty_present,
1528
+ ctx->penalize_nl,
1529
+ ctx->ignore_eos);
1530
+
1531
+ // copy the state
1532
+ {
1533
+ auto * result_ctx = (llama_sampler_penalties *) result->ctx;
1534
+
1535
+ result_ctx->prev = ctx->prev;
1536
+ }
629
1537
 
630
1538
  return result;
631
1539
  }
632
1540
 
633
- llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
634
- return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
1541
+ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1542
+ delete (llama_sampler_penalties *) smpl->ctx;
1543
+ }
1544
+
1545
+ static struct llama_sampler_i llama_sampler_penalties_i = {
1546
+ /* .name = */ llama_sampler_penalties_name,
1547
+ /* .accept = */ llama_sampler_penalties_accept,
1548
+ /* .apply = */ llama_sampler_penalties_apply,
1549
+ /* .reset = */ llama_sampler_penalties_reset,
1550
+ /* .clone = */ llama_sampler_penalties_clone,
1551
+ /* .free = */ llama_sampler_penalties_free,
1552
+ };
1553
+
1554
+ struct llama_sampler * llama_sampler_init_penalties(
1555
+ int32_t n_vocab,
1556
+ llama_token special_eos_id,
1557
+ llama_token linefeed_id,
1558
+ int32_t penalty_last_n,
1559
+ float penalty_repeat,
1560
+ float penalty_freq,
1561
+ float penalty_present,
1562
+ bool penalize_nl,
1563
+ bool ignore_eos) {
1564
+ if (linefeed_id == LLAMA_TOKEN_NULL) {
1565
+ penalize_nl = true;
1566
+ }
1567
+
1568
+ if (special_eos_id == LLAMA_TOKEN_NULL) {
1569
+ ignore_eos = false;
1570
+ }
1571
+
1572
+ penalty_last_n = std::max(penalty_last_n, 0);
1573
+
1574
+ return new llama_sampler {
1575
+ /* .iface = */ &llama_sampler_penalties_i,
1576
+ /* .ctx = */ new llama_sampler_penalties {
1577
+ /* .n_vocab = */ n_vocab,
1578
+ /* .special_eos_id = */ special_eos_id,
1579
+ /* .linefeed_id = */ linefeed_id,
1580
+ /* .penalty_last_n = */ penalty_last_n,
1581
+ /* .penalty_repeat = */ penalty_repeat,
1582
+ /* .penalty_freq = */ penalty_freq,
1583
+ /* .penalty_present = */ penalty_present,
1584
+ /* .penalize_nl = */ penalize_nl,
1585
+ /* .ignore_eos = */ ignore_eos,
1586
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1587
+ },
1588
+ };
1589
+ }
1590
+
1591
+ // DRY
1592
+
1593
+ struct llama_sampler_dry {
1594
+ int32_t total_context_size;
1595
+
1596
+ const float dry_multiplier;
1597
+ const float dry_base;
1598
+ const int32_t dry_allowed_length;
1599
+ const int32_t dry_penalty_last_n;
1600
+
1601
+ std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
1602
+ std::vector<int> dry_repeat_count;
1603
+ std::unordered_map<llama_token, int> dry_max_token_repeat;
1604
+ ring_buffer<llama_token> last_tokens;
1605
+ };
1606
+
1607
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1608
+ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1609
+ for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
1610
+ std::string word = llama_detokenize(vocab, {token_id}, true);
1611
+ if (word.find(str) != std::string::npos) {
1612
+ token_sequences.emplace(token_id, std::vector<llama_token>());
1613
+ } else {
1614
+ size_t word_len = word.size(), str_len = str.size();
1615
+ size_t pos = -1;
1616
+ while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
1617
+ bool match = true;
1618
+ size_t i;
1619
+ for (i = 1; i < str_len && i + pos < word_len; ++i) {
1620
+ if (word[pos + i] != str[i]) {
1621
+ match = false;
1622
+ break;
1623
+ }
1624
+ }
1625
+ if (match) {
1626
+ std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
1627
+ if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1628
+ tokenization.resize(max_tail_len);
1629
+ }
1630
+
1631
+ // Ensure we don't already have a duplicate matching tokenization
1632
+ auto its = token_sequences.equal_range(token_id);
1633
+ bool found = false;
1634
+ for (auto it = its.first; it != its.second; ++it) {
1635
+ if (tokenization == it->second) {
1636
+ found = true;
1637
+ break;
1638
+ }
1639
+ }
1640
+ if (!found) {
1641
+ token_sequences.emplace(token_id, tokenization);
1642
+ }
1643
+ }
1644
+ }
1645
+ }
1646
+ }
1647
+ }
1648
+
1649
+ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
1650
+ return "dry";
1651
+ }
1652
+
1653
+ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
1654
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1655
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1656
+ return;
1657
+ }
1658
+
1659
+ ctx->last_tokens.push_back(token);
1660
+ }
1661
+
1662
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1663
+ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1664
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1665
+
1666
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1667
+ return;
1668
+ }
1669
+
1670
+ int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
1671
+ int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
1672
+
1673
+ if (last_n_repeat <= ctx->dry_allowed_length) {
1674
+ return;
1675
+ }
1676
+
1677
+ ctx->dry_repeat_count.assign(last_n_repeat, 0);
1678
+ ctx->dry_max_token_repeat.clear();
1679
+
1680
+ // Step 1: Look for restart sequences to limit the maximum repetition length.
1681
+ // Work backwards through the context looking for any token that begins a restart sequence.
1682
+ //
1683
+ // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
1684
+ // sequences that together comprise a restart sequence. This allows us to quickly check
1685
+ // whether each token is the head of a complete sequence. Most restart sequences are actually
1686
+ // a single token, and for these the "tail" is an empty vector.
1687
+ //
1688
+ // If the token is a "head", test all restart sequences that begin with this token
1689
+ // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
1690
+ // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
1691
+ // longest matching sequence (if any) is used to limit the maximum repetition length.
1692
+ //
1693
+ // Note that in the case case of a short sequence contained in a longer one, this might fail to
1694
+ // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
1695
+ // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
1696
+ // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
1697
+ //
1698
+ // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
1699
+ // have already clamped the maximum tail sequence length when generating `restart_sequences`.
1700
+ // With clamping, this scan is O(N) in the context length.
1701
+
1702
+ int rep_limit = last_n_repeat;
1703
+ for (int i = 0; i < last_n_repeat; ++i) {
1704
+ llama_token token = ctx->last_tokens.rat(i);
1705
+ auto its = ctx->dry_processed_breakers.equal_range(token);
1706
+ if (its.first == ctx->dry_processed_breakers.end()) {
1707
+ continue;
1708
+ }
1709
+ int longest_match = -1;
1710
+ for (auto it = its.first; it != its.second; ++it) {
1711
+ // Note that (*it) does not contain the head character, so seq_len will be
1712
+ // the restart sequence length minus 1.
1713
+ // In the common case of a single-token restart sequence, (*it) will be empty
1714
+ // and we will trivially match.
1715
+ int seq_len = (int)it->second.size();
1716
+ if (seq_len > longest_match && seq_len <= (int)i) {
1717
+ bool match = true;
1718
+ for (int offset = 0; offset < seq_len; ++offset) {
1719
+ // The -1 when indexing `last_tokens` is because we already matched the head.
1720
+ if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
1721
+ match = false;
1722
+ break;
1723
+ }
1724
+ }
1725
+ if (match) {
1726
+ longest_match = seq_len;
1727
+ }
1728
+ }
1729
+ }
1730
+ if (longest_match >= 0) {
1731
+ // We found a restart sequence starting `i` tokens from the end and continuing for
1732
+ // `longest_match` tokens.
1733
+ rep_limit = i - longest_match;
1734
+ break;
1735
+ }
1736
+ }
1737
+ if (rep_limit < ctx->dry_allowed_length) {
1738
+ return;
1739
+ }
1740
+
1741
+ // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
1742
+ // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
1743
+ // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
1744
+ //
1745
+ // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
1746
+ // https://ivanyu.me/blog/2014/10/15/z-algorithm/
1747
+ //
1748
+ // The code below is adapted from the public domain implementation by the same author here:
1749
+ // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
1750
+ //
1751
+ // Example:
1752
+ // Last N tokens: a b c c b c y a b c
1753
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1754
+ // ^
1755
+ // This `3` means that the last three tokens of the context (a b c) also appear here.
1756
+ //
1757
+ // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
1758
+ // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
1759
+ // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
1760
+ // ensure that the inner while loops only examine each token in the context once as the outer
1761
+ // for loop iterates over the context.
1762
+
1763
+ {
1764
+ const int last = last_n_repeat - 1;
1765
+ int rt = 0, lt = 0;
1766
+
1767
+ for (int k = 1; k < last_n_repeat; ++k) {
1768
+ if (k > rt) {
1769
+ // If k is outside the current Z-box, do naive computation.
1770
+ int n = 0;
1771
+ while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
1772
+ ++n;
1773
+ }
1774
+ ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1775
+ if (n > 0) {
1776
+ lt = k;
1777
+ rt = k+n-1;
1778
+ }
1779
+ } else {
1780
+ // If k is inside the current Z-box, consider two cases.
1781
+
1782
+ int p = k - lt; // Pair index.
1783
+ int right_part_len = rt - k + 1;
1784
+
1785
+ if (ctx->dry_repeat_count[last - p] < right_part_len) {
1786
+ int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
1787
+ ctx->dry_repeat_count[last - k] = n;
1788
+ } else {
1789
+ int i = rt + 1;
1790
+ while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
1791
+ i += 1;
1792
+ }
1793
+
1794
+ int n = std::min(i - k, rep_limit);
1795
+ ctx->dry_repeat_count[last - k] = n;
1796
+ lt = k;
1797
+ rt = i - 1;
1798
+ }
1799
+ }
1800
+ }
1801
+ }
1802
+
1803
+ // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
1804
+ // that would be generated by emitting each new token that would extend a sequence.
1805
+ //
1806
+ // Following the same example as above:
1807
+ // Last N tokens: a b c c b c y a b c
1808
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1809
+ //
1810
+ // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
1811
+ // c: 3 -> 4 (from `a b c` to `a b c c`)
1812
+ // b: 1 -> 2 (from `c` to `c b`)
1813
+ // y: 2 -> 3 (from `b c` to `b c y`)
1814
+
1815
+ for (int i = 0; i < last_n_repeat - 1; ++i) {
1816
+ int repeat_len = ctx->dry_repeat_count[i];
1817
+ if (repeat_len >= ctx->dry_allowed_length) {
1818
+ // This token ends a repeat, so the next token would continue one.
1819
+ // By convention, the value of `repeat_len` only includes the tokens currently
1820
+ // in the context, not the new token that would be added.
1821
+ llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
1822
+ // Track the maximum sequence ending in this token.
1823
+ const auto& it = ctx->dry_max_token_repeat.find(token);
1824
+ if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
1825
+ ctx->dry_max_token_repeat[token] = repeat_len;
1826
+ }
1827
+ }
1828
+ }
1829
+
1830
+ // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
1831
+
1832
+ // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
1833
+ // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
1834
+ const float FLOAT_MAX_LOG = 88.7228391f;
1835
+ int max_exponent = 0;
1836
+ if (ctx->dry_base > 1.000001f) {
1837
+ max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
1838
+ }
1839
+
1840
+ for (size_t i = 0; i < cur_p->size; ++i) {
1841
+ const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
1842
+ if (af_kvp != ctx->dry_max_token_repeat.end()) {
1843
+ // Check all sequence breakers starting with this token
1844
+ auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
1845
+ bool is_single_token_breaker = false;
1846
+
1847
+ for (auto it = range.first; it != range.second; ++it) {
1848
+ if (it->second.empty()) {
1849
+ is_single_token_breaker = true;
1850
+ break;
1851
+ }
1852
+ }
1853
+
1854
+ // Apply penalty only if it's not a single-token sequence breaker
1855
+ if (!is_single_token_breaker) {
1856
+ int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
1857
+ if (max_exponent > 0 && repeat_exp > max_exponent) {
1858
+ repeat_exp = max_exponent;
1859
+ }
1860
+ float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
1861
+ cur_p->data[i].logit -= penalty;
1862
+ }
1863
+ }
1864
+ }
1865
+
1866
+ cur_p->sorted = false;
1867
+ }
1868
+
1869
+ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
1870
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1871
+ ctx->last_tokens.clear();
1872
+ ctx->dry_repeat_count.clear();
1873
+ ctx->dry_max_token_repeat.clear();
1874
+ }
1875
+
1876
+ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
1877
+ const auto * ctx = (llama_sampler_dry *) smpl->ctx;
1878
+
1879
+ llama_vocab dummy_vocab;
1880
+
1881
+ // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
1882
+ auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1883
+
1884
+ // Copy the state, including the processed breakers
1885
+ {
1886
+ auto * result_ctx = (llama_sampler_dry *) result->ctx;
1887
+ result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
1888
+ result_ctx->dry_repeat_count = ctx->dry_repeat_count;
1889
+ result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
1890
+ result_ctx->last_tokens = ctx->last_tokens;
1891
+ }
1892
+
1893
+ return result;
1894
+ }
1895
+
1896
+ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
1897
+ delete (llama_sampler_dry *) smpl->ctx;
1898
+ }
1899
+
1900
+ static struct llama_sampler_i llama_sampler_dry_i = {
1901
+ /* .name = */ llama_sampler_dry_name,
1902
+ /* .accept = */ llama_sampler_dry_accept,
1903
+ /* .apply = */ llama_sampler_dry_apply,
1904
+ /* .reset = */ llama_sampler_dry_reset,
1905
+ /* .clone = */ llama_sampler_dry_clone,
1906
+ /* .free = */ llama_sampler_dry_free,
1907
+ };
1908
+
1909
+ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
1910
+ int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
1911
+ std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
1912
+ const int MAX_CHAR_LEN = 40;
1913
+ const int MAX_SEQ_LEN = 20;
1914
+
1915
+ const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
1916
+
1917
+ if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
1918
+ // Process sequence breakers
1919
+ for (size_t i = 0; i < num_breakers; ++i) {
1920
+ if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
1921
+ LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
1922
+ continue;
1923
+ }
1924
+
1925
+ std::string sequence_break(seq_breakers[i]);
1926
+ if (sequence_break.empty()) {
1927
+ LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
1928
+ continue;
1929
+ }
1930
+
1931
+ if (sequence_break.size() > MAX_CHAR_LEN) {
1932
+ LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
1933
+ sequence_break.resize(MAX_CHAR_LEN);
1934
+ }
1935
+
1936
+ get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
1937
+ }
1938
+ }
1939
+
1940
+ return new llama_sampler {
1941
+ /* .iface = */ &llama_sampler_dry_i,
1942
+ /* .ctx = */ new llama_sampler_dry {
1943
+ /* .total_context_size = */ context_size,
1944
+ /* .dry_multiplier = */ dry_multiplier,
1945
+ /* .dry_base = */ dry_base,
1946
+ /* .dry_allowed_length = */ dry_allowed_length,
1947
+ /* .dry_penalty_last_n = */ dry_penalty_last_n,
1948
+ /* .dry_processed_breakers = */ std::move(processed_breakers),
1949
+ /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
1950
+ /* .dry_max_token_repeat = */ {},
1951
+ /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
1952
+ },
1953
+ };
1954
+ }
1955
+
1956
+ // wrapper for test-sampling.cpp
1957
+ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
1958
+ llama_vocab dummy_vocab;
1959
+ auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
1960
+ auto * ctx = (llama_sampler_dry *) result->ctx;
1961
+
1962
+ // Process the token-based sequence breakers
1963
+ ctx->dry_processed_breakers.clear();
1964
+ if (seq_breakers.empty()) {
1965
+ LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
1966
+ } else {
1967
+ for (const auto& breaker : seq_breakers) {
1968
+ if (breaker.empty()) {
1969
+ LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
1970
+ continue;
1971
+ }
1972
+ llama_token head_token = breaker[0];
1973
+ std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
1974
+ ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
1975
+ }
1976
+
1977
+ if (ctx->dry_processed_breakers.empty()) {
1978
+ LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
1979
+ }
1980
+ }
1981
+
1982
+ return result;
1983
+ }
1984
+
1985
+ // logit-bias
1986
+
1987
+ struct llama_sampler_logit_bias {
1988
+ const int32_t n_vocab;
1989
+
1990
+ const std::vector<llama_logit_bias> logit_bias;
1991
+
1992
+ std::vector<llama_logit_bias> to_search;
1993
+ };
1994
+
1995
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
1996
+ return "logit-bias";
1997
+ }
1998
+
1999
+ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2000
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
2001
+
2002
+ if (ctx->logit_bias.empty()) {
2003
+ return;
2004
+ }
2005
+
2006
+ ctx->to_search.clear();
2007
+
2008
+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
2009
+ for (const auto & lb : ctx->logit_bias) {
2010
+ if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
2011
+ cur_p->data[lb.token].logit += lb.bias;
2012
+ } else {
2013
+ ctx->to_search.push_back(lb);
2014
+ }
2015
+ }
2016
+
2017
+ if (ctx->to_search.empty()) {
2018
+ return;
2019
+ }
2020
+
2021
+ // search for the remaining candidates that were not found in the previous step
2022
+ for (size_t i = 0; i < cur_p->size; ++i) {
2023
+ for (const auto & lb : ctx->to_search) {
2024
+ if (cur_p->data[i].id == lb.token) {
2025
+ cur_p->data[i].logit += lb.bias;
2026
+ break;
2027
+ }
2028
+ }
2029
+ }
2030
+ }
2031
+
2032
+ static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
2033
+ const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
2034
+ return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
2035
+ }
2036
+
2037
+ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2038
+ delete (llama_sampler_logit_bias *) smpl->ctx;
2039
+ }
2040
+
2041
+ static struct llama_sampler_i llama_sampler_logit_bias_i = {
2042
+ /* .name = */ llama_sampler_logit_bias_name,
2043
+ /* .accept = */ nullptr,
2044
+ /* .apply = */ llama_sampler_logit_bias_apply,
2045
+ /* .reset = */ nullptr,
2046
+ /* .clone = */ llama_sampler_logit_bias_clone,
2047
+ /* .free = */ llama_sampler_logit_bias_free,
2048
+ };
2049
+
2050
+ struct llama_sampler * llama_sampler_init_logit_bias(
2051
+ int32_t n_vocab,
2052
+ int32_t n_logit_bias,
2053
+ const llama_logit_bias * logit_bias) {
2054
+ return new llama_sampler {
2055
+ /* .iface = */ &llama_sampler_logit_bias_i,
2056
+ /* .ctx = */ new llama_sampler_logit_bias {
2057
+ /* .n_vocab = */ n_vocab,
2058
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2059
+ /* .to_search = */ {},
2060
+ },
2061
+ };
2062
+ }
2063
+
2064
+ // infill
2065
+
2066
+ //#define GGML_DEBUG_SAMPLER_INFILL
2067
+
2068
+ struct llama_sampler_infill {
2069
+ const struct llama_vocab * vocab;
2070
+
2071
+ std::vector<char> buf0;
2072
+ std::vector<char> buf1;
2073
+ };
2074
+
2075
+ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
2076
+ return "infill";
2077
+ }
2078
+
2079
+ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2080
+ auto * ctx = (llama_sampler_infill *) smpl->ctx;
2081
+
2082
+ llama_sampler_softmax_impl(cur_p);
2083
+
2084
+ #if defined(GGML_DEBUG_SAMPLER_INFILL)
2085
+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
2086
+ #else
2087
+ #define LOG_DBG_CUR(...)
2088
+ #endif
2089
+
2090
+ for (size_t i = 0; i < cur_p->size; ++i) {
2091
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2092
+ }
2093
+
2094
+ float p_txt_sum = 0.0f;
2095
+ float p_eog_sum = 0.0f;
2096
+
2097
+ for (size_t i = 0; i < cur_p->size; ++i) {
2098
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2099
+ p_eog_sum += cur_p->data[i].p;
2100
+ } else {
2101
+ p_txt_sum += cur_p->data[i].p;
2102
+ }
2103
+ }
2104
+
2105
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
2106
+
2107
+ LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
2108
+
2109
+ if (3*p_eog_sum*cur_p->size > p_txt_sum) {
2110
+ LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
2111
+
2112
+ // keep just the EOG tokens
2113
+ const auto size_org = cur_p->size;
2114
+
2115
+ cur_p->size = 0;
2116
+
2117
+ float p_sum = 0.0f;
2118
+
2119
+ for (size_t i = 0; i < size_org; ++i) {
2120
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2121
+ p_sum += cur_p->data[i].p;
2122
+
2123
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2124
+ }
2125
+ }
2126
+
2127
+ // normalize probs
2128
+ for (size_t i = 0; i < cur_p->size; ++i) {
2129
+ cur_p->data[i].p /= p_sum;
2130
+ }
2131
+
2132
+ return;
2133
+ }
2134
+
2135
+ size_t n_combined = 0; GGML_UNUSED(n_combined);
2136
+
2137
+ // combine tokens with common prefix
2138
+ for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
2139
+ for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
2140
+ if (cur_p->data[i0].logit == -INFINITY) {
2141
+ break;
2142
+ }
2143
+
2144
+ if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
2145
+ continue;
2146
+ }
2147
+
2148
+ int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2149
+ if (len0 < 0) {
2150
+ ctx->buf0.resize(len0);
2151
+ len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2152
+ assert(len0 > 0);
2153
+ }
2154
+
2155
+ int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2156
+ if (len1 < 0) {
2157
+ ctx->buf1.resize(len1);
2158
+ len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2159
+ assert(len1 > 0);
2160
+ }
2161
+
2162
+ // token i0 is a prefix of token i1
2163
+ if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
2164
+ int dst = i0;
2165
+ int src = i1;
2166
+
2167
+ // merge into the token with higher probability
2168
+ if (cur_p->data[i1].p > cur_p->data[i0].p) {
2169
+ std::swap(dst, src);
2170
+ }
2171
+
2172
+ cur_p->data[dst].p += cur_p->data[src].p;
2173
+ cur_p->data[src].logit = -INFINITY;
2174
+ cur_p->data[src].p = 0.0f;
2175
+
2176
+ n_combined++;
2177
+ }
2178
+ }
2179
+ }
2180
+
2181
+ size_t n_non_eog = 0;
2182
+
2183
+ size_t size_org = cur_p->size;
2184
+
2185
+ float p_sum = 0.0f;
2186
+ float thold = 0.2f;
2187
+
2188
+ cur_p->size = 0;
2189
+
2190
+ LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2191
+
2192
+ for (size_t i = 0; i < size_org; ++i) {
2193
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2194
+
2195
+ if (cur_p->data[i].p < thold && !is_eog) {
2196
+ continue;
2197
+ }
2198
+
2199
+ if (!is_eog) {
2200
+ ++n_non_eog;
2201
+ }
2202
+
2203
+ p_sum += cur_p->data[i].p;
2204
+
2205
+ // keep this token
2206
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2207
+ }
2208
+
2209
+ LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2210
+
2211
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2212
+ if (n_non_eog == 0) {
2213
+ cur_p->size = 1;
2214
+ cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
2215
+ cur_p->data[0].logit = 1.0f;
2216
+
2217
+ return;
2218
+ }
2219
+
2220
+ // normalize probs
2221
+ for (size_t i = 0; i < cur_p->size; ++i) {
2222
+ cur_p->data[i].p /= p_sum;
2223
+
2224
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2225
+ }
2226
+
2227
+ size_org = cur_p->size;
2228
+ p_sum = 0.0f;
2229
+ thold = 1.0/(n_non_eog + 1);
2230
+
2231
+ cur_p->size = 0;
2232
+
2233
+ LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2234
+
2235
+ for (size_t i = 0; i < size_org; ++i) {
2236
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2237
+
2238
+ if (cur_p->data[i].p < thold && !is_eog) {
2239
+ continue;
2240
+ }
2241
+
2242
+ p_sum += cur_p->data[i].p;
2243
+
2244
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2245
+ }
2246
+
2247
+ // normalize probs
2248
+ for (size_t i = 0; i < cur_p->size; ++i) {
2249
+ cur_p->data[i].p /= p_sum;
2250
+
2251
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2252
+ }
2253
+
2254
+ #undef LOG_DBG_CUR
2255
+ }
2256
+
2257
+ static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2258
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2259
+ return llama_sampler_init_infill_impl(*ctx->vocab);
2260
+ }
2261
+
2262
+ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2263
+ delete (llama_sampler_infill *) smpl->ctx;
2264
+ }
2265
+
2266
+ static struct llama_sampler_i llama_sampler_infill_i = {
2267
+ /* .name = */ llama_sampler_infill_name,
2268
+ /* .accept = */ nullptr,
2269
+ /* .apply = */ llama_sampler_infill_apply,
2270
+ /* .reset = */ nullptr,
2271
+ /* .clone = */ llama_sampler_infill_clone,
2272
+ /* .free = */ llama_sampler_infill_free,
2273
+ };
2274
+
2275
+ struct llama_sampler * llama_sampler_init_infill_impl(
2276
+ const struct llama_vocab & vocab) {
2277
+ return new llama_sampler {
2278
+ /* .iface = */ &llama_sampler_infill_i,
2279
+ /* .ctx = */ new llama_sampler_infill {
2280
+ /* .vocab = */ &vocab,
2281
+ /* .buf0 = */ std::vector<char>(512),
2282
+ /* .buf1 = */ std::vector<char>(512),
2283
+ },
2284
+ };
2285
+ }
2286
+
2287
+ // utils
2288
+
2289
+ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2290
+ if (smpl->iface == &llama_sampler_dist_i) {
2291
+ return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
2292
+ }
2293
+
2294
+ if (smpl->iface == &llama_sampler_mirostat_i) {
2295
+ return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
2296
+ }
2297
+
2298
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
2299
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
2300
+ }
2301
+
2302
+ if (smpl->iface == &llama_sampler_chain_i) {
2303
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2304
+ for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2305
+ const uint32_t seed = llama_sampler_get_seed(*it);
2306
+ if (seed != LLAMA_DEFAULT_SEED) {
2307
+ return seed;
2308
+ }
2309
+ }
2310
+ }
2311
+
2312
+ return LLAMA_DEFAULT_SEED;
2313
+ }
2314
+
2315
+ // perf
2316
+
2317
+ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
2318
+ struct llama_perf_sampler_data data = {};
2319
+
2320
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2321
+ GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2322
+ }
2323
+
2324
+ const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
2325
+
2326
+ data.t_sample_ms = 1e-3 * ctx->t_sample_us;
2327
+ data.n_sample = std::max(0, ctx->n_sample);
2328
+
2329
+ return data;
2330
+ }
2331
+
2332
+ void llama_perf_sampler_print(const struct llama_sampler * chain) {
2333
+ const auto data = llama_perf_sampler(chain);
2334
+
2335
+ LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2336
+ __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
2337
+ }
2338
+
2339
+ void llama_perf_sampler_reset(struct llama_sampler * chain) {
2340
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2341
+ GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2342
+ }
2343
+
2344
+ auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2345
+
2346
+ ctx->t_sample_us = ctx->n_sample = 0;
635
2347
  }