cui-llama.rn 1.2.0 → 1.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +2 -0
  2. package/android/src/main/CMakeLists.txt +2 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +39 -0
  5. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  6. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/cpp/common.cpp +36 -1
  8. package/cpp/common.h +5 -1
  9. package/cpp/ggml-aarch64.c +2 -11
  10. package/cpp/ggml-alloc.h +1 -1
  11. package/cpp/ggml-backend-impl.h +151 -78
  12. package/cpp/{ggml-backend.c → ggml-backend.cpp} +565 -269
  13. package/cpp/ggml-backend.h +147 -62
  14. package/cpp/ggml-impl.h +15 -0
  15. package/cpp/ggml-metal.h +8 -9
  16. package/cpp/ggml-metal.m +2428 -2111
  17. package/cpp/ggml-quants.c +2 -2
  18. package/cpp/ggml-quants.h +0 -4
  19. package/cpp/ggml.c +799 -1121
  20. package/cpp/ggml.h +79 -72
  21. package/cpp/llama-vocab.cpp +189 -106
  22. package/cpp/llama-vocab.h +18 -9
  23. package/cpp/llama.cpp +736 -341
  24. package/cpp/llama.h +9 -4
  25. package/cpp/unicode-data.cpp +6 -4
  26. package/cpp/unicode-data.h +4 -4
  27. package/cpp/unicode.cpp +14 -7
  28. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  29. package/lib/commonjs/index.js +4 -0
  30. package/lib/commonjs/index.js.map +1 -1
  31. package/lib/module/NativeRNLlama.js.map +1 -1
  32. package/lib/module/index.js +3 -0
  33. package/lib/module/index.js.map +1 -1
  34. package/lib/typescript/NativeRNLlama.d.ts +6 -0
  35. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  36. package/lib/typescript/index.d.ts +2 -1
  37. package/lib/typescript/index.d.ts.map +1 -1
  38. package/package.json +1 -1
  39. package/src/NativeRNLlama.ts +7 -0
  40. package/src/index.ts +5 -0
@@ -50,7 +50,7 @@ struct naive_trie {
50
50
  res.first->second.insert(key + 1, len - 1, value);
51
51
  }
52
52
  }
53
- std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
53
+ std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
54
54
  if (len == 0 || offset == len) {
55
55
  return std::make_pair(key, offset);
56
56
  }
@@ -79,6 +79,15 @@ struct naive_trie {
79
79
  // impl
80
80
  //
81
81
 
82
+ struct llm_tokenizer {
83
+ llm_tokenizer() {}
84
+ virtual ~llm_tokenizer() = default;
85
+ };
86
+
87
+ llama_vocab::~llama_vocab() {
88
+ delete tokenizer;
89
+ }
90
+
82
91
  int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
83
92
  LM_GGML_ASSERT(token_left.find(' ') == std::string::npos);
84
93
  LM_GGML_ASSERT(token_left.find('\n') == std::string::npos);
@@ -187,10 +196,15 @@ struct llm_bigram_spm {
187
196
  size_t size;
188
197
  };
189
198
 
190
- struct llm_tokenizer_spm {
191
- llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
199
+ struct llm_tokenizer_spm : llm_tokenizer {
200
+ llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
201
+ };
202
+
203
+ struct llm_tokenizer_spm_session {
204
+ llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
192
205
 
193
206
  void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
207
+
194
208
  // split string into utf8 chars
195
209
  int index = 0;
196
210
  size_t offs = 0;
@@ -271,7 +285,7 @@ private:
271
285
  return;
272
286
  }
273
287
 
274
- resegment(symbols[p->second.first], output);
288
+ resegment(symbols[p->second.first], output);
275
289
  resegment(symbols[p->second.second], output);
276
290
  }
277
291
 
@@ -279,7 +293,6 @@ private:
279
293
  if (left == -1 || right == -1) {
280
294
  return;
281
295
  }
282
-
283
296
  const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
284
297
  auto token = vocab.token_to_id.find(text);
285
298
 
@@ -306,10 +319,11 @@ private:
306
319
  }
