@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -16,20 +16,6 @@
16
16
  // helpers
17
17
  //
18
18
 
19
- static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
20
- std::string result;
21
- for (size_t pos = 0; ; pos += search.length()) {
22
- auto new_pos = s.find(search, pos);
23
- if (new_pos == std::string::npos) {
24
- result += s.substr(pos, s.size() - pos);
25
- break;
26
- }
27
- result += s.substr(pos, new_pos - pos) + replace;
28
- pos = new_pos;
29
- }
30
- s = std::move(result);
31
- }
32
-
33
19
  LLAMA_ATTRIBUTE_FORMAT(1, 2)
34
20
  static std::string format(const char * fmt, ...) {
35
21
  va_list ap;
@@ -64,7 +50,7 @@ struct naive_trie {
64
50
  res.first->second.insert(key + 1, len - 1, value);
65
51
  }
66
52
  }
67
- std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
53
+ std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
68
54
  if (len == 0 || offset == len) {
69
55
  return std::make_pair(key, offset);
70
56
  }
@@ -72,17 +58,17 @@ struct naive_trie {
72
58
  auto res = children.find(c);
73
59
  if (res != children.end()) {
74
60
  return res->second.get_longest_prefix(key, len, offset + 1);
75
- } else {
76
- return std::make_pair(key, offset);
77
61
  }
62
+
63
+ return std::make_pair(key, offset);
78
64
  }
