@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
63
63
  }
64
64
  */
65
65
 
66
+ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
67
+ if (temp <= 0.0f) {
68
+ // find the token with the highest logit and set the rest to -inf
69
+ size_t max_i = 0;
70
+ float max_l = cur_p->data[0].logit;
71
+
72
+ for (size_t i = 1; i < cur_p->size; ++i) {
73
+ if (cur_p->data[i ].logit > max_l) {
74
+ cur_p->data[max_i].logit = -INFINITY;
75
+ max_i = i;
76
+ max_l = cur_p->data[i].logit;
77
+ } else {
78
+ cur_p->data[i].logit = -INFINITY;
79
+ }
80
+ }
81
+
82
+ return;
83
+ }
84
+
85
+ for (size_t i = 0; i < cur_p->size; ++i) {
86
+ cur_p->data[i].logit /= temp;
87
+ }
88
+ }
89
+
66
90
  static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
67
91
  GGML_ASSERT(cur_p->size > 0);
68
92
 
@@ -89,7 +113,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
89
113
  }
90
114
 
91
115
  static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
92
- // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
116
+ // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
93
117
  // if (k >= (int32_t)cur_p->size) {
94
118
  // return;
95
119
  // }
@@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
427
451
 
428
452
  static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
429
453
  auto * ctx = (llama_sampler_dist *) smpl->ctx;
454
+
455
+ llama_sampler_softmax_impl(cur_p);
456
+
430
457
  cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
431
458
  }
432
459
 
@@ -706,101 +733,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
706
733
  };
707
734
  }
708
735
 
709
- // tail-free
710
-
711
- struct llama_sampler_tail_free {
712
- const float z;
713
- const size_t min_keep;
714
- };
715
-
716
- static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
717
- return "tail-free";
718
- }
719
-
720
- static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
721
- const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
722
-
723
- if (ctx->z >= 1.0f || cur_p->size <= 2) {
724
- return;
725
- }
726
-
727
- llama_sampler_softmax_impl(cur_p);
728
-
729
- // Compute the first and second derivatives
730
- std::vector<float> first_derivatives(cur_p->size - 1);
731
- std::vector<float> second_derivatives(cur_p->size - 2);
732
-
733
- for (size_t i = 0; i < first_derivatives.size(); ++i) {
734
- first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
735
- }
736
- for (size_t i = 0; i < second_derivatives.size(); ++i) {
737
- second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
738
- }
739
-
740
- // Calculate absolute value of second derivatives
741
- for (size_t i = 0; i < second_derivatives.size(); ++i) {
742
- second_derivatives[i] = std::abs(second_derivatives[i]);
743
- }
744
-
745
- // Normalize the second derivatives
746
- {
747
- const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
748
-
749
- if (second_derivatives_sum > 1e-6f) {
750
- for (float & value : second_derivatives) {
751
- value /= second_derivatives_sum;
752
- }
753
- } else {
754
- for (float & value : second_derivatives) {
755
- value = 1.0f / second_derivatives.size();
756
- }
757
- }
758
- }
759
-
760
- float cum_sum = 0.0f;
761
- size_t last_idx = cur_p->size;
762
- for (size_t i = 0; i < second_derivatives.size(); ++i) {
763
- cum_sum += second_derivatives[i];
764
-
765
- // Check if the running sum is greater than z or if we have kept at least min_keep tokens
766
- if (cum_sum > ctx->z && i >= ctx->min_keep) {
767
- last_idx = i;
768
- break;
769
- }
770
- }
771
-
772
- // Resize the output vector to keep only the tokens above the tail location
773
- cur_p->size = last_idx;
774
- }
775
-
776
- static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
777
- const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
778
- return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
779
- }
780
-
781
- static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
782
- delete (llama_sampler_tail_free *) smpl->ctx;
783
- }
784
-
785
- static struct llama_sampler_i llama_sampler_tail_free_i = {
786
- /* .name = */ llama_sampler_tail_free_name,
787
- /* .accept = */ nullptr,
788
- /* .apply = */ llama_sampler_tail_free_apply,
789
- /* .reset = */ nullptr,
790
- /* .clone = */ llama_sampler_tail_free_clone,
791
- /* .free = */ llama_sampler_tail_free_free,
792
- };
793
-
794
- struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
795
- return new llama_sampler {
796
- /* .iface = */ &llama_sampler_tail_free_i,
797
- /* .ctx = */ new llama_sampler_tail_free {
798
- /* .z = */ z,
799
- /*. min_keep = */ min_keep,
800
- },
801
- };
802
- }
803
-
804
736
  // typical
805
737
 
