react-native-executorch 0.7.0 → 0.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/common/rnexecutorch/TokenizerModule.cpp +3 -2
  2. package/common/rnexecutorch/TokenizerModule.h +1 -1
  3. package/lib/module/modules/computer_vision/TextToImageModule.js +8 -4
  4. package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -1
  5. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -1
  6. package/package.json +4 -3
  7. package/src/modules/computer_vision/TextToImageModule.ts +9 -4
  8. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  9. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  10. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_model.h +84 -0
  11. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_tokenizer_base.h +6 -87
  12. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/hf_tokenizer.h +28 -176
  13. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/map_utils.h +174 -0
  14. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/model.h +151 -0
  15. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/normalizer.h +55 -1
  16. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/padding.h +112 -0
  17. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/post_processor.h +101 -42
  18. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/pre_tokenizer.h +25 -9
  19. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/token_decoder.h +33 -6
  20. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/tokenizer.h +2 -2
  21. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/truncation.h +92 -0
  22. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/wordpiece_model.h +74 -0
  23. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  24. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  25. package/common/rnexecutorch/tests/CMakeLists.txt +0 -253
  26. package/common/rnexecutorch/tests/README.md +0 -73
  27. package/common/rnexecutorch/tests/integration/BaseModelTest.cpp +0 -207
  28. package/common/rnexecutorch/tests/integration/BaseModelTests.h +0 -120
  29. package/common/rnexecutorch/tests/integration/ClassificationTest.cpp +0 -117
  30. package/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +0 -122
  31. package/common/rnexecutorch/tests/integration/ImageSegmentationTest.cpp +0 -152
  32. package/common/rnexecutorch/tests/integration/LLMTest.cpp +0 -155
  33. package/common/rnexecutorch/tests/integration/OCRTest.cpp +0 -128
  34. package/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +0 -135
  35. package/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp +0 -97
  36. package/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +0 -112
  37. package/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +0 -164
  38. package/common/rnexecutorch/tests/integration/TextToImageTest.cpp +0 -149
  39. package/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp +0 -98
  40. package/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +0 -238
  41. package/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp +0 -99
  42. package/common/rnexecutorch/tests/integration/assets/test_audio_float.raw +0 -0
  43. package/common/rnexecutorch/tests/integration/assets/we_are_software_mansion.jpg +0 -0
  44. package/common/rnexecutorch/tests/integration/libs/libfbjni.so +0 -0
  45. package/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp +0 -45
  46. package/common/rnexecutorch/tests/integration/utils/TestUtils.h +0 -36
  47. package/common/rnexecutorch/tests/run_tests.sh +0 -333
  48. package/common/rnexecutorch/tests/unit/FileUtilsTest.cpp +0 -32
  49. package/common/rnexecutorch/tests/unit/LogTest.cpp +0 -529
  50. package/common/rnexecutorch/tests/unit/NumericalTest.cpp +0 -107
@@ -38,6 +38,10 @@ public:
38
38
 
39
39
  /**
40
40
  * Process the token IDs (single sequence).
41
+ *
42
+ * NOTE: Unlike the Rust implementation which uses a single method
43
+ * taking Encoding and an Option<Encoding>, we use overloads here
44
+ * to explicitly handle single vs pair sequences while processing raw IDs.
41
45
  */
42
46
  virtual std::vector<uint64_t>
43
47
  process(const std::vector<uint64_t> &tokens,
@@ -54,27 +58,65 @@ public:
54
58
 
55
59
  // -- Factory/Common Types -----------------------------------------------------
56
60
 
61
+ // Helper macro to standardize addition of config member fields
62
+ #define POST_PROCESSOR_CONFIG_MEMBER(type, name) \
63
+ std::optional<type> name; \
64
+ PostProcessorConfig &set_##name(type arg) { \
65
+ this->name = std::move(arg); \
66
+ return *this; \
67
+ }
68
+
57
69
  enum class SequenceId { A, B };
58
70
 
