react-native-executorch 0.7.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/common/rnexecutorch/TokenizerModule.cpp +3 -2
  2. package/common/rnexecutorch/TokenizerModule.h +1 -1
  3. package/package.json +2 -1
  4. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  5. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  6. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_model.h +84 -0
  7. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_tokenizer_base.h +6 -87
  8. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/hf_tokenizer.h +28 -176
  9. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/map_utils.h +174 -0
  10. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/model.h +151 -0
  11. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/normalizer.h +55 -1
  12. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/padding.h +112 -0
  13. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/post_processor.h +101 -42
  14. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/pre_tokenizer.h +25 -9
  15. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/token_decoder.h +33 -6
  16. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/tokenizer.h +2 -2
  17. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/truncation.h +92 -0
  18. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/wordpiece_model.h +74 -0
  19. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  20. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  21. package/common/rnexecutorch/tests/CMakeLists.txt +0 -253
  22. package/common/rnexecutorch/tests/README.md +0 -73
  23. package/common/rnexecutorch/tests/integration/BaseModelTest.cpp +0 -207
  24. package/common/rnexecutorch/tests/integration/BaseModelTests.h +0 -120
  25. package/common/rnexecutorch/tests/integration/ClassificationTest.cpp +0 -117
  26. package/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +0 -122
  27. package/common/rnexecutorch/tests/integration/ImageSegmentationTest.cpp +0 -152
  28. package/common/rnexecutorch/tests/integration/LLMTest.cpp +0 -155
  29. package/common/rnexecutorch/tests/integration/OCRTest.cpp +0 -128
  30. package/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +0 -135
  31. package/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp +0 -97
  32. package/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +0 -112
  33. package/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +0 -164
  34. package/common/rnexecutorch/tests/integration/TextToImageTest.cpp +0 -149
  35. package/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp +0 -98
  36. package/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +0 -238
  37. package/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp +0 -99
  38. package/common/rnexecutorch/tests/integration/assets/test_audio_float.raw +0 -0
  39. package/common/rnexecutorch/tests/integration/assets/we_are_software_mansion.jpg +0 -0
  40. package/common/rnexecutorch/tests/integration/libs/libfbjni.so +0 -0
  41. package/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp +0 -45
  42. package/common/rnexecutorch/tests/integration/utils/TestUtils.h +0 -36
  43. package/common/rnexecutorch/tests/run_tests.sh +0 -333
  44. package/common/rnexecutorch/tests/unit/FileUtilsTest.cpp +0 -32
  45. package/common/rnexecutorch/tests/unit/LogTest.cpp +0 -529
  46. package/common/rnexecutorch/tests/unit/NumericalTest.cpp +0 -107
