@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -10,181 +10,208 @@
10
10
  #include <string>
11
11
  #include <vector>
12
12
 
13
- static void dump(const llama_token_data_array * candidates) {
14
- for (size_t i = 0; i < candidates->size; i++) {
15
- printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
13
+ extern struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers);
14
+
15
+ static void dump(const llama_token_data_array * cur_p) {
16
+ for (size_t i = 0; i < cur_p->size; i++) {
17
+ printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
16
18
  }
17
19
  }
18
20
 
19
- #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
21
+ #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
20
22
 
21
- static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
22
- const size_t n_vocab = probs.size();
23
- std::vector<llama_token_data> candidates;
24
- candidates.reserve(n_vocab);
25
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
26
- const float logit = logf(probs[token_id]);
27
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
23
+ struct sampler_tester {
24
+ sampler_tester(size_t n_vocab) {
25
+ cur.reserve(n_vocab);
26
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
27
+ const float logit = logf(token_id);
28
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
29
+ }
30
+
31
+ cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
28
32
  }
29
33
 
30
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
31
- llama_sample_softmax(nullptr, &candidates_p);
32
- DUMP(&candidates_p);
33
- llama_sample_top_k(nullptr, &candidates_p, k, 1);
34
- DUMP(&candidates_p);
34
+ sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
35
+ cur.reserve(probs.size());
36
+ for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
37
+ const float logit = logf(probs[token_id]);
38
+ cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
39
+ }
35
40
 
36
- GGML_ASSERT(candidates_p.size == expected_probs.size());
37
- for (size_t i = 0; i < candidates_p.size; i++) {
38
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
41
+ cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
39
42
  }
40
- }
41
43
 
42
- static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
43
- const size_t n_vocab = probs.size();
44
- std::vector<llama_token_data> candidates;
45
- candidates.reserve(n_vocab);
46
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
47
- const float logit = logf(probs[token_id]);
48
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
44
+ void apply(llama_sampler * sampler) {
45
+ llama_sampler_apply(sampler, &cur_p);
46
+ llama_sampler_free(sampler);
49
47
  }
50
48
 
51
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
52
- llama_sample_softmax(nullptr, &candidates_p);
53
- DUMP(&candidates_p);
54
- llama_sample_top_p(nullptr, &candidates_p, p, 1);
55
- DUMP(&candidates_p);
56
-
57
- GGML_ASSERT(candidates_p.size == expected_probs.size());
58
- for (size_t i = 0; i < candidates_p.size; i++) {
59
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
49
+ void check() {
50
+ GGML_ASSERT(cur_p.size == probs_expected.size());
51
+ for (size_t i = 0; i < cur_p.size; i++) {
52
+ GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
53
+ }
60
54
  }
55
+
56
+ llama_token_data_array cur_p;
57
+
58
+ private:
59
+ const std::vector<float> probs_expected;
60
+
61
+ std::vector<llama_token_data> cur;
62
+ };
63
+
64
+ static void test_temp(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp) {
65
+ sampler_tester tester(probs, probs_expected);
66
+
67
+ DUMP(&tester.cur_p);
68
+ tester.apply(llama_sampler_init_temp(temp));
69
+ tester.apply(llama_sampler_init_dist(0));
70
+ DUMP(&tester.cur_p);
71
+
72
+ tester.check();
61
73
  }
62
74
 
63
- static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
64
- const size_t n_vocab = probs.size();
65
- std::vector<llama_token_data> candidates;
66
- candidates.reserve(n_vocab);
67
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
68
- const float logit = logf(probs[token_id]);
69
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
70
- }
75
+ static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
76
+ sampler_tester tester(probs, probs_expected);
71
77
 
72
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
73
- DUMP(&candidates_p);
74
- llama_sample_tail_free(nullptr, &candidates_p, z, 1);
75
- DUMP(&candidates_p);
78
+ DUMP(&tester.cur_p);
79
+ tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
80
+ tester.apply(llama_sampler_init_dist (0));
81
+ DUMP(&tester.cur_p);
76
82
 
77
- GGML_ASSERT(candidates_p.size == expected_probs.size());
78
- for (size_t i = 0; i < candidates_p.size; i++) {
79
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
80
- }
83
+ tester.check();
81
84
  }
82
85
 
83
- static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
84
- const size_t n_vocab = probs.size();
85
- std::vector<llama_token_data> candidates;
86
- candidates.reserve(n_vocab);
87
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
88
- const float logit = logf(probs[token_id]);
89
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
90
- }
86
+ static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
87
+ sampler_tester tester(probs, probs_expected);
91
88
 