79
- struct naive_trie * traverse(const char c) {
65
+ const struct naive_trie * traverse(const char c) const {
80
66
  auto res = children.find(c);
81
67
  if (res != children.end()) {
82
68
  return &res->second;
83
- } else {
84
- return NULL;
85
69
  }
70
+
71
+ return NULL;
86
72
  }
87
73
  std::map<char, struct naive_trie> children;
88
74
  bool has_value;
@@ -93,6 +79,15 @@ struct naive_trie {
93
79
  // impl
94
80
  //
95
81
 
82
+ struct llm_tokenizer {
83
+ llm_tokenizer() {}
84
+ virtual ~llm_tokenizer() = default;
85
+ };
86
+
87
+ llama_vocab::~llama_vocab() {
88
+ delete tokenizer;
89
+ }
90
+
96
91
  int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
97
92
  GGML_ASSERT(token_left.find(' ') == std::string::npos);
98
93
  GGML_ASSERT(token_left.find('\n') == std::string::npos);
@@ -201,10 +196,15 @@ struct llm_bigram_spm {
201
196
  size_t size;
202
197
  };
203
198
 
204
- struct llm_tokenizer_spm {
205
- llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
199
+ struct llm_tokenizer_spm : llm_tokenizer {
200
+ llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
201
+ };
202
+
203
+ struct llm_tokenizer_spm_session {
204
+ llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
206
205
 
207
206
  void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
207
+
208
208
  // split string into utf8 chars
209
209
  int index = 0;
210
210
  size_t offs = 0;
@@ -221,7 +221,7 @@ struct llm_tokenizer_spm {
221
221
  }
222
222
 
223
223
  // seed the work queue with all possible 2-character tokens.
224
- for (size_t i = 1; i < symbols.size(); ++i) {
224
+ for (int i = 1; i < (int) symbols.size(); ++i) {
225
225
  try_add_bigram(i - 1, i);
226
226
  }
227
227
 
@@ -285,7 +285,7 @@ private:
285
285
  return;
286
286
  }
287
287
 
288
- resegment(symbols[p->second.first], output);
288
+ resegment(symbols[p->second.first], output);
289
289
  resegment(symbols[p->second.second], output);
290
290
  }
291
291
 
@@ -293,7 +293,6 @@ private:
293
293
  if (left == -1 || right == -1) {
294
294
  return;
295
295
  }
296
-
297
296
  const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
298
297
  auto token = vocab.token_to_id.find(text);
299
298
 
@@ -320,10 +319,11 @@ private:
320
319
  }
321
320
 
322
321
  const llama_vocab & vocab;
322
+ // currently unused
323
+ // const llm_tokenizer_spm * spm_tokenizer;
323
324
 
324
325
  std::vector<llm_symbol> symbols;
325
326
  llm_bigram_spm::queue work_queue;
326
-
327
327
  std::map<std::string, std::pair<int, int>> rev_merge;
328
328
  };
329
329
 
@@ -335,6 +335,21 @@ private:
335
335
 
336
336
  // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
337
337
 
338
+ template<typename T, typename Container = std::vector<T>, typename Compare = std::less<typename Container::value_type>>
339
+ class llama_priority_queue : public std::priority_queue<T, Container, Compare> {
340
+ public:
341
+ using std::priority_queue<T, Container, Compare>::priority_queue;
342
+
343
+ T pop_move() {
344
+ T item = std::move(this->c.front());
345
+ std::pop_heap(this->c.begin(), this->c.end(), this->comp);
346
+ this->c.pop_back();
347
+ return item;
348
+ }
349
+
350
+ void pop() = delete;
351
+ };
352
+
338
353
  struct llm_bigram_bpe {
339
354
  struct comparator {
340
355
  bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
@@ -343,7 +358,7 @@ struct llm_bigram_bpe {
343
358
  };
344
359
 
345
360
  using queue_storage = std::vector<llm_bigram_bpe>;
346
- using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
361
+ using queue = llama_priority_queue<llm_bigram_bpe, queue_storage, comparator>;
347
362
  llm_symbol::index left;
348
363
  llm_symbol::index right;
349
364
  std::string text;
@@ -351,8 +366,8 @@ struct llm_bigram_bpe {
351
366
  size_t size;
352
367
  };
353
368
 
354
- struct llm_tokenizer_bpe {
355
- llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
369
+ struct llm_tokenizer_bpe : llm_tokenizer {
370
+ llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
356
371
  GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
357
372
  switch (vocab.type_pre) {
358
373
  case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
@@ -402,6 +417,7 @@ struct llm_tokenizer_bpe {
402
417
  case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
403
418
  case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
404
419
  case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
420
+ case LLAMA_VOCAB_PRE_TYPE_EXAONE:
405
421
  regex_exprs = {
406
422
  "\\p{N}",
407
423
  "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
@@ -424,6 +440,8 @@ struct llm_tokenizer_bpe {
424
440
  };
425
441
  break;
426
442
  case LLAMA_VOCAB_PRE_TYPE_PORO:
443
+ case LLAMA_VOCAB_PRE_TYPE_BLOOM:
444
+ case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH:
427
445
  regex_exprs = {
428
446
  " ?[^(\\s|.,!?…。,、।۔،)]+",
429
447
  };
@@ -446,6 +464,20 @@ struct llm_tokenizer_bpe {
446
464
  "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
447
465
  };
448
466
  break;
467
+ case LLAMA_VOCAB_PRE_TYPE_CHAMELEON:
468
+ // Note: in theory, the special token (sentinel and image token) regex_exprs below
469
+ // are unnecessary, as they are split in `tokenizer_st_partition` anyway.
470
+ // However, since the upstream pre-tokenizer uses them, they are also
471
+ // included here (see https://huggingface.co/facebook/chameleon-7b).
472
+ regex_exprs = {
473
+ "<sentinel:[0-9]+>", // Sentinel tokens
474
+ "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens
475
+ "([\\t\\n]| | )", // directly from tokenizer.json
476
+ "\\p{N}", // Individual digits
477
+ "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated
478
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
479
+ };
480
+ break;
449
481
  default:
450
482
  // default regex for BPE tokenization pre-processing
451
483
  regex_exprs = {
@@ -458,7 +490,14 @@ struct llm_tokenizer_bpe {
458
490
  }
459
491
  }
460
492
 
461
- void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
493
+ std::vector<std::string> regex_exprs;
494
+ };
495
+
496
+ struct llm_tokenizer_bpe_session {
497
+ llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
498
+ bpe_tokenizer(static_cast<const llm_tokenizer_bpe *>(vocab.tokenizer)) {}
499
+
500
+ static void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) {
462
501
  output.push_back(token_id);
463
502
  }
464
503
 
@@ -497,12 +536,11 @@ struct llm_tokenizer_bpe {
497
536
 
498
537
  void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
499
538
  int final_prev_index = -1;
500
-
501
- const auto word_collection = unicode_regex_split(text, regex_exprs);
539
+ const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
502
540
 
503
541
  symbols_final.clear();
504
542
 
505
- for (auto & word : word_collection) {
543
+ for (const auto & word : word_collection) {
506
544
  work_queue = llm_bigram_bpe::queue();
507
545
  symbols.clear();
508
546
 
@@ -525,14 +563,13 @@ struct llm_tokenizer_bpe {
525
563
  index++;
526
564
  symbols.emplace_back(sym);
527
565
  }
528
- for (size_t i = 1; i < symbols.size(); ++i) {
566
+ for (int i = 1; i < (int) symbols.size(); ++i) {
529
567
  add_new_bigram(i - 1, i);
530
568
  }
531
569
 
532
570
  // build token(s)
533
571
  while (!work_queue.empty()) {
534
- auto bigram = work_queue.top();
535
- work_queue.pop();
572
+ auto bigram = work_queue.pop_move();
536
573
 
537
574
  auto & left_symbol = symbols[bigram.left];
538
575
  auto & right_symbol = symbols[bigram.right];
@@ -606,7 +643,6 @@ private:
606
643
  if (left == -1 || right == -1) {
607
644
  return;
608
645
  }
609
-
610
646
  std::string left_token = std::string(symbols[left].text, symbols[left].n);
611
647
  std::string right_token = std::string(symbols[right].text, symbols[right].n);
612
648
 
@@ -630,12 +666,10 @@ private:
630
666
  }
631
667
 
632
668
  const llama_vocab & vocab;
633
-
634
- std::vector<std::string> regex_exprs;
669
+ const llm_tokenizer_bpe * bpe_tokenizer;
635
670
 
636
671
  std::vector<llm_symbol> symbols;
637
672
  std::vector<llm_symbol> symbols_final;
638
-
639
673
  llm_bigram_bpe::queue work_queue;
640
674
  };
641
675
 
@@ -643,15 +677,17 @@ private:
643
677
  // WPM tokenizer
644
678
  //
645
679
 
646
- struct llm_tokenizer_wpm {
647
- llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
680
+ struct llm_tokenizer_wpm : llm_tokenizer {
681
+ llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
682
+ };
648
683
 
649
- void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
650
- const auto & token_map = vocab.token_to_id;
684
+ struct llm_tokenizer_wpm_session {
685
+ llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
651
686
 
687
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
688
+ const auto & token_map = vocab.token_to_id;
652
689
  // normalize and split by whitespace
653
690
  std::vector<std::string> words = preprocess(text);
654
-
655
691
  // bos token prepended already
656
692
 
657
693
  // find the longest tokens that form the words
@@ -696,7 +732,7 @@ struct llm_tokenizer_wpm {
696
732
  }
697
733
 
698
734
  // TODO: reduce string copies by using cpts_offs array
699
- std::vector<std::string> preprocess(const std::string & text) const {
735
+ static std::vector<std::string> preprocess(const std::string & text) {
700
736
  const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
701
737
  std::vector<std::string> words(1, "");
702
738
 
@@ -748,15 +784,18 @@ struct llm_tokenizer_wpm {
748
784
  //(cpt >= 0xFF00 && cpt <= 0xFFEF);
749
785
  }
750
786
 
787
+ private:
751
788
  const llama_vocab & vocab;
789
+ // currently unused
790
+ // const llm_tokenizer_wpm * wpm_tokenizer;
752
791
  };
753
792
 
754
793
  //
755
794
  // UGM tokenizer
756
795
  //
757
796
 
758
- struct llm_tokenizer_ugm {
759
- llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
797
+ struct llm_tokenizer_ugm : llm_tokenizer {
798
+ llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
760
799
  if (vocab.precompiled_charsmap.size() > 0) {
761
800
  size_t charsmap_offset = 0;
762
801
 
@@ -802,6 +841,30 @@ struct llm_tokenizer_ugm {
802
841
  unknown_token_score = min_score - unknown_token_score_penalty;
803
842
  }
804
843
 
844
+ // escaped space symbol - U+2581 (Lower One Eighth Block)
845
+ const std::string escaped_space = "\xE2\x96\x81";
846
+
847
+ const char * prefix_replacements = NULL;
848
+ size_t prefix_replacements_size = 0;
849
+
850
+ const uint32_t * xcda_array = NULL;
851
+ size_t xcda_array_size = 0;
852
+
853
+ struct naive_trie user_defined_token_matcher;
854
+
855
+ float min_score = FLT_MAX;
856
+ float max_score = -FLT_MAX;
857
+
858
+ float unknown_token_score_penalty = 10.0;
859
+ float unknown_token_score;
860
+
861
+ struct naive_trie token_matcher;
862
+ };
863
+
864
+ struct llm_tokenizer_ugm_session {
865
+ llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
866
+ ugm_tokenizer(static_cast<const llm_tokenizer_ugm *>(vocab.tokenizer)) {}
867
+
805
868
  /* This implementation is based on SentencePiece optimized Viterbi algorithm for
806
869
  * unigram language models. The general idea is to:
807
870
  * - move along the input sequence in steps of one UTF code point,
@@ -816,6 +879,9 @@ struct llm_tokenizer_ugm {
816
879
  * the best tokenization.
817
880
  */
818
881
  void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
882
+ // get current size of output (for reversal later)
883
+ size_t output_size = output.size();
884
+
819
885
  // normalize the input first
820
886
  std::string normalized;
821
887
  normalize(text, &normalized);
@@ -837,7 +903,7 @@ struct llm_tokenizer_ugm {
837
903
  // traverse the token matcher trie to find a matching token
838
904
  bool single_codepoint_token_found = false;
839
905
  const struct best_tokenization & current_best = tokenization_results[input_offset];
840
- struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
906
+ const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
841
907
 
842
908
  while (prefix_offset <= input_len && node != NULL) {
843
909
  // check if we found valid token in prefix
@@ -867,7 +933,7 @@ struct llm_tokenizer_ugm {
867
933
  // if we didn't find a valid token corresponding to the whole UTF code point
868
934
  // then use unknown token as the tokenization of this UTF code point
869
935
  if (!single_codepoint_token_found) {
870
- const double challenger_score = current_best.score_sum + unknown_token_score;
936
+ const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
871
937
  prefix_offset = input_offset + n_utf8_code_units;
872
938
  struct best_tokenization & current_champ = tokenization_results[prefix_offset];
873
939
  if (challenger_score > current_champ.score_sum) {
@@ -895,11 +961,10 @@ struct llm_tokenizer_ugm {
895
961
  }
896
962
 
897
963
  // reverse the output since we added tokens starting from the end of the input
898
- std::reverse(output.begin(), output.end());
964
+ std::reverse(output.begin() + output_size, output.end());
899
965
  }
900
966
 
901
967
  private:
902
- const llama_vocab & vocab;
903
968
 
904
969
  // helper structure for returning normalization results
905
970
  struct normalization_result {
@@ -912,7 +977,7 @@ private:
912
977
  normalized->clear();
913
978
  normalized->reserve(input.size() * 3);
914
979
 
915
- const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
980
+ const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
916
981
 
917
982
  bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
918
983
  bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
@@ -957,7 +1022,7 @@ private:
957
1022
  /*
958
1023
  * This structure is a view wrapper for XOR-compressed double array (XCDA)
959
1024
  * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
960
- * Eeach bit-packed entry contains:
1025
+ * Each bit-packed entry contains:
961
1026
  * - BASE array value in bits 10-30
962
1027
  * - LCHECK array value in bits 0-7
963
1028
  * - LEAF array value in bit 9
@@ -994,13 +1059,21 @@ private:
994
1059
  size_t xcda_array_size;
995
1060
  };
996
1061
 
1062
+ // this structure stores the best tokenization so far at input_offset
1063
+ struct best_tokenization {
1064
+ llama_token token_id;
1065
+ size_t input_offset;
1066
+ float score_sum;
1067
+ };
1068
+
997
1069
  struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
998
1070
  if (input_offset == input.size()) {
999
1071
  return { &input[input_offset], 0, 0 };
1000
1072
  }
1001
1073
 
1002
1074
  // if input prefix matches some user-defined token return this token as normalization result
1003
- auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1075
+ auto user_defined_token_match =
1076
+ ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1004
1077
  if (user_defined_token_match.second > 0) {
1005
1078
  return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
1006
1079
  }
@@ -1008,8 +1081,8 @@ private:
1008
1081
  size_t longest_prefix_length = 0;
1009
1082
  size_t longest_prefix_offset = 0;
1010
1083
 
1011
- if (xcda_array_size > 0) {
1012
- struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
1084
+ if (ugm_tokenizer->xcda_array_size > 0) {
1085
+ struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
1013
1086
 
1014
1087
  // Find the longest normalized sequence matching the input prefix by walking
1015
1088
  // the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1045,52 +1118,162 @@ private:
1045
1118
 
1046
1119
  if (longest_prefix_length > 0) {
1047
1120
  // we have a match, so return the replacement sequence
1048
- if (longest_prefix_offset >= prefix_replacements_size) {
1121
+ if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
1049
1122
  throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
1050
1123
  }
1051
- const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
1124
+ const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
1052
1125
  return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
1053
- } else {
1054
- // check if the input prefix contains a valid sequence of UTF-8 code units
1055
- try {
1056
- // if yes, return this sequence unmodified
1057
- size_t prefix_offset = input_offset;
1058
- unicode_cpt_from_utf8(input, prefix_offset);
1059
- return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
1060
- } catch (std::invalid_argument & /*ex*/) {
1061
- // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
1062
- return { "\xEF\xBF\xBD", 3, 1 };
1063
- }
1126
+ }
1127
+
1128
+ // check if the input prefix contains a valid sequence of UTF-8 code units
1129
+ try {
1130
+ // if yes, return this sequence unmodified
1131
+ size_t prefix_offset = input_offset;
1132
+ unicode_cpt_from_utf8(input, prefix_offset);
1133
+ return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
1134
+ } catch (std::invalid_argument & /*ex*/) {
1135
+ // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
1136
+ return { "\xEF\xBF\xBD", 3, 1 };
1064
1137
  }
1065
1138
  }
1066
1139
 
1067
- // escaped space symbol - U+2581 (Lower One Eighth Block)
1068
- const std::string escaped_space = "\xE2\x96\x81";
1140
+ const llama_vocab & vocab;
1141
+ const llm_tokenizer_ugm * ugm_tokenizer;
1142
+ };
1069
1143
 
1070
- const char * prefix_replacements = NULL;
1071
- size_t prefix_replacements_size = 0;
1144
+ //
1145
+ // RWKV tokenizer
1146
+ //
1072
1147
 
1073
- const uint32_t * xcda_array = NULL;
1074
- size_t xcda_array_size = 0;
1148
+ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
1149
+ std::vector<uint8_t> output;
1150
+ output.reserve(escaped.size());
1151
+
1152
+ // Parser state
1153
+ bool escaping = false;
1154
+ uint8_t hex_remaining = 0;
1155
+ uint8_t hex_acc = 0;
1156
+
1157
+ // Step through characters, performing parsing
1158
+ for (const char & c : escaped) {
1159
+ // If we're parsing a hex code, interpret the next character
1160
+ if (hex_remaining != 0) {
1161
+ uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
1162
+ hex_acc = (hex_acc << 4) + value;
1163
+
1164
+ hex_remaining -= 1;
1165
+ if (hex_remaining == 0) {
1166
+ output.push_back(hex_acc);
1167
+ hex_acc = 0;
1168
+ }
1075
1169
 
1076
- struct naive_trie user_defined_token_matcher;
1170
+ continue;
1171
+ }
1077
1172
 
1078
- // this structure stores the best tokenization so far at input_offset
1079
- struct best_tokenization {
1080
- llama_token token_id;
1081
- size_t input_offset;
1082
- float score_sum;
1083
- };
1173
+ // If we got an escape character, interpret it
1174
+ if (escaping) {
1175
+ if (c == 't') {
1176
+ output.push_back('\t');
1177
+ } else if (c == 'n') {
1178
+ output.push_back('\n');
1179
+ } else if (c == 'r') {
1180
+ output.push_back('\r');
1181
+ } else if (c == 'x') {
1182
+ hex_remaining = 2;
1183
+ } else {
1184
+ output.push_back(c);
1185
+ }
1084
1186
 
1085
- float min_score = FLT_MAX;
1086
- float max_score = -FLT_MAX;
1187
+ escaping = false;
1188
+ continue;
1189
+ }
1087
1190
 
1088
- float unknown_token_score_penalty = 10.0;
1089
- float unknown_token_score;
1191
+ if (c == '\\') {
1192
+ escaping = true;
1193
+ continue;
1194
+ }
1195
+
1196
+ output.push_back(c);
1197
+ }
1198
+
1199
+ return output;
1200
+ }
1201
+
1202
+ struct llm_tokenizer_rwkv : llm_tokenizer {
1203
+ llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
1204
+ // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
1205
+ // For now, we decode the vocab here into the lookup we'll use for tokenization.
1206
+
1207
+ // build trie
1208
+ for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
1209
+ const auto & token = vocab.id_to_token[id];
1210
+ const auto data = llama_unescape_rwkv_token(token.text);
1211
+ token_matcher.insert((const char *) data.data(), data.size(), id);
1212
+ }
1213
+ }
1090
1214
 
1091
1215
  struct naive_trie token_matcher;
1092
1216
  };
1093
1217
 
1218
+ struct llm_tokenizer_rwkv_session {
1219
+ llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
1220
+ rwkv_tokenizer(static_cast<const llm_tokenizer_rwkv &>(*vocab.tokenizer)) {}
1221
+
1222
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
1223
+ uint32_t position = 0;
1224
+ while (position < text.size()) {
1225
+ const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
1226
+ if (node == NULL) {
1227
+ // no matching token found, add unknown token
1228
+ output.push_back(vocab.special_unk_id);
1229
+ position += 1;
1230
+ continue;
1231
+ }
1232
+
1233
+ // traverse the trie to find the longest matching token
1234
+ uint32_t token_id = 0;
1235
+ uint32_t token_length = 0;
1236
+ while (node != NULL) {
1237
+ if (node->has_value) {
1238
+ token_id = node->value;
1239
+ token_length = position + 1;
1240
+ }
1241
+ node = node->traverse(text[++position]);
1242
+ }
1243
+
1244
+ // add the longest matching token
1245
+ output.push_back(token_id);
1246
+ position = token_length;
1247
+ }
1248
+ }
1249
+
1250
+ private:
1251
+ const llama_vocab & vocab;
1252
+ const llm_tokenizer_rwkv & rwkv_tokenizer;
1253
+ };
1254
+
1255
+ void llama_vocab::init_tokenizer() {
1256
+ switch (type) {
1257
+ case LLAMA_VOCAB_TYPE_SPM:
1258
+ tokenizer = new llm_tokenizer_spm(*this);
1259
+ break;
1260
+ case LLAMA_VOCAB_TYPE_BPE:
1261
+ tokenizer = new llm_tokenizer_bpe(*this);
1262
+ break;
1263
+ case LLAMA_VOCAB_TYPE_WPM:
1264
+ tokenizer = new llm_tokenizer_wpm(*this);
1265
+ break;
1266
+ case LLAMA_VOCAB_TYPE_UGM:
1267
+ tokenizer = new llm_tokenizer_ugm(*this);
1268
+ break;
1269
+ case LLAMA_VOCAB_TYPE_RWKV:
1270
+ tokenizer = new llm_tokenizer_rwkv(*this);
1271
+ break;
1272
+ default:
1273
+ GGML_ABORT("unsupported vocab type");
1274
+ }
1275
+ }
1276
+
1094
1277
  //
1095
1278
  // (de-) tokenize
1096
1279
  //
@@ -1152,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1152
1335
 
1153
1336
  // if a fragment is text ( not yet processed )
1154
1337
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1155
- auto & raw_text = fragment.raw_text;
1338
+ const auto & raw_text = fragment.raw_text;
1156
1339
 
1157
1340
  auto raw_text_base_offset = fragment.offset;
1158
1341
  auto raw_text_base_length = fragment.length;
@@ -1251,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1251
1434
  }
1252
1435
  }
1253
1436
 
1254
- std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
1437
+ std::vector<llama_vocab::id> llama_tokenize_internal(
1438
+ const llama_vocab & vocab,
1439
+ std::string raw_text,
1440
+ bool add_special,
1441
+ bool parse_special) {
1442
+ GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1443
+
1255
1444
  std::vector<llama_vocab::id> output;
1256
1445
  std::forward_list<fragment_buffer_variant> fragment_buffer;
1257
1446
 
@@ -1288,9 +1477,9 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1288
1477
  #ifdef PRETOKENIZERDEBUG
1289
1478
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1290
1479
  #endif
1291
- llm_tokenizer_spm tokenizer(vocab);
1292
1480
  llama_escape_whitespace(raw_text);
1293
- tokenizer.tokenize(raw_text, output);
1481
+ llm_tokenizer_spm_session session(vocab);
1482
+ session.tokenize(raw_text, output);
1294
1483
  is_prev_special = false;
1295
1484
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1296
1485
  output.push_back(fragment.token);
@@ -1312,10 +1501,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1312
1501
  } break;
1313
1502
  case LLAMA_VOCAB_TYPE_BPE:
1314
1503
  {
1315
- llm_tokenizer_bpe tokenizer(vocab);
1316
-
1504
+ llm_tokenizer_bpe_session session(vocab);
1505
+ // it calls some other methods that are not exist in llm_tokenizer,
1506
+ // here just cast it to bpe tokenizer object
1317
1507
  if (add_special) {
1318
- tokenizer.append_bos(output);
1508
+ session.append_bos(output);
1319
1509
  }
1320
1510
  for (const auto & fragment : fragment_buffer) {
1321
1511
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1324,15 +1514,15 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1324
1514
  #ifdef PRETOKENIZERDEBUG
1325
1515
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1326
1516
  #endif
1327
- tokenizer.tokenize(raw_text, output);
1517
+ session.tokenize(raw_text, output);
1328
1518
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1329
- tokenizer.append(fragment.token, output);
1519
+ session.append(fragment.token, output);
1330
1520
  }
1331
1521
  }
1332
1522
 
1333
1523
  if (add_special) {
1334
- tokenizer.append_eos(output);
1335
- tokenizer.check_double_bos_eos(output);
1524
+ session.append_eos(output);
1525
+ session.check_double_bos_eos(output);
1336
1526
  }
1337
1527
  } break;
1338
1528
  case LLAMA_VOCAB_TYPE_WPM:
@@ -1342,7 +1532,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1342
1532
  output.push_back(vocab.special_cls_id);
1343
1533
  }
1344
1534
 
1345
- llm_tokenizer_wpm tokenizer(vocab);
1535
+ llm_tokenizer_wpm_session session(vocab);
1346
1536
 
1347
1537
  for (const auto & fragment : fragment_buffer) {
1348
1538
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1351,7 +1541,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1351
1541
  #ifdef PRETOKENIZERDEBUG
1352
1542
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1353
1543
  #endif
1354
- tokenizer.tokenize(raw_text, output);
1544
+ session.tokenize(raw_text, output);
1355
1545
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1356
1546
  output.push_back(fragment.token);
1357
1547
  }
@@ -1364,12 +1554,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1364
1554
  } break;
1365
1555
  case LLAMA_VOCAB_TYPE_UGM:
1366
1556
  {
1367
- llm_tokenizer_ugm tokenizer(vocab);
1368
-
1369
- if (add_special && vocab.tokenizer_add_bos != 0) {
1557
+ if (add_special && vocab.tokenizer_add_bos) {
1370
1558
  GGML_ASSERT(vocab.special_bos_id != -1);
1371
1559
  output.push_back(vocab.special_bos_id);
1372
1560
  }
1561
+ llm_tokenizer_ugm_session session(vocab);
1373
1562
 
1374
1563
  for (const auto & fragment : fragment_buffer) {
1375
1564
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1377,24 +1566,41 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1377
1566
  #ifdef PRETOKENIZERDEBUG
1378
1567
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1379
1568
  #endif
1380
- tokenizer.tokenize(raw_text, output);
1569
+ session.tokenize(raw_text, output);
1381
1570
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1382
1571
  output.push_back(fragment.token);
1383
1572
  }
1384
1573
  }
1385
1574
 
1386
- if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1575
+ if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1387
1576
  LLAMA_LOG_WARN(
1388
1577
  "%s: Added a BOS token to the prompt as specified by the model but the prompt "
1389
1578
  "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
1390
1579
  "Are you sure this is what you want?\n", __FUNCTION__);
1391
1580
  }
1392
1581
 
1393
- if (add_special && vocab.tokenizer_add_eos == 1) {
1582
+ if (add_special && vocab.tokenizer_add_eos) {
1394
1583
  GGML_ASSERT(vocab.special_eos_id != -1);
1395
1584
  output.push_back(vocab.special_eos_id);
1396
1585
  }
1397
1586
  } break;
1587
+ case LLAMA_VOCAB_TYPE_RWKV:
1588
+ {
1589
+ llm_tokenizer_rwkv_session session(vocab);
1590
+ for (const auto & fragment : fragment_buffer) {
1591
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1592
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1593
+
1594
+ #ifdef PRETOKENIZERDEBUG
1595
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1596
+ #endif
1597
+
1598
+ session.tokenize(raw_text, output);
1599
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1600
+ output.push_back(fragment.token);
1601
+ }
1602
+ }
1603
+ } break;
1398
1604
  case LLAMA_VOCAB_TYPE_NONE:
1399
1605
  GGML_ABORT("fatal error");
1400
1606
  }
@@ -1442,10 +1648,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
1442
1648
  }
1443
1649
 
1444
1650
  bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1445
- return token != -1 && (
1446
- token == llama_token_eos_impl(vocab) ||
1447
- token == llama_token_eot_impl(vocab)
1448
- );
1651
+ return token != -1 && vocab.special_eog_ids.count(token) > 0;
1449
1652
  }
1450
1653
 
1451
1654
  bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
@@ -1460,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
1460
1663
  return vocab.special_eos_id;
1461
1664
  }
1462
1665
 
1666
+ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1667
+ return vocab.special_eot_id;
1668
+ }
1669
+
1670
+ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1671
+ return vocab.special_eom_id;
1672
+ }
1673
+
1463
1674
  llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
1464
1675
  return vocab.special_cls_id;
1465
1676
  }
