cui-llama.rn 1.0.3 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/README.md +35 -39
  2. package/android/src/main/CMakeLists.txt +12 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +29 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
  5. package/android/src/main/jni.cpp +62 -8
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  8. package/cpp/common.cpp +3237 -3231
  9. package/cpp/common.h +469 -468
  10. package/cpp/ggml-aarch64.c +2193 -2193
  11. package/cpp/ggml-aarch64.h +39 -39
  12. package/cpp/ggml-alloc.c +1036 -1042
  13. package/cpp/ggml-backend-impl.h +153 -153
  14. package/cpp/ggml-backend.c +2240 -2234
  15. package/cpp/ggml-backend.h +238 -238
  16. package/cpp/ggml-common.h +1833 -1829
  17. package/cpp/ggml-impl.h +755 -655
  18. package/cpp/ggml-metal.h +65 -65
  19. package/cpp/ggml-metal.m +3269 -3269
  20. package/cpp/ggml-quants.c +14872 -14860
  21. package/cpp/ggml-quants.h +132 -132
  22. package/cpp/ggml.c +22055 -22044
  23. package/cpp/ggml.h +2453 -2447
  24. package/cpp/llama-grammar.cpp +539 -0
  25. package/cpp/llama-grammar.h +39 -0
  26. package/cpp/llama-impl.h +26 -0
  27. package/cpp/llama-sampling.cpp +635 -0
  28. package/cpp/llama-sampling.h +56 -0
  29. package/cpp/llama-vocab.cpp +1721 -0
  30. package/cpp/llama-vocab.h +130 -0
  31. package/cpp/llama.cpp +19171 -21892
  32. package/cpp/llama.h +1240 -1217
  33. package/cpp/log.h +737 -737
  34. package/cpp/rn-llama.hpp +207 -29
  35. package/cpp/sampling.cpp +460 -460
  36. package/cpp/sgemm.cpp +1027 -1027
  37. package/cpp/sgemm.h +14 -14
  38. package/cpp/unicode.cpp +6 -0
  39. package/cpp/unicode.h +3 -0
  40. package/ios/RNLlama.mm +15 -6
  41. package/ios/RNLlamaContext.h +2 -8
  42. package/ios/RNLlamaContext.mm +41 -34
  43. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  44. package/lib/commonjs/chat.js +37 -0
  45. package/lib/commonjs/chat.js.map +1 -0
  46. package/lib/commonjs/index.js +14 -1
  47. package/lib/commonjs/index.js.map +1 -1
  48. package/lib/module/NativeRNLlama.js.map +1 -1
  49. package/lib/module/chat.js +31 -0
  50. package/lib/module/chat.js.map +1 -0
  51. package/lib/module/index.js +14 -1
  52. package/lib/module/index.js.map +1 -1
  53. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  54. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  55. package/lib/typescript/chat.d.ts +10 -0
  56. package/lib/typescript/chat.d.ts.map +1 -0
  57. package/lib/typescript/index.d.ts +9 -2
  58. package/lib/typescript/index.d.ts.map +1 -1
  59. package/package.json +1 -1
  60. package/src/NativeRNLlama.ts +10 -1
  61. package/src/chat.ts +44 -0
  62. package/src/index.ts +31 -4