806
738
  struct llama_sampler_typical {
@@ -912,9 +844,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
912
844
 
913
845
  static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
914
846
  const auto * ctx = (llama_sampler_temp *) smpl->ctx;
915
- for (size_t i = 0; i < cur_p->size; ++i) {
916
- cur_p->data[i].logit /= ctx->temp;
917
- }
847
+
848
+ llama_sampler_temp_impl(cur_p, ctx->temp);
918
849
  }
919
850
 
920
851
  static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@@ -961,6 +892,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
961
892
  if (ctx->delta > 0) {
962
893
  const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
963
894
  const float max_temp = ctx->temp + ctx->delta;
895
+
964
896
  float exponent_val = ctx->exponent;
965
897
 
966
898
  // no need to do anything if there is only one (or zero) candidates
@@ -998,9 +930,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
998
930
  #endif
999
931
 
1000
932
  // Apply the dynamically calculated temperature scaling
1001
- for (size_t i = 0; i < cur_p->size; ++i) {
1002
- cur_p->data[i].logit /= dyn_temp;
1003
- }
933
+ llama_sampler_temp_impl(cur_p, dyn_temp);
1004
934
 
1005
935
  // Re-compute softmax probabilities after scaling logits with dynamic temperature
1006
936
  const double max_l_double = cur_p->data[0].logit;
@@ -1024,9 +954,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
1024
954
  }
1025
955
  #endif
1026
956
  } else {
1027
- for (size_t i = 0; i < cur_p->size; ++i) {
1028
- cur_p->data[i].logit /= ctx->temp;
1029
- }
957
+ llama_sampler_temp_impl(cur_p, ctx->temp);
1030
958
  }
1031
959
  }
1032
960
 
@@ -1059,6 +987,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
1059
987
  };
1060
988
  }
1061
989
 
990
+ // xtc
991
+
992
+ struct llama_sampler_xtc {
993
+ const float probability;
994
+ const float threshold;
995
+ const size_t min_keep;
996
+
997
+ const uint32_t seed;
998
+ uint32_t seed_cur;
999
+
1000
+ std::mt19937 rng;
1001
+ };
1002
+
1003
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1004
+ return "xtc";
1005
+ }
1006
+
1007
+ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1008
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1009
+
1010
+ if (ctx->probability <= 0.0f
1011
+ || ctx->threshold > 0.5f
1012
+ || cur_p->size < 2) {
1013
+ return;
1014
+ }
1015
+
1016
+ std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1017
+ float chance = distribution(ctx->rng);
1018
+ if (chance > ctx->probability) return;
1019
+
1020
+ // in case it's not sorted/recalculated yet
1021
+ llama_sampler_softmax_impl(cur_p);
1022
+
1023
+ int pos_last = 0;
1024
+
1025
+ for (size_t i = 0; i < cur_p->size; ++i) {
1026
+ if (cur_p->data[i].p >= ctx->threshold) {
1027
+ pos_last = i;
1028
+ } else break;
1029
+ }
1030
+
1031
+ if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1032
+ cur_p->data += pos_last;
1033
+ cur_p->size -= pos_last;
1034
+ }
1035
+ }
1036
+
1037
+ static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1038
+ const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1039
+ auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1040
+
1041
+ // copy the state
1042
+ {
1043
+ auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1044
+
1045
+ result_ctx->rng = ctx->rng;
1046
+ }
1047
+
1048
+ return result;
1049
+ }
1050
+
1051
+ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1052
+ delete (llama_sampler_xtc *) smpl->ctx;
1053
+ }
1054
+
1055
+ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1056
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1057
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1058
+ ctx->rng.seed(ctx->seed_cur);
1059
+ }
1060
+
1061
+ static struct llama_sampler_i llama_sampler_xtc_i = {
1062
+ /* .name = */ llama_sampler_xtc_name,
1063
+ /* .accept = */ nullptr,
1064
+ /* .apply = */ llama_sample_xtc_apply,
1065
+ /* .reset = */ llama_sampler_xtc_reset,
1066
+ /* .clone = */ llama_sampler_xtc_clone,
1067
+ /* .free = */ llama_sampler_xtc_free,
1068
+ };
1069
+
1070
+ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1071
+ auto seed_cur = get_rng_seed(seed);
1072
+ return new llama_sampler {
1073
+ /* .iface = */ &llama_sampler_xtc_i,
1074
+ /* .ctx = */ new llama_sampler_xtc {
1075
+ /* .probability = */ p,
1076
+ /* .threshold = */ t,
1077
+ /* .min_keep = */ min_keep,
1078
+ /* .seed = */ seed,
1079
+ /* .seed_cur = */ seed_cur,
1080
+ /* .rng = */ std::mt19937(seed_cur),
1081
+ },
1082
+ };
1083
+ }
1084
+
1062
1085
  // mirostat
1063
1086
 