@@ -1476,38 +1687,58 @@ llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
1476
1687
  return vocab.special_pad_id;
1477
1688
  }
1478
1689
 
1479
- int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
1690
+ bool llama_add_bos_token_impl(const struct llama_vocab & vocab) {
1480
1691
  return vocab.tokenizer_add_bos;
1481
1692
  }
1482
1693
 
1483
- int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
1694
+ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
1484
1695
  return vocab.tokenizer_add_eos;
1485
1696
  }
1486
1697
 
1487
1698
  llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1488
- return vocab.special_prefix_id;
1699
+ return vocab.special_fim_pre_id;
1489
1700
  }
1490
1701
 
1491
1702
  llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1492
- return vocab.special_middle_id;
1703
+ return vocab.special_fim_mid_id;
1493
1704
  }
1494
1705
 
1495
1706
  llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1496
- return vocab.special_suffix_id;
1707
+ return vocab.special_fim_suf_id;
1497
1708
  }
1498
1709
 
1499
- llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1500
- return vocab.special_eot_id;
1710
+ llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
1711
+ return vocab.special_fim_pre_id;
1712
+ }
1713
+
1714
+ llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
1715
+ return vocab.special_fim_suf_id;
1716
+ }
1717
+
1718
+ llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
1719
+ return vocab.special_fim_mid_id;
1720
+ }
1721
+
1722
+ llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
1723
+ return vocab.special_fim_pad_id;
1724
+ }
1725
+
1726
+ llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
1727
+ return vocab.special_fim_rep_id;
1728
+ }
1729
+
1730
+ llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
1731
+ return vocab.special_fim_sep_id;
1501
1732
  }