307
320
 
308
321
  const llama_vocab & vocab;
322
+ // currently unused
323
+ // const llm_tokenizer_spm * spm_tokenizer;
309
324
 
310
325
  std::vector<llm_symbol> symbols;
311
326
  llm_bigram_spm::queue work_queue;
312
-
313
327
  std::map<std::string, std::pair<int, int>> rev_merge;
314
328
  };
315
329
 
@@ -352,8 +366,8 @@ struct llm_bigram_bpe {
352
366
  size_t size;
353
367
  };
354
368
 
355
- struct llm_tokenizer_bpe {
356
- llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
369
+ struct llm_tokenizer_bpe : llm_tokenizer {
370
+ llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
357
371
  LM_GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
358
372
  switch (vocab.type_pre) {
359
373
  case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
@@ -450,6 +464,20 @@ struct llm_tokenizer_bpe {
450
464
  "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
451
465
  };
452
466
  break;
467
+ case LLAMA_VOCAB_PRE_TYPE_CHAMELEON:
468
+ // Note: in theory, the special token (sentinel and image token) regex_exprs below
469
+ // are unnecessary, as they are split in `tokenizer_st_partition` anyway.
470
+ // However, since the upstream pre-tokenizer uses them, they are also
471
+ // included here (see https://huggingface.co/facebook/chameleon-7b).
472
+ regex_exprs = {
473
+ "<sentinel:[0-9]+>", // Sentinel tokens
474
+ "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens
475
+ "([\\t\\n]| | )", // directly from tokenizer.json
476
+ "\\p{N}", // Individual digits
477
+ "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated
478
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
479
+ };
480
+ break;
453
481
  default:
454
482
  // default regex for BPE tokenization pre-processing
455
483
  regex_exprs = {
@@ -462,7 +490,14 @@ struct llm_tokenizer_bpe {
462
490
  }
463
491
  }
464
492
 
465
- void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
493
+ std::vector<std::string> regex_exprs;
494
+ };
495
+
496
+ struct llm_tokenizer_bpe_session {
497
+ llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
498
+ bpe_tokenizer(static_cast<const llm_tokenizer_bpe *>(vocab.tokenizer)) {}
499
+
500
+ static void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) {
466
501
  output.push_back(token_id);
467
502
  }
468
503
 
@@ -501,12 +536,11 @@ struct llm_tokenizer_bpe {
501
536
 
502
537
  void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
503
538
  int final_prev_index = -1;
504
-
505
- const auto word_collection = unicode_regex_split(text, regex_exprs);
539
+ const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
506
540
 
507
541
  symbols_final.clear();
508
542
 
509
- for (auto & word : word_collection) {
543
+ for (const auto & word : word_collection) {
510
544
  work_queue = llm_bigram_bpe::queue();
511
545
  symbols.clear();
512
546
 
@@ -609,7 +643,6 @@ private:
609
643
  if (left == -1 || right == -1) {
610
644
  return;
611
645
  }
612
-
613
646
  std::string left_token = std::string(symbols[left].text, symbols[left].n);
614
647
  std::string right_token = std::string(symbols[right].text, symbols[right].n);
615
648
 
@@ -633,12 +666,10 @@ private:
633
666
  }
634
667
 
635
668
  const llama_vocab & vocab;
636
-
637
- std::vector<std::string> regex_exprs;
669
+ const llm_tokenizer_bpe * bpe_tokenizer;
638
670
 
639
671
  std::vector<llm_symbol> symbols;
640
672
  std::vector<llm_symbol> symbols_final;
641
-
642
673
  llm_bigram_bpe::queue work_queue;
643
674
  };
644
675
 
@@ -646,15 +677,17 @@ private:
646
677
  // WPM tokenizer
647
678
  //
648
679
 
649
- struct llm_tokenizer_wpm {
650
- llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
680
+ struct llm_tokenizer_wpm : llm_tokenizer {
681
+ llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
682
+ };
651
683
 
652
- void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
653
- const auto & token_map = vocab.token_to_id;
684
+ struct llm_tokenizer_wpm_session {
685
+ llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
654
686
 
687
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
688
+ const auto & token_map = vocab.token_to_id;
655
689
  // normalize and split by whitespace
656
690
  std::vector<std::string> words = preprocess(text);
657
-
658
691
  // bos token prepended already
659
692
 
660
693
  // find the longest tokens that form the words
@@ -699,7 +732,7 @@ struct llm_tokenizer_wpm {
699
732
  }
700
733
 
701
734
  // TODO: reduce string copies by using cpts_offs array
702
- std::vector<std::string> preprocess(const std::string & text) const {
735
+ static std::vector<std::string> preprocess(const std::string & text) {
703
736
  const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
704
737
  std::vector<std::string> words(1, "");
705
738
 
@@ -751,15 +784,18 @@ struct llm_tokenizer_wpm {
751
784
  //(cpt >= 0xFF00 && cpt <= 0xFFEF);
752
785
  }
753
786
 
787
+ private:
754
788
  const llama_vocab & vocab;
789
+ // currently unused
790
+ // const llm_tokenizer_wpm * wpm_tokenizer;
755
791
  };
756
792
 
757
793
  //
758
794
  // UGM tokenizer
759
795
  //
760
796
 
761
- struct llm_tokenizer_ugm {
762
- llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
797
+ struct llm_tokenizer_ugm : llm_tokenizer {
798
+ llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
763
799
  if (vocab.precompiled_charsmap.size() > 0) {
764
800
  size_t charsmap_offset = 0;
765
801
 
@@ -805,6 +841,30 @@ struct llm_tokenizer_ugm {
805
841
  unknown_token_score = min_score - unknown_token_score_penalty;
806
842
  }
807
843
 
844
+ // escaped space symbol - U+2581 (Lower One Eighth Block)
845
+ const std::string escaped_space = "\xE2\x96\x81";
846
+
847
+ const char * prefix_replacements = NULL;
848
+ size_t prefix_replacements_size = 0;
849
+
850
+ const uint32_t * xcda_array = NULL;
851
+ size_t xcda_array_size = 0;
852
+
853
+ struct naive_trie user_defined_token_matcher;
854
+
855
+ float min_score = FLT_MAX;
856
+ float max_score = -FLT_MAX;
857
+
858
+ float unknown_token_score_penalty = 10.0;
859
+ float unknown_token_score;
860
+
861
+ struct naive_trie token_matcher;
862
+ };
863
+
864
+ struct llm_tokenizer_ugm_session {
865
+ llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
866
+ ugm_tokenizer(static_cast<const llm_tokenizer_ugm *>(vocab.tokenizer)) {}
867
+
808
868
  /* This implementation is based on SentencePiece optimized Viterbi algorithm for
809
869
  * unigram language models. The general idea is to:
810
870
  * - move along the input sequence in steps of one UTF code point,
@@ -843,7 +903,7 @@ struct llm_tokenizer_ugm {
843
903
  // traverse the token matcher trie to find a matching token
844
904
  bool single_codepoint_token_found = false;
845
905
  const struct best_tokenization & current_best = tokenization_results[input_offset];
846
- const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
906
+ const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
847
907
 
848
908
  while (prefix_offset <= input_len && node != NULL) {
849
909
  // check if we found valid token in prefix
@@ -873,7 +933,7 @@ struct llm_tokenizer_ugm {
873
933
  // if we didn't find a valid token corresponding to the whole UTF code point
874
934
  // then use unknown token as the tokenization of this UTF code point
875
935
  if (!single_codepoint_token_found) {
876
- const double challenger_score = current_best.score_sum + unknown_token_score;
936
+ const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
877
937
  prefix_offset = input_offset + n_utf8_code_units;
878
938
  struct best_tokenization & current_champ = tokenization_results[prefix_offset];
879
939
  if (challenger_score > current_champ.score_sum) {
@@ -905,7 +965,6 @@ struct llm_tokenizer_ugm {
905
965
  }
906
966
 
907
967
  private:
908
- const llama_vocab & vocab;
909
968
 
910
969
  // helper structure for returning normalization results
911
970
  struct normalization_result {
@@ -918,7 +977,7 @@ private:
918
977
  normalized->clear();
919
978
  normalized->reserve(input.size() * 3);
920
979
 
921
- const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
980
+ const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
922
981
 
923
982
  bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
924
983
  bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
@@ -1000,13 +1059,21 @@ private:
1000
1059
  size_t xcda_array_size;
1001
1060
  };
1002
1061
 
1062
+ // this structure stores the best tokenization so far at input_offset
1063
+ struct best_tokenization {
1064
+ llama_token token_id;
1065
+ size_t input_offset;
1066
+ float score_sum;
1067
+ };
1068
+
1003
1069
  struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
1004
1070
  if (input_offset == input.size()) {
1005
1071
  return { &input[input_offset], 0, 0 };
1006
1072
  }
1007
1073
 
1008
1074
  // if input prefix matches some user-defined token return this token as normalization result
1009
- auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1075
+ auto user_defined_token_match =
1076
+ ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1010
1077
  if (user_defined_token_match.second > 0) {
1011
1078
  return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
1012
1079
  }
@@ -1014,8 +1081,8 @@ private:
1014
1081
  size_t longest_prefix_length = 0;
1015
1082
  size_t longest_prefix_offset = 0;
1016
1083
 
1017
- if (xcda_array_size > 0) {
1018
- struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
1084
+ if (ugm_tokenizer->xcda_array_size > 0) {
1085
+ struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
1019
1086
 
1020
1087
  // Find the longest normalized sequence matching the input prefix by walking
1021
1088
  // the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1051,50 +1118,27 @@ private:
1051
1118
 
1052
1119
  if (longest_prefix_length > 0) {
1053
1120
  // we have a match, so return the replacement sequence
1054
- if (longest_prefix_offset >= prefix_replacements_size) {
1121
+ if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
1055
1122
  throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
1056
1123
  }
1057
- const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
1124
+ const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
1058
1125
  return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
1059
- } else {
1060
- // check if the input prefix contains a valid sequence of UTF-8 code units
1061
- try {
1062
- // if yes, return this sequence unmodified
1063
- size_t prefix_offset = input_offset;
1064
- unicode_cpt_from_utf8(input, prefix_offset);
1065
- return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
1066
- } catch (std::invalid_argument & /*ex*/) {
1067
- // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
1068
- return { "\xEF\xBF\xBD", 3, 1 };
1069
- }
1070
1126
  }
1071
- }
1072
-
1073
- // escaped space symbol - U+2581 (Lower One Eighth Block)
1074
- const std::string escaped_space = "\xE2\x96\x81";
1075
1127
 
1076
- const char * prefix_replacements = NULL;
1077
- size_t prefix_replacements_size = 0;
1078
-
1079
- const uint32_t * xcda_array = NULL;
1080
- size_t xcda_array_size = 0;
1081
-
1082
- struct naive_trie user_defined_token_matcher;
1083
-
1084
- // this structure stores the best tokenization so far at input_offset
1085
- struct best_tokenization {
1086
- llama_token token_id;
1087
- size_t input_offset;
1088
- float score_sum;
1089
- };
1090
-
1091
- float min_score = FLT_MAX;
1092
- float max_score = -FLT_MAX;
1093
-
1094
- float unknown_token_score_penalty = 10.0;
1095
- float unknown_token_score;
1128
+ // check if the input prefix contains a valid sequence of UTF-8 code units
1129
+ try {
1130
+ // if yes, return this sequence unmodified
1131
+ size_t prefix_offset = input_offset;
1132
+ unicode_cpt_from_utf8(input, prefix_offset);
1133
+ return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
1134
+ } catch (std::invalid_argument & /*ex*/) {
1135
+ // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
1136
+ return { "\xEF\xBF\xBD", 3, 1 };
1137
+ }
1138
+ }
1096
1139
 
1097
- struct naive_trie token_matcher;
1140
+ const llama_vocab & vocab;
1141
+ const llm_tokenizer_ugm * ugm_tokenizer;
1098
1142
  };
1099
1143
 
1100
1144
  //
@@ -1155,8 +1199,8 @@ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escape
1155
1199
  return output;
1156
1200
  }
1157
1201
 
1158
- struct llm_tokenizer_rwkv {
1159
- llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
1202
+ struct llm_tokenizer_rwkv : llm_tokenizer {
1203
+ llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
1160
1204
  // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
1161
1205
  // For now, we decode the vocab here into the lookup we'll use for tokenization.
1162
1206
 
@@ -1168,11 +1212,17 @@ struct llm_tokenizer_rwkv {
1168
1212
  }
1169
1213
  }
1170
1214
 
1215
+ struct naive_trie token_matcher;
1216
+ };
1217
+
1218
+ struct llm_tokenizer_rwkv_session {
1219
+ llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
1220
+ rwkv_tokenizer(static_cast<const llm_tokenizer_rwkv &>(*vocab.tokenizer)) {}
1221
+
1171
1222
  void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
1172
1223
  uint32_t position = 0;
1173
-
1174
1224
  while (position < text.size()) {
1175
- const struct naive_trie * node = token_matcher.traverse(text[position]);
1225
+ const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
1176
1226
  if (node == NULL) {
1177
1227
  // no matching token found, add unknown token
1178
1228
  output.push_back(vocab.special_unk_id);
@@ -1197,11 +1247,33 @@ struct llm_tokenizer_rwkv {
1197
1247
  }
1198
1248
  }
1199
1249
 
1250
+ private:
1200
1251
  const llama_vocab & vocab;
1201
-
1202
- struct naive_trie token_matcher;
1252
+ const llm_tokenizer_rwkv & rwkv_tokenizer;
1203
1253
  };
1204
1254
 
1255
+ void llama_vocab::init_tokenizer() {
1256
+ switch (type) {
1257
+ case LLAMA_VOCAB_TYPE_SPM:
1258
+ tokenizer = new llm_tokenizer_spm(*this);
1259
+ break;
1260
+ case LLAMA_VOCAB_TYPE_BPE:
1261
+ tokenizer = new llm_tokenizer_bpe(*this);
1262
+ break;
1263
+ case LLAMA_VOCAB_TYPE_WPM:
1264
+ tokenizer = new llm_tokenizer_wpm(*this);
1265
+ break;
1266
+ case LLAMA_VOCAB_TYPE_UGM:
1267
+ tokenizer = new llm_tokenizer_ugm(*this);
1268
+ break;
1269
+ case LLAMA_VOCAB_TYPE_RWKV:
1270
+ tokenizer = new llm_tokenizer_rwkv(*this);
1271
+ break;
1272
+ default:
1273
+ LM_GGML_ABORT("unsupported vocab type");
1274
+ }
1275
+ }
1276
+
1205
1277
  //
1206
1278
  // (de-) tokenize
1207
1279
  //
@@ -1263,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1263
1335
 
1264
1336
  // if a fragment is text ( not yet processed )
1265
1337
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1266
- auto & raw_text = fragment.raw_text;
1338
+ const auto & raw_text = fragment.raw_text;
1267
1339
 
1268
1340
  auto raw_text_base_offset = fragment.offset;
1269
1341
  auto raw_text_base_length = fragment.length;
@@ -1362,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1362
1434
  }
1363
1435
  }
1364
1436
 
1365
- std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
1437
+ std::vector<llama_vocab::id> llama_tokenize_internal(
1438
+ const llama_vocab & vocab,
1439
+ std::string raw_text,
1440
+ bool add_special,
1441
+ bool parse_special) {
1442
+ LM_GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1443
+
1366
1444
  std::vector<llama_vocab::id> output;
1367
1445
  std::forward_list<fragment_buffer_variant> fragment_buffer;
1368
1446
 
@@ -1399,9 +1477,9 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1399
1477
  #ifdef PRETOKENIZERDEBUG
1400
1478
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1401
1479
  #endif
1402
- llm_tokenizer_spm tokenizer(vocab);
1403
1480
  llama_escape_whitespace(raw_text);
1404
- tokenizer.tokenize(raw_text, output);
1481
+ llm_tokenizer_spm_session session(vocab);
1482
+ session.tokenize(raw_text, output);
1405
1483
  is_prev_special = false;
1406
1484
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1407
1485
  output.push_back(fragment.token);
@@ -1423,10 +1501,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1423
1501
  } break;
1424
1502
  case LLAMA_VOCAB_TYPE_BPE:
1425
1503
  {
1426
- llm_tokenizer_bpe tokenizer(vocab);
1427
-
1504
+ llm_tokenizer_bpe_session session(vocab);
1505
+ // it calls some other methods that are not exist in llm_tokenizer,
1506
+ // here just cast it to bpe tokenizer object
1428
1507
  if (add_special) {
1429
- tokenizer.append_bos(output);
1508
+ session.append_bos(output);
1430
1509
  }
1431
1510
  for (const auto & fragment : fragment_buffer) {
1432
1511
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1435,15 +1514,15 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1435
1514
  #ifdef PRETOKENIZERDEBUG
1436
1515
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1437
1516
  #endif
1438
- tokenizer.tokenize(raw_text, output);
1517
+ session.tokenize(raw_text, output);
1439
1518
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1440
- tokenizer.append(fragment.token, output);
1519
+ session.append(fragment.token, output);
1441
1520
  }
1442
1521
  }
1443
1522
 
1444
1523
  if (add_special) {
1445
- tokenizer.append_eos(output);
1446
- tokenizer.check_double_bos_eos(output);
1524
+ session.append_eos(output);
1525
+ session.check_double_bos_eos(output);
1447
1526
  }
1448
1527
  } break;
1449
1528
  case LLAMA_VOCAB_TYPE_WPM:
@@ -1453,7 +1532,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1453
1532
  output.push_back(vocab.special_cls_id);
1454
1533
  }