@@ -0,0 +1,151 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+ // @lint-ignore-every LICENSELINT
9
+
10
+ #pragma once
11
+
12
+ #include <memory>
13
+ #include <string>
14
+ #include <vector>
15
+
16
+ #include <nlohmann/json.hpp>
17
+ #include <pytorch/tokenizers/map_utils.h>
18
+ #include <pytorch/tokenizers/result.h>
19
+ #include <pytorch/tokenizers/string_integer_map.h>
20
+
21
+ namespace tokenizers {
22
+
23
+ // -- Base ---------------------------------------------------------------------
24
+
25
+ /**
26
+ * Abstract base class for tokenization models.
27
+ *
28
+ * A Model corresponds to the core logic that converts a piece of text (usually
29
+ * resulting from the pre-tokenization step) into a sequence of token IDs, and
30
+ * vice-versa.
31
+ *
32
+ * It encapsulates the vocabulary and the algorithm (e.g., BPE, WordPiece,
33
+ * Unigram).
34
+ */
35
+ class Model {
36
+ public:
37
+ using Ptr = std::shared_ptr<Model>;
38
+
39
+ virtual ~Model() = default;
40
+
41
+ /**
42
+ * Tokenizes a string piece into a sequence of token IDs.
43
+ *
44
+ * @param piece The input string to tokenize.
45
+ * @return A Result containing the vector of token IDs.
46
+ */
47
+ virtual Result<std::vector<uint64_t>>
48
+ tokenize(const std::string &piece) const = 0;
49
+
50
+ /**
51
+ * Converts a token ID to its string representation.
52
+ *
53
+ * @param token The token ID.
54
+ * @return A Result containing the string representation of the token.
55
+ */
56
+ virtual Result<std::string> id_to_piece(uint64_t token) const = 0;
57
+
58
+ /**
59
+ * Converts a string representation to its token ID.
60
+ *
61
+ * @param piece The string representation of the token.
62
+ * @return A Result containing the token ID.
63
+ */
64
+ virtual Result<uint64_t> piece_to_id(const std::string &piece) const = 0;
65
+
66
+ /**
67
+ * Returns the size of the vocabulary.
68
+ *
69
+ * @return The number of tokens in the vocabulary.
70
+ */
71
+ virtual int32_t vocab_size() const = 0;
72
+
73
+ /**
74
+ * Returns whether the token is a special token.
75
+ *
76
+ * @param token The token ID.
77
+ * @return True if the token is a special token, false otherwise.
78
+ */
79
+ virtual bool is_special_token(uint64_t token) const = 0;
80
+
81
+ /**
82
+ * Returns whether the model is loaded.
83
+ *
84
+ * @return True if the model is loaded, false otherwise.
85
+ */
86
+ virtual bool is_loaded() const = 0;
87
+
88
+ /**
89
+ * Helper to split input text into a special token and the preceding regular
90
+ * text.
91
+ *
92
+ * @param input The input string.
93
+ * @param offset The starting offset.
94
+ * @return A pair of (matched special token string, preceding regular text).
95
+ */
96
+ virtual std::pair<std::optional<std::string>, std::string>
97
+ split_with_allowed_special_token(const std::string &input,
98
+ size_t offset) const = 0;
99
+
100
+ virtual uint64_t bos_token_id() const = 0;
101
+ virtual uint64_t eos_token_id() const = 0;
102
+ };
103
+
104
+ // -- Factory ------------------------------------------------------------------
105
+
106
+ // Helper macro to standardize addition of config member fields
107
+ #define MODEL_CONFIG_MEMBER(type, name) \
108
+ std::optional<type> name; \
109
+ ModelConfig &set_##name(type arg) { \
110
+ this->name = std::move(arg); \
111
+ return *this; \
112
+ }
113
+
114
+ /**
115
+ * Factory and config class for creating a new Model
116
+ */
117
+ class ModelConfig {
118
+ public:
119
+ std::string type;
120
+
121
+ // Data for BPEModel
122
+ using TokenPairs = std::vector<std::pair<std::string, uint64_t>>;
123
+ MODEL_CONFIG_MEMBER(TokenPairs, token_pairs)
124
+ MODEL_CONFIG_MEMBER(TokenPairs, special_token_pairs)
125
+
126
+ MODEL_CONFIG_MEMBER(std::vector<std::string>, merges)
127
+ MODEL_CONFIG_MEMBER(bool, byte_fallback)
128
+ MODEL_CONFIG_MEMBER(std::string, unk_token)
129
+ MODEL_CONFIG_MEMBER(std::string, bos_token)
130
+ MODEL_CONFIG_MEMBER(std::string, eos_token)
131
+ MODEL_CONFIG_MEMBER(std::string, continuing_subword_prefix)
132
+ MODEL_CONFIG_MEMBER(size_t, max_input_chars_per_word)
133
+
134
+ // Paths for extra config files (HuggingFace specific)
135
+ MODEL_CONFIG_MEMBER(std::string, model_config_path)
136
+ MODEL_CONFIG_MEMBER(std::string, special_tokens_map_path)
137
+
138
+ ModelConfig() = default;
139
+
140
+ /**
141
+ * Populate from a json config file (the root tokenizer.json)
142
+ */
143
+ ModelConfig &parse_json(const nlohmann::json &json_config);
144
+
145
+ /**
146
+ * Construct the model instance from the member data
147
+ */
148
+ Model::Ptr create() const;
149
+ };
150
+
151
+ } // namespace tokenizers
@@ -101,13 +101,22 @@ public:
101
101
  /**
102
102
  * Used by: SequenceNormalizer
103
103
  */
104
- NORMALIZER_CONFIG_MEMBER(std::vector<NormalizerConfig>, normalizers)
104
+ using Configs = std::vector<NormalizerConfig>;
105
+ NORMALIZER_CONFIG_MEMBER(Configs, normalizers)
105
106
 
106
107
  /**
107
108
  * Used by: PrependNormalizer
108
109
  */