71
+ struct SpecialToken {
72
+ std::string id;
73
+ std::vector<uint64_t> ids;
74
+ std::vector<std::string> tokens;
75
+ };
76
+
59
77
  struct Piece {
60
78
  bool is_special_token;
61
79
  std::string id; // For SpecialToken (e.g. "[CLS]"). For Sequence (e.g. "A").
62
- uint32_t type_id;
80
+ uint64_t type_id;
63
81
 
64
- static Piece Sequence(SequenceId id, uint32_t type_id) {
82
+ static Piece Sequence(SequenceId id, uint64_t type_id) {
65
83
  return {false, id == SequenceId::A ? "A" : "B", type_id};
66
84
  }
67
- static Piece SpecialToken(std::string id, uint32_t type_id) {
85
+ static Piece SpecialToken(std::string id, uint64_t type_id) {
68
86
  return {true, std::move(id), type_id};
69
87
  }
70
88
  };
71
89
 
72
90
  using Template = std::vector<Piece>;
91
+ // -- Config -------------------------------------------------------------------
73
92
 
74
- struct SpecialToken {
75
- std::string id;
76
- std::vector<uint32_t> ids;
77
- std::vector<std::string> tokens;
93
+ class PostProcessorConfig {
94
+ public:
95
+ using SpecialTokenMap = std::map<std::string, tokenizers::SpecialToken>;
96
+ using StringIdPair = std::pair<std::string, uint64_t>;
97
+
98
+ std::string type;
99
+
100
+ // TemplateProcessing
101
+ POST_PROCESSOR_CONFIG_MEMBER(Template, single)
102
+ POST_PROCESSOR_CONFIG_MEMBER(Template, pair)
103
+ POST_PROCESSOR_CONFIG_MEMBER(SpecialTokenMap, special_tokens)
104
+
105
+ // Bert / Roberta (unused params in no-op, but kept for parsing logic)
106
+ POST_PROCESSOR_CONFIG_MEMBER(StringIdPair, sep)
107
+ POST_PROCESSOR_CONFIG_MEMBER(StringIdPair, cls)
108
+ POST_PROCESSOR_CONFIG_MEMBER(bool, trim_offsets)
109
+ POST_PROCESSOR_CONFIG_MEMBER(bool, add_prefix_space)
110
+
111
+ // Sequence
112
+ using Configs = std::vector<PostProcessorConfig>;
113
+ POST_PROCESSOR_CONFIG_MEMBER(Configs, processors)
114
+
115
+ explicit PostProcessorConfig(std::string type = "");
116
+
117
+ PostProcessor::Ptr create() const;
118
+
119
+ PostProcessorConfig &parse_json(const nlohmann::json &json_config);
78
120
  };
79
121
 
80
122
  // -- TemplateProcessing -------------------------------------------------------
@@ -106,11 +148,9 @@ private:
106
148
  bool add_special_tokens) const;
107
149
  };
108
150
 
109
- // -- BertProcessing -----------------------------------------------------------
110
-
111
- class BertProcessing : public PostProcessor {
151
+ class Sequence : public PostProcessor {
112
152
  public:
113
- BertProcessing();
153
+ explicit Sequence(std::vector<PostProcessor::Ptr> processors);
114
154
 
115
155
  size_t added_tokens(bool is_pair) const override;
116
156
 
@@ -120,13 +160,17 @@ public:
120
160
  std::vector<uint64_t> process(const std::vector<uint64_t> &tokens_a,
121
161
  const std::vector<uint64_t> &tokens_b,
122
162
  bool add_special_tokens = true) const override;
123
- };
124
163
 
125
- // -- RobertaProcessing --------------------------------------------------------
164
+ private:
165
+ std::vector<PostProcessor::Ptr> processors_;
166
+ };
126
167
 
127
- class RobertaProcessing : public PostProcessor {
168
+ // -- BertProcessing -----------------------------------------------------------
169
+ // Used for BERT post-processing (adding special tokens)
170
+ class BertProcessing : public PostProcessor {
128
171
  public:
129
- RobertaProcessing();
172
+ BertProcessing(std::pair<std::string, uint64_t> sep,
173
+ std::pair<std::string, uint64_t> cls);
130
174
 
131
175
  size_t added_tokens(bool is_pair) const override;
132
176
 
@@ -136,13 +180,19 @@ public:
136
180
  std::vector<uint64_t> process(const std::vector<uint64_t> &tokens_a,
137
181
  const std::vector<uint64_t> &tokens_b,
138
182
  bool add_special_tokens = true) const override;
139
- };
140
183
 
141
- // -- Sequence -----------------------------------------------------------------
184
+ private:
185
+ std::pair<std::string, uint64_t> sep_;
186
+ std::pair<std::string, uint64_t> cls_;
187
+ };
142
188
 
143
- class Sequence : public PostProcessor {
189
+ // -- RobertaProcessing --------------------------------------------------------
190
+ // Used for RoBERTa post-processing
191
+ class RobertaProcessing : public PostProcessor {
144
192
  public:
145
- explicit Sequence(std::vector<PostProcessor::Ptr> processors);
193
+ RobertaProcessing(std::pair<std::string, uint64_t> sep,
194
+ std::pair<std::string, uint64_t> cls, bool trim_offsets,
195
+ bool add_prefix_space);
146
196
 
147
197
  size_t added_tokens(bool is_pair) const override;
148
198
 
@@ -154,34 +204,43 @@ public:
154
204
  bool add_special_tokens = true) const override;
155
205
 
156
206
  private:
157
- std::vector<PostProcessor::Ptr> processors_;
207
+ std::pair<std::string, uint64_t> sep_;
208
+ std::pair<std::string, uint64_t> cls_;
209
+ bool trim_offsets_;
210
+ bool add_prefix_space_;
158
211
  };
159
212
 
160
- // -- Config -------------------------------------------------------------------
161
-
162
- class PostProcessorConfig {
163
- public:
164
- std::string type;
165
-
166
- // TemplateProcessing
167
- Template single;
168
- Template pair;
169
- std::map<std::string, SpecialToken> special_tokens;
170
-
171
- // Bert / Roberta (unused params in no-op, but kept for parsing logic)
172
- std::pair<std::string, uint32_t> sep;
173
- std::pair<std::string, uint32_t> cls;
174
- bool trim_offsets = true;
175
- bool add_prefix_space = true;
213
+ // -- ByteLevel
214
+ // ----------------------------------------------------------------
215
+ // TODO: Implement ByteLevelProcessor
216
+ // This is a broader issue, as most of the processing is done on offsets.
217
+ // Our current implementation doesn't supoort it and would require us to
218
+ // introduce a complex Encoding type. Something similiar to the originl hf
219
+ // implementaiton:
220
+ // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/encoding.rs
221
+ // so we could store the offsets from pretokenization step.
222
+ /*
223
+ class ByteLevel : public PostProcessor {
224
+ public:
225
+ ByteLevel(bool trim_offsets, bool add_prefix_space);
176
226
 
177
- // Sequence
178
- std::vector<PostProcessorConfig> processors;
227
+ size_t added_tokens(bool is_pair) const override;
179
228
 
180
- explicit PostProcessorConfig(std::string type = "");
229
+ std::vector<uint64_t> process(
230
+ const std::vector<uint64_t>& tokens,
231
+ bool add_special_tokens = true) const override;
181
232
 
182
- PostProcessor::Ptr create() const;
233
+ std::vector<uint64_t> process(
234
+ const std::vector<uint64_t>& tokens_a,
235
+ const std::vector<uint64_t>& tokens_b,
236
+ bool add_special_tokens = true) const override;
183
237
 
184
- PostProcessorConfig &parse_json(const nlohmann::json &json_config);
238
+ private:
239
+ bool trim_offsets_;
240
+ bool add_prefix_space_;
185
241
  };
242
+ */
186
243
 
187
- } // namespace tokenizers
244
+ // -- Sequence
245
+ // -----------------------------------------------------------------
246
+ } // namespace tokenizers
@@ -53,7 +53,7 @@ public:
53
53
  // -- Factory ------------------------------------------------------------------