1455
1534
 
1456
- llm_tokenizer_wpm tokenizer(vocab);
1535
+ llm_tokenizer_wpm_session session(vocab);
1457
1536
 
1458
1537
  for (const auto & fragment : fragment_buffer) {
1459
1538
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1462,7 +1541,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1462
1541
  #ifdef PRETOKENIZERDEBUG
1463
1542
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1464
1543
  #endif
1465
- tokenizer.tokenize(raw_text, output);
1544
+ session.tokenize(raw_text, output);
1466
1545
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1467
1546
  output.push_back(fragment.token);
1468
1547
  }
@@ -1475,12 +1554,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1475
1554
  } break;
1476
1555
  case LLAMA_VOCAB_TYPE_UGM:
1477
1556
  {
1478
- llm_tokenizer_ugm tokenizer(vocab);
1479
-
1480
- if (add_special && vocab.tokenizer_add_bos != 0) {
1557
+ if (add_special && vocab.tokenizer_add_bos) {
1481
1558
  LM_GGML_ASSERT(vocab.special_bos_id != -1);
1482
1559
  output.push_back(vocab.special_bos_id);
1483
1560
  }
1561
+ llm_tokenizer_ugm_session session(vocab);
1484
1562
 
1485
1563
  for (const auto & fragment : fragment_buffer) {
1486
1564
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1488,26 +1566,27 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1488
1566
  #ifdef PRETOKENIZERDEBUG
1489
1567
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1490
1568
  #endif
1491
- tokenizer.tokenize(raw_text, output);
1569
+ session.tokenize(raw_text, output);
1492
1570
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1493
1571
  output.push_back(fragment.token);
1494
1572
  }
1495
1573
  }
1496
1574
 
1497
- if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1575
+ if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1498
1576
  LLAMA_LOG_WARN(
1499
1577
  "%s: Added a BOS token to the prompt as specified by the model but the prompt "
1500
1578
  "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
1501
1579
  "Are you sure this is what you want?\n", __FUNCTION__);
1502
1580
  }