92
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
93
- DUMP(&candidates_p);
94
- llama_sample_min_p(nullptr, &candidates_p, p, 1);
95
- DUMP(&candidates_p);
96
- llama_sample_softmax(nullptr, &candidates_p);
89
+ DUMP(&tester.cur_p);
90
+ tester.apply(llama_sampler_init_top_k(k));
91
+ tester.apply(llama_sampler_init_dist (0));
92
+ DUMP(&tester.cur_p);
97
93
 
98
- GGML_ASSERT(candidates_p.size == expected_probs.size());
99
- for (size_t i = 0; i < candidates_p.size; i++) {
100
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
101
- }
94
+ tester.check();
102
95
  }
103
96
 
104
- static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
105
- const size_t n_vocab = probs.size();
106
- std::vector<llama_token_data> candidates;
107
- candidates.reserve(n_vocab);
108
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
109
- const float logit = logf(probs[token_id]);
110
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
111
- }
97
+ static void test_top_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
98
+ sampler_tester tester(probs, probs_expected);
112
99
 
113
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
114
- DUMP(&candidates_p);
115
- llama_sample_typical(nullptr, &candidates_p, p, 1);
116
- DUMP(&candidates_p);
100
+ DUMP(&tester.cur_p);
101
+ tester.apply(llama_sampler_init_top_p(p, 1));
102
+ tester.apply(llama_sampler_init_dist (0));
103
+ DUMP(&tester.cur_p);
117
104
 
118
- GGML_ASSERT(candidates_p.size == expected_probs.size());
119
- for (size_t i = 0; i < candidates_p.size; i++) {
120
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
121
- }
105
+ tester.check();
106
+ }
107
+
108
+ static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
109
+ sampler_tester tester(probs, probs_expected);
110
+
111
+ DUMP(&tester.cur_p);
112
+ tester.apply(llama_sampler_init_min_p(p, 1));
113
+ tester.apply(llama_sampler_init_dist (0));
114
+ DUMP(&tester.cur_p);
115
+
116
+ tester.check();
117
+ }
118
+
119
+ static void test_xtc(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p, float t) {
120
+ sampler_tester tester(probs, probs_expected);
121
+
122
+ DUMP(&tester.cur_p);
123
+ tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
124
+ DUMP(&tester.cur_p);
125
+
126
+ tester.check();
122
127
  }
123
128
 