54
54
 
55
55
  // Helper macro to standardize addition of config member fields
56
- #define CONFIG_MEMBER(type, name) \
56
+ #define PRETOKENIZER_CONFIG_MEMBER(type, name) \
57
57
  std::optional<type> name; \
58
58
  PreTokenizerConfig &set_##name(type arg) { \
59
59
  this->name = std::move(arg); \
@@ -92,37 +92,38 @@ public:
92
92
  /**
93
93
  * Used by: RegexPreTokenizer, ByteLevelPreTokenizer
94
94
  */
95
- CONFIG_MEMBER(std::string, pattern)
95
+ PRETOKENIZER_CONFIG_MEMBER(std::string, pattern)
96
96
 
97
97
  /**
98
98
  * Used by: DigitsPreTokenizer
99
99
  */
100
- CONFIG_MEMBER(bool, individual_digits)
100
+ PRETOKENIZER_CONFIG_MEMBER(bool, individual_digits)
101
101
 
102
102
  /**
103
103
  * Used by: ByteLevelPreTokenizer
104
104
  */
105
- CONFIG_MEMBER(bool, add_prefix_space)
105
+ PRETOKENIZER_CONFIG_MEMBER(bool, add_prefix_space)
106
106
 
107
107
  /**
108
108
  * Used by RegexPreTokenizer
109
109
  */