1064
1087
  struct llama_sampler_mirostat {
@@ -1373,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
1373
1396
  // penalties
1374
1397
 
1375
1398
  struct llama_sampler_penalties {
1376
- const int32_t n_vocab;
1377
- const llama_token special_eos_id;
1378
- const llama_token linefeed_id;
1379
-
1380
1399
  const int32_t penalty_last_n;
1381
1400
  const float penalty_repeat;
1382
1401
  const float penalty_freq;
1383
1402
  const float penalty_present;
1384
1403
 
1385
- const bool penalize_nl;
1386
- const bool ignore_eos;
1387
-
1388
1404
  ring_buffer<llama_token> prev;
1405
+
1406
+ // a frequency map to count token occurrences
1407
+ std::unordered_map<llama_token, int> token_count;
1389
1408
  };
1390
1409
 
1391
1410
  static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@@ -1398,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
1398
1417
  return;
1399
1418
  }
1400
1419
 
1401
- ctx->prev.push_back(token);
1402
- }
1403
-
1404
- static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1405
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1420
+ ctx->token_count[token]++;
1406
1421
 
1407
- if (ctx->ignore_eos) {
1408
- assert(ctx->special_eos_id >= 0);
1422
+ // if the ring buffer is full, remove the oldest token
1423
+ if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
1424
+ const auto old = ctx->prev.front();
1409
1425
 
1410
- // optimistically check if the candidates are not yet sorted/shuffled/truncated
1411
- if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
1412
- cur_p->data[ctx->special_eos_id].logit = -INFINITY;
1413
- } else {
1414
- // else, search for the special EOS token
1415
- for (size_t i = 0; i < cur_p->size; ++i) {
1416
- if (cur_p->data[i].id == ctx->special_eos_id) {
1417
- cur_p->data[i].logit = -INFINITY;
1418
- break;
1419
- }
1420
- }
1426
+ ctx->token_count[old]--;
1427
+ if (ctx->token_count[old] == 0) {
1428
+ ctx->token_count.erase(old);
1421
1429
  }
1422
1430
  }
1423
1431
 
1424
- if ((ctx->penalty_last_n == 0) ||
1425
- (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1426
- return;
1427
- }
1428
-
1429
- bool nl_found = false;
1430
- size_t nl_idx = 0;
1431
- float nl_logit = -INFINITY;
1432
- if (!ctx->penalize_nl) {
1433
- assert(ctx->linefeed_id >= 0);
1432
+ ctx->prev.push_back(token);
1434
1433
 
1435
- // optimistically check if the candidates are not yet sorted/shuffled/truncated
1436
- if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
1437
- nl_found = true;
1438
- nl_idx = ctx->linefeed_id;
1439
- nl_logit = cur_p->data[ctx->linefeed_id].logit;
1440
- } else {
1441
- // else, search for the linefeed token
1442
- for (size_t i = 0; i < cur_p->size; ++i) {
1443
- if (cur_p->data[i].id == ctx->linefeed_id) {
1444
- nl_found = true;
1445
- nl_idx = i;
1446
- nl_logit = cur_p->data[i].logit;
1447
- break;
1448
- }
1449
- }
1450
- }
1434
+ #if 0
1435
+ // sanity check
1436
+ std::unordered_map<llama_token, int> tmp;
1437
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1438
+ tmp[ctx->prev.rat(i)]++;
1451
1439
  }
1452
1440
 
1453
- // Create a frequency map to count occurrences of each token in last_tokens
1454
- // TODO: optimize this by maintaining the token count in the sampler context
1455
- using llama_token_cnt = std::unordered_map<llama_token, int>;
1456
- llama_token_cnt token_count;
1441
+ assert(ctx->token_count == tmp);
1442
+ #endif
1443
+ }
1444
+
1445
+ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1446
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1457
1447
 
1458
- for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1459
- token_count[ctx->prev.rat(i)]++;
1448
+ if ((ctx->penalty_last_n == 0) ||
1449
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1450
+ return;
1460
1451
  }
1461
1452
 
1462
1453
  // Apply frequency and presence penalties to the cur_p
1463
1454
  for (size_t i = 0; i < cur_p->size; ++i) {
1464
- const auto token_iter = token_count.find(cur_p->data[i].id);
1465
- if (token_iter == token_count.end()) {
1455
+ const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
1456
+ if (token_iter == ctx->token_count.end()) {
1466
1457
  continue;
1467
1458
  }
1468
1459
 
1469
1460
  const int count = token_iter->second;
1470
1461
 
1462
+ assert(count > 0 && count <= ctx->penalty_last_n);
1463
+
1471
1464
  // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1472
1465
  // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1473
1466
  if (cur_p->data[i].logit <= 0) {
@@ -1480,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
1480
1473
  }
1481
1474
 
1482
1475
  cur_p->sorted = false;
1483
-
1484
- if (!ctx->penalize_nl && nl_found) {
1485
- // restore the logit of the newline token if it was penalized
1486
- cur_p->data[nl_idx].logit = nl_logit;
1487
- }
1488
1476
  }
1489
1477
 
1490
1478
  static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1491
1479
  auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1492
1480
  ctx->prev.clear();
1481
+ ctx->token_count.clear();
1493
1482
  }