109
110
  NORMALIZER_CONFIG_MEMBER(std::string, prepend)
110
111
 
112
+ /**
113
+ * Used by: BertNormalizer
114
+ */
115
+ NORMALIZER_CONFIG_MEMBER(bool, clean_text)
116
+ NORMALIZER_CONFIG_MEMBER(bool, handle_chinese_chars)
117
+ NORMALIZER_CONFIG_MEMBER(bool, lowercase)
118
+ NORMALIZER_CONFIG_MEMBER(bool, strip_accents)
119
+
111
120
  /*----------------*/
112
121
  /* Public methods */
113
122
  /*----------------*/
@@ -210,4 +219,49 @@ public:
210
219
 
211
220
  }; // end class NFCNormalizer
212
221
 
222
+ // -- Lowercase ----------------------------------------------------------------
223
+ // Used for lowercasing the input
224
+ // CITE:
225
+ // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/utils.rs
226
+
227
+ class LowercaseNormalizer : public Normalizer {
228
+ public:
229
+ /** Default constructor */
230
+ explicit LowercaseNormalizer() = default;
231
+
232
+ /** Lowercase the input */
233
+ std::string normalize(const std::string &input) const override;
234
+
235
+ }; // end class LowercaseNormalizer
236
+
237
+ // -- Bert ---------------------------------------------------------------------
238
+ // Used for BERT-style normalization (cleaning, lowercasing, accent removal)
239
+ // CITE:
240
+ // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/bert.rs
241
+
242
+ class BertNormalizer : public Normalizer {
243
+ public:
244
+ /**
245
+ * @param clean_text: Whether to clean the text (remove control chars, etc.)
246
+ * @param handle_chinese_chars: Whether to put spaces around Chinese
247
+ * characters
248
+ * @param lowercase: Whether to lowercase the input
249
+ * @param strip_accents: Whether to strip accents (optional, usually follows
250
+ * lowercase)
251
+ */
252
+ explicit BertNormalizer(bool clean_text, bool handle_chinese_chars,
253
+ bool lowercase, std::optional<bool> strip_accents)
254
+ : clean_text_(clean_text), handle_chinese_chars_(handle_chinese_chars),
255
+ lowercase_(lowercase), strip_accents_(strip_accents) {}
256
+
257
+ /** Perform BERT normalization steps */
258
+ std::string normalize(const std::string &input) const override;
259
+
260
+ protected:
261
+ const bool clean_text_;
262
+ const bool handle_chinese_chars_;
263
+ const bool lowercase_;
264
+ const std::optional<bool> strip_accents_;
265
+ };
266
+
213
267
  } // namespace tokenizers