124
- static void test_repetition_penalties(
129
+ static void test_typical(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
130
+ sampler_tester tester(probs, probs_expected);
131
+
132
+ DUMP(&tester.cur_p);
133
+ tester.apply(llama_sampler_init_typical(p, 1));
134
+ DUMP(&tester.cur_p);
135
+
136
+ tester.check();
137
+ }
138
+
139
+ static void test_penalties(
125
140
  const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
126
- const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
141
+ const std::vector<float> & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
127
142
  ) {
128
- GGML_ASSERT(probs.size() == expected_probs.size());
143
+ GGML_ASSERT(probs.size() == probs_expected.size());
144
+
145
+ sampler_tester tester(probs, probs_expected);
129
146
 
130
147
  const size_t n_vocab = probs.size();
131
- std::vector<llama_token_data> candidates;
132
- candidates.reserve(n_vocab);
133
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
134
- const float logit = logf(probs[token_id]);
135
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
148
+ auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
149
+
150
+ for (size_t i = 0; i < last_tokens.size(); i++) {
151
+ llama_sampler_accept(sampler, last_tokens[i]);
136
152
  }
137
153
 
138
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
139
- llama_sample_softmax(nullptr, &candidates_p);
140
- DUMP(&candidates_p);
141
- llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
142
- llama_sample_softmax(nullptr, &candidates_p);
143
- DUMP(&candidates_p);
154
+ DUMP(&tester.cur_p);
155
+ tester.apply(sampler);
156
+ tester.apply(llama_sampler_init_dist(0));
157
+ DUMP(&tester.cur_p);
144
158
 
145
- GGML_ASSERT(candidates_p.size == expected_probs.size());
146
- for (size_t i = 0; i < candidates_p.size; i++) {
147
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
148
- }
159
+ tester.check();
149
160
  }
150
161
 
151
- static void test_sampler_queue(
152
- const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
162
+ static void test_dry(
163
+ const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
164
+ const std::vector<float> & expected_probs, float dry_multiplier, float dry_base,
165
+ int dry_allowed_length, int dry_penalty_last_n,
166
+ const std::vector<std::vector<llama_token>> & seq_breakers
153
167
  ) {
154
- std::vector<llama_token_data> candidates;
155
- candidates.reserve(n_vocab);
156
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
157
- const float logit = logf(token_id);
158
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
168
+ GGML_ASSERT(probs.size() == expected_probs.size());
169
+
170
+ sampler_tester tester(probs, expected_probs);
171
+
172
+ auto * sampler = llama_sampler_init_dry_testing(1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
173
+
174
+ for (size_t i = 0; i < last_tokens.size(); i++) {
175
+ llama_sampler_accept(sampler, last_tokens[i]);
159
176
  }
160
177
 
161
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
178
+ DUMP(&tester.cur_p);
179
+ tester.apply(sampler);
180
+ tester.apply(llama_sampler_init_dist(0));
181
+ DUMP(&tester.cur_p);
182
+ tester.check();
183
+ }
184
+
185
+ static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
186
+ ) {
187
+ sampler_tester tester(n_vocab);
162
188
 
163
189
  llama_token min_token_id = 0;
164
190
  const llama_token max_token_id = n_vocab-1;
165
191
 
166
192
  for (auto s : samplers_sequence) {
167
193
  switch (s){
168
- case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
169
- case 'f': GGML_ABORT("tail_free test not implemented"); break;
170
- case 'y': GGML_ABORT("typical test not implemented"); break;
171
- case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
172
- case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
173
- case 't': GGML_ABORT("temperature test not implemented"); break;
174
- default : GGML_ABORT("Unknown sampler"); break;
194
+ case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
195
+ case 'y': GGML_ABORT("typical test not implemented");
196
+ case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
197
+ case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
198
+ case 't': GGML_ABORT("temperature test not implemented");
199
+ default : GGML_ABORT("Unknown sampler");
175
200
  }
176
201
 
177
- llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
202
+ tester.apply(llama_sampler_init_dist(0));
203
+
204
+ auto & cur_p = tester.cur_p;
178
205
 
179
- const int size = candidates_p.size;
206
+ const int size = cur_p.size;
180
207
 
181
208
  if (s == 'k') {
182
209
  const int expected_size = std::min(size, top_k);
183
210
  min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k));
184
211
 
185
212
  GGML_ASSERT(size == expected_size);
186
- GGML_ASSERT(candidates_p.data[0].id == max_token_id);
187
- GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
213
+ GGML_ASSERT(cur_p.data[0].id == max_token_id);
214
+ GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
188
215
  } else if (s == 'p') {
189
216
  const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2;
190
217
  const int softmax_numerator_target = ceilf(top_p * softmax_divisor);
@@ -206,8 +233,8 @@ static void test_sampler_queue(
206
233
  }
207
234
 
208
235
  GGML_ASSERT(size == expected_size);
209
- GGML_ASSERT(candidates_p.data[0].id == max_token_id);
210
- GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
236
+ GGML_ASSERT(cur_p.data[0].id == max_token_id);
237
+ GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
211
238
  } else if (s == 'm') {
212
239
  int expected_size = ceilf((1.0f-min_p) * n_vocab);
213
240
  expected_size = std::max(expected_size, 1);
@@ -219,29 +246,73 @@ static void test_sampler_queue(
219
246
  min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
220
247
 
221
248
  GGML_ASSERT(size == expected_size);
222
- GGML_ASSERT(candidates_p.data[0].id == max_token_id);
223
- GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
249
+ GGML_ASSERT(cur_p.data[0].id == max_token_id);
250
+ GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
224
251
  } else {
225
252
  GGML_ABORT("fatal error");
226
253
  }
227
254
  }
228
255
 
229
- printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
256
+ printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
230
257
  samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
231
258
  }
232
259
 
260
+ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
261
+ std::vector<llama_token_data> cur(data.size());
262
+ std::copy(data.begin(), data.end(), cur.begin());
263
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
264
+ llama_sampler_apply(cnstr, &cur_p);
265
+ llama_sampler_reset(cnstr);
266
+ const int64_t t_start = ggml_time_us();
267
+ for (int i = 0; i < n_iter; i++) {
268
+ std::copy(data.begin(), data.end(), cur.begin());
269
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
270
+ llama_sampler_apply(cnstr, &cur_p);
271
+ llama_sampler_reset(cnstr);
272
+ }
273
+ const int64_t t_end = ggml_time_us();
274
+ llama_sampler_free(cnstr);
275
+ printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
276
+ }
277
+
278
+ #define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
279
+
280
+ static void test_perf() {
281
+ const int n_vocab = 1 << 17;
282
+
283
+ std::vector<llama_token_data> data;
284
+
285
+ data.reserve(n_vocab);
286
+ for (int i = 0; i < n_vocab; i++) {
287
+ const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f);
288
+ data.emplace_back(llama_token_data{i, logit, 0.0f});
289
+ }
290
+
291
+ BENCH(llama_sampler_init_top_k (40), data, 32);
292
+ BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
293
+ BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
294
+ BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
295
+ BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
296
+ }
297
+
233
298
  int main(void) {
234
299
  ggml_time_init();
235
300
 
236
- test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
237
- test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
301
+ test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
302
+ test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
303
+
304
+ test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
305
+ test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
306
+
307
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
308
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
238
309
  test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
239
310
  test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
240
311
 
241
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
242
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
243
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
244
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
312
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
313
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
314
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
315
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
245
316
 
246
317
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
247
318
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
@@ -252,20 +323,31 @@ int main(void) {
252
323
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
253
324
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
254
325
 
255
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
256
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
257
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
326
+ printf("XTC should:\n");
327
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
328
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.19f);
329
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.29f);
330
+
331
+ printf("XTC should not:\n");
332
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
258
333
 