1494
1483
 
1495
1484
  static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1496
1485
  const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1497
1486
  auto * result = llama_sampler_init_penalties(
1498
- ctx->n_vocab,
1499
- ctx->special_eos_id,
1500
- ctx->linefeed_id,
1501
1487
  ctx->penalty_last_n,
1502
1488
  ctx->penalty_repeat,
1503
1489
  ctx->penalty_freq,
1504
- ctx->penalty_present,
1505
- ctx->penalize_nl,
1506
- ctx->ignore_eos);
1490
+ ctx->penalty_present);
1507
1491
 
1508
1492
  // copy the state
1509
1493
  {
@@ -1529,42 +1513,420 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
1529
1513
  };
1530
1514
 
1531
1515
  struct llama_sampler * llama_sampler_init_penalties(
1532
- int32_t n_vocab,
1533
- llama_token special_eos_id,
1534
- llama_token linefeed_id,
1535
1516
  int32_t penalty_last_n,
1536
1517
  float penalty_repeat,
1537
1518
  float penalty_freq,
1538
- float penalty_present,
1539
- bool penalize_nl,
1540
- bool ignore_eos) {
1541
- if (linefeed_id == LLAMA_TOKEN_NULL) {
1542
- penalize_nl = true;
1543
- }
1544
-
1545
- if (special_eos_id == LLAMA_TOKEN_NULL) {
1546
- ignore_eos = false;
1547
- }
1548
-
1519
+ float penalty_present) {
1549
1520
  penalty_last_n = std::max(penalty_last_n, 0);
1550
1521
 
1551
1522
  return new llama_sampler {
1552
1523
  /* .iface = */ &llama_sampler_penalties_i,
1553
1524
  /* .ctx = */ new llama_sampler_penalties {
1554
- /* .n_vocab = */ n_vocab,
1555
- /* .special_eos_id = */ special_eos_id,
1556
- /* .linefeed_id = */ linefeed_id,
1557
1525
  /* .penalty_last_n = */ penalty_last_n,
1558
1526
  /* .penalty_repeat = */ penalty_repeat,
1559
1527
  /* .penalty_freq = */ penalty_freq,
1560
1528
  /* .penalty_present = */ penalty_present,
1561
- /* .penalize_nl = */ penalize_nl,
1562
- /* .ignore_eos = */ ignore_eos,
1563
1529
  /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1530
+ /* .token_count = */ {},
1564
1531
  },
1565
1532
  };
1566
1533
  }
1567
1534
 