@@ -0,0 +1,112 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+ // @lint-ignore-every LICENSELINT
9
+
10
+ #pragma once
11
+
12
+ // Standard
13
+ #include <memory>
14
+ #include <optional>
15
+ #include <string>
16
+ #include <vector>
17
+
18
+ // Third Party
19
+ #include <nlohmann/json.hpp>
20
+
21
+ namespace tokenizers {
22
+
23
+ // -- Padding ------------------------------------------------------------------
24
+
25
+ enum class PaddingDirection {
26
+ Left,
27
+ Right,
28
+ };
29
+
30
+ enum class PaddingStrategy {
31
+ BatchLongest,
32
+ Fixed,
33
+ };
34
+
35
+ struct PaddingParams {
36
+ PaddingStrategy strategy = PaddingStrategy::BatchLongest;
37
+ PaddingDirection direction = PaddingDirection::Right;
38
+ std::optional<size_t> fixed_size;
39
+ std::optional<size_t> pad_to_multiple_of;
40
+ uint32_t pad_id = 0;
41
+ uint32_t pad_type_id = 0;
42
+ std::string pad_token = "[PAD]";
43
+ };
44
+
45
+ class Padding {
46
+ public:
47
+ /** Shared pointer type */
48
+ typedef std::shared_ptr<Padding> Ptr;
49
+
50
+ /**
51
+ * @param params: The padding parameters
52
+ */
53
+ explicit Padding(const PaddingParams &params);
54
+
55
+ /**
56
+ * Pad the tokens according to the configuration
57
+ */
58
+ std::vector<uint64_t> pad(std::vector<uint64_t> tokens) const;
59
+
60
+ /**
61
+ * Generate attention mask for the padded tokens.
62
+ * 1 for real tokens, 0 for padded tokens.
63
+ */
64
+ std::vector<uint32_t> generate_mask(const std::vector<uint64_t> &tokens,
65
+ size_t padded_size) const;
66
+
67
+ private:
68
+ PaddingParams params_;
69
+ };
70
+
71
+ // -- Factory ------------------------------------------------------------------
72
+
73
+ // Helper macro to standardize addition of config member fields
74
+ #define PADDING_CONFIG_MEMBER(type, name) \
75
+ PaddingConfig &set_##name(type arg) { \
76
+ this->params.name = std::move(arg); \
77
+ return *this; \
78
+ }
79
+
80
+ class PaddingConfig {
81
+ public:
82
+ explicit PaddingConfig(std::string strategy = "");
83
+
84
+ /**
85
+ * Construct the padding instance from the member data
86
+ */
87
+ Padding::Ptr create() const;
88
+
89
+ /**
90
+ * Populate from a json config file
91
+ */
92
+ PaddingConfig &parse_json(const nlohmann::json &json_config);
93
+
94
+ // Configuration members
95
+ PaddingParams params;
96
+
97
+ PADDING_CONFIG_MEMBER(PaddingStrategy, strategy)
98
+ PADDING_CONFIG_MEMBER(PaddingDirection, direction)
99
+
100
+ PaddingConfig &set_fixed_size(std::optional<size_t> arg) {
101
+ this->params.fixed_size = std::move(arg);
102
+ this->params.strategy = PaddingStrategy::Fixed;
103
+ return *this;
104
+ }
105
+
106
+ PADDING_CONFIG_MEMBER(std::optional<size_t>, pad_to_multiple_of)
107
+ PADDING_CONFIG_MEMBER(uint32_t, pad_id)
108
+ PADDING_CONFIG_MEMBER(uint32_t, pad_type_id)
109
+ PADDING_CONFIG_MEMBER(std::string, pad_token)
110
+ };
111
+
112
+ } // namespace tokenizers
@@ -38,6 +38,10 @@ public:
38
38
 
39
39
  /**
40
40
  * Process the token IDs (single sequence).
41
+ *
42
+ * NOTE: Unlike the Rust implementation which uses a single method
43
+ * taking Encoding and an Option<Encoding>, we use overloads here
44
+ * to explicitly handle single vs pair sequences while processing raw IDs.
41
45
  */
42
46
  virtual std::vector<uint64_t>
43
47
  process(const std::vector<uint64_t> &tokens,
@@ -54,27 +58,65 @@ public:
54
58
 
55
59
  // -- Factory/Common Types -----------------------------------------------------
56
60
 
61
+ // Helper macro to standardize addition of config member fields
62
+ #define POST_PROCESSOR_CONFIG_MEMBER(type, name) \
63
+ std::optional<type> name; \
64
+ PostProcessorConfig &set_##name(type arg) { \
65
+ this->name = std::move(arg); \
66
+ return *this; \
67
+ }
68
+
57
69
  enum class SequenceId { A, B };
58
70
 