110
- CONFIG_MEMBER(bool, is_delimiter)
110
+ PRETOKENIZER_CONFIG_MEMBER(bool, is_delimiter)
111
111
 
112
112
  /**
113
113
  * Used by RegexPreTokenizer - Split behavior
114
114
  */
115
- CONFIG_MEMBER(std::string, behavior)
115
+ PRETOKENIZER_CONFIG_MEMBER(std::string, behavior)
116
116
 
117
117
  /**
118
118
  * Used by RegexPreTokenizer - Split invert flag
119
119
  */
120
- CONFIG_MEMBER(bool, invert)
120
+ PRETOKENIZER_CONFIG_MEMBER(bool, invert)
121
121
 
122
122
  /**
123
123
  * Used by: SequencePreTokenizer
124
124
  */
125
- CONFIG_MEMBER(std::vector<PreTokenizerConfig>, pretokenizers)
125
+ using Configs = std::vector<PreTokenizerConfig>;
126
+ PRETOKENIZER_CONFIG_MEMBER(Configs, pretokenizers)
126
127
 
127
128
  /*----------------*/
128
129
  /* Public methods */
@@ -259,6 +260,21 @@ public:
259
260
  private:
260
261
  const std::vector<PreTokenizer::Ptr> pre_tokenizers_;
261
262
 
262
- }; // end class ByteLevelPreTokenizer
263
+ }; // end class SequencePreTokenizer
264
+
265
+ // -- Bert ---------------------------------------------------------------------
266
+ // Used for BERT-style pre-tokenization (splitting on whitespace and
267
+ // punctuation) CITE:
268
+ // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/bert.rs
269
+
270
+ class BertPreTokenizer : public PreTokenizer {
271
+ public:
272
+ BertPreTokenizer() = default;
273
+
274
+ /** Perform BERT pre-tokenization */
275
+ std::vector<std::string>
276
+ pre_tokenize(const std::string &input) const override;
277
+
278
+ }; // end class BertPreTokenizer
263
279
 
264
280
  } // namespace tokenizers
@@ -55,6 +55,14 @@ public:
55
55
 
56
56
  // -- Factory ------------------------------------------------------------------
57
57
 
58
+ // Helper macro to standardize addition of config member fields
59
+ #define TOKEN_DECODER_CONFIG_MEMBER(type, name) \
60
+ std::optional<type> name; \
61
+ TokenDecoderConfig &set_##name(type arg) { \
62
+ this->name = std::move(arg); \
63
+ return *this; \
64
+ }
65
+
58
66
  /**
59
67
  * Factory and config class for creating a new TokenDecoder
60
68
  */