1503
1581
 
1504
- if (add_special && vocab.tokenizer_add_eos == 1) {
1582
+ if (add_special && vocab.tokenizer_add_eos) {
1505
1583
  LM_GGML_ASSERT(vocab.special_eos_id != -1);
1506
1584
  output.push_back(vocab.special_eos_id);
1507
1585
  }
1508
1586
  } break;
1509
1587
  case LLAMA_VOCAB_TYPE_RWKV:
1510
1588
  {
1589
+ llm_tokenizer_rwkv_session session(vocab);
1511
1590
  for (const auto & fragment : fragment_buffer) {
1512
1591
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1513
1592
  auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -1516,8 +1595,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1516
1595
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1517
1596
  #endif
1518
1597
 
1519
- llm_tokenizer_rwkv tokenizer(vocab);
1520
- tokenizer.tokenize(raw_text, output);
1598
+ session.tokenize(raw_text, output);
1521
1599
  } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1522
1600
  output.push_back(fragment.token);
1523
1601
  }
@@ -1630,13 +1708,13 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1630
1708
  }
1631
1709
 
1632
1710
  int32_t llama_tokenize_impl(
1633
- const struct llama_vocab & vocab,
1634
- const char * text,
1635
- int32_t text_len,
1636
- llama_token * tokens,
1637
- int32_t n_tokens_max,
1638
- bool add_special,
1639
- bool parse_special) {
1711
+ const struct llama_vocab & vocab,
1712
+ const char * text,
1713
+ int32_t text_len,
1714
+ llama_token * tokens,
1715
+ int32_t n_tokens_max,
1716
+ bool add_special,
1717
+ bool parse_special) {
1640
1718
  auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
1641
1719
  if (n_tokens_max < (int) res.size()) {
1642
1720
  // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
@@ -1713,11 +1791,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
1713
1791
  // suppressing them like CONTROL tokens.
1714
1792
  if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1715
1793
  return _try_copy(token_text.data(), token_text.size());
1716
- } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1794
+ }
1795
+ if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1717
1796
  std::string result = token_text;
1718
1797
  llama_unescape_whitespace(result);
1719
1798
  return _try_copy(result.data(), result.size());
1720
- } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
1799
+ }
1800
+ if (attr & LLAMA_TOKEN_ATTR_BYTE) {
1721
1801
  char byte = (char) llama_token_to_byte(vocab, token);
1722
1802
  return _try_copy((char*) &byte, 1);
1723
1803
  }
@@ -1728,7 +1808,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
1728
1808
  // suppressing them like CONTROL tokens.
1729
1809
  if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1730
1810
  return _try_copy(token_text.data(), token_text.size());
1731
- } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1811
+ }
1812
+ if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1732
1813
  std::string result = llama_decode_text(token_text);
1733
1814
  return _try_copy(result.data(), result.size());
1734
1815
  }
@@ -1761,6 +1842,8 @@ int32_t llama_detokenize_impl(
1761
1842
  int32_t text_len_max,
1762
1843
  bool remove_special,
1763
1844
  bool unparse_special) {
1845
+ LM_GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1846
+
1764
1847
  int32_t avail = text_len_max;
1765
1848
  int32_t total = 0;
1766
1849