1535
+ // DRY
1536
+
1537
+ struct llama_sampler_dry {
1538
+ int32_t total_context_size;
1539
+
1540
+ const float dry_multiplier;
1541
+ const float dry_base;
1542
+ const int32_t dry_allowed_length;
1543
+ const int32_t dry_penalty_last_n;
1544
+
1545
+ std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
1546
+ std::vector<int> dry_repeat_count;
1547
+ std::unordered_map<llama_token, int> dry_max_token_repeat;
1548
+ ring_buffer<llama_token> last_tokens;
1549
+ };
1550
+
1551
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1552
+ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1553
+ for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
1554
+ std::string word = llama_detokenize(vocab, {token_id}, true);
1555
+ if (word.find(str) != std::string::npos) {
1556
+ token_sequences.emplace(token_id, std::vector<llama_token>());
1557
+ } else {
1558
+ size_t word_len = word.size();
1559
+ size_t str_len = str.size();
1560
+ size_t pos = -1;
1561
+ while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
1562
+ bool match = true;
1563
+ size_t i;
1564
+ for (i = 1; i < str_len && i + pos < word_len; ++i) {
1565
+ if (word[pos + i] != str[i]) {
1566
+ match = false;
1567
+ break;
1568
+ }
1569
+ }
1570
+ if (match) {
1571
+ std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
1572
+ if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1573
+ tokenization.resize(max_tail_len);
1574
+ }
1575
+
1576
+ // Ensure we don't already have a duplicate matching tokenization
1577
+ auto its = token_sequences.equal_range(token_id);
1578
+ bool found = false;
1579
+ for (auto it = its.first; it != its.second; ++it) {
1580
+ if (tokenization == it->second) {
1581
+ found = true;
1582
+ break;
1583
+ }
1584
+ }
1585
+ if (!found) {
1586
+ token_sequences.emplace(token_id, tokenization);
1587
+ }
1588
+ }
1589
+ }
1590
+ }
1591
+ }
1592
+ }
1593
+
1594
+ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
1595
+ return "dry";
1596
+ }
1597
+
1598
+ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
1599
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1600
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1601
+ return;
1602
+ }
1603
+
1604
+ ctx->last_tokens.push_back(token);
1605
+ }
1606
+
1607
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1608
+ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1609
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1610
+
1611
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1612
+ return;
1613
+ }
1614
+
1615
+ int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
1616
+ int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
1617
+
1618
+ if (last_n_repeat <= ctx->dry_allowed_length) {
1619
+ return;
1620
+ }
1621
+
1622
+ ctx->dry_repeat_count.assign(last_n_repeat, 0);
1623
+ ctx->dry_max_token_repeat.clear();
1624
+
1625
+ // Step 1: Look for restart sequences to limit the maximum repetition length.
1626
+ // Work backwards through the context looking for any token that begins a restart sequence.
1627
+ //
1628
+ // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
1629
+ // sequences that together comprise a restart sequence. This allows us to quickly check
1630
+ // whether each token is the head of a complete sequence. Most restart sequences are actually
1631
+ // a single token, and for these the "tail" is an empty vector.
1632
+ //
1633
+ // If the token is a "head", test all restart sequences that begin with this token
1634
+ // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
1635
+ // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
1636
+ // longest matching sequence (if any) is used to limit the maximum repetition length.
1637
+ //
1638
+ // Note that in the case case of a short sequence contained in a longer one, this might fail to
1639
+ // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
1640
+ // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
1641
+ // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
1642
+ //
1643
+ // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
1644
+ // have already clamped the maximum tail sequence length when generating `restart_sequences`.
1645
+ // With clamping, this scan is O(N) in the context length.
1646
+
1647
+ int rep_limit = last_n_repeat;
1648
+ for (int i = 0; i < last_n_repeat; ++i) {
1649
+ llama_token token = ctx->last_tokens.rat(i);
1650
+ auto its = ctx->dry_processed_breakers.equal_range(token);
1651
+ if (its.first == ctx->dry_processed_breakers.end()) {
1652
+ continue;
1653
+ }
1654
+ int longest_match = -1;
1655
+ for (auto it = its.first; it != its.second; ++it) {
1656
+ // Note that (*it) does not contain the head character, so seq_len will be
1657
+ // the restart sequence length minus 1.
1658
+ // In the common case of a single-token restart sequence, (*it) will be empty
1659
+ // and we will trivially match.
1660
+ int seq_len = (int)it->second.size();
1661
+ if (seq_len > longest_match && seq_len <= (int)i) {
1662
+ bool match = true;
1663
+ for (int offset = 0; offset < seq_len; ++offset) {
1664
+ // The -1 when indexing `last_tokens` is because we already matched the head.
1665
+ if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
1666
+ match = false;
1667
+ break;
1668
+ }
1669
+ }
1670
+ if (match) {
1671
+ longest_match = seq_len;
1672
+ }
1673
+ }
1674
+ }
1675
+ if (longest_match >= 0) {
1676
+ // We found a restart sequence starting `i` tokens from the end and continuing for
1677
+ // `longest_match` tokens.
1678
+ rep_limit = i - longest_match;
1679
+ break;
1680
+ }
1681
+ }
1682
+ if (rep_limit < ctx->dry_allowed_length) {
1683
+ return;
1684
+ }
1685
+
1686
+ // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
1687
+ // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
1688
+ // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
1689
+ //
1690
+ // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
1691
+ // https://ivanyu.me/blog/2014/10/15/z-algorithm/
1692
+ //
1693
+ // The code below is adapted from the public domain implementation by the same author here:
1694
+ // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
1695
+ //
1696
+ // Example:
1697
+ // Last N tokens: a b c c b c y a b c
1698
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1699
+ // ^
1700
+ // This `3` means that the last three tokens of the context (a b c) also appear here.
1701
+ //
1702
+ // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
1703
+ // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
1704
+ // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
1705
+ // ensure that the inner while loops only examine each token in the context once as the outer
1706
+ // for loop iterates over the context.
1707
+
1708
+ {
1709
+ const int last = last_n_repeat - 1;
1710
+ int rt = 0, lt = 0;
1711
+
1712
+ for (int k = 1; k < last_n_repeat; ++k) {
1713
+ if (k > rt) {
1714
+ // If k is outside the current Z-box, do naive computation.
1715
+ int n = 0;
1716
+ while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
1717
+ ++n;
1718
+ }
1719
+ ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1720
+ if (n > 0) {
1721
+ lt = k;
1722
+ rt = k+n-1;
1723
+ }
1724
+ } else {
1725
+ // If k is inside the current Z-box, consider two cases.
1726
+
1727
+ int p = k - lt; // Pair index.
1728
+ int right_part_len = rt - k + 1;
1729
+
1730
+ if (ctx->dry_repeat_count[last - p] < right_part_len) {
1731
+ int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
1732
+ ctx->dry_repeat_count[last - k] = n;
1733
+ } else {
1734
+ int i = rt + 1;
1735
+ while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
1736
+ i += 1;
1737
+ }
1738
+
1739
+ int n = std::min(i - k, rep_limit);
1740
+ ctx->dry_repeat_count[last - k] = n;
1741
+ lt = k;
1742
+ rt = i - 1;
1743
+ }
1744
+ }
1745
+ }
1746
+ }
1747
+
1748
+ // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
1749
+ // that would be generated by emitting each new token that would extend a sequence.
1750
+ //
1751
+ // Following the same example as above:
1752
+ // Last N tokens: a b c c b c y a b c
1753
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1754
+ //
1755
+ // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
1756
+ // c: 3 -> 4 (from `a b c` to `a b c c`)
1757
+ // b: 1 -> 2 (from `c` to `c b`)
1758
+ // y: 2 -> 3 (from `b c` to `b c y`)
1759
+
1760
+ for (int i = 0; i < last_n_repeat - 1; ++i) {
1761
+ int repeat_len = ctx->dry_repeat_count[i];
1762
+ if (repeat_len >= ctx->dry_allowed_length) {
1763
+ // This token ends a repeat, so the next token would continue one.
1764
+ // By convention, the value of `repeat_len` only includes the tokens currently
1765
+ // in the context, not the new token that would be added.
1766
+ llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
1767
+ // Track the maximum sequence ending in this token.
1768
+ const auto& it = ctx->dry_max_token_repeat.find(token);
1769
+ if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
1770
+ ctx->dry_max_token_repeat[token] = repeat_len;
1771
+ }
1772
+ }
1773
+ }
1774
+
1775
+ // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
1776
+
1777
+ // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
1778
+ // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
1779
+ const float FLOAT_MAX_LOG = 88.7228391f;
1780
+ int max_exponent = 0;
1781
+ if (ctx->dry_base > 1.000001f) {
1782
+ max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
1783
+ }
1784
+
1785
+ for (size_t i = 0; i < cur_p->size; ++i) {
1786
+ const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
1787
+ if (af_kvp != ctx->dry_max_token_repeat.end()) {
1788
+ // Check all sequence breakers starting with this token
1789
+ auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
1790
+ bool is_single_token_breaker = false;
1791
+
1792
+ for (auto it = range.first; it != range.second; ++it) {
1793
+ if (it->second.empty()) {
1794
+ is_single_token_breaker = true;
1795
+ break;
1796
+ }
1797
+ }
1798
+
1799
+ // Apply penalty only if it's not a single-token sequence breaker
1800
+ if (!is_single_token_breaker) {
1801
+ int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
1802
+ if (max_exponent > 0 && repeat_exp > max_exponent) {
1803
+ repeat_exp = max_exponent;
1804
+ }
1805
+ float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
1806
+ cur_p->data[i].logit -= penalty;
1807
+ }
1808
+ }
1809
+ }
1810
+
1811
+ cur_p->sorted = false;
1812
+ }
1813
+
1814
+ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
1815
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1816
+ ctx->last_tokens.clear();
1817
+ ctx->dry_repeat_count.clear();
1818
+ ctx->dry_max_token_repeat.clear();
1819
+ }
1820
+
1821
+ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
1822
+ const auto * ctx = (llama_sampler_dry *) smpl->ctx;
1823
+
1824
+ llama_vocab dummy_vocab;
1825
+
1826
+ // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
1827
+ auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1828
+
1829
+ // Copy the state, including the processed breakers
1830
+ {
1831
+ auto * result_ctx = (llama_sampler_dry *) result->ctx;
1832
+ result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
1833
+ result_ctx->dry_repeat_count = ctx->dry_repeat_count;
1834
+ result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
1835
+ result_ctx->last_tokens = ctx->last_tokens;
1836
+ }
1837
+
1838
+ return result;
1839
+ }
1840
+
1841
+ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
1842
+ delete (llama_sampler_dry *) smpl->ctx;
1843
+ }
1844
+
1845
+ static struct llama_sampler_i llama_sampler_dry_i = {
1846
+ /* .name = */ llama_sampler_dry_name,
1847
+ /* .accept = */ llama_sampler_dry_accept,
1848
+ /* .apply = */ llama_sampler_dry_apply,
1849
+ /* .reset = */ llama_sampler_dry_reset,
1850
+ /* .clone = */ llama_sampler_dry_clone,
1851
+ /* .free = */ llama_sampler_dry_free,
1852
+ };
1853
+
1854
+ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
1855
+ int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
1856
+ std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
1857
+ const int MAX_CHAR_LEN = 40;
1858
+ const int MAX_SEQ_LEN = 20;
1859
+
1860
+ const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
1861
+
1862
+ if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
1863
+ // Process sequence breakers
1864
+ for (size_t i = 0; i < num_breakers; ++i) {
1865
+ if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
1866
+ LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
1867
+ continue;
1868
+ }
1869
+
1870
+ std::string sequence_break(seq_breakers[i]);
1871
+ if (sequence_break.empty()) {
1872
+ LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
1873
+ continue;
1874
+ }
1875
+
1876
+ if (sequence_break.size() > MAX_CHAR_LEN) {
1877
+ LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
1878
+ sequence_break.resize(MAX_CHAR_LEN);
1879
+ }
1880
+
1881
+ get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
1882
+ }
1883
+ }
1884
+
1885
+ return new llama_sampler {
1886
+ /* .iface = */ &llama_sampler_dry_i,
1887
+ /* .ctx = */ new llama_sampler_dry {
1888
+ /* .total_context_size = */ context_size,
1889
+ /* .dry_multiplier = */ dry_multiplier,
1890
+ /* .dry_base = */ dry_base,
1891
+ /* .dry_allowed_length = */ dry_allowed_length,
1892
+ /* .dry_penalty_last_n = */ dry_penalty_last_n,
1893
+ /* .dry_processed_breakers = */ std::move(processed_breakers),
1894
+ /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
1895
+ /* .dry_max_token_repeat = */ {},
1896
+ /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
1897
+ },
1898
+ };
1899
+ }
1900
+
1901
+ // wrapper for test-sampling.cpp
1902
+ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
1903
+ llama_vocab dummy_vocab;
1904
+ auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
1905
+ auto * ctx = (llama_sampler_dry *) result->ctx;
1906
+
1907
+ // Process the token-based sequence breakers
1908
+ ctx->dry_processed_breakers.clear();
1909
+ if (seq_breakers.empty()) {
1910
+ LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
1911
+ } else {
1912
+ for (const auto& breaker : seq_breakers) {
1913
+ if (breaker.empty()) {
1914
+ LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
1915
+ continue;
1916
+ }
1917
+ llama_token head_token = breaker[0];
1918
+ std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
1919
+ ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
1920
+ }
1921
+
1922
+ if (ctx->dry_processed_breakers.empty()) {
1923
+ LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
1924
+ }
1925
+ }
1926
+
1927
+ return result;
1928
+ }
1929
+
1568
1930
  // logit-bias