@@ -67,16 +75,20 @@ public:
67
75
  std::string type;
68
76
 
69
77
  // Parameters for Replace decoder
70
- std::string replace_pattern;
71
- std::string replace_content;
78
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, replace_pattern)
79
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, replace_content)
72
80
 
73
81
  // Parameters for Sequence decoder
74
- std::vector<nlohmann::json> sequence_decoders;
82
+ TOKEN_DECODER_CONFIG_MEMBER(std::vector<nlohmann::json>, sequence_decoders)
75
83
 
76
84
  // Parameters for Strip decoder
77
- std::string strip_content;
78
- size_t strip_start;
79
- size_t strip_stop;
85
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, strip_content)
86
+ TOKEN_DECODER_CONFIG_MEMBER(size_t, strip_start)
87
+ TOKEN_DECODER_CONFIG_MEMBER(size_t, strip_stop)
88
+
89
+ // Parameters for WordPiece decoder
90
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, wordpiece_prefix)
91
+ TOKEN_DECODER_CONFIG_MEMBER(bool, wordpiece_cleanup)
80
92
 
81
93
  /*----------------*/
82
94
  /* Public methods */
@@ -161,6 +173,21 @@ private:
161
173
  size_t stop_;
162
174
  }; // end class StripTokenDecoder
163
175
 
176
+ // -- WordPiece ----------------------------------------------------------------
177
+ // Used for WordPiece decoding
178
+
179
+ class WordPieceTokenDecoder : public TokenDecoder {
180
+ public:
181
+ explicit WordPieceTokenDecoder(std::string prefix = "##",
182
+ bool cleanup = true);
183
+ std::vector<std::string>
184
+ decode(const std::vector<std::string> &tokens) const override;
185
+
186
+ private:
187
+ std::string prefix_;
188
+ bool cleanup_;
189
+ }; // end class WordPieceTokenDecoder
190
+
164
191
  // -- Sequence -----------------------------------------------------------------
165
192
  // Applies a sequence of decoders in order
166
193
 
@@ -13,8 +13,8 @@
13
13
 
14
14
  #pragma once
15
15
 
16
- #include "error.h"
17
- #include "result.h"
16
+ #include <pytorch/tokenizers/error.h>
17
+ #include <pytorch/tokenizers/result.h>
18
18
  #include <string>
19
19
  #include <vector>
20
20
 
