@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,460 +1,466 @@
1
- #define LLAMA_API_INTERNAL
2
1
  #include "sampling.h"
3
- #include <random>
4
2
 
5
- struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
6
- struct llama_sampling_context * result = new llama_sampling_context();
3
+ #include "common.h"
7
4
 
8
- result->params = params;
9
- result->grammar = nullptr;
5
+ #include <cmath>
6
+ #include <unordered_map>
10
7
 
11
- // if there is a grammar, parse it
12
- if (!params.grammar.empty()) {
13
- result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
8
+ // the ring buffer works similarly to std::deque, but with a fixed capacity
9
+ // TODO: deduplicate with llama-impl.h
10
+ template<typename T>
11
+ struct ring_buffer {
12
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
14
13
 
15
- // will be empty (default) if there are parse errors
16
- if (result->parsed_grammar.rules.empty()) {
17
- fprintf(stderr, "%s: failed to parse grammar\n", __func__);
18
- delete result;
19
- return nullptr;
14
+ T & front() {
15
+ if (sz == 0) {
16
+ throw std::runtime_error("ring buffer is empty");
20
17
  }
18
+ return data[first];
19
+ }
21
20
 
22
- // Ensure that there is a "root" node.
23
- if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
24
- fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
25
- delete result;
26
- return nullptr;
21
+ const T & front() const {
22
+ if (sz == 0) {
23
+ throw std::runtime_error("ring buffer is empty");
27
24
  }
25
+ return data[first];
26
+ }
28
27
 
29
- std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
30
-
31
- struct llama_grammar * grammar = llama_grammar_init(
32
- grammar_rules.data(),
33
- grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
34
- if (grammar == nullptr) {
35
- throw std::runtime_error("Failed to initialize llama_grammar");
28
+ T & back() {
29
+ if (sz == 0) {
30
+ throw std::runtime_error("ring buffer is empty");
36
31
  }
37
- result->grammar = grammar;
32
+ return data[pos];
38
33
  }
39
34
 
40
- result->prev.resize(params.n_prev);
41
-
42
- result->n_valid = 0;
43
-
44
- llama_sampling_set_rng_seed(result, params.seed);
45
-
46
- return result;
47
- }
48
-
49
- void llama_sampling_free(struct llama_sampling_context * ctx) {
50
- if (ctx->grammar != NULL) {
51
- llama_grammar_free(ctx->grammar);
35
+ const T & back() const {
36
+ if (sz == 0) {
37
+ throw std::runtime_error("ring buffer is empty");
38
+ }
39
+ return data[pos];
52
40
  }
53
41
 
54
- delete ctx;
55
- }
56
-
57
- void llama_sampling_reset(llama_sampling_context * ctx) {
58
- if (ctx->grammar != NULL) {
59
- llama_grammar_free(ctx->grammar);
60
- ctx->grammar = NULL;
42
+ void push_back(const T & value) {
43
+ if (sz == capacity) {
44
+ // advance the start when buffer is full
45
+ first = (first + 1) % capacity;
46
+ } else {
47
+ sz++;
48
+ }
49
+ data[pos] = value;
50
+ pos = (pos + 1) % capacity;
61
51
  }
62
52
 
63
- if (!ctx->parsed_grammar.rules.empty()) {
64
- std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
53
+ T pop_front() {
54
+ if (sz == 0) {
55
+ throw std::runtime_error("ring buffer is empty");
56
+ }
57
+ T value = data[first];
58
+ first = (first + 1) % capacity;
59
+ sz--;
60
+ return value;
61
+ }
65
62
 
66
- struct llama_grammar * grammar = llama_grammar_init(
67
- grammar_rules.data(),
68
- grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
69
- if (grammar == nullptr) {
70
- throw std::runtime_error("Failed to initialize llama_grammar");
63
+ const T & rat(size_t i) const {
64
+ if (i >= sz) {
65
+ throw std::runtime_error("ring buffer: index out of bounds");
71
66
  }
72
- ctx->grammar = grammar;
67
+ return data[(first + sz - i - 1) % capacity];
73
68
  }
74
69
 
75
- std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
76
- ctx->cur.clear();
77
- ctx->n_valid = 0;
78
- }
70
+ std::vector<T> to_vector() const {
71
+ std::vector<T> result;
72
+ result.reserve(sz);
73
+ for (size_t i = 0; i < sz; i++) {
74
+ result.push_back(data[(first + i) % capacity]);
75
+ }
76
+ return result;
77
+ }
79
78
 
80
- void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
81
- if (seed == LLAMA_DEFAULT_SEED) {
82
- seed = std::random_device{}();
79
+ void clear() {
80
+ // here only reset the status of the buffer
81
+ sz = 0;
82
+ first = 0;
83
+ pos = 0;
83
84
  }
84
- ctx->rng.seed(seed);
85
- }
86
85
 
87
- void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
88
- if (dst->grammar) {
89
- llama_grammar_free(dst->grammar);
90
- dst->grammar = nullptr;
86
+ bool empty() const {
87
+ return sz == 0;
91
88
  }
92
89
 
93
- if (src->grammar) {
94
- dst->grammar = llama_grammar_copy(src->grammar);
90
+ size_t size() const {
91
+ return sz;
95
92
  }
96
93
 
97
- dst->prev = src->prev;
98
- }
94
+ size_t capacity = 0;
95
+ size_t sz = 0;
96
+ size_t first = 0;
97
+ size_t pos = 0;
98
+ std::vector<T> data;
99
+ };
99
100
 
100
- llama_token llama_sampling_last(llama_sampling_context * ctx) {
101
- return ctx->prev.back();
102
- }
101
+ struct common_sampler {
102
+ common_sampler_params params;
103
103
 
104
- std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
105
- const int size = ctx_sampling->prev.size();
104
+ struct llama_sampler * grmr;
105
+ struct llama_sampler * chain;
106
106
 
107
- n = std::min(n, size);
107
+ ring_buffer<llama_token> prev;
108
108
 
109
- std::string result;
109
+ std::vector<llama_token_data> cur;
110
110
 
111
- for (int i = size - n; i < size; i++) {
112
- result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
113
- }
111
+ llama_token_data_array cur_p;
114
112
 
115
- return result;
116
- }
113
+ void set_logits(struct llama_context * ctx, int idx) {
114
+ const auto * logits = llama_get_logits_ith(ctx, idx);
115
+
116
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
117
+
118
+ cur.resize(n_vocab);
117
119
 
118
- std::string llama_sampling_print(const llama_sampling_params & params) {
120
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
121
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
122
+ }
123
+
124
+ cur_p = { cur.data(), cur.size(), -1, false };
125
+ }
126
+ };
127
+
128
+ std::string common_sampler_params::print() const {
119
129
  char result[1024];
120
130
 
121
131
  snprintf(result, sizeof(result),
122
132
  "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
123
- "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
133
+ "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
134
+ "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
124
135
  "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
125
- params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
126
- params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
127
- params.mirostat, params.mirostat_eta, params.mirostat_tau);
136
+ penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
137
+ dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
138
+ top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
139
+ mirostat, mirostat_eta, mirostat_tau);
128
140
 
129
141
  return std::string(result);
130
142
  }
131
143
 
132
- std::string llama_sampling_order_print(const llama_sampling_params & params) {
133
- std::string result = "CFG -> Penalties ";
144
+ struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
145
+ llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
146
+
147
+ lparams.no_perf = params.no_perf;
148
+
149
+ auto * result = new common_sampler {
150
+ /* .params = */ params,
151
+ /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
152
+ /* .chain = */ llama_sampler_chain_init(lparams),
153
+ /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
154
+ /* .cur = */ {},
155
+ /* .cur_p = */ {},
156
+ };
157
+
158
+ llama_sampler_chain_add(result->chain,
159
+ llama_sampler_init_logit_bias(
160
+ llama_n_vocab(model),
161
+ params.logit_bias.size(),
162
+ params.logit_bias.data()));
163
+
164
+ llama_sampler_chain_add(result->chain,
165
+ llama_sampler_init_penalties(
166
+ llama_n_vocab (model),
167
+ llama_token_eos(model),
168
+ llama_token_nl (model),
169
+ params.penalty_last_n,
170
+ params.penalty_repeat,
171
+ params.penalty_freq,
172
+ params.penalty_present,
173
+ params.penalize_nl,
174
+ params.ignore_eos));
175
+
134
176
  if (params.mirostat == 0) {
135
- for (auto sampler_type : params.samplers_sequence) {
136
- const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
137
- if (!sampler_type_name.empty()) {
138
- result += "-> " + sampler_type_name + " ";
177
+ for (const auto & cnstr : params.samplers) {
178
+ switch (cnstr) {
179
+ case COMMON_SAMPLER_TYPE_DRY:
180
+ {
181
+ std::vector<const char*> c_breakers;
182
+ c_breakers.reserve(params.dry_sequence_breakers.size());
183
+ for (const auto& str : params.dry_sequence_breakers) {
184
+ c_breakers.push_back(str.c_str());
185
+ }
186
+
187
+ llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
188
+ }
189
+ break;
190
+ case COMMON_SAMPLER_TYPE_TOP_K:
191
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
192
+ break;
193
+ case COMMON_SAMPLER_TYPE_TOP_P:
194
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
195
+ break;
196
+ case COMMON_SAMPLER_TYPE_MIN_P:
197
+ llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
198
+ break;
199
+ case COMMON_SAMPLER_TYPE_XTC:
200
+ llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
201
+ break;
202
+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
203
+ llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
204
+ break;
205
+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
206
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
207
+ break;
208
+ case COMMON_SAMPLER_TYPE_INFILL:
209
+ llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
210
+ break;
211
+ default:
212
+ GGML_ASSERT(false && "unknown sampler type");
139
213
  }
140
214
  }
215
+ llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
216
+ } else if (params.mirostat == 1) {
217
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
218
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
219
+ } else if (params.mirostat == 2) {
220
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
221
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
141
222
  } else {
142
- result += "-> mirostat ";
223
+ GGML_ASSERT(false && "unknown mirostat version");
143
224
  }
144
225
 
145
226
  return result;
146
227
  }
147
228
 
148
- std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
149
- switch (sampler_type) {
150
- case llama_sampler_type::TOP_K: return "top_k";
151
- case llama_sampler_type::TFS_Z: return "tfs_z";
152
- case llama_sampler_type::TYPICAL_P: return "typical_p";
153
- case llama_sampler_type::TOP_P: return "top_p";
154
- case llama_sampler_type::MIN_P: return "min_p";
155
- case llama_sampler_type::TEMPERATURE: return "temperature";
156
- default : return "";
229
+ void common_sampler_free(struct common_sampler * gsmpl) {
230
+ if (gsmpl) {
231
+ llama_sampler_free(gsmpl->grmr);
232
+
233
+ llama_sampler_free(gsmpl->chain);
234
+
235
+ delete gsmpl;
157
236
  }
158
237
  }
159
238
 
160
- std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
161
- std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
162
- {"top_k", llama_sampler_type::TOP_K},
163
- {"top_p", llama_sampler_type::TOP_P},
164
- {"typical_p", llama_sampler_type::TYPICAL_P},
165
- {"min_p", llama_sampler_type::MIN_P},
166
- {"tfs_z", llama_sampler_type::TFS_Z},
167
- {"temperature", llama_sampler_type::TEMPERATURE}
168
- };
239
+ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
240
+ if (accept_grammar) {
241
+ llama_sampler_accept(gsmpl->grmr, token);
242
+ }
169
243
 
170
- // since samplers names are written multiple ways
171
- // make it ready for both system names and input names
172
- std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
173
- {"top-k", llama_sampler_type::TOP_K},
174
- {"top-p", llama_sampler_type::TOP_P},
175
- {"nucleus", llama_sampler_type::TOP_P},
176
- {"typical-p", llama_sampler_type::TYPICAL_P},
177
- {"typical", llama_sampler_type::TYPICAL_P},
178
- {"min-p", llama_sampler_type::MIN_P},
179
- {"tfs-z", llama_sampler_type::TFS_Z},
180
- {"tfs", llama_sampler_type::TFS_Z},
181
- {"temp", llama_sampler_type::TEMPERATURE}
182
- };
244
+ llama_sampler_accept(gsmpl->chain, token);
183
245
 
184
- std::vector<llama_sampler_type> sampler_types;
185
- sampler_types.reserve(names.size());
186
- for (const auto & name : names)
187
- {
188
- auto sampler_item = sampler_canonical_name_map.find(name);
189
- if (sampler_item != sampler_canonical_name_map.end())
190
- {
191
- sampler_types.push_back(sampler_item->second);
192
- }
193
- else
194
- {
195
- if (allow_alt_names)
196
- {
197
- sampler_item = sampler_alt_name_map.find(name);
198
- if (sampler_item != sampler_alt_name_map.end())
199
- {
200
- sampler_types.push_back(sampler_item->second);
201
- }
202
- }
203
- }
204
- }
205
- return sampler_types;
246
+ gsmpl->prev.push_back(token);
206
247
  }
207
248
 
208
- std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
209
- std::unordered_map<char, llama_sampler_type> sampler_name_map {
210
- {'k', llama_sampler_type::TOP_K},
211
- {'p', llama_sampler_type::TOP_P},
212
- {'y', llama_sampler_type::TYPICAL_P},
213
- {'m', llama_sampler_type::MIN_P},
214
- {'f', llama_sampler_type::TFS_Z},
215
- {'t', llama_sampler_type::TEMPERATURE}
216
- };
249
+ void common_sampler_reset(struct common_sampler * gsmpl) {
250
+ llama_sampler_reset(gsmpl->grmr);
217
251
 
218
- std::vector<llama_sampler_type> sampler_types;
219
- sampler_types.reserve(names_string.size());
220
- for (const auto & c : names_string) {
221
- const auto sampler_item = sampler_name_map.find(c);
222
- if (sampler_item != sampler_name_map.end()) {
223
- sampler_types.push_back(sampler_item->second);
224
- }
225
- }
226
- return sampler_types;
252
+ llama_sampler_reset(gsmpl->chain);
227
253
  }
228
254
 
229
- // no reasons to expose this function in header
230
- static void sampler_queue(
231
- struct llama_context * ctx_main,
232
- const llama_sampling_params & params,
233
- llama_token_data_array & cur_p,
234
- size_t min_keep) {
235
- const float temp = params.temp;
236
- const float dynatemp_range = params.dynatemp_range;
237
- const float dynatemp_exponent = params.dynatemp_exponent;
238
- const int32_t top_k = params.top_k;
239
- const float top_p = params.top_p;
240
- const float min_p = params.min_p;
241
- const float tfs_z = params.tfs_z;
242
- const float typical_p = params.typical_p;
243
- const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
244
-
245
- for (auto sampler_type : samplers_sequence) {
246
- switch (sampler_type) {
247
- case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
248
- case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
249
- case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
250
- case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
251
- case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
252
- case llama_sampler_type::TEMPERATURE:
253
- if (dynatemp_range > 0) {
254
- float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
255
- float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
256
- llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
257
- } else {
258
- llama_sample_temp(ctx_main, &cur_p, temp);
259
- }
260
- break;
261
- default : break;
262
- }
263
- }
255
+ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
256
+ return new common_sampler {
257
+ /* .params = */ gsmpl->params,
258
+ /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
259
+ /* .chain = */ llama_sampler_clone(gsmpl->chain),
260
+ /* .prev = */ gsmpl->prev,
261
+ /* .cur = */ gsmpl->cur,
262
+ /* .cur_p = */ gsmpl->cur_p,
263
+ };
264
264
  }
265
265
 
266
- static llama_token llama_sampling_sample_impl(
267
- struct llama_sampling_context * ctx_sampling,
268
- struct llama_context * ctx_main,
269
- struct llama_context * ctx_cfg,
270
- const int idx,
271
- bool is_resampling) {
272
- const llama_sampling_params & params = ctx_sampling->params;
273
-
274
- const float temp = params.temp;
275
- const int mirostat = params.mirostat;
276
- const float mirostat_tau = params.mirostat_tau;
277
- const float mirostat_eta = params.mirostat_eta;
278
-
279
- std::vector<float> original_logits;
280
- auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
281
- if (ctx_sampling->grammar != NULL && !is_resampling) {
282
- GGML_ASSERT(!original_logits.empty());
283
- }
284
- llama_token id = 0;
285
-
286
- if (temp < 0.0) {
287
- // greedy sampling, with probs
288
- llama_sample_softmax(ctx_main, &cur_p);
289
- id = cur_p.data[0].id;
290
- } else if (temp == 0.0) {
291
- // greedy sampling, no probs
292
- id = llama_sample_token_greedy(ctx_main, &cur_p);
293
- } else {
294
- if (mirostat == 1) {
295
- const int mirostat_m = 100;
296
- llama_sample_temp(ctx_main, &cur_p, temp);
297
- id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
298
- } else if (mirostat == 2) {
299
- llama_sample_temp(ctx_main, &cur_p, temp);
300
- id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
301
- } else {
302
- // temperature sampling
303
- size_t min_keep = std::max(1, params.min_keep);
304
-
305
- sampler_queue(ctx_main, params, cur_p, min_keep);
266
+ void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
267
+ // TODO: measure grammar performance
306
268
 
307
- id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
269
+ if (gsmpl) {
270
+ llama_perf_sampler_print(gsmpl->chain);
271
+ }
272
+ if (ctx) {
273
+ llama_perf_context_print(ctx);
274
+ }
275
+ }
308
276
 
309
- //{
310
- // const int n_top = 10;
311
- // LOG("top %d candidates:\n", n_top);
277
+ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
278
+ gsmpl->set_logits(ctx, idx);
312
279
 
313
- // for (int i = 0; i < n_top; i++) {
314
- // const llama_token id = cur_p.data[i].id;
315
- // (void)id; // To avoid a warning that id is unused when logging is disabled.
316
- // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
317
- // }
318
- //}
280
+ auto & grmr = gsmpl->grmr;
281
+ auto & chain = gsmpl->chain;
282
+ auto & cur_p = gsmpl->cur_p; // initialized by set_logits
319
283
 
320
- //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
321
- }
284
+ if (grammar_first) {
285
+ llama_sampler_apply(grmr, &cur_p);
322
286
  }
323
287
 
324
- if (ctx_sampling->grammar != NULL && !is_resampling) {
325
- // Get a pointer to the logits
326
- float * logits = llama_get_logits_ith(ctx_main, idx);
288
+ llama_sampler_apply(chain, &cur_p);
327
289
 
328
- // Create an array with a single token data element for the sampled id
329
- llama_token_data single_token_data = {id, logits[id], 0.0f};
330
- llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
290
+ GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
331
291
 
332
- // Apply grammar constraints to the single token
333
- llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
292
+ const llama_token id = cur_p.data[cur_p.selected].id;
334
293
 
335
- // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
336
- bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
294
+ if (grammar_first) {
295
+ return id;
296
+ }
337
297
 
338
- // If the token is not valid according to the grammar, perform resampling
339
- if (!is_valid) {
340
- LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
298
+ // check if it the sampled token fits the grammar
299
+ {
300
+ llama_token_data single_token_data = { id, 1.0f, 0.0f };
301
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
341
302
 
342
- // Restore logits from the copy
343
- std::copy(original_logits.begin(), original_logits.end(), logits);
303
+ llama_sampler_apply(grmr, &single_token_data_array);
344
304
 
345
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
305
+ const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
306
+ if (is_valid) {
307
+ return id;
346
308
  }
347
309
  }
348
310
 
349
- ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
311
+ // resampling:
312
+ // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
313
+ gsmpl->set_logits(ctx, idx);
350
314
 
351
- return id;
352
- }
315
+ llama_sampler_apply(grmr, &cur_p);
316
+ llama_sampler_apply(chain, &cur_p);
353
317
 
354
- static llama_token_data_array llama_sampling_prepare_impl(
355
- struct llama_sampling_context * ctx_sampling,
356
- struct llama_context * ctx_main,
357
- struct llama_context * ctx_cfg,
358
- const int idx,
359
- bool apply_grammar,
360
- std::vector<float> * original_logits) {
361
- const llama_sampling_params & params = ctx_sampling->params;
318
+ GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
362
319
 
363
- const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
320
+ return cur_p.data[cur_p.selected].id;
321
+ }
364
322
 
365
- const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
366
- const float penalty_repeat = params.penalty_repeat;
367
- const float penalty_freq = params.penalty_freq;
368
- const float penalty_present = params.penalty_present;
323
+ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
324
+ return llama_sampler_get_seed(gsmpl->chain);
325
+ }
369
326
 
370
- const bool penalize_nl = params.penalize_nl;
327
+ // helpers
371
328
 
372
- auto & prev = ctx_sampling->prev;
373
- auto & cur = ctx_sampling->cur;
329
+ llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
330
+ return &gsmpl->cur_p;
331
+ }
374
332
 
375
- // Get a pointer to the logits
376
- float * logits = llama_get_logits_ith(ctx_main, idx);
333
+ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
334
+ return gsmpl->prev.rat(0);
335
+ }
377
336
 
378
- if (ctx_sampling->grammar != NULL && !apply_grammar) {
379
- GGML_ASSERT(original_logits != NULL);
380
- // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
381
- *original_logits = {logits, logits + n_vocab};
337
+ std::string common_sampler_print(const struct common_sampler * gsmpl) {
338
+ std::string result = "logits ";
339
+
340
+ for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
341
+ const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
342
+ result += std::string("-> ") + llama_sampler_name(smpl) + " ";
382
343
  }
383
344
 
384
- // apply params.logit_bias map
385
- for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
386
- logits[it->first] += it->second;
345
+ return result;
346
+ }
347
+
348
+ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
349
+ n = std::min(n, (int) gsmpl->prev.size());
350
+
351
+ if (n <= 0) {
352
+ return "";
387
353
  }
388
354
 
389
- if (ctx_cfg) {
390
- float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
391
- llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
355
+ std::string result;
356
+ result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
357
+
358
+ for (int i = n - 1; i >= 0; i--) {
359
+ const llama_token id = gsmpl->prev.rat(i);
360
+
361
+ GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
362
+
363
+ result += common_token_to_piece(ctx_main, id);
392
364
  }
393
365
 
394
- cur.resize(n_vocab);
366
+ return result;
367
+ }
395
368
 
396
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
397
- cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
369
+ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
370
+ switch (cnstr) {
371
+ case COMMON_SAMPLER_TYPE_DRY: return 'd';
372
+ case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
373
+ case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
374
+ case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
375
+ case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
376
+ case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
377
+ case COMMON_SAMPLER_TYPE_XTC: return 'x';
378
+ case COMMON_SAMPLER_TYPE_INFILL: return 'i';
379
+ default : return '?';
398
380
  }
381
+ }
399
382
 
400
- llama_token_data_array cur_p = { cur.data(), cur.size(), false };
383
+ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
384
+ switch (cnstr) {
385
+ case COMMON_SAMPLER_TYPE_DRY: return "dry";
386
+ case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
387
+ case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
388
+ case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
389
+ case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
390
+ case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
391
+ case COMMON_SAMPLER_TYPE_XTC: return "xtc";
392
+ case COMMON_SAMPLER_TYPE_INFILL: return "infill";
393
+ default : return "";
394
+ }
395
+ }
401
396
 
402
- // apply penalties
403
- const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
404
- const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
405
- if (penalty_tokens_used_size) {
406
- const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
397
+ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
398
+ std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
399
+ { "dry", COMMON_SAMPLER_TYPE_DRY },
400
+ { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
401
+ { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
402
+ { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
403
+ { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
404
+ { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
405
+ { "xtc", COMMON_SAMPLER_TYPE_XTC },
406
+ { "infill", COMMON_SAMPLER_TYPE_INFILL },
407
+ };
407
408
 
408
- llama_sample_repetition_penalties(ctx_main, &cur_p,
409
- penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
410
- penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
409
+ // since samplers names are written multiple ways
410
+ // make it ready for both system names and input names
411
+ std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
412
+ { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
413
+ { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
414
+ { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
415
+ { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
416
+ { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
417
+ { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
418
+ { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
419
+ { "min-p", COMMON_SAMPLER_TYPE_MIN_P },
420
+ { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
421
+ };
411
422
 
412
- if (!penalize_nl) {
413
- for (size_t idx = 0; idx < cur_p.size; idx++) {
414
- if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
415
- cur_p.data[idx].logit = nl_logit;
416
- break;
423
+ std::vector<common_sampler_type> samplers;
424
+ samplers.reserve(names.size());
425
+
426
+ for (const auto & name : names) {
427
+ auto sampler = sampler_canonical_name_map.find(name);
428
+ if (sampler != sampler_canonical_name_map.end()) {
429
+ samplers.push_back(sampler->second);
430
+ } else {
431
+ if (allow_alt_names) {
432
+ sampler = sampler_alt_name_map.find(name);
433
+ if (sampler != sampler_alt_name_map.end()) {
434
+ samplers.push_back(sampler->second);
417
435
  }
418
436
  }
419
437
  }
420
438
  }
421
439
 
422
- // apply grammar checks before sampling logic
423
- if (apply_grammar && ctx_sampling->grammar != NULL) {
424
- llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
425
- }
426
-
427
- return cur_p;
440
+ return samplers;
428
441
  }
429
442
 
430
- llama_token llama_sampling_sample(
431
- struct llama_sampling_context * ctx_sampling,
432
- struct llama_context * ctx_main,
433
- struct llama_context * ctx_cfg,
434
- const int idx) {
435
- // Call the implementation function with is_resampling set to false by default
436
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
437
- }
438
-
439
- llama_token_data_array llama_sampling_prepare(
440
- struct llama_sampling_context * ctx_sampling,
441
- struct llama_context * ctx_main,
442
- struct llama_context * ctx_cfg,
443
- const int idx,
444
- bool apply_grammar,
445
- std::vector<float> * original_logits) {
446
- return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
447
- }
443
+ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
444
+ std::unordered_map<char, common_sampler_type> sampler_name_map = {
445
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
446
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
447
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
448
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
449
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
450
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
451
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
452
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
453
+ };
448
454
 
449
- void llama_sampling_accept(
450
- struct llama_sampling_context * ctx_sampling,
451
- struct llama_context * ctx_main,
452
- llama_token id,
453
- bool apply_grammar) {
454
- ctx_sampling->prev.erase(ctx_sampling->prev.begin());
455
- ctx_sampling->prev.push_back(id);
455
+ std::vector<common_sampler_type> samplers;
456
+ samplers.reserve(chars.size());
456
457
 
457
- if (ctx_sampling->grammar != NULL && apply_grammar) {
458
- llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
458
+ for (const auto & c : chars) {
459
+ const auto sampler = sampler_name_map.find(c);
460
+ if (sampler != sampler_name_map.end()) {
461
+ samplers.push_back(sampler->second);
462
+ }
459
463
  }
464
+
465
+ return samplers;
460
466
  }