71
+ struct SpecialToken {
72
+ std::string id;
73
+ std::vector<uint64_t> ids;
74
+ std::vector<std::string> tokens;
75
+ };
76
+
59
77
  struct Piece {
60
78
  bool is_special_token;
61
79
  std::string id; // For SpecialToken (e.g. "[CLS]"). For Sequence (e.g. "A").
62
- uint32_t type_id;
80
+ uint64_t type_id;
63
81
 
64
- static Piece Sequence(SequenceId id, uint32_t type_id) {
82
+ static Piece Sequence(SequenceId id, uint64_t type_id) {
65
83
  return {false, id == SequenceId::A ? "A" : "B", type_id};
66
84
  }
67
- static Piece SpecialToken(std::string id, uint32_t type_id) {
85
+ static Piece SpecialToken(std::string id, uint64_t type_id) {
68
86
  return {true, std::move(id), type_id};
69
87
  }
70
88
  };
71
89
 
72
90
  using Template = std::vector<Piece>;
91
+ // -- Config -------------------------------------------------------------------
73
92
 
74
- struct SpecialToken {
75
- std::string id;
76
- std::vector<uint32_t> ids;
77
- std::vector<std::string> tokens;
93
+ class PostProcessorConfig {
94
+ public:
95
+ using SpecialTokenMap = std::map<std::string, tokenizers::SpecialToken>;
96
+ using StringIdPair = std::pair<std::string, uint64_t>;
97
+
98
+ std::string type;
99
+
100
+ // TemplateProcessing
101
+ POST_PROCESSOR_CONFIG_MEMBER(Template, single)
102
+ POST_PROCESSOR_CONFIG_MEMBER(Template, pair)
103
+ POST_PROCESSOR_CONFIG_MEMBER(SpecialTokenMap, special_tokens)
104
+
105
+ // Bert / Roberta (unused params in no-op, but kept for parsing logic)
106
+ POST_PROCESSOR_CONFIG_MEMBER(StringIdPair, sep)
107
+ POST_PROCESSOR_CONFIG_MEMBER(StringIdPair, cls)
108
+ POST_PROCESSOR_CONFIG_MEMBER(bool, trim_offsets)
109
+ POST_PROCESSOR_CONFIG_MEMBER(bool, add_prefix_space)
110
+
111
+ // Sequence
112
+ using Configs = std::vector<PostProcessorConfig>;
113
+ POST_PROCESSOR_CONFIG_MEMBER(Configs, processors)
114
+
115
+ explicit PostProcessorConfig(std::string type = "");
116
+
117
+ PostProcessor::Ptr create() const;
118
+
119
+ PostProcessorConfig &parse_json(const nlohmann::json &json_config);
78
120
  };
79
121
 
80
122
  // -- TemplateProcessing -------------------------------------------------------
@@ -106,11 +148,9 @@ private:
106
148
  bool add_special_tokens) const;
107
149
  };
108
150
 
109
- // -- BertProcessing -----------------------------------------------------------
110
-
111
- class BertProcessing : public PostProcessor {
151
+ class Sequence : public PostProcessor {
112
152
  public:
113
- BertProcessing();
153
+ explicit Sequence(std::vector<PostProcessor::Ptr> processors);
114
154
 
115
155
  size_t added_tokens(bool is_pair) const override;
116
156
 
@@ -120,13 +160,17 @@ public:
120
160
  std::vector<uint64_t> process(const std::vector<uint64_t> &tokens_a,
121
161
  const std::vector<uint64_t> &tokens_b,
122
162
  bool add_special_tokens = true) const override;
123
- };
124
163
 
125
- // -- RobertaProcessing --------------------------------------------------------
164
+ private:
165
+ std::vector<PostProcessor::Ptr> processors_;
166
+ };
126
167
 
127
- class RobertaProcessing : public PostProcessor {
168
+ // -- BertProcessing -----------------------------------------------------------
169
+ // Used for BERT post-processing (adding special tokens)
170
+ class BertProcessing : public PostProcessor {
128
171
  public:
129
- RobertaProcessing();
172
+ BertProcessing(std::pair<std::string, uint64_t> sep,
173
+ std::pair<std::string, uint64_t> cls);
130
174
 
131
175
  size_t added_tokens(bool is_pair) const override;
132
176
 
@@ -136,13 +180,19 @@ public:
136
180
  std::vector<uint64_t> process(const std::vector<uint64_t> &tokens_a,
137
181
  const std::vector<uint64_t> &tokens_b,
138
182
  bool add_special_tokens = true) const override;
139
- };
140
183
 
141
- // -- Sequence -----------------------------------------------------------------
184
+ private:
185
+ std::pair<std::string, uint64_t> sep_;
186
+ std::pair<std::string, uint64_t> cls_;
187
+ };
142
188
 
143
- class Sequence : public PostProcessor {
189
+ // -- RobertaProcessing --------------------------------------------------------
190
+ // Used for RoBERTa post-processing
191
+ class RobertaProcessing : public PostProcessor {
144
192
  public:
145
- explicit Sequence(std::vector<PostProcessor::Ptr> processors);
193
+ RobertaProcessing(std::pair<std::string, uint64_t> sep,
194
+ std::pair<std::string, uint64_t> cls, bool trim_offsets,
195
+ bool add_prefix_space);
146
196
 
147
197
  size_t added_tokens(bool is_pair) const override;
148
198
 
@@ -154,34 +204,43 @@ public:
154
204
  bool add_special_tokens = true) const override;
155
205
 
156
206
  private:
157
- std::vector<PostProcessor::Ptr> processors_;
207
+ std::pair<std::string, uint64_t> sep_;
208
+ std::pair<std::string, uint64_t> cls_;
209
+ bool trim_offsets_;
210
+ bool add_prefix_space_;
158
211
  };
159
212
 
160
- // -- Config -------------------------------------------------------------------
161
-
162
- class PostProcessorConfig {
163
- public:
164
- std::string type;
165
-
166
- // TemplateProcessing
167
- Template single;
168
- Template pair;
169
- std::map<std::string, SpecialToken> special_tokens;
170
-
171
- // Bert / Roberta (unused params in no-op, but kept for parsing logic)
172
- std::pair<std::string, uint32_t> sep;
173
- std::pair<std::string, uint32_t> cls;
174
- bool trim_offsets = true;
175
- bool add_prefix_space = true;
213
+ // -- ByteLevel
214
+ // ----------------------------------------------------------------
215
+ // TODO: Implement ByteLevelProcessor
216
+ // This is a broader issue, as most of the processing is done on offsets.
217
+ // Our current implementation doesn't supoort it and would require us to
218
+ // introduce a complex Encoding type. Something similiar to the originl hf
219
+ // implementaiton:
220
+ // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/encoding.rs
221
+ // so we could store the offsets from pretokenization step.
222
+ /*
223
+ class ByteLevel : public PostProcessor {
224
+ public:
225
+ ByteLevel(bool trim_offsets, bool add_prefix_space);
176
226
 
177
- // Sequence
178
- std::vector<PostProcessorConfig> processors;
227
+ size_t added_tokens(bool is_pair) const override;
179
228
 
180
- explicit PostProcessorConfig(std::string type = "");
229
+ std::vector<uint64_t> process(
230
+ const std::vector<uint64_t>& tokens,
231
+ bool add_special_tokens = true) const override;
181
232
 
182
- PostProcessor::Ptr create() const;
233
+ std::vector<uint64_t> process(
234
+ const std::vector<uint64_t>& tokens_a,
235
+ const std::vector<uint64_t>& tokens_b,
236
+ bool add_special_tokens = true) const override;
183
237
 
184
- PostProcessorConfig &parse_json(const nlohmann::json &json_config);
238
+ private:
239
+ bool trim_offsets_;
240
+ bool add_prefix_space_;
185
241
  };
242
+ */
186
243
 
187
- } // namespace tokenizers
244
+ // -- Sequence
245
+ // -----------------------------------------------------------------
246
+ } // namespace tokenizers
@@ -53,7 +53,7 @@ public:
53
53
  // -- Factory ------------------------------------------------------------------
