tokenizers 0.2.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/Cargo.lock +33 -74
- data/README.md +4 -0
- data/ext/tokenizers/Cargo.toml +4 -2
- data/ext/tokenizers/src/decoders.rs +275 -6
- data/ext/tokenizers/src/encoding.rs +78 -3
- data/ext/tokenizers/src/error.rs +2 -2
- data/ext/tokenizers/src/lib.rs +88 -17
- data/ext/tokenizers/src/models.rs +372 -11
- data/ext/tokenizers/src/normalizers.rs +435 -7
- data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
- data/ext/tokenizers/src/processors.rs +210 -0
- data/ext/tokenizers/src/tokenizer.rs +448 -20
- data/ext/tokenizers/src/trainers.rs +749 -0
- data/ext/tokenizers/src/utils/mod.rs +5 -0
- data/ext/tokenizers/src/utils/normalization.rs +85 -0
- data/ext/tokenizers/src/utils/regex.rs +22 -0
- data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
- data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
- data/lib/tokenizers/decoders/ctc.rb +9 -0
- data/lib/tokenizers/decoders/metaspace.rb +9 -0
- data/lib/tokenizers/decoders/word_piece.rb +9 -0
- data/lib/tokenizers/encoding.rb +19 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/bpe.rb +9 -0
- data/lib/tokenizers/models/unigram.rb +9 -0
- data/lib/tokenizers/models/word_level.rb +13 -0
- data/lib/tokenizers/models/word_piece.rb +9 -0
- data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
- data/lib/tokenizers/normalizers/strip.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
- data/lib/tokenizers/processors/byte_level.rb +9 -0
- data/lib/tokenizers/processors/roberta_processing.rb +9 -0
- data/lib/tokenizers/processors/template_processing.rb +9 -0
- data/lib/tokenizers/tokenizer.rb +45 -0
- data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
- data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
- data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
- data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +49 -7
- metadata +32 -3
@@ -1,14 +1,478 @@
|
|
1
|
+
use std::sync::{Arc, RwLock};
|
2
|
+
|
3
|
+
use magnus::typed_data::DataTypeBuilder;
|
4
|
+
use magnus::{
|
5
|
+
function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
|
6
|
+
RArray, RClass, RModule, TypedData,
|
7
|
+
};
|
8
|
+
|
9
|
+
use serde::ser::SerializeStruct;
|
10
|
+
use serde::{Deserialize, Serialize, Serializer};
|
11
|
+
|
1
12
|
use tk::pre_tokenizers::bert::BertPreTokenizer;
|
13
|
+
use tk::pre_tokenizers::byte_level::ByteLevel;
|
14
|
+
use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
|
15
|
+
use tk::pre_tokenizers::digits::Digits;
|
16
|
+
use tk::pre_tokenizers::metaspace::Metaspace;
|
17
|
+
use tk::pre_tokenizers::punctuation::Punctuation;
|
18
|
+
use tk::pre_tokenizers::split::Split;
|
19
|
+
use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
|
20
|
+
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
21
|
+
use tk::pre_tokenizers::PreTokenizerWrapper;
|
22
|
+
use tk::tokenizer::Offsets;
|
23
|
+
use tk::{PreTokenizedString, PreTokenizer};
|
24
|
+
|
25
|
+
use super::utils::*;
|
26
|
+
use super::{RbError, RbResult};
|
27
|
+
|
28
|
+
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
29
|
+
pub struct RbPreTokenizer {
|
30
|
+
#[serde(flatten)]
|
31
|
+
pub(crate) pretok: RbPreTokenizerTypeWrapper,
|
32
|
+
}
|
33
|
+
|
34
|
+
impl RbPreTokenizer {
|
35
|
+
fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
|
36
|
+
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
|
37
|
+
|
38
|
+
self.pretok.pre_tokenize(&mut pretokenized).map_err(RbError::from)?;
|
39
|
+
|
40
|
+
Ok(pretokenized
|
41
|
+
.get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
|
42
|
+
.into_iter()
|
43
|
+
.map(|(s, o, _)| (s.to_owned(), o))
|
44
|
+
.collect())
|
45
|
+
}
|
46
|
+
}
|
2
47
|
|
3
|
-
|
4
|
-
|
5
|
-
|
48
|
+
macro_rules! getter {
|
49
|
+
($self: ident, $variant: ident, $($name: tt)+) => {{
|
50
|
+
if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
|
51
|
+
if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) =
|
52
|
+
*single.read().unwrap() {
|
53
|
+
pretok.$($name)+
|
54
|
+
} else {
|
55
|
+
unreachable!()
|
56
|
+
}
|
57
|
+
} else {
|
58
|
+
unreachable!()
|
59
|
+
}
|
60
|
+
}};
|
6
61
|
}
|
7
62
|
|
63
|
+
macro_rules! setter {
|
64
|
+
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
65
|
+
if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
|
66
|
+
if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
67
|
+
*single.write().unwrap()
|
68
|
+
{
|
69
|
+
pretok.$name = $value;
|
70
|
+
}
|
71
|
+
}
|
72
|
+
}};
|
73
|
+
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
|
74
|
+
if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
|
75
|
+
if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
76
|
+
*single.write().unwrap()
|
77
|
+
{
|
78
|
+
pretok.$name($value);
|
79
|
+
}
|
80
|
+
}
|
81
|
+
}};
|
82
|
+
}
|
83
|
+
|
84
|
+
impl RbPreTokenizer {
|
85
|
+
#[allow(dead_code)]
|
86
|
+
pub(crate) fn new(pretok: RbPreTokenizerTypeWrapper) -> Self {
|
87
|
+
RbPreTokenizer { pretok }
|
88
|
+
}
|
89
|
+
|
90
|
+
fn byte_level_add_prefix_space(&self) -> bool {
|
91
|
+
getter!(self, ByteLevel, add_prefix_space)
|
92
|
+
}
|
93
|
+
|
94
|
+
fn byte_level_set_add_prefix_space(&self, add_prefix_space: bool) {
|
95
|
+
setter!(self, ByteLevel, add_prefix_space, add_prefix_space);
|
96
|
+
}
|
97
|
+
|
98
|
+
fn byte_level_use_regex(&self) -> bool {
|
99
|
+
getter!(self, ByteLevel, use_regex)
|
100
|
+
}
|
101
|
+
|
102
|
+
fn byte_level_set_use_regex(&self, use_regex: bool) {
|
103
|
+
setter!(self, ByteLevel, use_regex, use_regex);
|
104
|
+
}
|
105
|
+
|
106
|
+
fn char_delimiter_split_delimiter(&self) -> String {
|
107
|
+
getter!(self, Delimiter, delimiter.to_string())
|
108
|
+
}
|
109
|
+
|
110
|
+
fn char_delimiter_split_set_delimiter(&self, delimiter: char) {
|
111
|
+
setter!(self, Delimiter, delimiter, delimiter);
|
112
|
+
}
|
113
|
+
|
114
|
+
fn digits_individual_digits(&self) -> bool {
|
115
|
+
getter!(self, Digits, individual_digits)
|
116
|
+
}
|
117
|
+
|
118
|
+
fn digits_set_individual_digits(&self, individual_digits: bool) {
|
119
|
+
setter!(self, Digits, individual_digits, individual_digits);
|
120
|
+
}
|
121
|
+
|
122
|
+
fn metaspace_add_prefix_space(&self) -> bool {
|
123
|
+
getter!(self, Metaspace, add_prefix_space)
|
124
|
+
}
|
125
|
+
|
126
|
+
fn metaspace_set_add_prefix_space(&self, add_prefix_space: bool) {
|
127
|
+
setter!(self, Metaspace, add_prefix_space, add_prefix_space);
|
128
|
+
}
|
129
|
+
|
130
|
+
fn metaspace_replacement(&self) -> String {
|
131
|
+
getter!(self, Metaspace, get_replacement().to_string())
|
132
|
+
}
|
133
|
+
|
134
|
+
fn metaspace_set_replacement(&self, replacement: char) {
|
135
|
+
setter!(self, Metaspace, @set_replacement, replacement);
|
136
|
+
}
|
137
|
+
}
|
138
|
+
|
139
|
+
impl PreTokenizer for RbPreTokenizer {
|
140
|
+
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
|
141
|
+
self.pretok.pre_tokenize(normalized)
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
pub struct RbByteLevel {}
|
146
|
+
|
147
|
+
impl RbByteLevel {
|
148
|
+
pub fn new(add_prefix_space: bool, use_regex: bool) -> RbPreTokenizer {
|
149
|
+
ByteLevel::default()
|
150
|
+
.add_prefix_space(add_prefix_space)
|
151
|
+
.use_regex(use_regex)
|
152
|
+
.into()
|
153
|
+
}
|
154
|
+
|
155
|
+
fn alphabet() -> Vec<String> {
|
156
|
+
ByteLevel::alphabet()
|
157
|
+
.into_iter()
|
158
|
+
.map(|c| c.to_string())
|
159
|
+
.collect()
|
160
|
+
}
|
161
|
+
}
|
162
|
+
|
163
|
+
pub struct RbCharDelimiterSplit {}
|
164
|
+
|
165
|
+
impl RbCharDelimiterSplit {
|
166
|
+
pub fn new(delimiter: char) -> RbPreTokenizer {
|
167
|
+
CharDelimiterSplit::new(delimiter).into()
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
pub struct RbDigits {}
|
172
|
+
|
173
|
+
impl RbDigits {
|
174
|
+
fn new(individual_digits: bool) -> RbPreTokenizer {
|
175
|
+
Digits::new(individual_digits).into()
|
176
|
+
}
|
177
|
+
}
|
178
|
+
|
179
|
+
pub struct RbMetaspace {}
|
180
|
+
|
181
|
+
impl RbMetaspace {
|
182
|
+
fn new(
|
183
|
+
replacement: char,
|
184
|
+
add_prefix_space: bool,
|
185
|
+
) -> RbPreTokenizer {
|
186
|
+
Metaspace::new(replacement, add_prefix_space).into()
|
187
|
+
}
|
188
|
+
}
|
189
|
+
|
190
|
+
pub struct RbPunctuation {}
|
191
|
+
|
192
|
+
impl RbPunctuation {
|
193
|
+
pub fn new(behavior: RbSplitDelimiterBehavior) -> RbResult<RbPreTokenizer> {
|
194
|
+
Ok(Punctuation::new(behavior.into()).into())
|
195
|
+
}
|
196
|
+
}
|
197
|
+
|
198
|
+
pub struct RbSplit {}
|
199
|
+
|
200
|
+
impl RbSplit {
|
201
|
+
pub fn new(pattern: RbPattern, behavior: RbSplitDelimiterBehavior, invert: bool) -> RbResult<RbPreTokenizer> {
|
202
|
+
Split::new(pattern, behavior.into(), invert).map(|v| v.into()).map_err(RbError::from)
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
pub struct RbUnicodeScripts {}
|
207
|
+
|
208
|
+
impl RbUnicodeScripts {
|
209
|
+
pub fn new() -> RbPreTokenizer {
|
210
|
+
UnicodeScripts::new().into()
|
211
|
+
}
|
212
|
+
}
|
213
|
+
|
214
|
+
pub struct RbWhitespace {}
|
215
|
+
|
216
|
+
impl RbWhitespace {
|
217
|
+
pub fn new() -> RbPreTokenizer {
|
218
|
+
Whitespace::default().into()
|
219
|
+
}
|
220
|
+
}
|
221
|
+
|
222
|
+
pub struct RbWhitespaceSplit {}
|
223
|
+
|
224
|
+
impl RbWhitespaceSplit {
|
225
|
+
pub fn new() -> RbPreTokenizer {
|
226
|
+
WhitespaceSplit.into()
|
227
|
+
}
|
228
|
+
}
|
229
|
+
|
230
|
+
pub struct RbBertPreTokenizer {}
|
231
|
+
|
8
232
|
impl RbBertPreTokenizer {
|
9
|
-
pub fn new() ->
|
10
|
-
|
11
|
-
|
233
|
+
pub fn new() -> RbPreTokenizer {
|
234
|
+
BertPreTokenizer.into()
|
235
|
+
}
|
236
|
+
}
|
237
|
+
|
238
|
+
pub struct RbSequence {}
|
239
|
+
|
240
|
+
impl RbSequence {
|
241
|
+
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
242
|
+
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
243
|
+
for n in pre_tokenizers.each() {
|
244
|
+
let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
|
245
|
+
match &pretokenizer.pretok {
|
246
|
+
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
247
|
+
sequence.extend(inner.iter().cloned())
|
248
|
+
}
|
249
|
+
RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
250
|
+
}
|
251
|
+
}
|
252
|
+
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(sequence)))
|
253
|
+
}
|
254
|
+
}
|
255
|
+
|
256
|
+
#[derive(Clone, Deserialize)]
|
257
|
+
#[serde(untagged)]
|
258
|
+
pub(crate) enum RbPreTokenizerWrapper {
|
259
|
+
// Custom(CustomPreTokenizer),
|
260
|
+
Wrapped(PreTokenizerWrapper),
|
261
|
+
}
|
262
|
+
|
263
|
+
impl Serialize for RbPreTokenizerWrapper {
|
264
|
+
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
|
265
|
+
where
|
266
|
+
S: Serializer,
|
267
|
+
{
|
268
|
+
match self {
|
269
|
+
RbPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
|
270
|
+
// RbPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
|
271
|
+
}
|
272
|
+
}
|
273
|
+
}
|
274
|
+
|
275
|
+
#[derive(Clone, Deserialize)]
|
276
|
+
#[serde(untagged)]
|
277
|
+
pub(crate) enum RbPreTokenizerTypeWrapper {
|
278
|
+
Sequence(Vec<Arc<RwLock<RbPreTokenizerWrapper>>>),
|
279
|
+
Single(Arc<RwLock<RbPreTokenizerWrapper>>),
|
280
|
+
}
|
281
|
+
|
282
|
+
impl Serialize for RbPreTokenizerTypeWrapper {
|
283
|
+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
284
|
+
where
|
285
|
+
S: Serializer,
|
286
|
+
{
|
287
|
+
match self {
|
288
|
+
RbPreTokenizerTypeWrapper::Sequence(seq) => {
|
289
|
+
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
290
|
+
ser.serialize_field("type", "Sequence")?;
|
291
|
+
ser.serialize_field("pretokenizers", seq)?;
|
292
|
+
ser.end()
|
293
|
+
}
|
294
|
+
RbPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer),
|
295
|
+
}
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
impl<I> From<I> for RbPreTokenizerWrapper
|
300
|
+
where
|
301
|
+
I: Into<PreTokenizerWrapper>,
|
302
|
+
{
|
303
|
+
fn from(pretok: I) -> Self {
|
304
|
+
RbPreTokenizerWrapper::Wrapped(pretok.into())
|
305
|
+
}
|
306
|
+
}
|
307
|
+
|
308
|
+
impl<I> From<I> for RbPreTokenizerTypeWrapper
|
309
|
+
where
|
310
|
+
I: Into<RbPreTokenizerWrapper>,
|
311
|
+
{
|
312
|
+
fn from(pretok: I) -> Self {
|
313
|
+
RbPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into())))
|
314
|
+
}
|
315
|
+
}
|
316
|
+
|
317
|
+
impl<I> From<I> for RbPreTokenizer
|
318
|
+
where
|
319
|
+
I: Into<PreTokenizerWrapper>,
|
320
|
+
{
|
321
|
+
fn from(pretok: I) -> Self {
|
322
|
+
RbPreTokenizer {
|
323
|
+
pretok: pretok.into().into(),
|
324
|
+
}
|
325
|
+
}
|
326
|
+
}
|
327
|
+
|
328
|
+
impl PreTokenizer for RbPreTokenizerTypeWrapper {
|
329
|
+
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
330
|
+
match self {
|
331
|
+
RbPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok),
|
332
|
+
RbPreTokenizerTypeWrapper::Sequence(inner) => inner
|
333
|
+
.iter()
|
334
|
+
.try_for_each(|n| n.read().unwrap().pre_tokenize(pretok)),
|
12
335
|
}
|
13
336
|
}
|
14
337
|
}
|
338
|
+
|
339
|
+
impl PreTokenizer for RbPreTokenizerWrapper {
|
340
|
+
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
341
|
+
match self {
|
342
|
+
RbPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok),
|
343
|
+
// RbPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok),
|
344
|
+
}
|
345
|
+
}
|
346
|
+
}
|
347
|
+
|
348
|
+
unsafe impl TypedData for RbPreTokenizer {
|
349
|
+
fn class() -> RClass {
|
350
|
+
*memoize!(RClass: {
|
351
|
+
let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
|
352
|
+
class.undef_alloc_func();
|
353
|
+
class
|
354
|
+
})
|
355
|
+
}
|
356
|
+
|
357
|
+
fn data_type() -> &'static DataType {
|
358
|
+
memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
|
359
|
+
}
|
360
|
+
|
361
|
+
fn class_for(value: &Self) -> RClass {
|
362
|
+
match &value.pretok {
|
363
|
+
RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
|
364
|
+
let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
|
365
|
+
class.undef_alloc_func();
|
366
|
+
class
|
367
|
+
}),
|
368
|
+
RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
369
|
+
RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
|
370
|
+
PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
|
371
|
+
let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
|
372
|
+
class.undef_alloc_func();
|
373
|
+
class
|
374
|
+
}),
|
375
|
+
PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
|
376
|
+
let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
|
377
|
+
class.undef_alloc_func();
|
378
|
+
class
|
379
|
+
}),
|
380
|
+
PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
|
381
|
+
let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
|
382
|
+
class.undef_alloc_func();
|
383
|
+
class
|
384
|
+
}),
|
385
|
+
PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
|
386
|
+
let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
|
387
|
+
class.undef_alloc_func();
|
388
|
+
class
|
389
|
+
}),
|
390
|
+
PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
|
391
|
+
let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
|
392
|
+
class.undef_alloc_func();
|
393
|
+
class
|
394
|
+
}),
|
395
|
+
PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
|
396
|
+
let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
|
397
|
+
class.undef_alloc_func();
|
398
|
+
class
|
399
|
+
}),
|
400
|
+
PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
|
401
|
+
let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
|
402
|
+
class.undef_alloc_func();
|
403
|
+
class
|
404
|
+
}),
|
405
|
+
PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
|
406
|
+
let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
|
407
|
+
class.undef_alloc_func();
|
408
|
+
class
|
409
|
+
}),
|
410
|
+
PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
|
411
|
+
let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
|
412
|
+
class.undef_alloc_func();
|
413
|
+
class
|
414
|
+
}),
|
415
|
+
PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
|
416
|
+
let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
|
417
|
+
class.undef_alloc_func();
|
418
|
+
class
|
419
|
+
}),
|
420
|
+
_ => todo!(),
|
421
|
+
},
|
422
|
+
},
|
423
|
+
}
|
424
|
+
}
|
425
|
+
}
|
426
|
+
|
427
|
+
pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
|
428
|
+
let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
|
429
|
+
pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
|
430
|
+
|
431
|
+
let class = module.define_class("Sequence", pre_tokenizer)?;
|
432
|
+
class.define_singleton_method("new", function!(RbSequence::new, 1))?;
|
433
|
+
|
434
|
+
let class = module.define_class("BertPreTokenizer", pre_tokenizer)?;
|
435
|
+
class.define_singleton_method("new", function!(RbBertPreTokenizer::new, 0))?;
|
436
|
+
|
437
|
+
let class = module.define_class("ByteLevel", pre_tokenizer)?;
|
438
|
+
class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
|
439
|
+
class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
|
440
|
+
class.define_method("add_prefix_space", method!(RbPreTokenizer::byte_level_add_prefix_space, 0))?;
|
441
|
+
class.define_method("add_prefix_space=", method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1))?;
|
442
|
+
class.define_method("use_regex", method!(RbPreTokenizer::byte_level_use_regex, 0))?;
|
443
|
+
class.define_method("use_regex=", method!(RbPreTokenizer::byte_level_set_use_regex, 1))?;
|
444
|
+
|
445
|
+
let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
|
446
|
+
class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
|
447
|
+
class.define_method("delimiter", method!(RbPreTokenizer::char_delimiter_split_delimiter, 0))?;
|
448
|
+
class.define_method("delimiter=", method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1))?;
|
449
|
+
|
450
|
+
let class = module.define_class("Digits", pre_tokenizer)?;
|
451
|
+
class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
|
452
|
+
class.define_method("individual_digits", method!(RbPreTokenizer::digits_individual_digits, 0))?;
|
453
|
+
class.define_method("individual_digits=", method!(RbPreTokenizer::digits_set_individual_digits, 1))?;
|
454
|
+
|
455
|
+
let class = module.define_class("Metaspace", pre_tokenizer)?;
|
456
|
+
class.define_singleton_method("_new", function!(RbMetaspace::new, 2))?;
|
457
|
+
class.define_method("add_prefix_space", method!(RbPreTokenizer::metaspace_add_prefix_space, 0))?;
|
458
|
+
class.define_method("add_prefix_space=", method!(RbPreTokenizer::metaspace_set_add_prefix_space, 1))?;
|
459
|
+
class.define_method("replacement", method!(RbPreTokenizer::metaspace_replacement, 0))?;
|
460
|
+
class.define_method("replacement=", method!(RbPreTokenizer::metaspace_set_replacement, 1))?;
|
461
|
+
|
462
|
+
let class = module.define_class("Punctuation", pre_tokenizer)?;
|
463
|
+
class.define_singleton_method("_new", function!(RbPunctuation::new, 1))?;
|
464
|
+
|
465
|
+
let class = module.define_class("Split", pre_tokenizer)?;
|
466
|
+
class.define_singleton_method("_new", function!(RbSplit::new, 3))?;
|
467
|
+
|
468
|
+
let class = module.define_class("UnicodeScripts", pre_tokenizer)?;
|
469
|
+
class.define_singleton_method("new", function!(RbUnicodeScripts::new, 0))?;
|
470
|
+
|
471
|
+
let class = module.define_class("Whitespace", pre_tokenizer)?;
|
472
|
+
class.define_singleton_method("new", function!(RbWhitespace::new, 0))?;
|
473
|
+
|
474
|
+
let class = module.define_class("WhitespaceSplit", pre_tokenizer)?;
|
475
|
+
class.define_singleton_method("new", function!(RbWhitespaceSplit::new, 0))?;
|
476
|
+
|
477
|
+
Ok(())
|
478
|
+
}
|
@@ -0,0 +1,210 @@
|
|
1
|
+
use std::sync::Arc;
|
2
|
+
|
3
|
+
use magnus::typed_data::DataTypeBuilder;
|
4
|
+
use magnus::{
|
5
|
+
function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
6
|
+
TryConvert, TypedData, Value,
|
7
|
+
};
|
8
|
+
use serde::{Deserialize, Serialize};
|
9
|
+
use tk::processors::bert::BertProcessing;
|
10
|
+
use tk::processors::byte_level::ByteLevel;
|
11
|
+
use tk::processors::roberta::RobertaProcessing;
|
12
|
+
use tk::processors::template::{SpecialToken, Template};
|
13
|
+
use tk::processors::PostProcessorWrapper;
|
14
|
+
use tk::{Encoding, PostProcessor};
|
15
|
+
|
16
|
+
use super::RbResult;
|
17
|
+
|
18
|
+
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
19
|
+
pub struct RbPostProcessor {
|
20
|
+
#[serde(flatten)]
|
21
|
+
pub processor: Arc<PostProcessorWrapper>,
|
22
|
+
}
|
23
|
+
|
24
|
+
impl RbPostProcessor {
|
25
|
+
pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
|
26
|
+
RbPostProcessor { processor }
|
27
|
+
}
|
28
|
+
}
|
29
|
+
|
30
|
+
impl PostProcessor for RbPostProcessor {
|
31
|
+
fn added_tokens(&self, is_pair: bool) -> usize {
|
32
|
+
self.processor.added_tokens(is_pair)
|
33
|
+
}
|
34
|
+
|
35
|
+
fn process_encodings(
|
36
|
+
&self,
|
37
|
+
encodings: Vec<Encoding>,
|
38
|
+
add_special_tokens: bool,
|
39
|
+
) -> tk::Result<Vec<Encoding>> {
|
40
|
+
self.processor
|
41
|
+
.process_encodings(encodings, add_special_tokens)
|
42
|
+
}
|
43
|
+
}
|
44
|
+
|
45
|
+
#[derive(Clone, Debug)]
|
46
|
+
pub struct RbSpecialToken(SpecialToken);
|
47
|
+
|
48
|
+
impl From<RbSpecialToken> for SpecialToken {
|
49
|
+
fn from(v: RbSpecialToken) -> Self {
|
50
|
+
v.0
|
51
|
+
}
|
52
|
+
}
|
53
|
+
|
54
|
+
impl TryConvert for RbSpecialToken {
|
55
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
56
|
+
if let Ok(v) = ob.try_convert::<(String, u32)>() {
|
57
|
+
Ok(Self(v.into()))
|
58
|
+
} else if let Ok(v) = ob.try_convert::<(u32, String)>() {
|
59
|
+
Ok(Self(v.into()))
|
60
|
+
} else {
|
61
|
+
todo!()
|
62
|
+
}
|
63
|
+
}
|
64
|
+
}
|
65
|
+
|
66
|
+
#[derive(Clone, Debug)]
|
67
|
+
pub struct RbTemplate(Template);
|
68
|
+
|
69
|
+
impl From<RbTemplate> for Template {
|
70
|
+
fn from(v: RbTemplate) -> Self {
|
71
|
+
v.0
|
72
|
+
}
|
73
|
+
}
|
74
|
+
|
75
|
+
impl TryConvert for RbTemplate {
|
76
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
77
|
+
if let Ok(s) = ob.try_convert::<String>() {
|
78
|
+
Ok(Self(
|
79
|
+
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
80
|
+
))
|
81
|
+
} else if let Ok(s) = ob.try_convert::<Vec<String>>() {
|
82
|
+
Ok(Self(
|
83
|
+
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
84
|
+
))
|
85
|
+
} else {
|
86
|
+
todo!()
|
87
|
+
}
|
88
|
+
}
|
89
|
+
}
|
90
|
+
|
91
|
+
pub struct RbBertProcessing {}
|
92
|
+
|
93
|
+
impl RbBertProcessing {
|
94
|
+
pub fn new(sep: (String, u32), cls: (String, u32)) -> RbPostProcessor {
|
95
|
+
RbPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into()))
|
96
|
+
}
|
97
|
+
}
|
98
|
+
|
99
|
+
pub struct RbByteLevel {}
|
100
|
+
|
101
|
+
impl RbByteLevel {
|
102
|
+
pub fn new(trim_offsets: Option<bool>) -> RbPostProcessor {
|
103
|
+
let mut byte_level = ByteLevel::default();
|
104
|
+
|
105
|
+
if let Some(to) = trim_offsets {
|
106
|
+
byte_level = byte_level.trim_offsets(to);
|
107
|
+
}
|
108
|
+
RbPostProcessor::new(Arc::new(byte_level.into()))
|
109
|
+
}
|
110
|
+
|
111
|
+
}
|
112
|
+
|
113
|
+
pub struct RbRobertaProcessing {}
|
114
|
+
|
115
|
+
impl RbRobertaProcessing {
|
116
|
+
fn new(
|
117
|
+
sep: (String, u32),
|
118
|
+
cls: (String, u32),
|
119
|
+
trim_offsets: bool,
|
120
|
+
add_prefix_space: bool,
|
121
|
+
) -> RbPostProcessor {
|
122
|
+
let proc = RobertaProcessing::new(sep, cls)
|
123
|
+
.trim_offsets(trim_offsets)
|
124
|
+
.add_prefix_space(add_prefix_space);
|
125
|
+
RbPostProcessor::new(Arc::new(proc.into()))
|
126
|
+
}
|
127
|
+
}
|
128
|
+
|
129
|
+
pub struct RbTemplateProcessing {}
|
130
|
+
|
131
|
+
impl RbTemplateProcessing {
|
132
|
+
pub fn new(
|
133
|
+
single: Option<RbTemplate>,
|
134
|
+
pair: Option<RbTemplate>,
|
135
|
+
special_tokens: Option<Vec<(String, u32)>>,
|
136
|
+
) -> RbResult<RbPostProcessor> {
|
137
|
+
let mut builder = tk::processors::template::TemplateProcessing::builder();
|
138
|
+
|
139
|
+
if let Some(seq) = single {
|
140
|
+
builder.single(seq.into());
|
141
|
+
}
|
142
|
+
if let Some(seq) = pair {
|
143
|
+
builder.pair(seq.into());
|
144
|
+
}
|
145
|
+
if let Some(sp) = special_tokens {
|
146
|
+
builder.special_tokens(sp);
|
147
|
+
}
|
148
|
+
let processor = builder.build().unwrap(); //.map_err(RbError::from)?;
|
149
|
+
|
150
|
+
Ok(RbPostProcessor::new(Arc::new(processor.into())))
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
unsafe impl TypedData for RbPostProcessor {
|
155
|
+
fn class() -> RClass {
|
156
|
+
*memoize!(RClass: {
|
157
|
+
let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
|
158
|
+
class.undef_alloc_func();
|
159
|
+
class
|
160
|
+
})
|
161
|
+
}
|
162
|
+
|
163
|
+
fn data_type() -> &'static DataType {
|
164
|
+
memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
|
165
|
+
}
|
166
|
+
|
167
|
+
fn class_for(value: &Self) -> RClass {
|
168
|
+
match *value.processor {
|
169
|
+
PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
|
170
|
+
let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
|
171
|
+
class.undef_alloc_func();
|
172
|
+
class
|
173
|
+
}),
|
174
|
+
PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
|
175
|
+
let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
|
176
|
+
class.undef_alloc_func();
|
177
|
+
class
|
178
|
+
}),
|
179
|
+
PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
|
180
|
+
let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
|
181
|
+
class.undef_alloc_func();
|
182
|
+
class
|
183
|
+
}),
|
184
|
+
PostProcessorWrapper::Template(_) => *memoize!(RClass: {
|
185
|
+
let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
|
186
|
+
class.undef_alloc_func();
|
187
|
+
class
|
188
|
+
}),
|
189
|
+
_ => todo!(),
|
190
|
+
}
|
191
|
+
}
|
192
|
+
}
|
193
|
+
|
194
|
+
pub fn processors(module: &RModule) -> RbResult<()> {
|
195
|
+
let post_processor = module.define_class("PostProcessor", Default::default())?;
|
196
|
+
|
197
|
+
let class = module.define_class("BertProcessing", post_processor)?;
|
198
|
+
class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
|
199
|
+
|
200
|
+
let class = module.define_class("ByteLevel", post_processor)?;
|
201
|
+
class.define_singleton_method("_new", function!(RbByteLevel::new, 1))?;
|
202
|
+
|
203
|
+
let class = module.define_class("RobertaProcessing", post_processor)?;
|
204
|
+
class.define_singleton_method("_new", function!(RbRobertaProcessing::new, 4))?;
|
205
|
+
|
206
|
+
let class = module.define_class("TemplateProcessing", post_processor)?;
|
207
|
+
class.define_singleton_method("_new", function!(RbTemplateProcessing::new, 3))?;
|
208
|
+
|
209
|
+
Ok(())
|
210
|
+
}
|