@@ -0,0 +1,92 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+ // @lint-ignore-every LICENSELINT
9
+
10
+ #pragma once
11
+
12
+ // Standard
13
+ #include <memory>
14
+ #include <optional>
15
+ #include <string>
16
+ #include <vector>
17
+
18
+ // Third Party
19
+ #include <nlohmann/json.hpp>
20
+
21
+ namespace tokenizers {
22
+
23
+ // -- Truncation ---------------------------------------------------------------
24
+
25
+ enum class TruncationStrategy {
26
+ LongestFirst,
27
+ OnlyFirst,
28
+ OnlySecond,
29
+ };
30
+
31
+ enum class TruncationDirection {
32
+ Left,
33
+ Right,
34
+ };
35
+
36
+ struct TruncationParams {
37
+ TruncationDirection direction = TruncationDirection::Right;
38
+ size_t max_length = 512;
39
+ TruncationStrategy strategy = TruncationStrategy::LongestFirst;
40
+ size_t stride = 0;
41
+ };
42
+
43
+ class Truncation {
44
+ public:
45
+ /** Shared pointer type */
46
+ typedef std::shared_ptr<Truncation> Ptr;
47
+
48
+ /**
49
+ * @param params: The truncation parameters
50
+ */
51
+ explicit Truncation(const TruncationParams &params);
52
+
53
+ /**
54
+ * Truncate the tokens according to the configuration.
55
+ *
56
+ * @param tokens The tokens to truncate.
57
+ * @param num_tokens_to_add The number of special tokens that will be added
58
+ * later. These are subtracted from max_length during truncation calculation.
59
+ */
60
+ std::vector<uint64_t> truncate(std::vector<uint64_t> tokens,
61
+ size_t num_tokens_to_add = 0) const;
62
+
63
+ /**
64
+ * Truncate a pair of sequences according to the configuration.
65
+ */
66
+ std::pair<std::vector<uint64_t>, std::vector<uint64_t>>
67
+ truncate_pair(std::vector<uint64_t> a, std::vector<uint64_t> b,
68
+ size_t num_tokens_to_add = 0) const;
69
+
70
+ private:
71
+ TruncationParams params_;
72
+ };
73
+
74
+ // -- Factory ------------------------------------------------------------------
75
+
76
+ class TruncationConfig {
77
+ public:
78
+ /**
79
+ * Construct the truncation instance from the member data
80
+ */
81
+ Truncation::Ptr create() const;
82
+
83
+ /**
84
+ * Populate from a json config file
85
+ */
86
+ TruncationConfig &parse_json(const nlohmann::json &json_config);
87
+
88
+ // Configuration members
89
+ TruncationParams params;
90
+ };
91
+
92
+ } // namespace tokenizers
@@ -0,0 +1,74 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+ // @lint-ignore-every LICENSELINT
9
+
10
+ #pragma once
11
+
12
+ #include <memory>
13
+ #include <optional>
14
+ #include <string>
15
+ #include <vector>
16
+
17
+ #include <pytorch/tokenizers/model.h>
18
+ #include <pytorch/tokenizers/regex.h>
19
+ #include <pytorch/tokenizers/result.h>
20
+ #include <pytorch/tokenizers/string_integer_map.h>
21
+
22
+ namespace tokenizers {
23
+
24
+ class WordPieceModel : public Model {
25
+ public:
26
+ explicit WordPieceModel(detail::TokenMap token_map,
27
+ detail::TokenMap special_token_map,
28
+ std::string unk_token,
29
+ std::string continuing_subword_prefix,
30
+ size_t max_input_chars_per_word,
31
+ std::optional<uint64_t> unk_token_id,
32
+ std::optional<uint64_t> bos_token_id,
33
+ std::optional<uint64_t> eos_token_id);
34
+
35
+ ~WordPieceModel() override = default;
36
+
37
+ Result<std::vector<uint64_t>>
38
+ tokenize(const std::string &piece) const override;
39
+
40
+ Result<std::string> id_to_piece(uint64_t token) const override;
41
+ Result<uint64_t> piece_to_id(const std::string &token) const override;
42
+
43
+ int32_t vocab_size() const override { return vocab_size_; }
44
+
45
+ bool is_special_token(uint64_t token) const override;
46
+
47
+ bool is_loaded() const override { return initialized_; }
48
+
49
+ std::pair<std::optional<std::string>, std::string>
50
+ split_with_allowed_special_token(const std::string &input,
51
+ size_t offset) const override;
52
+
53
+ uint64_t bos_token_id() const override { return bos_token_id_.value_or(0); }
54
+
55
+ uint64_t eos_token_id() const override { return eos_token_id_.value_or(0); }
56
+
57
+ private:
58
+ detail::TokenMap token_map_;
59
+ detail::TokenMap special_token_map_;
60
+ std::unique_ptr<IRegex> special_token_regex_;
61
+
62
+ std::string unk_token_;
63
+ std::string continuing_subword_prefix_;
64
+ size_t max_input_chars_per_word_;
65
+
66
+ std::optional<uint64_t> unk_token_id_;
67
+ std::optional<uint64_t> bos_token_id_;
68
+ std::optional<uint64_t> eos_token_id_;
69
+
70
+ bool initialized_ = false;
71
+ int32_t vocab_size_ = 0;
72
+ };
73
+
74
+ } // namespace tokenizers