1569
1931
 
1570
1932
  struct llama_sampler_logit_bias {
@@ -1644,6 +2006,229 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1644
2006
  };
1645
2007
  }
1646
2008
 
2009
+ // infill
2010
+
2011
+ //#define GGML_DEBUG_SAMPLER_INFILL
2012
+
2013
+ struct llama_sampler_infill {
2014
+ const struct llama_vocab * vocab;
2015
+
2016
+ std::vector<char> buf0;
2017
+ std::vector<char> buf1;
2018
+ };
2019
+
2020
+ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
2021
+ return "infill";
2022
+ }
2023
+
2024
+ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2025
+ auto * ctx = (llama_sampler_infill *) smpl->ctx;
2026
+
2027
+ llama_sampler_softmax_impl(cur_p);
2028
+
2029
+ #if defined(GGML_DEBUG_SAMPLER_INFILL)
2030
+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
2031
+ #else
2032
+ #define LOG_DBG_CUR(...)
2033
+ #endif
2034
+
2035
+ for (size_t i = 0; i < cur_p->size; ++i) {
2036
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2037
+ }
2038
+
2039
+ float p_txt_sum = 0.0f;
2040
+ float p_eog_sum = 0.0f;
2041
+
2042
+ for (size_t i = 0; i < cur_p->size; ++i) {
2043
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2044
+ p_eog_sum += cur_p->data[i].p;
2045
+ } else {
2046
+ p_txt_sum += cur_p->data[i].p;
2047
+ }
2048
+ }
2049
+
2050
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
2051
+
2052
+ LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
2053
+
2054
+ if (3*p_eog_sum*cur_p->size > p_txt_sum) {
2055
+ LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
2056
+
2057
+ // keep just the EOG tokens
2058
+ const auto size_org = cur_p->size;
2059
+
2060
+ cur_p->size = 0;
2061
+
2062
+ float p_sum = 0.0f;
2063
+
2064
+ for (size_t i = 0; i < size_org; ++i) {
2065
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2066
+ p_sum += cur_p->data[i].p;
2067
+
2068
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2069
+ }
2070
+ }
2071
+
2072
+ // normalize probs
2073
+ for (size_t i = 0; i < cur_p->size; ++i) {
2074
+ cur_p->data[i].p /= p_sum;
2075
+ }
2076
+
2077
+ return;
2078
+ }
2079
+
2080
+ size_t n_combined = 0; GGML_UNUSED(n_combined);
2081
+
2082
+ // combine tokens with common prefix
2083
+ for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
2084
+ for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
2085
+ if (cur_p->data[i0].logit == -INFINITY) {
2086
+ break;
2087
+ }
2088
+
2089
+ if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
2090
+ continue;
2091
+ }
2092
+
2093
+ int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2094
+ if (len0 < 0) {
2095
+ ctx->buf0.resize(len0);
2096
+ len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2097
+ assert(len0 > 0);
2098
+ }
2099
+
2100
+ int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2101
+ if (len1 < 0) {
2102
+ ctx->buf1.resize(len1);
2103
+ len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2104
+ assert(len1 > 0);
2105
+ }
2106
+
2107
+ // token i0 is a prefix of token i1
2108
+ if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
2109
+ int dst = i0;
2110
+ int src = i1;
2111
+
2112
+ // merge into the token with higher probability
2113
+ if (cur_p->data[i1].p > cur_p->data[i0].p) {
2114
+ std::swap(dst, src);
2115
+ }
2116
+
2117
+ cur_p->data[dst].p += cur_p->data[src].p;
2118
+ cur_p->data[src].logit = -INFINITY;
2119
+ cur_p->data[src].p = 0.0f;
2120
+
2121
+ n_combined++;
2122
+ }
2123
+ }
2124
+ }
2125
+
2126
+ size_t n_non_eog = 0;
2127
+
2128
+ size_t size_org = cur_p->size;
2129
+
2130
+ float p_sum = 0.0f;
2131
+ float thold = 0.2f;
2132
+
2133
+ cur_p->size = 0;
2134
+
2135
+ LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2136
+
2137
+ for (size_t i = 0; i < size_org; ++i) {
2138
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2139
+
2140
+ if (cur_p->data[i].p < thold && !is_eog) {
2141
+ continue;
2142
+ }
2143
+
2144
+ if (!is_eog) {
2145
+ ++n_non_eog;
2146
+ }
2147
+
2148
+ p_sum += cur_p->data[i].p;
2149
+
2150
+ // keep this token
2151
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2152
+ }
2153
+
2154
+ LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2155
+
2156
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2157
+ if (n_non_eog == 0) {
2158
+ cur_p->size = 1;
2159
+ cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
2160
+ cur_p->data[0].logit = 1.0f;
2161
+
2162
+ return;
2163
+ }
2164
+
2165
+ // normalize probs
2166
+ for (size_t i = 0; i < cur_p->size; ++i) {
2167
+ cur_p->data[i].p /= p_sum;
2168
+
2169
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2170
+ }
2171
+
2172
+ size_org = cur_p->size;
2173
+ p_sum = 0.0f;
2174
+ thold = 1.0/(n_non_eog + 1);
2175
+
2176
+ cur_p->size = 0;
2177
+
2178
+ LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2179
+
2180
+ for (size_t i = 0; i < size_org; ++i) {
2181
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2182
+
2183
+ if (cur_p->data[i].p < thold && !is_eog) {
2184
+ continue;
2185
+ }
2186
+
2187
+ p_sum += cur_p->data[i].p;
2188
+
2189
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2190
+ }
2191
+
2192
+ // normalize probs
2193
+ for (size_t i = 0; i < cur_p->size; ++i) {
2194
+ cur_p->data[i].p /= p_sum;
2195
+
2196
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2197
+ }
2198
+
2199
+ #undef LOG_DBG_CUR
2200
+ }
2201
+
2202
+ static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2203
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2204
+ return llama_sampler_init_infill_impl(*ctx->vocab);
2205
+ }
2206
+
2207
+ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2208
+ delete (llama_sampler_infill *) smpl->ctx;
2209
+ }
2210
+
2211
+ static struct llama_sampler_i llama_sampler_infill_i = {
2212
+ /* .name = */ llama_sampler_infill_name,
2213
+ /* .accept = */ nullptr,
2214
+ /* .apply = */ llama_sampler_infill_apply,
2215
+ /* .reset = */ nullptr,
2216
+ /* .clone = */ llama_sampler_infill_clone,
2217
+ /* .free = */ llama_sampler_infill_free,
2218
+ };
2219
+
2220
+ struct llama_sampler * llama_sampler_init_infill_impl(
2221
+ const struct llama_vocab & vocab) {
2222
+ return new llama_sampler {
2223
+ /* .iface = */ &llama_sampler_infill_i,
2224
+ /* .ctx = */ new llama_sampler_infill {
2225
+ /* .vocab = */ &vocab,
2226
+ /* .buf0 = */ std::vector<char>(512),
2227
+ /* .buf1 = */ std::vector<char>(512),
2228
+ },
2229
+ };
2230
+ }
2231
+
1647
2232
  // utils
1648
2233
 
1649
2234
  uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {