tokenizers 0.5.2 → 0.5.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/Cargo.lock +154 -83
- data/ext/tokenizers/Cargo.toml +2 -2
- data/ext/tokenizers/src/decoders.rs +32 -14
- data/ext/tokenizers/src/error.rs +6 -1
- data/ext/tokenizers/src/lib.rs +47 -12
- data/ext/tokenizers/src/models.rs +75 -23
- data/ext/tokenizers/src/normalizers.rs +84 -24
- data/ext/tokenizers/src/pre_tokenizers.rs +121 -42
- data/ext/tokenizers/src/processors.rs +22 -10
- data/ext/tokenizers/src/tokenizer.rs +141 -39
- data/ext/tokenizers/src/trainers.rs +215 -56
- data/ext/tokenizers/src/utils/regex.rs +6 -4
- data/lib/tokenizers/added_token.rb +7 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +1 -0
- metadata +4 -7
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
5
|
-
RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
4
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
5
|
+
DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
6
6
|
};
|
7
7
|
|
8
8
|
use serde::ser::SerializeStruct;
|
@@ -22,7 +22,7 @@ use tk::tokenizer::Offsets;
|
|
22
22
|
use tk::{PreTokenizedString, PreTokenizer};
|
23
23
|
|
24
24
|
use super::utils::*;
|
25
|
-
use super::{
|
25
|
+
use super::{RbError, RbResult, PRE_TOKENIZERS};
|
26
26
|
|
27
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
28
28
|
pub struct RbPreTokenizer {
|
@@ -34,7 +34,9 @@ impl RbPreTokenizer {
|
|
34
34
|
fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
|
35
35
|
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
|
36
36
|
|
37
|
-
self.pretok
|
37
|
+
self.pretok
|
38
|
+
.pre_tokenize(&mut pretokenized)
|
39
|
+
.map_err(RbError::from)?;
|
38
40
|
|
39
41
|
Ok(pretokenized
|
40
42
|
.get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
|
@@ -195,11 +197,7 @@ impl RbDigits {
|
|
195
197
|
pub struct RbMetaspace {}
|
196
198
|
|
197
199
|
impl RbMetaspace {
|
198
|
-
fn new(
|
199
|
-
replacement: char,
|
200
|
-
prepend_scheme: String,
|
201
|
-
split: bool,
|
202
|
-
) -> RbResult<RbPreTokenizer> {
|
200
|
+
fn new(replacement: char, prepend_scheme: String, split: bool) -> RbResult<RbPreTokenizer> {
|
203
201
|
let prepend_scheme = from_string(prepend_scheme)?;
|
204
202
|
Ok(Metaspace::new(replacement, prepend_scheme, split).into())
|
205
203
|
}
|
@@ -216,8 +214,14 @@ impl RbPunctuation {
|
|
216
214
|
pub struct RbSplit {}
|
217
215
|
|
218
216
|
impl RbSplit {
|
219
|
-
pub fn new(
|
220
|
-
|
217
|
+
pub fn new(
|
218
|
+
pattern: RbPattern,
|
219
|
+
behavior: RbSplitDelimiterBehavior,
|
220
|
+
invert: bool,
|
221
|
+
) -> RbResult<RbPreTokenizer> {
|
222
|
+
Split::new(pattern, behavior.into(), invert)
|
223
|
+
.map(|v| v.into())
|
224
|
+
.map_err(RbError::from)
|
221
225
|
}
|
222
226
|
}
|
223
227
|
|
@@ -258,16 +262,18 @@ pub struct RbSequence {}
|
|
258
262
|
impl RbSequence {
|
259
263
|
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
260
264
|
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
261
|
-
for n in pre_tokenizers
|
265
|
+
for n in pre_tokenizers {
|
262
266
|
let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n)?;
|
263
267
|
match &pretokenizer.pretok {
|
264
268
|
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
265
|
-
sequence.extend(inner.iter().cloned())
|
269
|
+
sequence.extend(inner.iter().cloned());
|
266
270
|
}
|
267
271
|
RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
268
272
|
}
|
269
273
|
}
|
270
|
-
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
|
274
|
+
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
|
275
|
+
sequence,
|
276
|
+
)))
|
271
277
|
}
|
272
278
|
}
|
273
279
|
|
@@ -277,10 +283,13 @@ pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
|
|
277
283
|
"never" => PrependScheme::Never,
|
278
284
|
"always" => PrependScheme::Always,
|
279
285
|
_ => {
|
280
|
-
return Err(Error::new(
|
281
|
-
|
282
|
-
|
283
|
-
|
286
|
+
return Err(Error::new(
|
287
|
+
exception::arg_error(),
|
288
|
+
format!(
|
289
|
+
"{} is an unknown variant, should be one of ['first', 'never', 'always']",
|
290
|
+
string
|
291
|
+
),
|
292
|
+
));
|
284
293
|
}
|
285
294
|
};
|
286
295
|
Ok(scheme)
|
@@ -381,7 +390,10 @@ impl PreTokenizer for RbPreTokenizerWrapper {
|
|
381
390
|
unsafe impl TypedData for RbPreTokenizer {
|
382
391
|
fn class(ruby: &Ruby) -> RClass {
|
383
392
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
384
|
-
let class: RClass = ruby
|
393
|
+
let class: RClass = ruby
|
394
|
+
.get_inner(&PRE_TOKENIZERS)
|
395
|
+
.const_get("PreTokenizer")
|
396
|
+
.unwrap();
|
385
397
|
class.undef_default_alloc_func();
|
386
398
|
class
|
387
399
|
});
|
@@ -389,28 +401,41 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
389
401
|
}
|
390
402
|
|
391
403
|
fn data_type() -> &'static DataType {
|
392
|
-
static DATA_TYPE: DataType =
|
404
|
+
static DATA_TYPE: DataType =
|
405
|
+
data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
|
393
406
|
&DATA_TYPE
|
394
407
|
}
|
395
408
|
|
396
409
|
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
397
410
|
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
398
|
-
let class: RClass = ruby
|
411
|
+
let class: RClass = ruby
|
412
|
+
.get_inner(&PRE_TOKENIZERS)
|
413
|
+
.const_get("Sequence")
|
414
|
+
.unwrap();
|
399
415
|
class.undef_default_alloc_func();
|
400
416
|
class
|
401
417
|
});
|
402
418
|
static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
403
|
-
let class: RClass = ruby
|
419
|
+
let class: RClass = ruby
|
420
|
+
.get_inner(&PRE_TOKENIZERS)
|
421
|
+
.const_get("BertPreTokenizer")
|
422
|
+
.unwrap();
|
404
423
|
class.undef_default_alloc_func();
|
405
424
|
class
|
406
425
|
});
|
407
426
|
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
408
|
-
let class: RClass = ruby
|
427
|
+
let class: RClass = ruby
|
428
|
+
.get_inner(&PRE_TOKENIZERS)
|
429
|
+
.const_get("ByteLevel")
|
430
|
+
.unwrap();
|
409
431
|
class.undef_default_alloc_func();
|
410
432
|
class
|
411
433
|
});
|
412
434
|
static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
413
|
-
let class: RClass = ruby
|
435
|
+
let class: RClass = ruby
|
436
|
+
.get_inner(&PRE_TOKENIZERS)
|
437
|
+
.const_get("CharDelimiterSplit")
|
438
|
+
.unwrap();
|
414
439
|
class.undef_default_alloc_func();
|
415
440
|
class
|
416
441
|
});
|
@@ -420,12 +445,18 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
420
445
|
class
|
421
446
|
});
|
422
447
|
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
423
|
-
let class: RClass = ruby
|
448
|
+
let class: RClass = ruby
|
449
|
+
.get_inner(&PRE_TOKENIZERS)
|
450
|
+
.const_get("Metaspace")
|
451
|
+
.unwrap();
|
424
452
|
class.undef_default_alloc_func();
|
425
453
|
class
|
426
454
|
});
|
427
455
|
static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
|
428
|
-
let class: RClass = ruby
|
456
|
+
let class: RClass = ruby
|
457
|
+
.get_inner(&PRE_TOKENIZERS)
|
458
|
+
.const_get("Punctuation")
|
459
|
+
.unwrap();
|
429
460
|
class.undef_default_alloc_func();
|
430
461
|
class
|
431
462
|
});
|
@@ -435,17 +466,26 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
435
466
|
class
|
436
467
|
});
|
437
468
|
static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
|
438
|
-
let class: RClass = ruby
|
469
|
+
let class: RClass = ruby
|
470
|
+
.get_inner(&PRE_TOKENIZERS)
|
471
|
+
.const_get("UnicodeScripts")
|
472
|
+
.unwrap();
|
439
473
|
class.undef_default_alloc_func();
|
440
474
|
class
|
441
475
|
});
|
442
476
|
static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
443
|
-
let class: RClass = ruby
|
477
|
+
let class: RClass = ruby
|
478
|
+
.get_inner(&PRE_TOKENIZERS)
|
479
|
+
.const_get("Whitespace")
|
480
|
+
.unwrap();
|
444
481
|
class.undef_default_alloc_func();
|
445
482
|
class
|
446
483
|
});
|
447
484
|
static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
448
|
-
let class: RClass = ruby
|
485
|
+
let class: RClass = ruby
|
486
|
+
.get_inner(&PRE_TOKENIZERS)
|
487
|
+
.const_get("WhitespaceSplit")
|
488
|
+
.unwrap();
|
449
489
|
class.undef_default_alloc_func();
|
450
490
|
class
|
451
491
|
});
|
@@ -472,7 +512,10 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
472
512
|
|
473
513
|
pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
474
514
|
let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
|
475
|
-
pre_tokenizer.define_method(
|
515
|
+
pre_tokenizer.define_method(
|
516
|
+
"pre_tokenize_str",
|
517
|
+
method!(RbPreTokenizer::pre_tokenize_str, 1),
|
518
|
+
)?;
|
476
519
|
|
477
520
|
let class = module.define_class("Sequence", pre_tokenizer)?;
|
478
521
|
class.define_singleton_method("new", function!(RbSequence::new, 1))?;
|
@@ -483,27 +526,63 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
483
526
|
let class = module.define_class("ByteLevel", pre_tokenizer)?;
|
484
527
|
class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
|
485
528
|
class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
|
486
|
-
class.define_method(
|
487
|
-
|
488
|
-
|
489
|
-
|
529
|
+
class.define_method(
|
530
|
+
"add_prefix_space",
|
531
|
+
method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
|
532
|
+
)?;
|
533
|
+
class.define_method(
|
534
|
+
"add_prefix_space=",
|
535
|
+
method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1),
|
536
|
+
)?;
|
537
|
+
class.define_method(
|
538
|
+
"use_regex",
|
539
|
+
method!(RbPreTokenizer::byte_level_use_regex, 0),
|
540
|
+
)?;
|
541
|
+
class.define_method(
|
542
|
+
"use_regex=",
|
543
|
+
method!(RbPreTokenizer::byte_level_set_use_regex, 1),
|
544
|
+
)?;
|
490
545
|
|
491
546
|
let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
|
492
547
|
class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
|
493
|
-
class.define_method(
|
494
|
-
|
548
|
+
class.define_method(
|
549
|
+
"delimiter",
|
550
|
+
method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
|
551
|
+
)?;
|
552
|
+
class.define_method(
|
553
|
+
"delimiter=",
|
554
|
+
method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1),
|
555
|
+
)?;
|
495
556
|
|
496
557
|
let class = module.define_class("Digits", pre_tokenizer)?;
|
497
558
|
class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
|
498
|
-
class.define_method(
|
499
|
-
|
559
|
+
class.define_method(
|
560
|
+
"individual_digits",
|
561
|
+
method!(RbPreTokenizer::digits_individual_digits, 0),
|
562
|
+
)?;
|
563
|
+
class.define_method(
|
564
|
+
"individual_digits=",
|
565
|
+
method!(RbPreTokenizer::digits_set_individual_digits, 1),
|
566
|
+
)?;
|
500
567
|
|
501
568
|
let class = module.define_class("Metaspace", pre_tokenizer)?;
|
502
569
|
class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
|
503
|
-
class.define_method(
|
504
|
-
|
505
|
-
|
506
|
-
|
570
|
+
class.define_method(
|
571
|
+
"prepend_scheme",
|
572
|
+
method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
|
573
|
+
)?;
|
574
|
+
class.define_method(
|
575
|
+
"prepend_scheme=",
|
576
|
+
method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1),
|
577
|
+
)?;
|
578
|
+
class.define_method(
|
579
|
+
"replacement",
|
580
|
+
method!(RbPreTokenizer::metaspace_replacement, 0),
|
581
|
+
)?;
|
582
|
+
class.define_method(
|
583
|
+
"replacement=",
|
584
|
+
method!(RbPreTokenizer::metaspace_set_replacement, 1),
|
585
|
+
)?;
|
507
586
|
class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
|
508
587
|
class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
|
509
588
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::Arc;
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
-
Ruby, TryConvert, TypedData, Value,
|
4
|
+
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
+
RClass, RModule, Ruby, TryConvert, TypedData, Value,
|
6
6
|
};
|
7
7
|
use serde::{Deserialize, Serialize};
|
8
8
|
use tk::processors::bert::BertProcessing;
|
@@ -12,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
|
|
12
12
|
use tk::processors::PostProcessorWrapper;
|
13
13
|
use tk::{Encoding, PostProcessor};
|
14
14
|
|
15
|
-
use super::{
|
15
|
+
use super::{RbResult, PROCESSORS};
|
16
16
|
|
17
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
18
18
|
pub struct RbPostProcessor {
|
@@ -106,7 +106,6 @@ impl RbByteLevel {
|
|
106
106
|
}
|
107
107
|
RbPostProcessor::new(Arc::new(byte_level.into()))
|
108
108
|
}
|
109
|
-
|
110
109
|
}
|
111
110
|
|
112
111
|
pub struct RbRobertaProcessing {}
|
@@ -117,7 +116,7 @@ impl RbRobertaProcessing {
|
|
117
116
|
cls: (String, u32),
|
118
117
|
trim_offsets: bool,
|
119
118
|
add_prefix_space: bool,
|
120
|
-
) ->
|
119
|
+
) -> RbPostProcessor {
|
121
120
|
let proc = RobertaProcessing::new(sep, cls)
|
122
121
|
.trim_offsets(trim_offsets)
|
123
122
|
.add_prefix_space(add_prefix_space);
|
@@ -153,7 +152,10 @@ impl RbTemplateProcessing {
|
|
153
152
|
unsafe impl TypedData for RbPostProcessor {
|
154
153
|
fn class(ruby: &Ruby) -> RClass {
|
155
154
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
156
|
-
let class: RClass = ruby
|
155
|
+
let class: RClass = ruby
|
156
|
+
.get_inner(&PROCESSORS)
|
157
|
+
.const_get("PostProcessor")
|
158
|
+
.unwrap();
|
157
159
|
class.undef_default_alloc_func();
|
158
160
|
class
|
159
161
|
});
|
@@ -161,13 +163,17 @@ unsafe impl TypedData for RbPostProcessor {
|
|
161
163
|
}
|
162
164
|
|
163
165
|
fn data_type() -> &'static DataType {
|
164
|
-
static DATA_TYPE: DataType =
|
166
|
+
static DATA_TYPE: DataType =
|
167
|
+
data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
|
165
168
|
&DATA_TYPE
|
166
169
|
}
|
167
170
|
|
168
171
|
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
169
172
|
static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
170
|
-
let class: RClass = ruby
|
173
|
+
let class: RClass = ruby
|
174
|
+
.get_inner(&PROCESSORS)
|
175
|
+
.const_get("BertProcessing")
|
176
|
+
.unwrap();
|
171
177
|
class.undef_default_alloc_func();
|
172
178
|
class
|
173
179
|
});
|
@@ -177,12 +183,18 @@ unsafe impl TypedData for RbPostProcessor {
|
|
177
183
|
class
|
178
184
|
});
|
179
185
|
static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
180
|
-
let class: RClass = ruby
|
186
|
+
let class: RClass = ruby
|
187
|
+
.get_inner(&PROCESSORS)
|
188
|
+
.const_get("RobertaProcessing")
|
189
|
+
.unwrap();
|
181
190
|
class.undef_default_alloc_func();
|
182
191
|
class
|
183
192
|
});
|
184
193
|
static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
185
|
-
let class: RClass = ruby
|
194
|
+
let class: RClass = ruby
|
195
|
+
.get_inner(&PROCESSORS)
|
196
|
+
.const_get("TemplateProcessing")
|
197
|
+
.unwrap();
|
186
198
|
class.undef_default_alloc_func();
|
187
199
|
class
|
188
200
|
});
|
@@ -6,8 +6,8 @@ use std::str::FromStr;
|
|
6
6
|
use magnus::prelude::*;
|
7
7
|
use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
|
8
8
|
use tk::tokenizer::{
|
9
|
-
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
10
|
-
|
9
|
+
Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
|
10
|
+
TruncationParams, TruncationStrategy,
|
11
11
|
};
|
12
12
|
use tk::AddedToken;
|
13
13
|
|
@@ -22,9 +22,10 @@ use super::processors::RbPostProcessor;
|
|
22
22
|
use super::trainers::RbTrainer;
|
23
23
|
use super::{RbError, RbResult};
|
24
24
|
|
25
|
+
#[magnus::wrap(class = "Tokenizers::AddedToken")]
|
25
26
|
pub struct RbAddedToken {
|
26
27
|
pub content: String,
|
27
|
-
pub
|
28
|
+
pub special: bool,
|
28
29
|
pub single_word: Option<bool>,
|
29
30
|
pub lstrip: Option<bool>,
|
30
31
|
pub rstrip: Option<bool>,
|
@@ -32,10 +33,10 @@ pub struct RbAddedToken {
|
|
32
33
|
}
|
33
34
|
|
34
35
|
impl RbAddedToken {
|
35
|
-
pub fn from<S: Into<String>>(content: S,
|
36
|
+
pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
|
36
37
|
Self {
|
37
38
|
content: content.into(),
|
38
|
-
|
39
|
+
special: special.unwrap_or(false),
|
39
40
|
single_word: None,
|
40
41
|
lstrip: None,
|
41
42
|
rstrip: None,
|
@@ -44,7 +45,7 @@ impl RbAddedToken {
|
|
44
45
|
}
|
45
46
|
|
46
47
|
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
47
|
-
let mut token = tk::AddedToken::from(&self.content, self.
|
48
|
+
let mut token = tk::AddedToken::from(&self.content, self.special);
|
48
49
|
|
49
50
|
if let Some(sw) = self.single_word {
|
50
51
|
token = token.single_word(sw);
|
@@ -71,11 +72,73 @@ impl From<tk::AddedToken> for RbAddedToken {
|
|
71
72
|
lstrip: Some(token.lstrip),
|
72
73
|
rstrip: Some(token.rstrip),
|
73
74
|
normalized: Some(token.normalized),
|
74
|
-
|
75
|
+
special: !token.normalized,
|
75
76
|
}
|
76
77
|
}
|
77
78
|
}
|
78
79
|
|
80
|
+
impl RbAddedToken {
|
81
|
+
pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
82
|
+
let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
|
83
|
+
|
84
|
+
let value: Value = kwargs.delete(Symbol::new("single_word"))?;
|
85
|
+
if !value.is_nil() {
|
86
|
+
token.single_word = TryConvert::try_convert(value)?;
|
87
|
+
}
|
88
|
+
|
89
|
+
let value: Value = kwargs.delete(Symbol::new("lstrip"))?;
|
90
|
+
if !value.is_nil() {
|
91
|
+
token.lstrip = TryConvert::try_convert(value)?;
|
92
|
+
}
|
93
|
+
|
94
|
+
let value: Value = kwargs.delete(Symbol::new("rstrip"))?;
|
95
|
+
if !value.is_nil() {
|
96
|
+
token.rstrip = TryConvert::try_convert(value)?;
|
97
|
+
}
|
98
|
+
|
99
|
+
let value: Value = kwargs.delete(Symbol::new("normalized"))?;
|
100
|
+
if !value.is_nil() {
|
101
|
+
token.normalized = TryConvert::try_convert(value)?;
|
102
|
+
}
|
103
|
+
|
104
|
+
let value: Value = kwargs.delete(Symbol::new("special"))?;
|
105
|
+
if !value.is_nil() {
|
106
|
+
token.special = TryConvert::try_convert(value)?;
|
107
|
+
}
|
108
|
+
|
109
|
+
if !kwargs.is_empty() {
|
110
|
+
// TODO improve message
|
111
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
112
|
+
}
|
113
|
+
|
114
|
+
Ok(token)
|
115
|
+
}
|
116
|
+
|
117
|
+
pub fn get_content(&self) -> String {
|
118
|
+
self.content.to_string()
|
119
|
+
}
|
120
|
+
|
121
|
+
pub fn get_rstrip(&self) -> bool {
|
122
|
+
self.get_token().rstrip
|
123
|
+
}
|
124
|
+
|
125
|
+
pub fn get_lstrip(&self) -> bool {
|
126
|
+
self.get_token().lstrip
|
127
|
+
}
|
128
|
+
|
129
|
+
pub fn get_single_word(&self) -> bool {
|
130
|
+
self.get_token().single_word
|
131
|
+
}
|
132
|
+
|
133
|
+
pub fn get_normalized(&self) -> bool {
|
134
|
+
self.get_token().normalized
|
135
|
+
}
|
136
|
+
|
137
|
+
pub fn get_special(&self) -> bool {
|
138
|
+
self.get_token().special
|
139
|
+
}
|
140
|
+
}
|
141
|
+
|
79
142
|
struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
80
143
|
|
81
144
|
impl<'s> TryConvert for TextInputSequence<'s> {
|
@@ -221,7 +284,10 @@ impl RbTokenizer {
|
|
221
284
|
}
|
222
285
|
|
223
286
|
pub fn to_str(&self, pretty: bool) -> RbResult<String> {
|
224
|
-
self.tokenizer
|
287
|
+
self.tokenizer
|
288
|
+
.borrow()
|
289
|
+
.to_string(pretty)
|
290
|
+
.map_err(RbError::from)
|
225
291
|
}
|
226
292
|
|
227
293
|
pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
|
@@ -320,7 +386,11 @@ impl RbTokenizer {
|
|
320
386
|
.map_err(RbError::from)
|
321
387
|
}
|
322
388
|
|
323
|
-
pub fn decode_batch(
|
389
|
+
pub fn decode_batch(
|
390
|
+
&self,
|
391
|
+
sequences: Vec<Vec<u32>>,
|
392
|
+
skip_special_tokens: bool,
|
393
|
+
) -> RbResult<Vec<String>> {
|
324
394
|
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
325
395
|
self.tokenizer
|
326
396
|
.borrow()
|
@@ -392,7 +462,12 @@ impl RbTokenizer {
|
|
392
462
|
params.direction = match dir_str.as_str() {
|
393
463
|
"left" => PaddingDirection::Left,
|
394
464
|
"right" => PaddingDirection::Right,
|
395
|
-
_ =>
|
465
|
+
_ => {
|
466
|
+
return Err(Error::new(
|
467
|
+
exception::arg_error(),
|
468
|
+
"The direction value must be 'left' or 'right'",
|
469
|
+
))
|
470
|
+
}
|
396
471
|
}
|
397
472
|
}
|
398
473
|
|
@@ -438,24 +513,27 @@ impl RbTokenizer {
|
|
438
513
|
}
|
439
514
|
|
440
515
|
pub fn padding(&self) -> RbResult<Option<RHash>> {
|
441
|
-
self.tokenizer
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
516
|
+
self.tokenizer
|
517
|
+
.borrow()
|
518
|
+
.get_padding()
|
519
|
+
.map_or(Ok(None), |params| {
|
520
|
+
let ret_hash = RHash::new();
|
521
|
+
|
522
|
+
ret_hash.aset(
|
523
|
+
"length",
|
524
|
+
match params.strategy {
|
525
|
+
tk::PaddingStrategy::BatchLongest => None,
|
526
|
+
tk::PaddingStrategy::Fixed(size) => Some(size),
|
527
|
+
},
|
528
|
+
)?;
|
529
|
+
ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
|
530
|
+
ret_hash.aset("pad_id", params.pad_id)?;
|
531
|
+
ret_hash.aset("pad_token", &*params.pad_token)?;
|
532
|
+
ret_hash.aset("pad_type_id", params.pad_type_id)?;
|
533
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
534
|
+
|
535
|
+
Ok(Some(ret_hash))
|
536
|
+
})
|
459
537
|
}
|
460
538
|
|
461
539
|
pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
|
@@ -476,7 +554,10 @@ impl RbTokenizer {
|
|
476
554
|
"longest_first" => TruncationStrategy::LongestFirst,
|
477
555
|
"only_first" => TruncationStrategy::OnlyFirst,
|
478
556
|
"only_second" => TruncationStrategy::OnlySecond,
|
479
|
-
_ => return Err(Error::new(
|
557
|
+
_ => return Err(Error::new(
|
558
|
+
exception::arg_error(),
|
559
|
+
"The strategy value must be 'longest_first', 'only_first', or 'only_second'",
|
560
|
+
)),
|
480
561
|
}
|
481
562
|
}
|
482
563
|
|
@@ -486,7 +567,12 @@ impl RbTokenizer {
|
|
486
567
|
params.direction = match dir_str.as_str() {
|
487
568
|
"left" => TruncationDirection::Left,
|
488
569
|
"right" => TruncationDirection::Right,
|
489
|
-
_ =>
|
570
|
+
_ => {
|
571
|
+
return Err(Error::new(
|
572
|
+
exception::arg_error(),
|
573
|
+
"The direction value must be 'left' or 'right'",
|
574
|
+
))
|
575
|
+
}
|
490
576
|
}
|
491
577
|
}
|
492
578
|
|
@@ -496,7 +582,10 @@ impl RbTokenizer {
|
|
496
582
|
}
|
497
583
|
|
498
584
|
if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
499
|
-
return Err(Error::new(
|
585
|
+
return Err(Error::new(
|
586
|
+
exception::arg_error(),
|
587
|
+
error_message.to_string(),
|
588
|
+
));
|
500
589
|
}
|
501
590
|
|
502
591
|
Ok(())
|
@@ -510,16 +599,19 @@ impl RbTokenizer {
|
|
510
599
|
}
|
511
600
|
|
512
601
|
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|
513
|
-
self.tokenizer
|
514
|
-
|
602
|
+
self.tokenizer
|
603
|
+
.borrow()
|
604
|
+
.get_truncation()
|
605
|
+
.map_or(Ok(None), |params| {
|
606
|
+
let ret_hash = RHash::new();
|
515
607
|
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
608
|
+
ret_hash.aset("max_length", params.max_length)?;
|
609
|
+
ret_hash.aset("stride", params.stride)?;
|
610
|
+
ret_hash.aset("strategy", params.strategy.as_ref())?;
|
611
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
520
612
|
|
521
|
-
|
522
|
-
|
613
|
+
Ok(Some(ret_hash))
|
614
|
+
})
|
523
615
|
}
|
524
616
|
|
525
617
|
pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
@@ -536,4 +628,14 @@ impl RbTokenizer {
|
|
536
628
|
pub fn vocab_size(&self, with_added_tokens: bool) -> usize {
|
537
629
|
self.tokenizer.borrow().get_vocab_size(with_added_tokens)
|
538
630
|
}
|
631
|
+
|
632
|
+
pub fn get_added_tokens_decoder(&self) -> RbResult<RHash> {
|
633
|
+
let sorted_map = RHash::new();
|
634
|
+
|
635
|
+
for (key, value) in self.tokenizer.borrow().get_added_tokens_decoder() {
|
636
|
+
sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
|
637
|
+
}
|
638
|
+
|
639
|
+
Ok(sorted_map)
|
640
|
+
}
|
539
641
|
}
|