1502
1733
 
1503
1734
  int32_t llama_tokenize_impl(
1504
- const struct llama_vocab & vocab,
1505
- const char * text,
1506
- int32_t text_len,
1507
- llama_token * tokens,
1508
- int32_t n_tokens_max,
1509
- bool add_special,
1510
- bool parse_special) {
1735
+ const struct llama_vocab & vocab,
1736
+ const char * text,
1737
+ int32_t text_len,
1738
+ llama_token * tokens,
1739
+ int32_t n_tokens_max,
1740
+ bool add_special,
1741
+ bool parse_special) {
1511
1742
  auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
1512
1743
  if (n_tokens_max < (int) res.size()) {
1513
1744
  // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
@@ -1584,11 +1815,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
1584
1815
  // suppressing them like CONTROL tokens.
1585
1816
  if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1586
1817
  return _try_copy(token_text.data(), token_text.size());
1587
- } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1818
+ }
1819
+ if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1588
1820
  std::string result = token_text;
1589
1821
  llama_unescape_whitespace(result);
1590
1822
  return _try_copy(result.data(), result.size());
1591
- } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
1823
+ }
1824
+ if (attr & LLAMA_TOKEN_ATTR_BYTE) {
1592
1825
  char byte = (char) llama_token_to_byte(vocab, token);
1593
1826
  return _try_copy((char*) &byte, 1);
1594
1827
  }
