tokenizers 0.5.2 → 0.5.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/Cargo.lock +154 -83
- data/ext/tokenizers/Cargo.toml +2 -2
- data/ext/tokenizers/src/decoders.rs +32 -14
- data/ext/tokenizers/src/error.rs +6 -1
- data/ext/tokenizers/src/lib.rs +47 -12
- data/ext/tokenizers/src/models.rs +75 -23
- data/ext/tokenizers/src/normalizers.rs +84 -24
- data/ext/tokenizers/src/pre_tokenizers.rs +121 -42
- data/ext/tokenizers/src/processors.rs +22 -10
- data/ext/tokenizers/src/tokenizer.rs +141 -39
- data/ext/tokenizers/src/trainers.rs +215 -56
- data/ext/tokenizers/src/utils/regex.rs +6 -4
- data/lib/tokenizers/added_token.rb +7 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +1 -0
- metadata +4 -7
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
5
|
-
RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
4
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
5
|
+
DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
6
6
|
};
|
7
7
|
|
8
8
|
use serde::ser::SerializeStruct;
|
@@ -22,7 +22,7 @@ use tk::tokenizer::Offsets;
|
|
22
22
|
use tk::{PreTokenizedString, PreTokenizer};
|
23
23
|
|
24
24
|
use super::utils::*;
|
25
|
-
use super::{
|
25
|
+
use super::{RbError, RbResult, PRE_TOKENIZERS};
|
26
26
|
|
27
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
28
28
|
pub struct RbPreTokenizer {
|
@@ -34,7 +34,9 @@ impl RbPreTokenizer {
|
|
34
34
|
fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
|
35
35
|
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
|
36
36
|
|
37
|
-
self.pretok
|
37
|
+
self.pretok
|
38
|
+
.pre_tokenize(&mut pretokenized)
|
39
|
+
.map_err(RbError::from)?;
|
38
40
|
|
39
41
|
Ok(pretokenized
|
40
42
|
.get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
|
@@ -195,11 +197,7 @@ impl RbDigits {
|
|
195
197
|
pub struct RbMetaspace {}
|
196
198
|
|
197
199
|
impl RbMetaspace {
|
198
|
-
fn new(
|
199
|
-
replacement: char,
|
200
|
-
prepend_scheme: String,
|
201
|
-
split: bool,
|
202
|
-
) -> RbResult<RbPreTokenizer> {
|
200
|
+
fn new(replacement: char, prepend_scheme: String, split: bool) -> RbResult<RbPreTokenizer> {
|
203
201
|
let prepend_scheme = from_string(prepend_scheme)?;
|
204
202
|
Ok(Metaspace::new(replacement, prepend_scheme, split).into())
|
205
203
|
}
|
@@ -216,8 +214,14 @@ impl RbPunctuation {
|
|
216
214
|
pub struct RbSplit {}
|
217
215
|
|
218
216
|
impl RbSplit {
|
219
|
-
pub fn new(
|
220
|
-
|
217
|
+
pub fn new(
|
218
|
+
pattern: RbPattern,
|
219
|
+
behavior: RbSplitDelimiterBehavior,
|
220
|
+
invert: bool,
|
221
|
+
) -> RbResult<RbPreTokenizer> {
|
222
|
+
Split::new(pattern, behavior.into(), invert)
|
223
|
+
.map(|v| v.into())
|
224
|
+
.map_err(RbError::from)
|
221
225
|
}
|
222
226
|
}
|
223
227
|
|
@@ -258,16 +262,18 @@ pub struct RbSequence {}
|
|
258
262
|
impl RbSequence {
|
259
263
|
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
260
264
|
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
261
|
-
for n in pre_tokenizers
|
265
|
+
for n in pre_tokenizers {
|
262
266
|
let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n)?;
|
263
267
|
match &pretokenizer.pretok {
|
264
268
|
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
265
|
-
sequence.extend(inner.iter().cloned())
|
269
|
+
sequence.extend(inner.iter().cloned());
|
266
270
|
}
|
267
271
|
RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
268
272
|
}
|
269
273
|
}
|
270
|
-
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
|
274
|
+
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
|
275
|
+
sequence,
|
276
|
+
)))
|
271
277
|
}
|
272
278
|
}
|
273
279
|
|
@@ -277,10 +283,13 @@ pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
|
|
277
283
|
"never" => PrependScheme::Never,
|
278
284
|
"always" => PrependScheme::Always,
|
279
285
|
_ => {
|
280
|
-
return Err(Error::new(
|
281
|
-
|
282
|
-
|
283
|
-
|
286
|
+
return Err(Error::new(
|
287
|
+
exception::arg_error(),
|
288
|
+
format!(
|
289
|
+
"{} is an unknown variant, should be one of ['first', 'never', 'always']",
|
290
|
+
string
|
291
|
+
),
|
292
|
+
));
|
284
293
|
}
|
285
294
|
};
|
286
295
|
Ok(scheme)
|
@@ -381,7 +390,10 @@ impl PreTokenizer for RbPreTokenizerWrapper {
|
|
381
390
|
unsafe impl TypedData for RbPreTokenizer {
|
382
391
|
fn class(ruby: &Ruby) -> RClass {
|
383
392
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
384
|
-
let class: RClass = ruby
|
393
|
+
let class: RClass = ruby
|
394
|
+
.get_inner(&PRE_TOKENIZERS)
|
395
|
+
.const_get("PreTokenizer")
|
396
|
+
.unwrap();
|
385
397
|
class.undef_default_alloc_func();
|
386
398
|
class
|
387
399
|
});
|
@@ -389,28 +401,41 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
389
401
|
}
|
390
402
|
|
391
403
|
fn data_type() -> &'static DataType {
|
392
|
-
static DATA_TYPE: DataType =
|
404
|
+
static DATA_TYPE: DataType =
|
405
|
+
data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
|
393
406
|
&DATA_TYPE
|
394
407
|
}
|
395
408
|
|
396
409
|
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
397
410
|
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
398
|
-
let class: RClass = ruby
|
411
|
+
let class: RClass = ruby
|
412
|
+
.get_inner(&PRE_TOKENIZERS)
|
413
|
+
.const_get("Sequence")
|
414
|
+
.unwrap();
|
399
415
|
class.undef_default_alloc_func();
|
400
416
|
class
|
401
417
|
});
|
402
418
|
static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
403
|
-
let class: RClass = ruby
|
419
|
+
let class: RClass = ruby
|
420
|
+
.get_inner(&PRE_TOKENIZERS)
|
421
|
+
.const_get("BertPreTokenizer")
|
422
|
+
.unwrap();
|
404
423
|
class.undef_default_alloc_func();
|
405
424
|
class
|
406
425
|
});
|
407
426
|
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
408
|
-
let class: RClass = ruby
|
427
|
+
let class: RClass = ruby
|
428
|
+
.get_inner(&PRE_TOKENIZERS)
|
429
|
+
.const_get("ByteLevel")
|
430
|
+
.unwrap();
|
409
431
|
class.undef_default_alloc_func();
|
410
432
|
class
|
411
433
|
});
|
412
434
|
static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
413
|
-
let class: RClass = ruby
|
435
|
+
let class: RClass = ruby
|
436
|
+
.get_inner(&PRE_TOKENIZERS)
|
437
|
+
.const_get("CharDelimiterSplit")
|
438
|
+
.unwrap();
|
414
439
|
class.undef_default_alloc_func();
|
415
440
|
class
|
416
441
|
});
|
@@ -420,12 +445,18 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
420
445
|
class
|
421
446
|
});
|
422
447
|
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
423
|
-
let class: RClass = ruby
|
448
|
+
let class: RClass = ruby
|
449
|
+
.get_inner(&PRE_TOKENIZERS)
|
450
|
+
.const_get("Metaspace")
|
451
|
+
.unwrap();
|
424
452
|
class.undef_default_alloc_func();
|
425
453
|
class
|
426
454
|
});
|
427
455
|
static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
|
428
|
-
let class: RClass = ruby
|
456
|
+
let class: RClass = ruby
|
457
|
+
.get_inner(&PRE_TOKENIZERS)
|
458
|
+
.const_get("Punctuation")
|
459
|
+
.unwrap();
|
429
460
|
class.undef_default_alloc_func();
|
430
461
|
class
|
431
462
|
});
|
@@ -435,17 +466,26 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
435
466
|
class
|
436
467
|
});
|
437
468
|
static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
|
438
|
-
let class: RClass = ruby
|
469
|
+
let class: RClass = ruby
|
470
|
+
.get_inner(&PRE_TOKENIZERS)
|
471
|
+
.const_get("UnicodeScripts")
|
472
|
+
.unwrap();
|
439
473
|
class.undef_default_alloc_func();
|
440
474
|
class
|
441
475
|
});
|
442
476
|
static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
443
|
-
let class: RClass = ruby
|
477
|
+
let class: RClass = ruby
|
478
|
+
.get_inner(&PRE_TOKENIZERS)
|
479
|
+
.const_get("Whitespace")
|
480
|
+
.unwrap();
|
444
481
|
class.undef_default_alloc_func();
|
445
482
|
class
|
446
483
|
});
|
447
484
|
static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
448
|
-
let class: RClass = ruby
|
485
|
+
let class: RClass = ruby
|
486
|
+
.get_inner(&PRE_TOKENIZERS)
|
487
|
+
.const_get("WhitespaceSplit")
|
488
|
+
.unwrap();
|
449
489
|
class.undef_default_alloc_func();
|
450
490
|
class
|
451
491
|
});
|
@@ -472,7 +512,10 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
472
512
|
|
473
513
|
pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
474
514
|
let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
|
475
|
-
pre_tokenizer.define_method(
|
515
|
+
pre_tokenizer.define_method(
|
516
|
+
"pre_tokenize_str",
|
517
|
+
method!(RbPreTokenizer::pre_tokenize_str, 1),
|
518
|
+
)?;
|
476
519
|
|
477
520
|
let class = module.define_class("Sequence", pre_tokenizer)?;
|
478
521
|
class.define_singleton_method("new", function!(RbSequence::new, 1))?;
|
@@ -483,27 +526,63 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
483
526
|
let class = module.define_class("ByteLevel", pre_tokenizer)?;
|
484
527
|
class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
|
485
528
|
class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
|
486
|
-
class.define_method(
|
487
|
-
|
488
|
-
|
489
|
-
|
529
|
+
class.define_method(
|
530
|
+
"add_prefix_space",
|
531
|
+
method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
|
532
|
+
)?;
|
533
|
+
class.define_method(
|
534
|
+
"add_prefix_space=",
|
535
|
+
method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1),
|
536
|
+
)?;
|
537
|
+
class.define_method(
|
538
|
+
"use_regex",
|
539
|
+
method!(RbPreTokenizer::byte_level_use_regex, 0),
|
540
|
+
)?;
|
541
|
+
class.define_method(
|
542
|
+
"use_regex=",
|
543
|
+
method!(RbPreTokenizer::byte_level_set_use_regex, 1),
|
544
|
+
)?;
|
490
545
|
|
491
546
|
let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
|
492
547
|
class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
|
493
|
-
class.define_method(
|
494
|
-
|
548
|
+
class.define_method(
|
549
|
+
"delimiter",
|
550
|
+
method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
|
551
|
+
)?;
|
552
|
+
class.define_method(
|
553
|
+
"delimiter=",
|
554
|
+
method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1),
|
555
|
+
)?;
|
495
556
|
|
496
557
|
let class = module.define_class("Digits", pre_tokenizer)?;
|
497
558
|
class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
|
498
|
-
class.define_method(
|
499
|
-
|
559
|
+
class.define_method(
|
560
|
+
"individual_digits",
|
561
|
+
method!(RbPreTokenizer::digits_individual_digits, 0),
|
562
|
+
)?;
|
563
|
+
class.define_method(
|
564
|
+
"individual_digits=",
|
565
|
+
method!(RbPreTokenizer::digits_set_individual_digits, 1),
|
566
|
+
)?;
|
500
567
|
|
501
568
|
let class = module.define_class("Metaspace", pre_tokenizer)?;
|
502
569
|
class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
|
503
|
-
class.define_method(
|
504
|
-
|
505
|
-
|
506
|
-
|
570
|
+
class.define_method(
|
571
|
+
"prepend_scheme",
|
572
|
+
method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
|
573
|
+
)?;
|
574
|
+
class.define_method(
|
575
|
+
"prepend_scheme=",
|
576
|
+
method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1),
|
577
|
+
)?;
|
578
|
+
class.define_method(
|
579
|
+
"replacement",
|
580
|
+
method!(RbPreTokenizer::metaspace_replacement, 0),
|
581
|
+
)?;
|
582
|
+
class.define_method(
|
583
|
+
"replacement=",
|
584
|
+
method!(RbPreTokenizer::metaspace_set_replacement, 1),
|
585
|
+
)?;
|
507
586
|
class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
|
508
587
|
class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
|
509
588
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::Arc;
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
-
Ruby, TryConvert, TypedData, Value,
|
4
|
+
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
+
RClass, RModule, Ruby, TryConvert, TypedData, Value,
|
6
6
|
};
|
7
7
|
use serde::{Deserialize, Serialize};
|
8
8
|
use tk::processors::bert::BertProcessing;
|
@@ -12,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
|
|
12
12
|
use tk::processors::PostProcessorWrapper;
|
13
13
|
use tk::{Encoding, PostProcessor};
|
14
14
|
|
15
|
-
use super::{
|
15
|
+
use super::{RbResult, PROCESSORS};
|
16
16
|
|
17
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
18
18
|
pub struct RbPostProcessor {
|
@@ -106,7 +106,6 @@ impl RbByteLevel {
|
|
106
106
|
}
|
107
107
|
RbPostProcessor::new(Arc::new(byte_level.into()))
|
108
108
|
}
|
109
|
-
|
110
109
|
}
|
111
110
|
|
112
111
|
pub struct RbRobertaProcessing {}
|
@@ -117,7 +116,7 @@ impl RbRobertaProcessing {
|
|
117
116
|
cls: (String, u32),
|
118
117
|
trim_offsets: bool,
|
119
118
|
add_prefix_space: bool,
|
120
|
-
) ->
|
119
|
+
) -> RbPostProcessor {
|
121
120
|
let proc = RobertaProcessing::new(sep, cls)
|
122
121
|
.trim_offsets(trim_offsets)
|
123
122
|
.add_prefix_space(add_prefix_space);
|
@@ -153,7 +152,10 @@ impl RbTemplateProcessing {
|
|
153
152
|
unsafe impl TypedData for RbPostProcessor {
|
154
153
|
fn class(ruby: &Ruby) -> RClass {
|
155
154
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
156
|
-
let class: RClass = ruby
|
155
|
+
let class: RClass = ruby
|
156
|
+
.get_inner(&PROCESSORS)
|
157
|
+
.const_get("PostProcessor")
|
158
|
+
.unwrap();
|
157
159
|
class.undef_default_alloc_func();
|
158
160
|
class
|
159
161
|
});
|
@@ -161,13 +163,17 @@ unsafe impl TypedData for RbPostProcessor {
|
|
161
163
|
}
|
162
164
|
|
163
165
|
fn data_type() -> &'static DataType {
|
164
|
-
static DATA_TYPE: DataType =
|
166
|
+
static DATA_TYPE: DataType =
|
167
|
+
data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
|
165
168
|
&DATA_TYPE
|
166
169
|
}
|
167
170
|
|
168
171
|
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
169
172
|
static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
170
|
-
let class: RClass = ruby
|
173
|
+
let class: RClass = ruby
|
174
|
+
.get_inner(&PROCESSORS)
|
175
|
+
.const_get("BertProcessing")
|
176
|
+
.unwrap();
|
171
177
|
class.undef_default_alloc_func();
|
172
178
|
class
|
173
179
|
});
|
@@ -177,12 +183,18 @@ unsafe impl TypedData for RbPostProcessor {
|
|
177
183
|
class
|
178
184
|
});
|
179
185
|
static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
180
|
-
let class: RClass = ruby
|
186
|
+
let class: RClass = ruby
|
187
|
+
.get_inner(&PROCESSORS)
|
188
|
+
.const_get("RobertaProcessing")
|
189
|
+
.unwrap();
|
181
190
|
class.undef_default_alloc_func();
|
182
191
|
class
|
183
192
|
});
|
184
193
|
static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
185
|
-
let class: RClass = ruby
|
194
|
+
let class: RClass = ruby
|
195
|
+
.get_inner(&PROCESSORS)
|
196
|
+
.const_get("TemplateProcessing")
|
197
|
+
.unwrap();
|
186
198
|
class.undef_default_alloc_func();
|
187
199
|
class
|
188
200
|
});
|
@@ -6,8 +6,8 @@ use std::str::FromStr;
|
|
6
6
|
use magnus::prelude::*;
|
7
7
|
use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
|
8
8
|
use tk::tokenizer::{
|
9
|
-
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
10
|
-
|
9
|
+
Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
|
10
|
+
TruncationParams, TruncationStrategy,
|
11
11
|
};
|
12
12
|
use tk::AddedToken;
|
13
13
|
|
@@ -22,9 +22,10 @@ use super::processors::RbPostProcessor;
|
|
22
22
|
use super::trainers::RbTrainer;
|
23
23
|
use super::{RbError, RbResult};
|
24
24
|
|
25
|
+
#[magnus::wrap(class = "Tokenizers::AddedToken")]
|
25
26
|
pub struct RbAddedToken {
|
26
27
|
pub content: String,
|
27
|
-
pub
|
28
|
+
pub special: bool,
|
28
29
|
pub single_word: Option<bool>,
|
29
30
|
pub lstrip: Option<bool>,
|
30
31
|
pub rstrip: Option<bool>,
|
@@ -32,10 +33,10 @@ pub struct RbAddedToken {
|
|
32
33
|
}
|
33
34
|
|
34
35
|
impl RbAddedToken {
|
35
|
-
pub fn from<S: Into<String>>(content: S,
|
36
|
+
pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
|
36
37
|
Self {
|
37
38
|
content: content.into(),
|
38
|
-
|
39
|
+
special: special.unwrap_or(false),
|
39
40
|
single_word: None,
|
40
41
|
lstrip: None,
|
41
42
|
rstrip: None,
|
@@ -44,7 +45,7 @@ impl RbAddedToken {
|
|
44
45
|
}
|
45
46
|
|
46
47
|
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
47
|
-
let mut token = tk::AddedToken::from(&self.content, self.
|
48
|
+
let mut token = tk::AddedToken::from(&self.content, self.special);
|
48
49
|
|
49
50
|
if let Some(sw) = self.single_word {
|
50
51
|
token = token.single_word(sw);
|
@@ -71,11 +72,73 @@ impl From<tk::AddedToken> for RbAddedToken {
|
|
71
72
|
lstrip: Some(token.lstrip),
|
72
73
|
rstrip: Some(token.rstrip),
|
73
74
|
normalized: Some(token.normalized),
|
74
|
-
|
75
|
+
special: !token.normalized,
|
75
76
|
}
|
76
77
|
}
|
77
78
|
}
|
78
79
|
|
80
|
+
impl RbAddedToken {
|
81
|
+
pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
82
|
+
let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
|
83
|
+
|
84
|
+
let value: Value = kwargs.delete(Symbol::new("single_word"))?;
|
85
|
+
if !value.is_nil() {
|
86
|
+
token.single_word = TryConvert::try_convert(value)?;
|
87
|
+
}
|
88
|
+
|
89
|
+
let value: Value = kwargs.delete(Symbol::new("lstrip"))?;
|
90
|
+
if !value.is_nil() {
|
91
|
+
token.lstrip = TryConvert::try_convert(value)?;
|
92
|
+
}
|
93
|
+
|
94
|
+
let value: Value = kwargs.delete(Symbol::new("rstrip"))?;
|
95
|
+
if !value.is_nil() {
|
96
|
+
token.rstrip = TryConvert::try_convert(value)?;
|
97
|
+
}
|
98
|
+
|
99
|
+
let value: Value = kwargs.delete(Symbol::new("normalized"))?;
|
100
|
+
if !value.is_nil() {
|
101
|
+
token.normalized = TryConvert::try_convert(value)?;
|
102
|
+
}
|
103
|
+
|
104
|
+
let value: Value = kwargs.delete(Symbol::new("special"))?;
|
105
|
+
if !value.is_nil() {
|
106
|
+
token.special = TryConvert::try_convert(value)?;
|
107
|
+
}
|
108
|
+
|
109
|
+
if !kwargs.is_empty() {
|
110
|
+
// TODO improve message
|
111
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
112
|
+
}
|
113
|
+
|
114
|
+
Ok(token)
|
115
|
+
}
|
116
|
+
|
117
|
+
pub fn get_content(&self) -> String {
|
118
|
+
self.content.to_string()
|
119
|
+
}
|
120
|
+
|
121
|
+
pub fn get_rstrip(&self) -> bool {
|
122
|
+
self.get_token().rstrip
|
123
|
+
}
|
124
|
+
|
125
|
+
pub fn get_lstrip(&self) -> bool {
|
126
|
+
self.get_token().lstrip
|
127
|
+
}
|
128
|
+
|
129
|
+
pub fn get_single_word(&self) -> bool {
|
130
|
+
self.get_token().single_word
|
131
|
+
}
|
132
|
+
|
133
|
+
pub fn get_normalized(&self) -> bool {
|
134
|
+
self.get_token().normalized
|
135
|
+
}
|
136
|
+
|
137
|
+
pub fn get_special(&self) -> bool {
|
138
|
+
self.get_token().special
|
139
|
+
}
|
140
|
+
}
|
141
|
+
|
79
142
|
struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
80
143
|
|
81
144
|
impl<'s> TryConvert for TextInputSequence<'s> {
|
@@ -221,7 +284,10 @@ impl RbTokenizer {
|
|
221
284
|
}
|
222
285
|
|
223
286
|
pub fn to_str(&self, pretty: bool) -> RbResult<String> {
|
224
|
-
self.tokenizer
|
287
|
+
self.tokenizer
|
288
|
+
.borrow()
|
289
|
+
.to_string(pretty)
|
290
|
+
.map_err(RbError::from)
|
225
291
|
}
|
226
292
|
|
227
293
|
pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
|
@@ -320,7 +386,11 @@ impl RbTokenizer {
|
|
320
386
|
.map_err(RbError::from)
|
321
387
|
}
|
322
388
|
|
323
|
-
pub fn decode_batch(
|
389
|
+
pub fn decode_batch(
|
390
|
+
&self,
|
391
|
+
sequences: Vec<Vec<u32>>,
|
392
|
+
skip_special_tokens: bool,
|
393
|
+
) -> RbResult<Vec<String>> {
|
324
394
|
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
325
395
|
self.tokenizer
|
326
396
|
.borrow()
|
@@ -392,7 +462,12 @@ impl RbTokenizer {
|
|
392
462
|
params.direction = match dir_str.as_str() {
|
393
463
|
"left" => PaddingDirection::Left,
|
394
464
|
"right" => PaddingDirection::Right,
|
395
|
-
_ =>
|
465
|
+
_ => {
|
466
|
+
return Err(Error::new(
|
467
|
+
exception::arg_error(),
|
468
|
+
"The direction value must be 'left' or 'right'",
|
469
|
+
))
|
470
|
+
}
|
396
471
|
}
|
397
472
|
}
|
398
473
|
|
@@ -438,24 +513,27 @@ impl RbTokenizer {
|
|
438
513
|
}
|
439
514
|
|
440
515
|
pub fn padding(&self) -> RbResult<Option<RHash>> {
|
441
|
-
self.tokenizer
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
516
|
+
self.tokenizer
|
517
|
+
.borrow()
|
518
|
+
.get_padding()
|
519
|
+
.map_or(Ok(None), |params| {
|
520
|
+
let ret_hash = RHash::new();
|
521
|
+
|
522
|
+
ret_hash.aset(
|
523
|
+
"length",
|
524
|
+
match params.strategy {
|
525
|
+
tk::PaddingStrategy::BatchLongest => None,
|
526
|
+
tk::PaddingStrategy::Fixed(size) => Some(size),
|
527
|
+
},
|
528
|
+
)?;
|
529
|
+
ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
|
530
|
+
ret_hash.aset("pad_id", params.pad_id)?;
|
531
|
+
ret_hash.aset("pad_token", &*params.pad_token)?;
|
532
|
+
ret_hash.aset("pad_type_id", params.pad_type_id)?;
|
533
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
534
|
+
|
535
|
+
Ok(Some(ret_hash))
|
536
|
+
})
|
459
537
|
}
|
460
538
|
|
461
539
|
pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
|
@@ -476,7 +554,10 @@ impl RbTokenizer {
|
|
476
554
|
"longest_first" => TruncationStrategy::LongestFirst,
|
477
555
|
"only_first" => TruncationStrategy::OnlyFirst,
|
478
556
|
"only_second" => TruncationStrategy::OnlySecond,
|
479
|
-
_ => return Err(Error::new(
|
557
|
+
_ => return Err(Error::new(
|
558
|
+
exception::arg_error(),
|
559
|
+
"The strategy value must be 'longest_first', 'only_first', or 'only_second'",
|
560
|
+
)),
|
480
561
|
}
|
481
562
|
}
|
482
563
|
|
@@ -486,7 +567,12 @@ impl RbTokenizer {
|
|
486
567
|
params.direction = match dir_str.as_str() {
|
487
568
|
"left" => TruncationDirection::Left,
|
488
569
|
"right" => TruncationDirection::Right,
|
489
|
-
_ =>
|
570
|
+
_ => {
|
571
|
+
return Err(Error::new(
|
572
|
+
exception::arg_error(),
|
573
|
+
"The direction value must be 'left' or 'right'",
|
574
|
+
))
|
575
|
+
}
|
490
576
|
}
|
491
577
|
}
|
492
578
|
|
@@ -496,7 +582,10 @@ impl RbTokenizer {
|
|
496
582
|
}
|
497
583
|
|
498
584
|
if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
499
|
-
return Err(Error::new(
|
585
|
+
return Err(Error::new(
|
586
|
+
exception::arg_error(),
|
587
|
+
error_message.to_string(),
|
588
|
+
));
|
500
589
|
}
|
501
590
|
|
502
591
|
Ok(())
|
@@ -510,16 +599,19 @@ impl RbTokenizer {
|
|
510
599
|
}
|
511
600
|
|
512
601
|
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|
513
|
-
self.tokenizer
|
514
|
-
|
602
|
+
self.tokenizer
|
603
|
+
.borrow()
|
604
|
+
.get_truncation()
|
605
|
+
.map_or(Ok(None), |params| {
|
606
|
+
let ret_hash = RHash::new();
|
515
607
|
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
608
|
+
ret_hash.aset("max_length", params.max_length)?;
|
609
|
+
ret_hash.aset("stride", params.stride)?;
|
610
|
+
ret_hash.aset("strategy", params.strategy.as_ref())?;
|
611
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
520
612
|
|
521
|
-
|
522
|
-
|
613
|
+
Ok(Some(ret_hash))
|
614
|
+
})
|
523
615
|
}
|
524
616
|
|
525
617
|
pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
@@ -536,4 +628,14 @@ impl RbTokenizer {
|
|
536
628
|
pub fn vocab_size(&self, with_added_tokens: bool) -> usize {
|
537
629
|
self.tokenizer.borrow().get_vocab_size(with_added_tokens)
|
538
630
|
}
|
631
|
+
|
632
|
+
pub fn get_added_tokens_decoder(&self) -> RbResult<RHash> {
|
633
|
+
let sorted_map = RHash::new();
|
634
|
+
|
635
|
+
for (key, value) in self.tokenizer.borrow().get_added_tokens_decoder() {
|
636
|
+
sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
|
637
|
+
}
|
638
|
+
|
639
|
+
Ok(sorted_map)
|
640
|
+
}
|
539
641
|
}
|