259
334
  test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
260
335
  test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
261
336
 
262
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
263
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
264
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
337
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
338
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
339
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
340
+
341
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
342
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
343
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
265
344
 
266
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
267
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
268
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
345
+
346
+ test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
347
+ test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
348
+ test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
349
+ test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
350
+ test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
269
351
 
270
352
  test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
271
353
  test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
@@ -297,5 +379,7 @@ int main(void) {
297
379
 
298
380
  printf("OK\n");
299
381
 
382
+ test_perf();
383
+
300
384
  return 0;
301
385
  }
@@ -7,6 +7,7 @@
7
7
  #include <map>
8
8
  #include <vector>
9
9
  #include <fstream>
10
+ #include <thread>
10
11
 
11
12
  //static const std::map<std::string, std::vector<llama_token>> & k_tests() {
12
13
  // static std::map<std::string, std::vector<llama_token>> _k_tests = {
@@ -194,45 +195,64 @@ int main(int argc, char **argv) {
194
195
 
195
196
  const bool add_special = false;
196
197
 
197
- for (const auto & test_kv : k_tests) {
198
- const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
199
-
200
- printf("\n");
201
- printf("src: '%s'\n", test_kv.first.c_str());
202
- printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
203
- printf("tok: ");
204
- for (const auto & tok : res) {
205
- printf("%d ", tok);
206
- }
207
- printf("\n");
208
-
209
- bool correct = res.size() == test_kv.second.size();
210
- for (int i = 0; i < (int) res.size() && correct; ++i) {
211
- if (test_kv.second[i] != res[i]) {
212
- correct = false;
198
+ // multi-threaded tokenization
199
+ const int nthread = std::thread::hardware_concurrency();
200
+ std::vector<std::thread> threads(nthread);
201
+
202
+ for (int i = 0; i < nthread; i++) {
203
+ threads[i] = std::thread([&, i]() {
204
+ for (const auto & test_kv : k_tests) {
205
+ const std::vector<llama_token> res = common_tokenize(ctx, test_kv.first, add_special, false);
206
+
207
+ // here only print the result of the first thread
208
+ // because the other threads are running the same tests
209
+ if (i != 0) {
210
+ continue;
211
+ }
212
+
213
+ printf("\n");
214
+ printf("src: '%s'\n", test_kv.first.c_str());
215
+ printf("res: '%s'\n", common_detokenize(ctx, res).c_str());
216
+ printf("tok: ");
217
+ for (const auto & tok : res) {
218
+ printf("%d ", tok);
219
+ }
220
+ printf("\n");
221
+
222
+ bool correct = res.size() == test_kv.second.size();
223
+ for (int i = 0; i < (int) res.size() && correct; ++i) {
224
+ if (test_kv.second[i] != res[i]) {
225
+ correct = false;
226
+ }
227
+ }
228
+
229
+ if (!correct) {
230
+ fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
231
+ fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
232
+ common_detokenize(ctx, res).c_str(),
233
+ common_detokenize(ctx, test_kv.second).c_str());
234
+ fprintf(stderr, "%s : expected tokens: ", __func__);
235
+ for (const auto & t : test_kv.second) {
236
+ fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
237
+ }
238
+ fprintf(stderr, "\n");
239
+ fprintf(stderr, "%s : got tokens: ", __func__);
240
+ for (const auto & t : res) {
241
+ fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
242
+ }
243
+ fprintf(stderr, "\n");
244
+
245
+ success = false;
246
+ }
213
247
  }
214
- }
215
-
216
- if (!correct) {
217
- fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
218
- fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
219
- llama_detokenize(ctx, res).c_str(),
220
- llama_detokenize(ctx, test_kv.second).c_str());
221
- fprintf(stderr, "%s : expected tokens: ", __func__);
222
- for (const auto & t : test_kv.second) {
223
- fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
224
- }
225
- fprintf(stderr, "\n");
226
- fprintf(stderr, "%s : got tokens: ", __func__);
227
- for (const auto & t : res) {
228
- fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
229
- }
230
- fprintf(stderr, "\n");
248
+ });
249
+ }
231
250
 
232
- success = false;
233
- }
251
+ for (int i = 0; i < nthread; i++) {
252
+ threads[i].join();
234
253
  }
235
254
 
255
+ // single threaded tokenization
236
256
  if (!fname_text.empty()) {
237
257
  fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
238
258
 
@@ -253,7 +273,7 @@ int main(int argc, char **argv) {
253
273
  {
254
274
  const auto t_start = ggml_time_us();
255
275
 
256
- res = llama_tokenize(ctx, text, add_special, false);
276
+ res = common_tokenize(ctx, text, add_special, false);
257
277
 
258
278
  const auto t_end = ggml_time_us();
259
279
 
@@ -78,10 +78,10 @@ int main(int argc, char **argv) {
78
78
  const int n_vocab = llama_n_vocab(model);
79
79
 
80
80
  for (int i = 0; i < n_vocab; ++i) {
81
- std::string str = llama_detokenize(ctx, std::vector<int>(1, i));
81
+ std::string str = common_detokenize(ctx, std::vector<int>(1, i));
82
82
  try {
83
83
  auto cps = unicode_cpts_from_utf8(str);
84
- std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
84
+ std::vector<llama_token> tokens = common_tokenize(ctx, str, false, true);
85
85
  if (ignore_merges && tokens.size() > 1) {
86
86
  fprintf(stderr,
87
87
  "%s : error: token %d detokenizes to '%s'(%zu) but "
@@ -94,7 +94,7 @@ int main(int argc, char **argv) {
94
94
  fprintf(stderr, "]\n");
95
95
  return 2;
96
96
  }
97
- std::string check = llama_detokenize(ctx, tokens);
97
+ std::string check = common_detokenize(ctx, tokens);
98
98
  if (check != str) {
99
99
  fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
100
100
  __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -123,8 +123,8 @@ int main(int argc, char **argv) {
123
123
  }
124
124
 
125
125
  std::string str = unicode_cpt_to_utf8(cp);
126
- std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
127
- std::string check = llama_detokenize(ctx, tokens);
126
+ std::vector<llama_token> tokens = common_tokenize(ctx, str, false);
127
+ std::string check = common_detokenize(ctx, tokens);
128
128
  if (cp != 9601 && str != check) {
129
129
  fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
130
130
  cp, check.c_str(), check.length(), str.c_str(), str.length());
@@ -66,9 +66,9 @@ int main(int argc, char ** argv) {
66
66
  const int n_vocab = llama_n_vocab(model);
67
67
 
68
68
  for (int i = 0; i < n_vocab; ++i) {
69
- std::string str = llama_detokenize(ctx, std::vector<int>(1, i), true);
70
- std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
71
- std::string check = llama_detokenize(ctx, tokens);
69
+ std::string str = common_detokenize(ctx, std::vector<int>(1, i), true);
70
+ std::vector<llama_token> tokens = common_tokenize(ctx, str, false, true);
71
+ std::string check = common_detokenize(ctx, tokens);
72
72
  if (check != str) {
73
73
  fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
74
74
  __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
93
93
  }
94
94
 
95
95
  std::string str = unicode_cpt_to_utf8(cp);
96
- std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
97
- std::string check = llama_detokenize(ctx, tokens);
96
+ std::vector<llama_token> tokens = common_tokenize(ctx, str, false, true);
97
+ std::string check = common_detokenize(ctx, tokens);
98
98
  if (cp != 9601 && str != check) {
99
99
  fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
100
100
  cp, check.c_str(), check.length(), str.c_str(), str.length());