54
54
 
55
55
  // Helper macro to standardize addition of config member fields
56
- #define CONFIG_MEMBER(type, name) \
56
+ #define PRETOKENIZER_CONFIG_MEMBER(type, name) \
57
57
  std::optional<type> name; \
58
58
  PreTokenizerConfig &set_##name(type arg) { \
59
59
  this->name = std::move(arg); \
@@ -92,37 +92,38 @@ public:
92
92
  /**
93
93
  * Used by: RegexPreTokenizer, ByteLevelPreTokenizer
94
94
  */
95
- CONFIG_MEMBER(std::string, pattern)
95
+ PRETOKENIZER_CONFIG_MEMBER(std::string, pattern)
96
96
 
97
97
  /**
98
98
  * Used by: DigitsPreTokenizer
99
99
  */
100
- CONFIG_MEMBER(bool, individual_digits)
100
+ PRETOKENIZER_CONFIG_MEMBER(bool, individual_digits)
101
101
 
102
102
  /**
103
103
  * Used by: ByteLevelPreTokenizer
104
104
  */
105
- CONFIG_MEMBER(bool, add_prefix_space)
105
+ PRETOKENIZER_CONFIG_MEMBER(bool, add_prefix_space)
106
106
 
107
107
  /**
108
108
  * Used by RegexPreTokenizer
109
109
  */
110
- CONFIG_MEMBER(bool, is_delimiter)
110
+ PRETOKENIZER_CONFIG_MEMBER(bool, is_delimiter)
111
111
 
112
112
  /**
113
113
  * Used by RegexPreTokenizer - Split behavior
114
114
  */
115
- CONFIG_MEMBER(std::string, behavior)
115
+ PRETOKENIZER_CONFIG_MEMBER(std::string, behavior)
116
116
 