@@ -0,0 +1,1721 @@
1
+ #include "llama-vocab.h"
2
+
3
+ #include "unicode.h"
4
+
5
+ #include <algorithm>
6
+ #include <cassert>
7
+ #include <cfloat>
8
+ #include <climits>
9
+ #include <cstdarg>
10
+ #include <cstring>
11
+ #include <forward_list>
12
+ #include <queue>
13
+ #include <sstream>
14
+
15
+ //
16
+ // helpers
17
+ //
18
+
19
+ static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
20
+ std::string result;
21
+ for (size_t pos = 0; ; pos += search.length()) {
22
+ auto new_pos = s.find(search, pos);
23
+ if (new_pos == std::string::npos) {
24
+ result += s.substr(pos, s.size() - pos);
25
+ break;
26
+ }
27
+ result += s.substr(pos, new_pos - pos) + replace;
28
+ pos = new_pos;
29
+ }
30
+ s = std::move(result);
31
+ }
32
+
33
+ LLAMA_ATTRIBUTE_FORMAT(1, 2)
34
+ static std::string format(const char * fmt, ...) {
35
+ va_list ap;
36
+ va_list ap2;
37
+ va_start(ap, fmt);
38
+ va_copy(ap2, ap);
39
+ int size = vsnprintf(NULL, 0, fmt, ap);
40
+ LM_GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
41
+ std::vector<char> buf(size + 1);
42
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
43
+ LM_GGML_ASSERT(size2 == size);
44
+ va_end(ap2);
45
+ va_end(ap);
46
+ return std::string(buf.data(), size);
47
+ }
48
+
49
+ struct naive_trie {
50
+ naive_trie() : has_value(false), value(0) {
51
+ }
52
+ void insert(const char * key, size_t len, int32_t value = 0) {
53
+ if (len == 0) {
54
+ this->has_value = true;
55
+ this->value = value;
56
+ return;
57
+ }
58
+ char c = key[0];
59
+ auto res = children.find(c);
60
+ if (res != children.end()) {
61
+ res->second.insert(key + 1, len - 1, value);
62
+ } else {
63
+ auto res = children.insert(std::make_pair(c, naive_trie()));
64
+ res.first->second.insert(key + 1, len - 1, value);
65
+ }
66
+ }
67
+ std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
68
+ if (len == 0 || offset == len) {
69
+ return std::make_pair(key, offset);
70
+ }
71
+ char c = key[offset];
72
+ auto res = children.find(c);
73
+ if (res != children.end()) {
74
+ return res->second.get_longest_prefix(key, len, offset + 1);
75
+ } else {
76
+ return std::make_pair(key, offset);
77
+ }
78
+ }
79
+ struct naive_trie * traverse(const char c) {
80
+ auto res = children.find(c);
81
+ if (res != children.end()) {
82
+ return &res->second;
83
+ } else {
84
+ return NULL;
85
+ }
86
+ }
87
+ std::map<char, struct naive_trie> children;
88
+ bool has_value;
89
+ llama_token value;
90
+ };
91
+
92
+ //
93
+ // impl
94
+ //
95
+
96
+ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
97
+ LM_GGML_ASSERT(token_left.find(' ') == std::string::npos);
98
+ LM_GGML_ASSERT(token_left.find('\n') == std::string::npos);
99
+ LM_GGML_ASSERT(token_right.find(' ') == std::string::npos);
100
+ LM_GGML_ASSERT(token_right.find('\n') == std::string::npos);
101
+
102
+ auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
103
+ if (it == bpe_ranks.end()) {
104
+ return -1;
105
+ }
106
+
107
+ return it->second;
108
+ }
109
+
110
+ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
111
+ return vocab.type;
112
+ }
113
+
114
+ static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
115
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
116
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
117
+ }
118
+
119
+ static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
120
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
121
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
122
+ }
123
+
124
+ static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
125
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
126
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
127
+ }
128
+
129
+ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
130
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
131
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
132
+ }
133
+
134
+ static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) {
135
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
136
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
137
+ }
138
+
139
+ static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) {
140
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
141
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
142
+ }
143
+
144
+ static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
145
+ LM_GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
146
+ LM_GGML_ASSERT(llama_is_byte_token(vocab, id));
147
+ const auto & token_data = vocab.id_to_token.at(id);
148
+ switch (llama_vocab_get_type(vocab)) {
149
+ case LLAMA_VOCAB_TYPE_SPM:
150
+ case LLAMA_VOCAB_TYPE_UGM: {
151
+ auto buf = token_data.text.substr(3, 2);
152
+ return strtol(buf.c_str(), NULL, 16);
153
+ }
154
+ case LLAMA_VOCAB_TYPE_BPE: {
155
+ LM_GGML_ABORT("fatal error");
156
+ //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after LM_GGML_ASSERT?
157
+ }
158
+ case LLAMA_VOCAB_TYPE_WPM: {
159
+ LM_GGML_ABORT("fatal error");
160
+ }
161
+ default:
162
+ LM_GGML_ABORT("fatal error");
163
+ }
164
+ }
165
+
166
+ static void llama_escape_whitespace(std::string & text) {
167
+ replace_all(text, " ", "\xe2\x96\x81");
168
+ }
169
+
170
+ static void llama_unescape_whitespace(std::string & word) {
171
+ replace_all(word, "\xe2\x96\x81", " ");
172
+ }
173
+
174
+ struct llm_symbol {
175
+ using index = int;
176
+ index prev;
177
+ index next;
178
+ const char * text;
179
+ size_t n;
180
+ };
181
+
182
+ static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
183
+
184
+ //
185
+ // SPM tokenizer
186
+ // original implementation:
187
+ // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
188
+ //
189
+
190
+ struct llm_bigram_spm {
191
+ struct comparator {
192
+ bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
193
+ return (l.score < r.score) || (l.score == r.score && l.left > r.left);
194
+ }
195
+ };
196
+ using queue_storage = std::vector<llm_bigram_spm>;
197
+ using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
198
+ llm_symbol::index left;
199
+ llm_symbol::index right;
200
+ float score;
201
+ size_t size;
202
+ };
203
+
204
+ struct llm_tokenizer_spm {
205
+ llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
206
+
207
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
208
+ // split string into utf8 chars
209
+ int index = 0;
210
+ size_t offs = 0;
211
+ while (offs < text.size()) {
212
+ llm_symbol sym;
213
+ size_t len = unicode_len_utf8(text[offs]);
214
+ sym.text = text.c_str() + offs;
215
+ sym.n = std::min(len, text.size() - offs);
216
+ offs += sym.n;
217
+ sym.prev = index - 1;
218
+ sym.next = offs == text.size() ? -1 : index + 1;
219
+ index++;
220
+ symbols.emplace_back(sym);
221
+ }
222
+
223
+ // seed the work queue with all possible 2-character tokens.
224
+ for (size_t i = 1; i < symbols.size(); ++i) {
225
+ try_add_bigram(i - 1, i);
226
+ }
227
+
228
+ // keep substituting the highest frequency pairs for as long as we can.
229
+ while (!work_queue.empty()) {
230
+ auto bigram = work_queue.top();
231
+ work_queue.pop();
232
+
233
+ auto & left_sym = symbols[bigram.left];
234
+ auto & right_sym = symbols[bigram.right];
235
+
236
+ // if one of the symbols already got merged, skip it.
237
+ if (left_sym.n == 0 || right_sym.n == 0 ||
238
+ left_sym.n + right_sym.n != bigram.size) {
239
+ continue;
240
+ }
241
+
242
+ // merge the right sym into the left one
243
+ left_sym.n += right_sym.n;
244
+ right_sym.n = 0;
245
+
246
+ //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
247
+
248
+ // remove the right sym from the chain
249
+ left_sym.next = right_sym.next;
250
+ if (right_sym.next >= 0) {
251
+ symbols[right_sym.next].prev = bigram.left;
252
+ }
253
+
254
+ // find more substitutions
255
+ try_add_bigram(left_sym.prev, bigram.left);
256
+ try_add_bigram(bigram.left, left_sym.next);
257
+ }
258
+
259
+ for (int i = 0; i != -1; i = symbols[i].next) {
260
+ auto & symbol = symbols[i];
261
+ resegment(symbol, output);
262
+ }
263
+ }
264
+
265
+ private:
266
+ void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) {
267
+ auto text = std::string(symbol.text, symbol.n);
268
+ auto token = vocab.token_to_id.find(text);
269
+
270
+ // Do we need to support is_unused?
271
+ if (token != vocab.token_to_id.end()) {
272
+ output.push_back((*token).second);
273
+ return;
274
+ }
275
+
276
+ const auto p = rev_merge.find(text);
277
+
278
+ if (p == rev_merge.end()) {
279
+ // output any symbols that did not form tokens as bytes.
280
+ output.reserve(output.size() + symbol.n);
281
+ for (int j = 0; j < (int)symbol.n; ++j) {
282
+ llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]);
283
+ output.push_back(token_id);
284
+ }
285
+ return;
286
+ }
287
+
288
+ resegment(symbols[p->second.first], output);
289
+ resegment(symbols[p->second.second], output);
290
+ }
291
+
292
+ void try_add_bigram(int left, int right) {
293
+ if (left == -1 || right == -1) {
294
+ return;
295
+ }
296
+
297
+ const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
298
+ auto token = vocab.token_to_id.find(text);
299
+
300
+ if (token == vocab.token_to_id.end()) {
301
+ return;
302
+ }
303
+
304
+ if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
305
+ return;
306
+ }
307
+
308
+ const auto & tok_data = vocab.id_to_token[(*token).second];
309
+
310
+ llm_bigram_spm bigram;
311
+ bigram.left = left;
312
+ bigram.right = right;
313
+ bigram.score = tok_data.score;
314
+ bigram.size = text.size();
315
+
316
+ work_queue.push(bigram);
317
+
318
+ // Do we need to support is_unused?
319
+ rev_merge[text] = std::make_pair(left, right);
320
+ }
321
+
322
+ const llama_vocab & vocab;
323
+
324
+ std::vector<llm_symbol> symbols;
325
+ llm_bigram_spm::queue work_queue;
326
+
327
+ std::map<std::string, std::pair<int, int>> rev_merge;
328
+ };
329
+
330
+ //
331
+ // BPE tokenizer
332
+ // adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License]
333
+ // tried to simplify unicode stuff, so most likely does not work 100% correctly!
334
+ //
335
+
336
+ // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
337
+
338
+ struct llm_bigram_bpe {
339
+ struct comparator {
340
+ bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
341
+ return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
342
+ }
343
+ };
344
+
345
+ using queue_storage = std::vector<llm_bigram_bpe>;
346
+ using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
347
+ llm_symbol::index left;
348
+ llm_symbol::index right;
349
+ std::string text;
350
+ int rank;
351
+ size_t size;
352
+ };
353
+
354
+ struct llm_tokenizer_bpe {
355
+ llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
356
+ LM_GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
357
+ switch (vocab.type_pre) {
358
+ case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
359
+ regex_exprs = {
360
+ // original regex from tokenizer.json
361
+ //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
362
+
363
+ // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
364
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
365
+ };
366
+ break;
367
+ case LLAMA_VOCAB_PRE_TYPE_DBRX:
368
+ case LLAMA_VOCAB_PRE_TYPE_SMAUG:
369
+ regex_exprs = {
370
+ // same as llama3
371
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
372
+ };
373
+ break;
374
+ case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
375
+ regex_exprs = {
376
+ "[\r\n]",
377
+ "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
378
+ "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
379
+ "\\s+$",
380
+ "[一-龥ࠀ-一가-퟿]+",
381
+ "\\p{N}+",
382
+ };
383
+ break;
384
+ case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
385
+ regex_exprs = {
386
+ "[\r\n]",
387
+ "\\s?\\p{L}+",
388
+ "\\s?\\p{P}+",
389
+ "[一-龥ࠀ-一가-퟿]+",
390
+ "\\p{N}",
391
+ };
392
+ break;
393
+ case LLAMA_VOCAB_PRE_TYPE_FALCON:
394
+ regex_exprs = {
395
+ "[\\p{P}\\$\\+<=>\\^~\\|`]+",
396
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
397
+ "[0-9][0-9][0-9]",
398
+ };
399
+ break;
400
+ case LLAMA_VOCAB_PRE_TYPE_STARCODER:
401
+ case LLAMA_VOCAB_PRE_TYPE_REFACT:
402
+ case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
403
+ case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
404
+ case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
405
+ regex_exprs = {
406
+ "\\p{N}",
407
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
408
+ };
409
+ break;
410
+ case LLAMA_VOCAB_PRE_TYPE_GPT2:
411
+ case LLAMA_VOCAB_PRE_TYPE_MPT:
412
+ case LLAMA_VOCAB_PRE_TYPE_OLMO:
413
+ case LLAMA_VOCAB_PRE_TYPE_JAIS:
414
+ regex_exprs = {
415
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
416
+ };
417
+ break;
418
+ case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
419
+ case LLAMA_VOCAB_PRE_TYPE_QWEN2:
420
+ regex_exprs = {
421
+ // original regex from tokenizer.json
422
+ // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
423
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
424
+ };
425
+ break;
426
+ case LLAMA_VOCAB_PRE_TYPE_PORO:
427
+ regex_exprs = {
428
+ " ?[^(\\s|.,!?…。,、।۔،)]+",
429
+ };
430
+ break;
431
+ case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
432
+ regex_exprs = {
433
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
434
+ };
435
+ break;
436
+ case LLAMA_VOCAB_PRE_TYPE_VIKING:
437
+ regex_exprs = {
438
+ " ?[^(\\s|.,!?…。,、।۔،)]+",
439
+ "\\p{N}",
440
+ };
441
+ break;
442
+ case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
443
+ // original regex from tokenizer.json
444
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
445
+ regex_exprs = {
446
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
447
+ };
448
+ break;
449
+ default:
450
+ // default regex for BPE tokenization pre-processing
451
+ regex_exprs = {
452
+ "[\\p{P}\\$\\+<=>\\^~\\|]+",
453
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
454
+ "\\p{N}+",
455
+ "[0-9][0-9][0-9]",
456
+ };
457
+ break;
458
+ }
459
+ }
460
+
461
+ void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
462
+ output.push_back(token_id);
463
+ }
464
+
465
+ bool append_bos(std::vector<llama_vocab::id> & output) const {
466
+ if (vocab.tokenizer_add_bos) {
467
+ LM_GGML_ASSERT(vocab.special_bos_id != -1);
468
+ output.push_back(vocab.special_bos_id);
469
+ return true;
470
+ }
471
+ return false;
472
+ }
473
+
474
+ bool append_eos(std::vector<llama_vocab::id> & output) const {
475
+ if (vocab.tokenizer_add_eos) {
476
+ LM_GGML_ASSERT(vocab.special_eos_id != -1);
477
+ output.push_back(vocab.special_eos_id);
478
+ return true;
479
+ }
480
+ return false;
481
+ }
482
+
483
+ void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
484
+ if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
485
+ LLAMA_LOG_WARN(
486
+ "%s: Added a BOS token to the prompt as specified by the model but the prompt "
487
+ "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
488
+ "Are you sure this is what you want?\n", __FUNCTION__);
489
+ }
490
+ if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
491
+ LLAMA_LOG_WARN(
492
+ "%s: Added a EOS token to the prompt as specified by the model but the prompt "
493
+ "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
494
+ "Are you sure this is what you want?\n", __FUNCTION__);
495
+ }
496
+ }
497
+
498
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
499
+ int final_prev_index = -1;
500
+
501
+ const auto word_collection = unicode_regex_split(text, regex_exprs);
502
+
503
+ symbols_final.clear();
504
+
505
+ for (auto & word : word_collection) {
506
+ work_queue = llm_bigram_bpe::queue();
507
+ symbols.clear();
508
+
509
+ int index = 0;
510
+ size_t offset = 0;
511
+
512
+ if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
513
+ symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
514
+ offset = word.size();
515
+ }
516
+
517
+ while (offset < word.size()) {
518
+ llm_symbol sym;
519
+ size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset]));
520
+ sym.text = word.c_str() + offset;
521
+ sym.n = char_len;
522
+ offset += sym.n;
523
+ sym.prev = index - 1;
524
+ sym.next = offset == word.size() ? -1 : index + 1;
525
+ index++;
526
+ symbols.emplace_back(sym);
527
+ }
528
+ for (size_t i = 1; i < symbols.size(); ++i) {
529
+ add_new_bigram(i - 1, i);
530
+ }
531
+
532
+ // build token(s)
533
+ while (!work_queue.empty()) {
534
+ auto bigram = work_queue.top();
535
+ work_queue.pop();
536
+
537
+ auto & left_symbol = symbols[bigram.left];
538
+ auto & right_symbol = symbols[bigram.right];
539
+
540
+ if (left_symbol.n == 0 || right_symbol.n == 0) {
541
+ continue;
542
+ }
543
+ std::string left_token = std::string(left_symbol.text, left_symbol.n);
544
+ std::string right_token = std::string(right_symbol.text, right_symbol.n);
545
+ if (left_token + right_token != bigram.text) {
546
+ continue; // Skip this bigram if it's outdated
547
+ }
548
+
549
+ // merge the right sym into the left one
550
+ left_symbol.n += right_symbol.n;
551
+ right_symbol.n = 0;
552
+
553
+ // remove the right sym from the chain
554
+ left_symbol.next = right_symbol.next;
555
+ if (right_symbol.next >= 0) {
556
+ symbols[right_symbol.next].prev = bigram.left;
557
+ }
558
+
559
+ add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol
560
+ add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol
561
+ }
562
+
563
+ // add the finished tokens to the final list keeping correct order for next and prev
564
+ for (auto & sym : symbols) {
565
+ if (sym.n > 0) {
566
+ sym.prev = final_prev_index;
567
+ sym.next = -1;
568
+ if (final_prev_index != -1) {
569
+ symbols_final[final_prev_index].next = symbols_final.size();
570
+ }
571
+ symbols_final.emplace_back(sym);
572
+ final_prev_index = symbols_final.size() - 1;
573
+ }
574
+ }
575
+ }
576
+
577
+ symbols = symbols_final;
578
+
579
+ if (!symbols.empty()) {
580
+ for (int i = 0; i != -1; i = symbols[i].next) {
581
+ auto & symbol = symbols[i];
582
+ if (symbol.n == 0) {
583
+ continue;
584
+ }
585
+
586
+ const std::string str = std::string(symbol.text, symbol.n);
587
+ const auto token = vocab.token_to_id.find(str);
588
+
589
+ if (token == vocab.token_to_id.end()) {
590
+ for (auto j = str.begin(); j != str.end(); ++j) {
591
+ std::string byte_str(1, *j);
592
+ auto token_multibyte = vocab.token_to_id.find(byte_str);
593
+ if (token_multibyte != vocab.token_to_id.end()) {
594
+ output.push_back(token_multibyte->second);
595
+ }
596
+ }
597
+ } else {
598
+ output.push_back((*token).second);
599
+ }
600
+ }
601
+ }
602
+ }
603
+
604
+ private:
605
+ void add_new_bigram(int left, int right) {
606
+ if (left == -1 || right == -1) {
607
+ return;
608
+ }
609
+
610
+ std::string left_token = std::string(symbols[left].text, symbols[left].n);
611
+ std::string right_token = std::string(symbols[right].text, symbols[right].n);
612
+
613
+ int rank_found = -1;
614
+
615
+ rank_found = vocab.find_bpe_rank(left_token, right_token);
616
+
617
+ if (rank_found < 0) {
618
+ return;
619
+ }
620
+
621
+ llm_bigram_bpe bigram;
622
+
623
+ bigram.left = left;
624
+ bigram.right = right;
625
+ bigram.text = left_token + right_token;
626
+ bigram.size = left_token.size() + right_token.size();
627
+ bigram.rank = rank_found;
628
+
629
+ work_queue.push(bigram);
630
+ }
631
+
632
+ const llama_vocab & vocab;
633
+
634
+ std::vector<std::string> regex_exprs;
635
+
636
+ std::vector<llm_symbol> symbols;
637
+ std::vector<llm_symbol> symbols_final;
638
+
639
+ llm_bigram_bpe::queue work_queue;
640
+ };
641
+
642
+ //
643
+ // WPM tokenizer
644
+ //
645
+
646
+ struct llm_tokenizer_wpm {
647
+ llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
648
+
649
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
650
+ const auto & token_map = vocab.token_to_id;
651
+
652
+ // normalize and split by whitespace
653
+ std::vector<std::string> words = preprocess(text);
654
+
655
+ // bos token prepended already
656
+
657
+ // find the longest tokens that form the words
658
+ for (const std::string & word : words) {
659
+ // skip empty words
660
+ if (word.size() == 0) {
661
+ continue;
662
+ }
663
+
664
+ // prepend phantom space
665
+ const std::string word1 = "\xe2\x96\x81" + word;
666
+ const int n = word1.size();
667
+
668
+ const size_t current_tokens = output.size();
669
+
670
+ // we're at the start of a new word
671
+ // move through character position in word
672
+ for (int i = 0; i < n; ++i) {
673
+ // loop through possible match length
674
+ bool match = false;
675
+ for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
676
+ auto it = token_map.find(word1.substr(i, j - i));
677
+ if (it != token_map.end()) {
678
+ output.push_back(it->second);
679
+ match = true;
680
+ i = j - 1;
681
+ break;
682
+ }
683
+ }
684
+
685
+ if (!match) { // discard all
686
+ output.resize(current_tokens);
687
+ break; // and discard next tokens
688
+ }
689
+ }
690
+
691
+ // we didn't find any matches for this word
692
+ if (current_tokens == output.size()) {
693
+ output.push_back(vocab.special_unk_id);
694
+ }
695
+ }
696
+ }
697
+
698
+ // TODO: reduce string copies by using cpts_offs array
699
+ std::vector<std::string> preprocess(const std::string & text) const {
700
+ const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
701
+ std::vector<std::string> words(1, "");
702
+
703
+ for (const uint32_t cpt : cpts_nfd) {
704
+ const auto flags = unicode_cpt_flags(cpt);
705
+
706
+ if (flags.is_whitespace) {
707
+ if (words.back().size()) { // finish previous word if any
708
+ words.emplace_back();
709
+ }
710
+ continue;
711
+ }
712
+
713
+ assert (!flags.is_separator);
714
+ if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
715
+ continue;
716
+ }
717
+
718
+ const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
719
+ if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
720
+ if (words.back().size()) { // finish previous word if any
721
+ words.emplace_back();
722
+ }
723
+ words.back() = s; // single char word
724
+ words.emplace_back(); // start a new word
725
+ } else {
726
+ words.back() += s; // append char to word
727
+ }
728
+ }
729
+
730
+ if (!words.back().size()) {
731
+ words.pop_back();
732
+ }
733
+
734
+ return words;
735
+ }
736
+
737
+ static bool is_chinese_char(uint32_t cpt) {
738
+ return
739
+ (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
740
+ (cpt >= 0x03400 && cpt <= 0x04DBF) ||
741
+ (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
742
+ (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
743
+ (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
744
+ (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
745
+ (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
746
+ (cpt >= 0x2F800 && cpt <= 0x2FA1F);
747
+ //(cpt >= 0x3000 && cpt <= 0x303F) ||
748
+ //(cpt >= 0xFF00 && cpt <= 0xFFEF);
749
+ }
750
+
751
+ const llama_vocab & vocab;
752
+ };
753
+
754
+ //
755
+ // UGM tokenizer
756
+ //
757
+
758
+ struct llm_tokenizer_ugm {
759
+ llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
760
+ if (vocab.precompiled_charsmap.size() > 0) {
761
+ size_t charsmap_offset = 0;
762
+
763
+ // First four bytes of precompiled_charsmap contains length of binary
764
+ // blob containing XOR-compressed compact double array (XCDA) entries
765
+ uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
766
+ charsmap_offset += sizeof(xcda_blob_size);
767
+ if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
768
+ throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
769
+ }
770
+
771
+ // Next xcda_blob_size bytes contain entries of XOR-compressed compact
772
+ // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
773
+ xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
774
+ xcda_array_size = xcda_blob_size / sizeof(uint32_t);
775
+ charsmap_offset += xcda_blob_size;
776
+
777
+ // Remaining bytes of precompiled charsmap contain null-terminated
778
+ // replacement strings for prefixes matched by the XCDA.
779
+ prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
780
+ prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
781
+ }
782
+
783
+ for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
784
+ const auto &token_data = vocab.id_to_token[id];
785
+
786
+ if (llama_is_normal_token(vocab, id)) {
787
+ min_score = std::min<float>(min_score, token_data.score);
788
+ max_score = std::max<float>(max_score, token_data.score);
789
+ }
790
+
791
+ if (llama_is_normal_token(vocab, id) ||
792
+ llama_is_user_defined_token(vocab, id) ||
793
+ llama_is_unused_token(vocab, id)) {
794
+ token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
795
+ }
796
+
797
+ if (llama_is_user_defined_token(vocab, id)) {
798
+ user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
799
+ }
800
+ }
801
+
802
+ unknown_token_score = min_score - unknown_token_score_penalty;
803
+ }
804
+
805
+ /* This implementation is based on SentencePiece optimized Viterbi algorithm for
806
+ * unigram language models. The general idea is to:
807
+ * - move along the input sequence in steps of one UTF code point,
808
+ * - at each step find all possible tokenizations of the prefix by
809
+ * traversing the tokens trie,
810
+ * - for each tokenization store the best one so far (by higher score)
811
+ * - use the position in sequence after given token as an index to store
812
+ * results
813
+ * - if there was no valid tokenization of the current UTF code point
814
+ * then use unknown token with additional score penalty
815
+ * After processing the whole sequence we backtrack from the end to get
816
+ * the best tokenization.
817
+ */
818
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
819
+ // normalize the input first
820
+ std::string normalized;
821
+ normalize(text, &normalized);
822
+ size_t input_len = normalized.size();
823
+ if (input_len == 0) {
824
+ return;
825
+ }
826
+
827
+ // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
828
+ std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
829
+ // at the beginning tokenization score is zero
830
+ tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
831
+
832
+ for (size_t input_offset = 0; input_offset < input_len;) {
833
+ size_t prefix_offset = input_offset;
834
+ // calculate how many code units are in the currently processed UTF code point
835
+ size_t n_utf8_code_units = std::min<size_t>(unicode_len_utf8(normalized[input_offset]), input_len - input_offset);
836
+
837
+ // traverse the token matcher trie to find a matching token
838
+ bool single_codepoint_token_found = false;
839
+ const struct best_tokenization & current_best = tokenization_results[input_offset];
840
+ struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
841
+
842
+ while (prefix_offset <= input_len && node != NULL) {
843
+ // check if we found valid token in prefix
844
+ if (node->has_value) {
845
+ // check if it corresponds to the whole UTF code point
846
+ if (prefix_offset - input_offset == n_utf8_code_units) {
847
+ single_codepoint_token_found = true;
848
+ }
849
+ llama_token token_id = node->value;
850
+ const auto & token_data = vocab.id_to_token[token_id];
851
+
852
+ // we set the user-defined token scores to 0 to make them more likely to be selected
853
+ // (normal token scores are log probabilities, so they are negative)
854
+ // score type is double here to make tokenization results exactly
855
+ // the same as in the HF tokenizer using SentencePiece
856
+ const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
857
+ const double challenger_score = current_best.score_sum + token_score;
858
+ struct best_tokenization & current_champ = tokenization_results[prefix_offset];
859
+ if (challenger_score > current_champ.score_sum) {
860
+ struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
861
+ current_champ = challenger;
862
+ }
863
+ }
864
+ node = node->traverse(normalized[prefix_offset++]);
865
+ }
866
+
867
+ // if we didn't find a valid token corresponding to the whole UTF code point
868
+ // then use unknown token as the tokenization of this UTF code point
869
+ if (!single_codepoint_token_found) {
870
+ const double challenger_score = current_best.score_sum + unknown_token_score;
871
+ prefix_offset = input_offset + n_utf8_code_units;
872
+ struct best_tokenization & current_champ = tokenization_results[prefix_offset];
873
+ if (challenger_score > current_champ.score_sum) {
874
+ struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
875
+ current_champ = challenger;
876
+ }
877
+ }
878
+
879
+ // move to the next UTF code point
880
+ input_offset += n_utf8_code_units;
881
+ }
882
+
883
+ // now backtrack from the end to gather token ids of the best tokenization
884
+ // merge sequences of consecutive unknown tokens into single unknown tokens
885
+ bool is_prev_unknown = false;
886
+ for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
887
+ bool is_unknown = tokenization.token_id == vocab.special_unk_id;
888
+ if (!(is_prev_unknown && is_unknown)) {
889
+ output.push_back(tokenization.token_id);
890
+ }
891
+ if (tokenization.input_offset == 0) {
892
+ break;
893
+ }
894
+ is_prev_unknown = is_unknown;
895
+ }
896
+
897
+ // reverse the output since we added tokens starting from the end of the input
898
+ std::reverse(output.begin(), output.end());
899
+ }
900
+
901
+ private:
902
+ const llama_vocab & vocab;
903
+
904
+ // helper structure for returning normalization results
905
+ struct normalization_result {
906
+ const char * normalized;
907
+ size_t normalized_len;
908
+ size_t consumed_input;
909
+ };
910
+
911
+ void normalize(const std::string& input, std::string * normalized) {
912
+ normalized->clear();
913
+ normalized->reserve(input.size() * 3);
914
+
915
+ const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
916
+
917
+ bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
918
+ bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
919
+ bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
920
+
921
+ bool is_space_prepended = false;
922
+ bool processing_non_ws = false;
923
+
924
+ size_t input_len = input.size();
925
+
926
+ for (size_t input_offset = 0; input_offset < input_len; ) {
927
+ auto norm_res = normalize_prefix(input, input_offset);
928
+ for (size_t i = 0; i < norm_res.normalized_len; i++) {
929
+ char c = norm_res.normalized[i];
930
+ if (c != ' ') {
931
+ if (!processing_non_ws) {
932
+ processing_non_ws = true;
933
+ if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) {
934
+ normalized->append(space);
935
+ is_space_prepended = true;
936
+ }
937
+ }
938
+ normalized->push_back(c);
939
+ } else {
940
+ if (processing_non_ws) {
941
+ processing_non_ws = false;
942
+ }
943
+ if (!shall_merge_spaces) {
944
+ normalized->append(space);
945
+ }
946
+ }
947
+ }
948
+
949
+ input_offset += norm_res.consumed_input;
950
+ }
951
+
952
+ if (shall_append_space) {
953
+ normalized->append(space);
954
+ }
955
+ }
956
+
957
+ /*
958
+ * This structure is a view wrapper for XOR-compressed double array (XCDA)
959
+ * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
960
+ * Eeach bit-packed entry contains:
961
+ * - BASE array value in bits 10-30
962
+ * - LCHECK array value in bits 0-7
963
+ * - LEAF array value in bit 9
964
+ * Entries containing indexes of replacement sequences have set bit 31
965
+ */
966
+ struct xcda_array_view {
967
+ public:
968
+ xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) {
969
+ }
970
+ uint32_t get_base(size_t index) {
971
+ uint32_t packed_node = get_node(index);
972
+ return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6);
973
+ }
974
+ uint32_t get_lcheck(size_t index) {
975
+ uint32_t packed_node = get_node(index);
976
+ return packed_node & ((1U << 31) | 0xff);
977
+ }
978
+ bool get_leaf(size_t index) {
979
+ uint32_t packed_node = get_node(index);
980
+ return (packed_node >> 8) & 1;
981
+ }
982
+ uint32_t get_value(size_t index) {
983
+ uint32_t packed_node = get_node(index);
984
+ return packed_node & ((1U << 31) - 1);
985
+ }
986
+ private:
987
+ uint32_t get_node(size_t index) {
988
+ if (index > xcda_array_size) {
989
+ throw std::runtime_error("Index out of array bounds in XCDA array!");
990
+ }
991
+ return xcda_array[index];
992
+ }
993
+ const uint32_t * xcda_array;
994
+ size_t xcda_array_size;
995
+ };
996
+
997
+ struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
998
+ if (input_offset == input.size()) {
999
+ return { &input[input_offset], 0, 0 };
1000
+ }
1001
+
1002
+ // if input prefix matches some user-defined token return this token as normalization result
1003
+ auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1004
+ if (user_defined_token_match.second > 0) {
1005
+ return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
1006
+ }
1007
+
1008
+ size_t longest_prefix_length = 0;
1009
+ size_t longest_prefix_offset = 0;
1010
+
1011
+ if (xcda_array_size > 0) {
1012
+ struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
1013
+
1014
+ // Find the longest normalized sequence matching the input prefix by walking
1015
+ // the XOR-compressed compact double array (XCDA) starting from the root node
1016
+ // We find the index of the next node by calculating BASE[s] ^ c where s is
1017
+ // the index of the previous node and c is a numerical character value
1018
+ uint32_t node_index = 0;
1019
+ // get BASE of the root node
1020
+ node_index = xcda_view.get_base(node_index);
1021
+ for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
1022
+ unsigned char c = input[prefix_offset];
1023
+ if (c == 0) {
1024
+ break;
1025
+ }
1026
+ node_index ^= c;
1027
+ // if value of LCHECK is not c it means that this is not a child of
1028
+ // the previous node, so we stop matching
1029
+ if (xcda_view.get_lcheck(node_index) != c) {
1030
+ break;
1031
+ }
1032
+ bool is_leaf = xcda_view.get_leaf(node_index);
1033
+ // get BASE of the current node
1034
+ node_index ^= xcda_view.get_base(node_index);
1035
+ // if LEAF of the current node is true, it means that its BASE points to the node
1036
+ // containing index of replacement sequence for currently matched input prefix
1037
+ if (is_leaf)
1038
+ {
1039
+ longest_prefix_length = prefix_offset - input_offset + 1;
1040
+ // get index of replacement sequence for currently matched input prefix
1041
+ longest_prefix_offset = xcda_view.get_value(node_index);
1042
+ }
1043
+ }
1044
+ }
1045
+
1046
+ if (longest_prefix_length > 0) {
1047
+ // we have a match, so return the replacement sequence
1048
+ if (longest_prefix_offset >= prefix_replacements_size) {
1049
+ throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
1050
+ }
1051
+ const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
1052
+ return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
1053
+ } else {
1054
+ // check if the input prefix contains a valid sequence of UTF-8 code units
1055
+ try {
1056
+ // if yes, return this sequence unmodified
1057
+ size_t prefix_offset = input_offset;
1058
+ unicode_cpt_from_utf8(input, prefix_offset);
1059
+ return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
1060
+ } catch (std::invalid_argument & /*ex*/) {
1061
+ // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
1062
+ return { "\xEF\xBF\xBD", 3, 1 };
1063
+ }
1064
+ }
1065
+ }
1066
+
1067
+ // escaped space symbol - U+2581 (Lower One Eighth Block)
1068
+ const std::string escaped_space = "\xE2\x96\x81";
1069
+
1070
+ const char * prefix_replacements = NULL;
1071
+ size_t prefix_replacements_size = 0;
1072
+
1073
+ const uint32_t * xcda_array = NULL;
1074
+ size_t xcda_array_size = 0;
1075
+
1076
+ struct naive_trie user_defined_token_matcher;
1077
+
1078
+ // this structure stores the best tokenization so far at input_offset
1079
+ struct best_tokenization {
1080
+ llama_token token_id;
1081
+ size_t input_offset;
1082
+ float score_sum;
1083
+ };
1084
+
1085
+ float min_score = FLT_MAX;
1086
+ float max_score = -FLT_MAX;
1087
+
1088
+ float unknown_token_score_penalty = 10.0;
1089
+ float unknown_token_score;
1090
+
1091
+ struct naive_trie token_matcher;
1092
+ };
1093
+
1094
+ //
1095
+ // (de-) tokenize
1096
+ //
1097
+
1098
+ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
1099
+ FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
1100
+ FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
1101
+ } FRAGMENT_BUFFER_VARIANT_TYPE;
1102
+
1103
+ struct fragment_buffer_variant {
1104
+ fragment_buffer_variant(llama_vocab::id _token)
1105
+ :
1106
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
1107
+ token(_token),
1108
+ raw_text(_dummy),
1109
+ offset(0),
1110
+ length(0) {}
1111
+
1112
+ fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
1113
+ :
1114
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
1115
+ token((llama_vocab::id) - 1),
1116
+ raw_text(_raw_text),
1117
+ offset(_offset),
1118
+ length(_length){
1119
+ LM_GGML_ASSERT(_offset >= 0);
1120
+ LM_GGML_ASSERT(_length >= 1);
1121
+ LM_GGML_ASSERT(offset + length <= raw_text.length());
1122
+ }
1123
+
1124
+ const FRAGMENT_BUFFER_VARIANT_TYPE type;
1125
+ const llama_vocab::id token;
1126
+ const std::string _dummy;
1127
+ const std::string & raw_text;
1128
+ const uint64_t offset;
1129
+ const uint64_t length;
1130
+ };
1131
+
1132
+ // #define PRETOKENIZERDEBUG
1133
+
1134
+ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
1135
+ // for each special token
1136
+ for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
1137
+ const auto & data = vocab.id_to_token[special_id];
1138
+ const auto & special_token = data.text;
1139
+
1140
+ if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
1141
+ // Ignore control and unknown tokens when parse_special == false
1142
+ continue;
1143
+ // User-defined tokens are still pre-tokenized before everything else
1144
+ // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
1145
+ // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
1146
+ }
1147
+
1148
+ // for each text fragment
1149
+ std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
1150
+ while (it != buffer.end()) {
1151
+ auto & fragment = (*it);
1152
+
1153
+ // if a fragment is text ( not yet processed )
1154
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1155
+ auto & raw_text = fragment.raw_text;
1156
+
1157
+ auto raw_text_base_offset = fragment.offset;
1158
+ auto raw_text_base_length = fragment.length;
1159
+
1160
+ // loop over the text
1161
+ while (true) {
1162
+ // find the first occurrence of a given special token in this fragment
1163
+ // passing offset argument only limit the "search area" but match coordinates
1164
+ // are still relative to the source full raw_text
1165
+ auto match = raw_text.find(special_token, raw_text_base_offset);
1166
+
1167
+ // no occurrences found, stop processing this fragment for a given special token
1168
+ if (match == std::string::npos) break;
1169
+
1170
+ // check if match is within bounds of offset <-> length
1171
+ if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
1172
+
1173
+ #ifdef PRETOKENIZERDEBUG
1174
+ LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
1175
+ #endif
1176
+ auto source = std::distance(buffer.begin(), it);
1177
+
1178
+ // if match is further than base offset
1179
+ // then we have some text to the left of it
1180
+ if (match > raw_text_base_offset) {
1181
+ // left
1182
+ const int64_t left_reminder_offset = raw_text_base_offset + 0;
1183
+ int64_t left_reminder_length = match - raw_text_base_offset;
1184
+
1185
+ if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
1186
+ while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
1187
+ left_reminder_length--;
1188
+ }
1189
+ }
1190
+
1191
+ if (left_reminder_length > 0) {
1192
+ buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
1193
+ it++;
1194
+ }
1195
+
1196
+ #ifdef PRETOKENIZERDEBUG
1197
+ LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
1198
+ #endif
1199
+ }
1200
+
1201
+ // special token
1202
+ buffer.emplace_after(it, special_id);
1203
+ it++;
1204
+
1205
+ // right
1206
+ if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
1207
+ int64_t right_reminder_offset = match + special_token.length();
1208
+ int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
1209
+
1210
+ if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
1211
+ while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
1212
+ right_reminder_offset++;
1213
+ right_reminder_length--;
1214
+ }
1215
+ }
1216
+
1217
+ if (right_reminder_length > 0) {
1218
+ buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
1219
+ it++;
1220
+ }
1221
+
1222
+ #ifdef PRETOKENIZERDEBUG
1223
+ LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
1224
+ #endif
1225
+
1226
+ if (source == 0) {
1227
+ buffer.erase_after(buffer.before_begin());
1228
+ } else {
1229
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
1230
+ }
1231
+
1232
+ // repeat for the right side
1233
+ raw_text_base_offset = right_reminder_offset;
1234
+ raw_text_base_length = right_reminder_length;
1235
+
1236
+ #ifdef PRETOKENIZERDEBUG
1237
+ LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
1238
+ #endif
1239
+ } else {
1240
+ if (source == 0) {
1241
+ buffer.erase_after(buffer.before_begin());
1242
+ } else {
1243
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
1244
+ }
1245
+ break;
1246
+ }
1247
+ }
1248
+ }
1249
+ it++;
1250
+ }
1251
+ }
1252
+ }
1253
+
1254
+ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
1255
+ std::vector<llama_vocab::id> output;
1256
+ std::forward_list<fragment_buffer_variant> fragment_buffer;
1257
+
1258
+ if (!raw_text.empty()) {
1259
+ fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
1260
+ tokenizer_st_partition(vocab, fragment_buffer, parse_special);
1261
+ }
1262
+
1263
+ switch (vocab.type) {
1264
+ case LLAMA_VOCAB_TYPE_SPM:
1265
+ {
1266
+ // OG tokenizer behavior:
1267
+ //
1268
+ // tokenizer.encode('', add_special_tokens=True) returns [1]
1269
+ // tokenizer.encode('', add_special_tokens=False) returns []
1270
+
1271
+ bool is_prev_special = true; // prefix with space if first token
1272
+
1273
+ if (add_special && vocab.tokenizer_add_bos) {
1274
+ LM_GGML_ASSERT(vocab.special_bos_id != -1);
1275
+ output.push_back(vocab.special_bos_id);
1276
+ is_prev_special = true;
1277
+ }
1278
+
1279
+ for (const auto & fragment : fragment_buffer) {
1280
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1281
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1282
+
1283
+ // prefix with space if previous is special
1284
+ if (vocab.tokenizer_add_space_prefix && is_prev_special) {
1285
+ raw_text = " " + raw_text;
1286
+ }
1287
+
1288
+ #ifdef PRETOKENIZERDEBUG
1289
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1290
+ #endif
1291
+ llm_tokenizer_spm tokenizer(vocab);
1292
+ llama_escape_whitespace(raw_text);
1293
+ tokenizer.tokenize(raw_text, output);
1294
+ is_prev_special = false;
1295
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1296
+ output.push_back(fragment.token);
1297
+ is_prev_special = true;
1298
+ }
1299
+ }
1300
+
1301
+ if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1302
+ LLAMA_LOG_WARN(
1303
+ "%s: Added a BOS token to the prompt as specified by the model but the prompt "
1304
+ "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
1305
+ "Are you sure this is what you want?\n", __FUNCTION__);
1306
+ }
1307
+
1308
+ if (add_special && vocab.tokenizer_add_eos) {
1309
+ LM_GGML_ASSERT(vocab.special_eos_id != -1);
1310
+ output.push_back(vocab.special_eos_id);
1311
+ }
1312
+ } break;
1313
+ case LLAMA_VOCAB_TYPE_BPE:
1314
+ {
1315
+ llm_tokenizer_bpe tokenizer(vocab);
1316
+
1317
+ if (add_special) {
1318
+ tokenizer.append_bos(output);
1319
+ }
1320
+ for (const auto & fragment : fragment_buffer) {
1321
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1322
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1323
+
1324
+ #ifdef PRETOKENIZERDEBUG
1325
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1326
+ #endif
1327
+ tokenizer.tokenize(raw_text, output);
1328
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1329
+ tokenizer.append(fragment.token, output);
1330
+ }
1331
+ }
1332
+
1333
+ if (add_special) {
1334
+ tokenizer.append_eos(output);
1335
+ tokenizer.check_double_bos_eos(output);
1336
+ }
1337
+ } break;
1338
+ case LLAMA_VOCAB_TYPE_WPM:
1339
+ {
1340
+ if (add_special) {
1341
+ LM_GGML_ASSERT(vocab.special_cls_id != -1);
1342
+ output.push_back(vocab.special_cls_id);
1343
+ }
1344
+
1345
+ llm_tokenizer_wpm tokenizer(vocab);
1346
+
1347
+ for (const auto & fragment : fragment_buffer) {
1348
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1349
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1350
+
1351
+ #ifdef PRETOKENIZERDEBUG
1352
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1353
+ #endif
1354
+ tokenizer.tokenize(raw_text, output);
1355
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1356
+ output.push_back(fragment.token);
1357
+ }
1358
+ }
1359
+
1360
+ if (add_special) {
1361
+ LM_GGML_ASSERT(vocab.special_sep_id != -1);
1362
+ output.push_back(vocab.special_sep_id);
1363
+ }
1364
+ } break;
1365
+ case LLAMA_VOCAB_TYPE_UGM:
1366
+ {
1367
+ llm_tokenizer_ugm tokenizer(vocab);
1368
+
1369
+ if (add_special && vocab.tokenizer_add_bos != 0) {
1370
+ LM_GGML_ASSERT(vocab.special_bos_id != -1);
1371
+ output.push_back(vocab.special_bos_id);
1372
+ }
1373
+
1374
+ for (const auto & fragment : fragment_buffer) {
1375
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1376
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1377
+ #ifdef PRETOKENIZERDEBUG
1378
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1379
+ #endif
1380
+ tokenizer.tokenize(raw_text, output);
1381
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1382
+ output.push_back(fragment.token);
1383
+ }
1384
+ }
1385
+
1386
+ if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1387
+ LLAMA_LOG_WARN(
1388
+ "%s: Added a BOS token to the prompt as specified by the model but the prompt "
1389
+ "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
1390
+ "Are you sure this is what you want?\n", __FUNCTION__);
1391
+ }
1392
+
1393
+ if (add_special && vocab.tokenizer_add_eos == 1) {
1394
+ LM_GGML_ASSERT(vocab.special_eos_id != -1);
1395
+ output.push_back(vocab.special_eos_id);
1396
+ }
1397
+ } break;
1398
+ case LLAMA_VOCAB_TYPE_NONE:
1399
+ LM_GGML_ABORT("fatal error");
1400
+ }
1401
+
1402
+ return output;
1403
+ }
1404
+
1405
+ llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
1406
+ LM_GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
1407
+ static const char * hex = "0123456789ABCDEF";
1408
+ switch (llama_vocab_get_type(vocab)) {
1409
+ case LLAMA_VOCAB_TYPE_SPM:
1410
+ case LLAMA_VOCAB_TYPE_UGM: {
1411
+ const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
1412
+ auto token = vocab.token_to_id.find(buf);
1413
+ if (token != vocab.token_to_id.end()) {
1414
+ return (*token).second;
1415
+ }
1416
+ // Try to fall back to just the byte as a string
1417
+ const char buf2[2] = { (char)ch, 0 };
1418
+ return vocab.token_to_id.at(buf2);
1419
+ }
1420
+ case LLAMA_VOCAB_TYPE_WPM:
1421
+ case LLAMA_VOCAB_TYPE_BPE: {
1422
+ return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
1423
+ }
1424
+ default:
1425
+ LM_GGML_ABORT("fatal error");
1426
+ }
1427
+ }
1428
+
1429
+ const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
1430
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
1431
+ return vocab.id_to_token[token].text.c_str();
1432
+ }
1433
+
1434
+ float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
1435
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
1436
+ return vocab.id_to_token[token].score;
1437
+ }
1438
+
1439
+ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
1440
+ LM_GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
1441
+ return vocab.id_to_token[token].attr;
1442
+ }
1443
+
1444
+ bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1445
+ return token != -1 && (
1446
+ token == llama_token_eos_impl(vocab) ||
1447
+ token == llama_token_eot_impl(vocab)
1448
+ );
1449
+ }
1450
+
1451
+ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
1452
+ return llama_is_control_token(vocab, token);
1453
+ }
1454
+
1455
+ llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
1456
+ return vocab.special_bos_id;
1457
+ }
1458
+
1459
+ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
1460
+ return vocab.special_eos_id;
1461
+ }
1462
+
1463
+ llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
1464
+ return vocab.special_cls_id;
1465
+ }
1466
+
1467
+ llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
1468
+ return vocab.special_sep_id;
1469
+ }
1470
+
1471
+ llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
1472
+ return vocab.linefeed_id;
1473
+ }
1474
+
1475
+ llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
1476
+ return vocab.special_pad_id;
1477
+ }
1478
+
1479
+ int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
1480
+ return vocab.tokenizer_add_bos;
1481
+ }
1482
+
1483
+ int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
1484
+ return vocab.tokenizer_add_eos;
1485
+ }
1486
+
1487
+ llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1488
+ return vocab.special_prefix_id;
1489
+ }
1490
+
1491
+ llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1492
+ return vocab.special_middle_id;
1493
+ }
1494
+
1495
+ llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1496
+ return vocab.special_suffix_id;
1497
+ }
1498
+
1499
+ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1500
+ return vocab.special_eot_id;
1501
+ }
1502
+
1503
+ int32_t llama_tokenize_impl(
1504
+ const struct llama_vocab & vocab,
1505
+ const char * text,
1506
+ int32_t text_len,
1507
+ llama_token * tokens,
1508
+ int32_t n_tokens_max,
1509
+ bool add_special,
1510
+ bool parse_special) {
1511
+ auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
1512
+ if (n_tokens_max < (int) res.size()) {
1513
+ // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
1514
+ return -((int) res.size());
1515
+ }
1516
+
1517
+ for (size_t i = 0; i < res.size(); i++) {
1518
+ tokens[i] = res[i];
1519
+ }
1520
+
1521
+ return res.size();
1522
+ }
1523
+
1524
+ static std::string llama_decode_text(const std::string & text) {
1525
+ std::string decoded_text;
1526
+
1527
+ const auto cpts = unicode_cpts_from_utf8(text);
1528
+ for (const auto cpt : cpts) {
1529
+ const auto utf8 = unicode_cpt_to_utf8(cpt);
1530
+ try {
1531
+ decoded_text += unicode_utf8_to_byte(utf8);
1532
+ } catch (const std::out_of_range & /*e*/) {
1533
+ decoded_text += "[UNK_BYTE_0x";
1534
+ for (const auto c : utf8) {
1535
+ decoded_text += format("%02x", (uint8_t) c);
1536
+ }
1537
+ decoded_text += text + "]";
1538
+ }
1539
+ }
1540
+
1541
+ return decoded_text;
1542
+ }
1543
+
1544
+ // does not write null-terminator to buf
1545
+ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
1546
+ // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
1547
+ static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
1548
+ const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
1549
+ if (!special && (attr & attr_special)) {
1550
+ return 0;
1551
+ }
1552
+
1553
+ // copy piece chars to output text buffer
1554
+ // skip up to 'lstrip' leading spaces before copying
1555
+ auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
1556
+ for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
1557
+ token++;
1558
+ size--;
1559
+ }
1560
+ if (length < (int32_t)size) {
1561
+ return -(int32_t) size;
1562
+ }
1563
+ memcpy(buf, token, size);
1564
+ return (int32_t) size;
1565
+ };
1566
+
1567
+ // if we have a cache - use it
1568
+ {
1569
+ const auto & cache = vocab.cache_token_to_piece;
1570
+
1571
+ if (!cache.empty()) {
1572
+ const auto & result = cache.at(token);
1573
+ return _try_copy(result.data(), result.size());
1574
+ }
1575
+ }
1576
+
1577
+ if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
1578
+ const std::string & token_text = vocab.id_to_token[token].text;
1579
+ switch (llama_vocab_get_type(vocab)) {
1580
+ case LLAMA_VOCAB_TYPE_WPM:
1581
+ case LLAMA_VOCAB_TYPE_SPM:
1582
+ case LLAMA_VOCAB_TYPE_UGM: {
1583
+ // NOTE: we accept all unsupported token types,
1584
+ // suppressing them like CONTROL tokens.
1585
+ if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1586
+ return _try_copy(token_text.data(), token_text.size());
1587
+ } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1588
+ std::string result = token_text;
1589
+ llama_unescape_whitespace(result);
1590
+ return _try_copy(result.data(), result.size());
1591
+ } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
1592
+ char byte = (char) llama_token_to_byte(vocab, token);
1593
+ return _try_copy((char*) &byte, 1);
1594
+ }
1595
+ break;
1596
+ }
1597
+ case LLAMA_VOCAB_TYPE_BPE: {
1598
+ // NOTE: we accept all unsupported token types,
1599
+ // suppressing them like CONTROL tokens.
1600
+ if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1601
+ return _try_copy(token_text.data(), token_text.size());
1602
+ } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1603
+ std::string result = llama_decode_text(token_text);
1604
+ return _try_copy(result.data(), result.size());
1605
+ }
1606
+ break;
1607
+ }
1608
+ default:
1609
+ LM_GGML_ABORT("fatal error");
1610
+ }
1611
+ }
1612
+
1613
+ return 0;
1614
+ }
1615
+
1616
+ int32_t llama_detokenize_impl(
1617
+ const struct llama_vocab & vocab,
1618
+ const llama_token * tokens,
1619
+ int32_t n_tokens,
1620
+ char * text,
1621
+ int32_t text_len_max,
1622
+ bool remove_special,
1623
+ bool unparse_special) {
1624
+ int32_t avail = text_len_max;
1625
+ int32_t total = 0;
1626
+
1627
+ // remove the leading space
1628
+ bool remove_space = vocab.tokenizer_add_space_prefix;
1629
+
1630
+ if (remove_special && vocab.tokenizer_add_bos) {
1631
+ if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
1632
+ remove_space = false;
1633
+ n_tokens--;
1634
+ tokens++;
1635
+ }
1636
+ }
1637
+
1638
+ if (remove_special && vocab.tokenizer_add_eos) {
1639
+ if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
1640
+ n_tokens--;
1641
+ }
1642
+ }
1643
+
1644
+ for (int32_t i = 0; i < n_tokens; ++i) {
1645
+ LM_GGML_ASSERT(avail >= 0);
1646
+ int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
1647
+ remove_space = false;
1648
+ if (n_chars < 0) {
1649
+ avail = 0;
1650
+ total -= n_chars;
1651
+ } else if (n_chars > 0) {
1652
+ avail -= n_chars;
1653
+ text += n_chars;
1654
+ total += n_chars;
1655
+ }
1656
+ }
1657
+
1658
+ if (total > text_len_max) {
1659
+ return -total;
1660
+ }
1661
+
1662
+ if (vocab.tokenizer_clean_spaces) {
1663
+ text -= total; // restart text
1664
+
1665
+ // first pass: characters ?!., //TODO: where do these characters come from?
1666
+ const int32_t total1 = total;
1667
+ total = total ? 1 : 0;
1668
+ for (int32_t i = 1; i < total1; ++i) {
1669
+ const char x = text[i];
1670
+ if (text[i - 1] == ' ') {
1671
+ if (x == '?' || x == '!' || x == '.' || x == ',') { // " ?", " !", " .", " ,"
1672
+ total--; // remove space
1673
+ }
1674
+ }
1675
+ text[total++] = x;
1676
+ }
1677
+
1678
+ // second pass: strip single apostrophe between spaces
1679
+ const int32_t total2 = total;
1680
+ total = total ? 1 : 0;
1681
+ for (int32_t i = 1; i < total2; ++i) {
1682
+ const char x = text[i];
1683
+ if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') { // " ' "
1684
+ total--; // remove prev space
1685
+ text[++i] = '\0'; // remove next space
1686
+ }
1687
+ text[total++] = x;
1688
+ }
1689
+
1690
+ // third pass: apostrophe contractions //NOTE: this makes sense?
1691
+ const int32_t total3 = total;
1692
+ total = total ? 1 : 0;
1693
+ for (int32_t i = 1; i < total3; ++i) {
1694
+ const char x = text[i];
1695
+ if (text[i - 1] == ' ') {
1696
+ if (x == '\'' && i + 1 < total3) {
1697
+ const char x1 = text[i + 1];
1698
+ if (x1 == 't' || x1 == 'd') { // " 't", " 'd"
1699
+ //total--; // remove space
1700
+ } else if (x1 == 's' || x1 == 'm') { // " 's", " 'm"
1701
+ total--; // remove space
1702
+ } else if (i + 2 < total3) {
1703
+ const char x2 = text[i + 2];
1704
+ if ((x1 == 'l' && x2 == 'l')) { // " 'll"
1705
+ //total--; // remove space
1706
+ } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) { // " 're", " 've"
1707
+ total--; // remove space
1708
+ } else {
1709
+ //total--; // remove space
1710
+ }
1711
+ } else {
1712
+ //total--; // remove space
1713
+ }
1714
+ }
1715
+ }
1716
+ text[total++] = x;
1717
+ }
1718
+ }
1719
+
1720
+ return total <= text_len_max ? total : -total;
1721
+ }