@@ -1599,12 +1832,24 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
1599
1832
  // suppressing them like CONTROL tokens.
1600
1833
  if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1601
1834
  return _try_copy(token_text.data(), token_text.size());
1602
- } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1835
+ }
1836
+ if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1603
1837
  std::string result = llama_decode_text(token_text);
1604
1838
  return _try_copy(result.data(), result.size());
1605
1839
  }
1606
1840
  break;
1607
1841
  }
1842
+ case LLAMA_VOCAB_TYPE_RWKV: {
1843
+ std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
1844
+
1845
+ // If we don't have enough space, return an error
1846
+ if (result.size() > (size_t)length) {
1847
+ return -(int)result.size();
1848
+ }
1849
+
1850
+ memcpy(buf, result.data(), result.size());
1851
+ return (int)result.size();
1852
+ }
1608
1853
  default:
1609
1854
  GGML_ABORT("fatal error");
1610
1855
  }
@@ -1621,6 +1866,8 @@ int32_t llama_detokenize_impl(
1621
1866
  int32_t text_len_max,
1622
1867
  bool remove_special,
1623
1868
  bool unparse_special) {
1869
+ GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1870
+
1624
1871
  int32_t avail = text_len_max;
1625
1872
  int32_t total = 0;
1626
1873
 
@@ -1719,3 +1966,19 @@ int32_t llama_detokenize_impl(
1719
1966
 
1720
1967
  return total <= text_len_max ? total : -total;
1721
1968
  }
1969
+
1970
+ std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
1971
+ std::string text;
1972
+ text.resize(std::max(text.capacity(), tokens.size()));
1973
+ int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1974
+ if (n_chars < 0) {
1975
+ text.resize(-n_chars);
1976
+ n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1977
+ GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
1978
+ }
1979
+
1980
+ text.resize(n_chars);
1981
+
1982
+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
1983
+ return text;
1984
+ }