117
117
  /**
118
118
  * Used by RegexPreTokenizer - Split invert flag
119
119
  */
120
- CONFIG_MEMBER(bool, invert)
120
+ PRETOKENIZER_CONFIG_MEMBER(bool, invert)
121
121
 
122
122
  /**
123
123
  * Used by: SequencePreTokenizer
124
124
  */
125
- CONFIG_MEMBER(std::vector<PreTokenizerConfig>, pretokenizers)
125
+ using Configs = std::vector<PreTokenizerConfig>;
126
+ PRETOKENIZER_CONFIG_MEMBER(Configs, pretokenizers)
126
127
 
127
128
  /*----------------*/
128
129
  /* Public methods */
@@ -259,6 +260,21 @@ public:
259
260
  private:
260
261
  const std::vector<PreTokenizer::Ptr> pre_tokenizers_;
261
262
 
262
- }; // end class ByteLevelPreTokenizer
263
+ }; // end class SequencePreTokenizer
264
+
265
+ // -- Bert ---------------------------------------------------------------------
266
+ // Used for BERT-style pre-tokenization (splitting on whitespace and
267
+ // punctuation) CITE:
268
+ // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/bert.rs
269
+
270
+ class BertPreTokenizer : public PreTokenizer {
271
+ public:
272
+ BertPreTokenizer() = default;
273
+
274
+ /** Perform BERT pre-tokenization */
275
+ std::vector<std::string>
276
+ pre_tokenize(const std::string &input) const override;
277
+
278
+ }; // end class BertPreTokenizer
263
279
 
264
280
  } // namespace tokenizers
@@ -55,6 +55,14 @@ public:
55
55
 
56
56
  // -- Factory ------------------------------------------------------------------
57
57
 
58
+ // Helper macro to standardize addition of config member fields
59
+ #define TOKEN_DECODER_CONFIG_MEMBER(type, name) \
60
+ std::optional<type> name; \
61
+ TokenDecoderConfig &set_##name(type arg) { \
62
+ this->name = std::move(arg); \
63
+ return *this; \
64
+ }
65
+
58
66
  /**
59
67
  * Factory and config class for creating a new TokenDecoder
60
68
  */
@@ -67,16 +75,20 @@ public:
67
75
  std::string type;
68
76
 
69
77
  // Parameters for Replace decoder
70
- std::string replace_pattern;
71
- std::string replace_content;
78
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, replace_pattern)
79
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, replace_content)
72
80
 
73
81
  // Parameters for Sequence decoder
74
- std::vector<nlohmann::json> sequence_decoders;
82
+ TOKEN_DECODER_CONFIG_MEMBER(std::vector<nlohmann::json>, sequence_decoders)
75
83
 
76
84
  // Parameters for Strip decoder
77
- std::string strip_content;
78
- size_t strip_start;
79
- size_t strip_stop;
85
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, strip_content)
86
+ TOKEN_DECODER_CONFIG_MEMBER(size_t, strip_start)
87
+ TOKEN_DECODER_CONFIG_MEMBER(size_t, strip_stop)
88
+
89
+ // Parameters for WordPiece decoder
90
+ TOKEN_DECODER_CONFIG_MEMBER(std::string, wordpiece_prefix)
91
+ TOKEN_DECODER_CONFIG_MEMBER(bool, wordpiece_cleanup)
80
92
 
81
93
  /*----------------*/
82
94
  /* Public methods */
@@ -161,6 +173,21 @@ private:
161
173
  size_t stop_;
162
174
  }; // end class StripTokenDecoder
163
175
 
176
+ // -- WordPiece ----------------------------------------------------------------
177
+ // Used for WordPiece decoding
178
+
179
+ class WordPieceTokenDecoder : public TokenDecoder {
180
+ public:
181
+ explicit WordPieceTokenDecoder(std::string prefix = "##",
182
+ bool cleanup = true);
183
+ std::vector<std::string>
184
+ decode(const std::vector<std::string> &tokens) const override;
185
+
186
+ private:
187
+ std::string prefix_;
188
+ bool cleanup_;
189
+ }; // end class WordPieceTokenDecoder
190
+
164
191
  // -- Sequence -----------------------------------------------------------------
165
192
  // Applies a sequence of decoders in order
166
193