react-native-executorch 0.7.0 → 0.7.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/common/rnexecutorch/TokenizerModule.cpp +3 -2
- package/common/rnexecutorch/TokenizerModule.h +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +8 -4
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -1
- package/package.json +4 -3
- package/src/modules/computer_vision/TextToImageModule.ts +9 -4
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_model.h +84 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_tokenizer_base.h +6 -87
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/hf_tokenizer.h +28 -176
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/map_utils.h +174 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/model.h +151 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/normalizer.h +55 -1
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/padding.h +112 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/post_processor.h +101 -42
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/pre_tokenizer.h +25 -9
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/token_decoder.h +33 -6
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/tokenizer.h +2 -2
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/truncation.h +92 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/wordpiece_model.h +74 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +0 -253
- package/common/rnexecutorch/tests/README.md +0 -73
- package/common/rnexecutorch/tests/integration/BaseModelTest.cpp +0 -207
- package/common/rnexecutorch/tests/integration/BaseModelTests.h +0 -120
- package/common/rnexecutorch/tests/integration/ClassificationTest.cpp +0 -117
- package/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +0 -122
- package/common/rnexecutorch/tests/integration/ImageSegmentationTest.cpp +0 -152
- package/common/rnexecutorch/tests/integration/LLMTest.cpp +0 -155
- package/common/rnexecutorch/tests/integration/OCRTest.cpp +0 -128
- package/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +0 -135
- package/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp +0 -97
- package/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +0 -112
- package/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +0 -164
- package/common/rnexecutorch/tests/integration/TextToImageTest.cpp +0 -149
- package/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp +0 -98
- package/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +0 -238
- package/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp +0 -99
- package/common/rnexecutorch/tests/integration/assets/test_audio_float.raw +0 -0
- package/common/rnexecutorch/tests/integration/assets/we_are_software_mansion.jpg +0 -0
- package/common/rnexecutorch/tests/integration/libs/libfbjni.so +0 -0
- package/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp +0 -45
- package/common/rnexecutorch/tests/integration/utils/TestUtils.h +0 -36
- package/common/rnexecutorch/tests/run_tests.sh +0 -333
- package/common/rnexecutorch/tests/unit/FileUtilsTest.cpp +0 -32
- package/common/rnexecutorch/tests/unit/LogTest.cpp +0 -529
- package/common/rnexecutorch/tests/unit/NumericalTest.cpp +0 -107
|
@@ -38,6 +38,10 @@ public:
|
|
|
38
38
|
|
|
39
39
|
/**
|
|
40
40
|
* Process the token IDs (single sequence).
|
|
41
|
+
*
|
|
42
|
+
* NOTE: Unlike the Rust implementation which uses a single method
|
|
43
|
+
* taking Encoding and an Option<Encoding>, we use overloads here
|
|
44
|
+
* to explicitly handle single vs pair sequences while processing raw IDs.
|
|
41
45
|
*/
|
|
42
46
|
virtual std::vector<uint64_t>
|
|
43
47
|
process(const std::vector<uint64_t> &tokens,
|
|
@@ -54,27 +58,65 @@ public:
|
|
|
54
58
|
|
|
55
59
|
// -- Factory/Common Types -----------------------------------------------------
|
|
56
60
|
|
|
61
|
+
// Helper macro to standardize addition of config member fields
|
|
62
|
+
#define POST_PROCESSOR_CONFIG_MEMBER(type, name) \
|
|
63
|
+
std::optional<type> name; \
|
|
64
|
+
PostProcessorConfig &set_##name(type arg) { \
|
|
65
|
+
this->name = std::move(arg); \
|
|
66
|
+
return *this; \
|
|
67
|
+
}
|
|
68
|
+
|
|
57
69
|
enum class SequenceId { A, B };
|
|
58
70
|
|
|
71
|
+
struct SpecialToken {
|
|
72
|
+
std::string id;
|
|
73
|
+
std::vector<uint64_t> ids;
|
|
74
|
+
std::vector<std::string> tokens;
|
|
75
|
+
};
|
|
76
|
+
|
|
59
77
|
struct Piece {
|
|
60
78
|
bool is_special_token;
|
|
61
79
|
std::string id; // For SpecialToken (e.g. "[CLS]"). For Sequence (e.g. "A").
|
|
62
|
-
|
|
80
|
+
uint64_t type_id;
|
|
63
81
|
|
|
64
|
-
static Piece Sequence(SequenceId id,
|
|
82
|
+
static Piece Sequence(SequenceId id, uint64_t type_id) {
|
|
65
83
|
return {false, id == SequenceId::A ? "A" : "B", type_id};
|
|
66
84
|
}
|
|
67
|
-
static Piece SpecialToken(std::string id,
|
|
85
|
+
static Piece SpecialToken(std::string id, uint64_t type_id) {
|
|
68
86
|
return {true, std::move(id), type_id};
|
|
69
87
|
}
|
|
70
88
|
};
|
|
71
89
|
|
|
72
90
|
using Template = std::vector<Piece>;
|
|
91
|
+
// -- Config -------------------------------------------------------------------
|
|
73
92
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
std::
|
|
77
|
-
std::
|
|
93
|
+
class PostProcessorConfig {
|
|
94
|
+
public:
|
|
95
|
+
using SpecialTokenMap = std::map<std::string, tokenizers::SpecialToken>;
|
|
96
|
+
using StringIdPair = std::pair<std::string, uint64_t>;
|
|
97
|
+
|
|
98
|
+
std::string type;
|
|
99
|
+
|
|
100
|
+
// TemplateProcessing
|
|
101
|
+
POST_PROCESSOR_CONFIG_MEMBER(Template, single)
|
|
102
|
+
POST_PROCESSOR_CONFIG_MEMBER(Template, pair)
|
|
103
|
+
POST_PROCESSOR_CONFIG_MEMBER(SpecialTokenMap, special_tokens)
|
|
104
|
+
|
|
105
|
+
// Bert / Roberta (unused params in no-op, but kept for parsing logic)
|
|
106
|
+
POST_PROCESSOR_CONFIG_MEMBER(StringIdPair, sep)
|
|
107
|
+
POST_PROCESSOR_CONFIG_MEMBER(StringIdPair, cls)
|
|
108
|
+
POST_PROCESSOR_CONFIG_MEMBER(bool, trim_offsets)
|
|
109
|
+
POST_PROCESSOR_CONFIG_MEMBER(bool, add_prefix_space)
|
|
110
|
+
|
|
111
|
+
// Sequence
|
|
112
|
+
using Configs = std::vector<PostProcessorConfig>;
|
|
113
|
+
POST_PROCESSOR_CONFIG_MEMBER(Configs, processors)
|
|
114
|
+
|
|
115
|
+
explicit PostProcessorConfig(std::string type = "");
|
|
116
|
+
|
|
117
|
+
PostProcessor::Ptr create() const;
|
|
118
|
+
|
|
119
|
+
PostProcessorConfig &parse_json(const nlohmann::json &json_config);
|
|
78
120
|
};
|
|
79
121
|
|
|
80
122
|
// -- TemplateProcessing -------------------------------------------------------
|
|
@@ -106,11 +148,9 @@ private:
|
|
|
106
148
|
bool add_special_tokens) const;
|
|
107
149
|
};
|
|
108
150
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class BertProcessing : public PostProcessor {
|
|
151
|
+
class Sequence : public PostProcessor {
|
|
112
152
|
public:
|
|
113
|
-
|
|
153
|
+
explicit Sequence(std::vector<PostProcessor::Ptr> processors);
|
|
114
154
|
|
|
115
155
|
size_t added_tokens(bool is_pair) const override;
|
|
116
156
|
|
|
@@ -120,13 +160,17 @@ public:
|
|
|
120
160
|
std::vector<uint64_t> process(const std::vector<uint64_t> &tokens_a,
|
|
121
161
|
const std::vector<uint64_t> &tokens_b,
|
|
122
162
|
bool add_special_tokens = true) const override;
|
|
123
|
-
};
|
|
124
163
|
|
|
125
|
-
|
|
164
|
+
private:
|
|
165
|
+
std::vector<PostProcessor::Ptr> processors_;
|
|
166
|
+
};
|
|
126
167
|
|
|
127
|
-
|
|
168
|
+
// -- BertProcessing -----------------------------------------------------------
|
|
169
|
+
// Used for BERT post-processing (adding special tokens)
|
|
170
|
+
class BertProcessing : public PostProcessor {
|
|
128
171
|
public:
|
|
129
|
-
|
|
172
|
+
BertProcessing(std::pair<std::string, uint64_t> sep,
|
|
173
|
+
std::pair<std::string, uint64_t> cls);
|
|
130
174
|
|
|
131
175
|
size_t added_tokens(bool is_pair) const override;
|
|
132
176
|
|
|
@@ -136,13 +180,19 @@ public:
|
|
|
136
180
|
std::vector<uint64_t> process(const std::vector<uint64_t> &tokens_a,
|
|
137
181
|
const std::vector<uint64_t> &tokens_b,
|
|
138
182
|
bool add_special_tokens = true) const override;
|
|
139
|
-
};
|
|
140
183
|
|
|
141
|
-
|
|
184
|
+
private:
|
|
185
|
+
std::pair<std::string, uint64_t> sep_;
|
|
186
|
+
std::pair<std::string, uint64_t> cls_;
|
|
187
|
+
};
|
|
142
188
|
|
|
143
|
-
|
|
189
|
+
// -- RobertaProcessing --------------------------------------------------------
|
|
190
|
+
// Used for RoBERTa post-processing
|
|
191
|
+
class RobertaProcessing : public PostProcessor {
|
|
144
192
|
public:
|
|
145
|
-
|
|
193
|
+
RobertaProcessing(std::pair<std::string, uint64_t> sep,
|
|
194
|
+
std::pair<std::string, uint64_t> cls, bool trim_offsets,
|
|
195
|
+
bool add_prefix_space);
|
|
146
196
|
|
|
147
197
|
size_t added_tokens(bool is_pair) const override;
|
|
148
198
|
|
|
@@ -154,34 +204,43 @@ public:
|
|
|
154
204
|
bool add_special_tokens = true) const override;
|
|
155
205
|
|
|
156
206
|
private:
|
|
157
|
-
std::
|
|
207
|
+
std::pair<std::string, uint64_t> sep_;
|
|
208
|
+
std::pair<std::string, uint64_t> cls_;
|
|
209
|
+
bool trim_offsets_;
|
|
210
|
+
bool add_prefix_space_;
|
|
158
211
|
};
|
|
159
212
|
|
|
160
|
-
// --
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
std::pair<std::string, uint32_t> cls;
|
|
174
|
-
bool trim_offsets = true;
|
|
175
|
-
bool add_prefix_space = true;
|
|
213
|
+
// -- ByteLevel
|
|
214
|
+
// ----------------------------------------------------------------
|
|
215
|
+
// TODO: Implement ByteLevelProcessor
|
|
216
|
+
// This is a broader issue, as most of the processing is done on offsets.
|
|
217
|
+
// Our current implementation doesn't supoort it and would require us to
|
|
218
|
+
// introduce a complex Encoding type. Something similiar to the originl hf
|
|
219
|
+
// implementaiton:
|
|
220
|
+
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/encoding.rs
|
|
221
|
+
// so we could store the offsets from pretokenization step.
|
|
222
|
+
/*
|
|
223
|
+
class ByteLevel : public PostProcessor {
|
|
224
|
+
public:
|
|
225
|
+
ByteLevel(bool trim_offsets, bool add_prefix_space);
|
|
176
226
|
|
|
177
|
-
|
|
178
|
-
std::vector<PostProcessorConfig> processors;
|
|
227
|
+
size_t added_tokens(bool is_pair) const override;
|
|
179
228
|
|
|
180
|
-
|
|
229
|
+
std::vector<uint64_t> process(
|
|
230
|
+
const std::vector<uint64_t>& tokens,
|
|
231
|
+
bool add_special_tokens = true) const override;
|
|
181
232
|
|
|
182
|
-
|
|
233
|
+
std::vector<uint64_t> process(
|
|
234
|
+
const std::vector<uint64_t>& tokens_a,
|
|
235
|
+
const std::vector<uint64_t>& tokens_b,
|
|
236
|
+
bool add_special_tokens = true) const override;
|
|
183
237
|
|
|
184
|
-
|
|
238
|
+
private:
|
|
239
|
+
bool trim_offsets_;
|
|
240
|
+
bool add_prefix_space_;
|
|
185
241
|
};
|
|
242
|
+
*/
|
|
186
243
|
|
|
187
|
-
|
|
244
|
+
// -- Sequence
|
|
245
|
+
// -----------------------------------------------------------------
|
|
246
|
+
} // namespace tokenizers
|
|
@@ -53,7 +53,7 @@ public:
|
|
|
53
53
|
// -- Factory ------------------------------------------------------------------
|
|
54
54
|
|
|
55
55
|
// Helper macro to standardize addition of config member fields
|
|
56
|
-
#define
|
|
56
|
+
#define PRETOKENIZER_CONFIG_MEMBER(type, name) \
|
|
57
57
|
std::optional<type> name; \
|
|
58
58
|
PreTokenizerConfig &set_##name(type arg) { \
|
|
59
59
|
this->name = std::move(arg); \
|
|
@@ -92,37 +92,38 @@ public:
|
|
|
92
92
|
/**
|
|
93
93
|
* Used by: RegexPreTokenizer, ByteLevelPreTokenizer
|
|
94
94
|
*/
|
|
95
|
-
|
|
95
|
+
PRETOKENIZER_CONFIG_MEMBER(std::string, pattern)
|
|
96
96
|
|
|
97
97
|
/**
|
|
98
98
|
* Used by: DigitsPreTokenizer
|
|
99
99
|
*/
|
|
100
|
-
|
|
100
|
+
PRETOKENIZER_CONFIG_MEMBER(bool, individual_digits)
|
|
101
101
|
|
|
102
102
|
/**
|
|
103
103
|
* Used by: ByteLevelPreTokenizer
|
|
104
104
|
*/
|
|
105
|
-
|
|
105
|
+
PRETOKENIZER_CONFIG_MEMBER(bool, add_prefix_space)
|
|
106
106
|
|
|
107
107
|
/**
|
|
108
108
|
* Used by RegexPreTokenizer
|
|
109
109
|
*/
|
|
110
|
-
|
|
110
|
+
PRETOKENIZER_CONFIG_MEMBER(bool, is_delimiter)
|
|
111
111
|
|
|
112
112
|
/**
|
|
113
113
|
* Used by RegexPreTokenizer - Split behavior
|
|
114
114
|
*/
|
|
115
|
-
|
|
115
|
+
PRETOKENIZER_CONFIG_MEMBER(std::string, behavior)
|
|
116
116
|
|
|
117
117
|
/**
|
|
118
118
|
* Used by RegexPreTokenizer - Split invert flag
|
|
119
119
|
*/
|
|
120
|
-
|
|
120
|
+
PRETOKENIZER_CONFIG_MEMBER(bool, invert)
|
|
121
121
|
|
|
122
122
|
/**
|
|
123
123
|
* Used by: SequencePreTokenizer
|
|
124
124
|
*/
|
|
125
|
-
|
|
125
|
+
using Configs = std::vector<PreTokenizerConfig>;
|
|
126
|
+
PRETOKENIZER_CONFIG_MEMBER(Configs, pretokenizers)
|
|
126
127
|
|
|
127
128
|
/*----------------*/
|
|
128
129
|
/* Public methods */
|
|
@@ -259,6 +260,21 @@ public:
|
|
|
259
260
|
private:
|
|
260
261
|
const std::vector<PreTokenizer::Ptr> pre_tokenizers_;
|
|
261
262
|
|
|
262
|
-
}; // end class
|
|
263
|
+
}; // end class SequencePreTokenizer
|
|
264
|
+
|
|
265
|
+
// -- Bert ---------------------------------------------------------------------
|
|
266
|
+
// Used for BERT-style pre-tokenization (splitting on whitespace and
|
|
267
|
+
// punctuation) CITE:
|
|
268
|
+
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/bert.rs
|
|
269
|
+
|
|
270
|
+
class BertPreTokenizer : public PreTokenizer {
|
|
271
|
+
public:
|
|
272
|
+
BertPreTokenizer() = default;
|
|
273
|
+
|
|
274
|
+
/** Perform BERT pre-tokenization */
|
|
275
|
+
std::vector<std::string>
|
|
276
|
+
pre_tokenize(const std::string &input) const override;
|
|
277
|
+
|
|
278
|
+
}; // end class BertPreTokenizer
|
|
263
279
|
|
|
264
280
|
} // namespace tokenizers
|
|
@@ -55,6 +55,14 @@ public:
|
|
|
55
55
|
|
|
56
56
|
// -- Factory ------------------------------------------------------------------
|
|
57
57
|
|
|
58
|
+
// Helper macro to standardize addition of config member fields
|
|
59
|
+
#define TOKEN_DECODER_CONFIG_MEMBER(type, name) \
|
|
60
|
+
std::optional<type> name; \
|
|
61
|
+
TokenDecoderConfig &set_##name(type arg) { \
|
|
62
|
+
this->name = std::move(arg); \
|
|
63
|
+
return *this; \
|
|
64
|
+
}
|
|
65
|
+
|
|
58
66
|
/**
|
|
59
67
|
* Factory and config class for creating a new TokenDecoder
|
|
60
68
|
*/
|
|
@@ -67,16 +75,20 @@ public:
|
|
|
67
75
|
std::string type;
|
|
68
76
|
|
|
69
77
|
// Parameters for Replace decoder
|
|
70
|
-
std::string replace_pattern
|
|
71
|
-
std::string replace_content
|
|
78
|
+
TOKEN_DECODER_CONFIG_MEMBER(std::string, replace_pattern)
|
|
79
|
+
TOKEN_DECODER_CONFIG_MEMBER(std::string, replace_content)
|
|
72
80
|
|
|
73
81
|
// Parameters for Sequence decoder
|
|
74
|
-
std::vector<nlohmann::json
|
|
82
|
+
TOKEN_DECODER_CONFIG_MEMBER(std::vector<nlohmann::json>, sequence_decoders)
|
|
75
83
|
|
|
76
84
|
// Parameters for Strip decoder
|
|
77
|
-
std::string strip_content
|
|
78
|
-
size_t strip_start
|
|
79
|
-
size_t strip_stop
|
|
85
|
+
TOKEN_DECODER_CONFIG_MEMBER(std::string, strip_content)
|
|
86
|
+
TOKEN_DECODER_CONFIG_MEMBER(size_t, strip_start)
|
|
87
|
+
TOKEN_DECODER_CONFIG_MEMBER(size_t, strip_stop)
|
|
88
|
+
|
|
89
|
+
// Parameters for WordPiece decoder
|
|
90
|
+
TOKEN_DECODER_CONFIG_MEMBER(std::string, wordpiece_prefix)
|
|
91
|
+
TOKEN_DECODER_CONFIG_MEMBER(bool, wordpiece_cleanup)
|
|
80
92
|
|
|
81
93
|
/*----------------*/
|
|
82
94
|
/* Public methods */
|
|
@@ -161,6 +173,21 @@ private:
|
|
|
161
173
|
size_t stop_;
|
|
162
174
|
}; // end class StripTokenDecoder
|
|
163
175
|
|
|
176
|
+
// -- WordPiece ----------------------------------------------------------------
|
|
177
|
+
// Used for WordPiece decoding
|
|
178
|
+
|
|
179
|
+
class WordPieceTokenDecoder : public TokenDecoder {
|
|
180
|
+
public:
|
|
181
|
+
explicit WordPieceTokenDecoder(std::string prefix = "##",
|
|
182
|
+
bool cleanup = true);
|
|
183
|
+
std::vector<std::string>
|
|
184
|
+
decode(const std::vector<std::string> &tokens) const override;
|
|
185
|
+
|
|
186
|
+
private:
|
|
187
|
+
std::string prefix_;
|
|
188
|
+
bool cleanup_;
|
|
189
|
+
}; // end class WordPieceTokenDecoder
|
|
190
|
+
|
|
164
191
|
// -- Sequence -----------------------------------------------------------------
|
|
165
192
|
// Applies a sequence of decoders in order
|
|
166
193
|
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
// @lint-ignore-every LICENSELINT
|
|
9
|
+
|
|
10
|
+
#pragma once
|
|
11
|
+
|
|
12
|
+
// Standard
|
|
13
|
+
#include <memory>
|
|
14
|
+
#include <optional>
|
|
15
|
+
#include <string>
|
|
16
|
+
#include <vector>
|
|
17
|
+
|
|
18
|
+
// Third Party
|
|
19
|
+
#include <nlohmann/json.hpp>
|
|
20
|
+
|
|
21
|
+
namespace tokenizers {
|
|
22
|
+
|
|
23
|
+
// -- Truncation ---------------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
enum class TruncationStrategy {
|
|
26
|
+
LongestFirst,
|
|
27
|
+
OnlyFirst,
|
|
28
|
+
OnlySecond,
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
enum class TruncationDirection {
|
|
32
|
+
Left,
|
|
33
|
+
Right,
|
|
34
|
+
};
|
|
35
|
+
|
|
36
|
+
struct TruncationParams {
|
|
37
|
+
TruncationDirection direction = TruncationDirection::Right;
|
|
38
|
+
size_t max_length = 512;
|
|
39
|
+
TruncationStrategy strategy = TruncationStrategy::LongestFirst;
|
|
40
|
+
size_t stride = 0;
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
class Truncation {
|
|
44
|
+
public:
|
|
45
|
+
/** Shared pointer type */
|
|
46
|
+
typedef std::shared_ptr<Truncation> Ptr;
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* @param params: The truncation parameters
|
|
50
|
+
*/
|
|
51
|
+
explicit Truncation(const TruncationParams ¶ms);
|
|
52
|
+
|
|
53
|
+
/**
|
|
54
|
+
* Truncate the tokens according to the configuration.
|
|
55
|
+
*
|
|
56
|
+
* @param tokens The tokens to truncate.
|
|
57
|
+
* @param num_tokens_to_add The number of special tokens that will be added
|
|
58
|
+
* later. These are subtracted from max_length during truncation calculation.
|
|
59
|
+
*/
|
|
60
|
+
std::vector<uint64_t> truncate(std::vector<uint64_t> tokens,
|
|
61
|
+
size_t num_tokens_to_add = 0) const;
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* Truncate a pair of sequences according to the configuration.
|
|
65
|
+
*/
|
|
66
|
+
std::pair<std::vector<uint64_t>, std::vector<uint64_t>>
|
|
67
|
+
truncate_pair(std::vector<uint64_t> a, std::vector<uint64_t> b,
|
|
68
|
+
size_t num_tokens_to_add = 0) const;
|
|
69
|
+
|
|
70
|
+
private:
|
|
71
|
+
TruncationParams params_;
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
// -- Factory ------------------------------------------------------------------
|
|
75
|
+
|
|
76
|
+
class TruncationConfig {
|
|
77
|
+
public:
|
|
78
|
+
/**
|
|
79
|
+
* Construct the truncation instance from the member data
|
|
80
|
+
*/
|
|
81
|
+
Truncation::Ptr create() const;
|
|
82
|
+
|
|
83
|
+
/**
|
|
84
|
+
* Populate from a json config file
|
|
85
|
+
*/
|
|
86
|
+
TruncationConfig &parse_json(const nlohmann::json &json_config);
|
|
87
|
+
|
|
88
|
+
// Configuration members
|
|
89
|
+
TruncationParams params;
|
|
90
|
+
};
|
|
91
|
+
|
|
92
|
+
} // namespace tokenizers
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
// @lint-ignore-every LICENSELINT
|
|
9
|
+
|
|
10
|
+
#pragma once
|
|
11
|
+
|
|
12
|
+
#include <memory>
|
|
13
|
+
#include <optional>
|
|
14
|
+
#include <string>
|
|
15
|
+
#include <vector>
|
|
16
|
+
|
|
17
|
+
#include <pytorch/tokenizers/model.h>
|
|
18
|
+
#include <pytorch/tokenizers/regex.h>
|
|
19
|
+
#include <pytorch/tokenizers/result.h>
|
|
20
|
+
#include <pytorch/tokenizers/string_integer_map.h>
|
|
21
|
+
|
|
22
|
+
namespace tokenizers {
|
|
23
|
+
|
|
24
|
+
class WordPieceModel : public Model {
|
|
25
|
+
public:
|
|
26
|
+
explicit WordPieceModel(detail::TokenMap token_map,
|
|
27
|
+
detail::TokenMap special_token_map,
|
|
28
|
+
std::string unk_token,
|
|
29
|
+
std::string continuing_subword_prefix,
|
|
30
|
+
size_t max_input_chars_per_word,
|
|
31
|
+
std::optional<uint64_t> unk_token_id,
|
|
32
|
+
std::optional<uint64_t> bos_token_id,
|
|
33
|
+
std::optional<uint64_t> eos_token_id);
|
|
34
|
+
|
|
35
|
+
~WordPieceModel() override = default;
|
|
36
|
+
|
|
37
|
+
Result<std::vector<uint64_t>>
|
|
38
|
+
tokenize(const std::string &piece) const override;
|
|
39
|
+
|
|
40
|
+
Result<std::string> id_to_piece(uint64_t token) const override;
|
|
41
|
+
Result<uint64_t> piece_to_id(const std::string &token) const override;
|
|
42
|
+
|
|
43
|
+
int32_t vocab_size() const override { return vocab_size_; }
|
|
44
|
+
|
|
45
|
+
bool is_special_token(uint64_t token) const override;
|
|
46
|
+
|
|
47
|
+
bool is_loaded() const override { return initialized_; }
|
|
48
|
+
|
|
49
|
+
std::pair<std::optional<std::string>, std::string>
|
|
50
|
+
split_with_allowed_special_token(const std::string &input,
|
|
51
|
+
size_t offset) const override;
|
|
52
|
+
|
|
53
|
+
uint64_t bos_token_id() const override { return bos_token_id_.value_or(0); }
|
|
54
|
+
|
|
55
|
+
uint64_t eos_token_id() const override { return eos_token_id_.value_or(0); }
|
|
56
|
+
|
|
57
|
+
private:
|
|
58
|
+
detail::TokenMap token_map_;
|
|
59
|
+
detail::TokenMap special_token_map_;
|
|
60
|
+
std::unique_ptr<IRegex> special_token_regex_;
|
|
61
|
+
|
|
62
|
+
std::string unk_token_;
|
|
63
|
+
std::string continuing_subword_prefix_;
|
|
64
|
+
size_t max_input_chars_per_word_;
|
|
65
|
+
|
|
66
|
+
std::optional<uint64_t> unk_token_id_;
|
|
67
|
+
std::optional<uint64_t> bos_token_id_;
|
|
68
|
+
std::optional<uint64_t> eos_token_id_;
|
|
69
|
+
|
|
70
|
+
bool initialized_ = false;
|
|
71
|
+
int32_t vocab_size_ = 0;
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
} // namespace tokenizers
|