tokenizers 0.5.3 → 0.5.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +154 -83
- data/ext/tokenizers/Cargo.toml +2 -2
- data/ext/tokenizers/src/decoders.rs +32 -14
- data/ext/tokenizers/src/error.rs +6 -1
- data/ext/tokenizers/src/lib.rs +37 -12
- data/ext/tokenizers/src/models.rs +75 -23
- data/ext/tokenizers/src/normalizers.rs +84 -24
- data/ext/tokenizers/src/pre_tokenizers.rs +121 -42
- data/ext/tokenizers/src/processors.rs +22 -10
- data/ext/tokenizers/src/tokenizer.rs +63 -34
- data/ext/tokenizers/src/trainers.rs +215 -56
- data/ext/tokenizers/src/utils/regex.rs +6 -4
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/version.rb +1 -1
- metadata +3 -7
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
5
|
-
RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
4
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
5
|
+
DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
6
6
|
};
|
7
7
|
|
8
8
|
use serde::ser::SerializeStruct;
|
@@ -22,7 +22,7 @@ use tk::tokenizer::Offsets;
|
|
22
22
|
use tk::{PreTokenizedString, PreTokenizer};
|
23
23
|
|
24
24
|
use super::utils::*;
|
25
|
-
use super::{
|
25
|
+
use super::{RbError, RbResult, PRE_TOKENIZERS};
|
26
26
|
|
27
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
28
28
|
pub struct RbPreTokenizer {
|
@@ -34,7 +34,9 @@ impl RbPreTokenizer {
|
|
34
34
|
fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
|
35
35
|
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
|
36
36
|
|
37
|
-
self.pretok
|
37
|
+
self.pretok
|
38
|
+
.pre_tokenize(&mut pretokenized)
|
39
|
+
.map_err(RbError::from)?;
|
38
40
|
|
39
41
|
Ok(pretokenized
|
40
42
|
.get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
|
@@ -195,11 +197,7 @@ impl RbDigits {
|
|
195
197
|
pub struct RbMetaspace {}
|
196
198
|
|
197
199
|
impl RbMetaspace {
|
198
|
-
fn new(
|
199
|
-
replacement: char,
|
200
|
-
prepend_scheme: String,
|
201
|
-
split: bool,
|
202
|
-
) -> RbResult<RbPreTokenizer> {
|
200
|
+
fn new(replacement: char, prepend_scheme: String, split: bool) -> RbResult<RbPreTokenizer> {
|
203
201
|
let prepend_scheme = from_string(prepend_scheme)?;
|
204
202
|
Ok(Metaspace::new(replacement, prepend_scheme, split).into())
|
205
203
|
}
|
@@ -216,8 +214,14 @@ impl RbPunctuation {
|
|
216
214
|
pub struct RbSplit {}
|
217
215
|
|
218
216
|
impl RbSplit {
|
219
|
-
pub fn new(
|
220
|
-
|
217
|
+
pub fn new(
|
218
|
+
pattern: RbPattern,
|
219
|
+
behavior: RbSplitDelimiterBehavior,
|
220
|
+
invert: bool,
|
221
|
+
) -> RbResult<RbPreTokenizer> {
|
222
|
+
Split::new(pattern, behavior.into(), invert)
|
223
|
+
.map(|v| v.into())
|
224
|
+
.map_err(RbError::from)
|
221
225
|
}
|
222
226
|
}
|
223
227
|
|
@@ -258,16 +262,18 @@ pub struct RbSequence {}
|
|
258
262
|
impl RbSequence {
|
259
263
|
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
260
264
|
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
261
|
-
for n in pre_tokenizers
|
265
|
+
for n in pre_tokenizers {
|
262
266
|
let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n)?;
|
263
267
|
match &pretokenizer.pretok {
|
264
268
|
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
265
|
-
sequence.extend(inner.iter().cloned())
|
269
|
+
sequence.extend(inner.iter().cloned());
|
266
270
|
}
|
267
271
|
RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
268
272
|
}
|
269
273
|
}
|
270
|
-
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
|
274
|
+
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
|
275
|
+
sequence,
|
276
|
+
)))
|
271
277
|
}
|
272
278
|
}
|
273
279
|
|
@@ -277,10 +283,13 @@ pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
|
|
277
283
|
"never" => PrependScheme::Never,
|
278
284
|
"always" => PrependScheme::Always,
|
279
285
|
_ => {
|
280
|
-
return Err(Error::new(
|
281
|
-
|
282
|
-
|
283
|
-
|
286
|
+
return Err(Error::new(
|
287
|
+
exception::arg_error(),
|
288
|
+
format!(
|
289
|
+
"{} is an unknown variant, should be one of ['first', 'never', 'always']",
|
290
|
+
string
|
291
|
+
),
|
292
|
+
));
|
284
293
|
}
|
285
294
|
};
|
286
295
|
Ok(scheme)
|
@@ -381,7 +390,10 @@ impl PreTokenizer for RbPreTokenizerWrapper {
|
|
381
390
|
unsafe impl TypedData for RbPreTokenizer {
|
382
391
|
fn class(ruby: &Ruby) -> RClass {
|
383
392
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
384
|
-
let class: RClass = ruby
|
393
|
+
let class: RClass = ruby
|
394
|
+
.get_inner(&PRE_TOKENIZERS)
|
395
|
+
.const_get("PreTokenizer")
|
396
|
+
.unwrap();
|
385
397
|
class.undef_default_alloc_func();
|
386
398
|
class
|
387
399
|
});
|
@@ -389,28 +401,41 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
389
401
|
}
|
390
402
|
|
391
403
|
fn data_type() -> &'static DataType {
|
392
|
-
static DATA_TYPE: DataType =
|
404
|
+
static DATA_TYPE: DataType =
|
405
|
+
data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
|
393
406
|
&DATA_TYPE
|
394
407
|
}
|
395
408
|
|
396
409
|
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
397
410
|
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
398
|
-
let class: RClass = ruby
|
411
|
+
let class: RClass = ruby
|
412
|
+
.get_inner(&PRE_TOKENIZERS)
|
413
|
+
.const_get("Sequence")
|
414
|
+
.unwrap();
|
399
415
|
class.undef_default_alloc_func();
|
400
416
|
class
|
401
417
|
});
|
402
418
|
static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
403
|
-
let class: RClass = ruby
|
419
|
+
let class: RClass = ruby
|
420
|
+
.get_inner(&PRE_TOKENIZERS)
|
421
|
+
.const_get("BertPreTokenizer")
|
422
|
+
.unwrap();
|
404
423
|
class.undef_default_alloc_func();
|
405
424
|
class
|
406
425
|
});
|
407
426
|
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
408
|
-
let class: RClass = ruby
|
427
|
+
let class: RClass = ruby
|
428
|
+
.get_inner(&PRE_TOKENIZERS)
|
429
|
+
.const_get("ByteLevel")
|
430
|
+
.unwrap();
|
409
431
|
class.undef_default_alloc_func();
|
410
432
|
class
|
411
433
|
});
|
412
434
|
static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
413
|
-
let class: RClass = ruby
|
435
|
+
let class: RClass = ruby
|
436
|
+
.get_inner(&PRE_TOKENIZERS)
|
437
|
+
.const_get("CharDelimiterSplit")
|
438
|
+
.unwrap();
|
414
439
|
class.undef_default_alloc_func();
|
415
440
|
class
|
416
441
|
});
|
@@ -420,12 +445,18 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
420
445
|
class
|
421
446
|
});
|
422
447
|
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
423
|
-
let class: RClass = ruby
|
448
|
+
let class: RClass = ruby
|
449
|
+
.get_inner(&PRE_TOKENIZERS)
|
450
|
+
.const_get("Metaspace")
|
451
|
+
.unwrap();
|
424
452
|
class.undef_default_alloc_func();
|
425
453
|
class
|
426
454
|
});
|
427
455
|
static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
|
428
|
-
let class: RClass = ruby
|
456
|
+
let class: RClass = ruby
|
457
|
+
.get_inner(&PRE_TOKENIZERS)
|
458
|
+
.const_get("Punctuation")
|
459
|
+
.unwrap();
|
429
460
|
class.undef_default_alloc_func();
|
430
461
|
class
|
431
462
|
});
|
@@ -435,17 +466,26 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
435
466
|
class
|
436
467
|
});
|
437
468
|
static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
|
438
|
-
let class: RClass = ruby
|
469
|
+
let class: RClass = ruby
|
470
|
+
.get_inner(&PRE_TOKENIZERS)
|
471
|
+
.const_get("UnicodeScripts")
|
472
|
+
.unwrap();
|
439
473
|
class.undef_default_alloc_func();
|
440
474
|
class
|
441
475
|
});
|
442
476
|
static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
443
|
-
let class: RClass = ruby
|
477
|
+
let class: RClass = ruby
|
478
|
+
.get_inner(&PRE_TOKENIZERS)
|
479
|
+
.const_get("Whitespace")
|
480
|
+
.unwrap();
|
444
481
|
class.undef_default_alloc_func();
|
445
482
|
class
|
446
483
|
});
|
447
484
|
static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
448
|
-
let class: RClass = ruby
|
485
|
+
let class: RClass = ruby
|
486
|
+
.get_inner(&PRE_TOKENIZERS)
|
487
|
+
.const_get("WhitespaceSplit")
|
488
|
+
.unwrap();
|
449
489
|
class.undef_default_alloc_func();
|
450
490
|
class
|
451
491
|
});
|
@@ -472,7 +512,10 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
472
512
|
|
473
513
|
pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
474
514
|
let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
|
475
|
-
pre_tokenizer.define_method(
|
515
|
+
pre_tokenizer.define_method(
|
516
|
+
"pre_tokenize_str",
|
517
|
+
method!(RbPreTokenizer::pre_tokenize_str, 1),
|
518
|
+
)?;
|
476
519
|
|
477
520
|
let class = module.define_class("Sequence", pre_tokenizer)?;
|
478
521
|
class.define_singleton_method("new", function!(RbSequence::new, 1))?;
|
@@ -483,27 +526,63 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
483
526
|
let class = module.define_class("ByteLevel", pre_tokenizer)?;
|
484
527
|
class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
|
485
528
|
class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
|
486
|
-
class.define_method(
|
487
|
-
|
488
|
-
|
489
|
-
|
529
|
+
class.define_method(
|
530
|
+
"add_prefix_space",
|
531
|
+
method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
|
532
|
+
)?;
|
533
|
+
class.define_method(
|
534
|
+
"add_prefix_space=",
|
535
|
+
method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1),
|
536
|
+
)?;
|
537
|
+
class.define_method(
|
538
|
+
"use_regex",
|
539
|
+
method!(RbPreTokenizer::byte_level_use_regex, 0),
|
540
|
+
)?;
|
541
|
+
class.define_method(
|
542
|
+
"use_regex=",
|
543
|
+
method!(RbPreTokenizer::byte_level_set_use_regex, 1),
|
544
|
+
)?;
|
490
545
|
|
491
546
|
let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
|
492
547
|
class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
|
493
|
-
class.define_method(
|
494
|
-
|
548
|
+
class.define_method(
|
549
|
+
"delimiter",
|
550
|
+
method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
|
551
|
+
)?;
|
552
|
+
class.define_method(
|
553
|
+
"delimiter=",
|
554
|
+
method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1),
|
555
|
+
)?;
|
495
556
|
|
496
557
|
let class = module.define_class("Digits", pre_tokenizer)?;
|
497
558
|
class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
|
498
|
-
class.define_method(
|
499
|
-
|
559
|
+
class.define_method(
|
560
|
+
"individual_digits",
|
561
|
+
method!(RbPreTokenizer::digits_individual_digits, 0),
|
562
|
+
)?;
|
563
|
+
class.define_method(
|
564
|
+
"individual_digits=",
|
565
|
+
method!(RbPreTokenizer::digits_set_individual_digits, 1),
|
566
|
+
)?;
|
500
567
|
|
501
568
|
let class = module.define_class("Metaspace", pre_tokenizer)?;
|
502
569
|
class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
|
503
|
-
class.define_method(
|
504
|
-
|
505
|
-
|
506
|
-
|
570
|
+
class.define_method(
|
571
|
+
"prepend_scheme",
|
572
|
+
method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
|
573
|
+
)?;
|
574
|
+
class.define_method(
|
575
|
+
"prepend_scheme=",
|
576
|
+
method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1),
|
577
|
+
)?;
|
578
|
+
class.define_method(
|
579
|
+
"replacement",
|
580
|
+
method!(RbPreTokenizer::metaspace_replacement, 0),
|
581
|
+
)?;
|
582
|
+
class.define_method(
|
583
|
+
"replacement=",
|
584
|
+
method!(RbPreTokenizer::metaspace_set_replacement, 1),
|
585
|
+
)?;
|
507
586
|
class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
|
508
587
|
class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
|
509
588
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::Arc;
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
-
Ruby, TryConvert, TypedData, Value,
|
4
|
+
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
+
RClass, RModule, Ruby, TryConvert, TypedData, Value,
|
6
6
|
};
|
7
7
|
use serde::{Deserialize, Serialize};
|
8
8
|
use tk::processors::bert::BertProcessing;
|
@@ -12,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
|
|
12
12
|
use tk::processors::PostProcessorWrapper;
|
13
13
|
use tk::{Encoding, PostProcessor};
|
14
14
|
|
15
|
-
use super::{
|
15
|
+
use super::{RbResult, PROCESSORS};
|
16
16
|
|
17
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
18
18
|
pub struct RbPostProcessor {
|
@@ -106,7 +106,6 @@ impl RbByteLevel {
|
|
106
106
|
}
|
107
107
|
RbPostProcessor::new(Arc::new(byte_level.into()))
|
108
108
|
}
|
109
|
-
|
110
109
|
}
|
111
110
|
|
112
111
|
pub struct RbRobertaProcessing {}
|
@@ -117,7 +116,7 @@ impl RbRobertaProcessing {
|
|
117
116
|
cls: (String, u32),
|
118
117
|
trim_offsets: bool,
|
119
118
|
add_prefix_space: bool,
|
120
|
-
) ->
|
119
|
+
) -> RbPostProcessor {
|
121
120
|
let proc = RobertaProcessing::new(sep, cls)
|
122
121
|
.trim_offsets(trim_offsets)
|
123
122
|
.add_prefix_space(add_prefix_space);
|
@@ -153,7 +152,10 @@ impl RbTemplateProcessing {
|
|
153
152
|
unsafe impl TypedData for RbPostProcessor {
|
154
153
|
fn class(ruby: &Ruby) -> RClass {
|
155
154
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
156
|
-
let class: RClass = ruby
|
155
|
+
let class: RClass = ruby
|
156
|
+
.get_inner(&PROCESSORS)
|
157
|
+
.const_get("PostProcessor")
|
158
|
+
.unwrap();
|
157
159
|
class.undef_default_alloc_func();
|
158
160
|
class
|
159
161
|
});
|
@@ -161,13 +163,17 @@ unsafe impl TypedData for RbPostProcessor {
|
|
161
163
|
}
|
162
164
|
|
163
165
|
fn data_type() -> &'static DataType {
|
164
|
-
static DATA_TYPE: DataType =
|
166
|
+
static DATA_TYPE: DataType =
|
167
|
+
data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
|
165
168
|
&DATA_TYPE
|
166
169
|
}
|
167
170
|
|
168
171
|
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
169
172
|
static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
170
|
-
let class: RClass = ruby
|
173
|
+
let class: RClass = ruby
|
174
|
+
.get_inner(&PROCESSORS)
|
175
|
+
.const_get("BertProcessing")
|
176
|
+
.unwrap();
|
171
177
|
class.undef_default_alloc_func();
|
172
178
|
class
|
173
179
|
});
|
@@ -177,12 +183,18 @@ unsafe impl TypedData for RbPostProcessor {
|
|
177
183
|
class
|
178
184
|
});
|
179
185
|
static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
180
|
-
let class: RClass = ruby
|
186
|
+
let class: RClass = ruby
|
187
|
+
.get_inner(&PROCESSORS)
|
188
|
+
.const_get("RobertaProcessing")
|
189
|
+
.unwrap();
|
181
190
|
class.undef_default_alloc_func();
|
182
191
|
class
|
183
192
|
});
|
184
193
|
static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
185
|
-
let class: RClass = ruby
|
194
|
+
let class: RClass = ruby
|
195
|
+
.get_inner(&PROCESSORS)
|
196
|
+
.const_get("TemplateProcessing")
|
197
|
+
.unwrap();
|
186
198
|
class.undef_default_alloc_func();
|
187
199
|
class
|
188
200
|
});
|
@@ -6,8 +6,8 @@ use std::str::FromStr;
|
|
6
6
|
use magnus::prelude::*;
|
7
7
|
use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
|
8
8
|
use tk::tokenizer::{
|
9
|
-
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
10
|
-
|
9
|
+
Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
|
10
|
+
TruncationParams, TruncationStrategy,
|
11
11
|
};
|
12
12
|
use tk::AddedToken;
|
13
13
|
|
@@ -284,7 +284,10 @@ impl RbTokenizer {
|
|
284
284
|
}
|
285
285
|
|
286
286
|
pub fn to_str(&self, pretty: bool) -> RbResult<String> {
|
287
|
-
self.tokenizer
|
287
|
+
self.tokenizer
|
288
|
+
.borrow()
|
289
|
+
.to_string(pretty)
|
290
|
+
.map_err(RbError::from)
|
288
291
|
}
|
289
292
|
|
290
293
|
pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
|
@@ -383,7 +386,11 @@ impl RbTokenizer {
|
|
383
386
|
.map_err(RbError::from)
|
384
387
|
}
|
385
388
|
|
386
|
-
pub fn decode_batch(
|
389
|
+
pub fn decode_batch(
|
390
|
+
&self,
|
391
|
+
sequences: Vec<Vec<u32>>,
|
392
|
+
skip_special_tokens: bool,
|
393
|
+
) -> RbResult<Vec<String>> {
|
387
394
|
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
388
395
|
self.tokenizer
|
389
396
|
.borrow()
|
@@ -455,7 +462,12 @@ impl RbTokenizer {
|
|
455
462
|
params.direction = match dir_str.as_str() {
|
456
463
|
"left" => PaddingDirection::Left,
|
457
464
|
"right" => PaddingDirection::Right,
|
458
|
-
_ =>
|
465
|
+
_ => {
|
466
|
+
return Err(Error::new(
|
467
|
+
exception::arg_error(),
|
468
|
+
"The direction value must be 'left' or 'right'",
|
469
|
+
))
|
470
|
+
}
|
459
471
|
}
|
460
472
|
}
|
461
473
|
|
@@ -501,24 +513,27 @@ impl RbTokenizer {
|
|
501
513
|
}
|
502
514
|
|
503
515
|
pub fn padding(&self) -> RbResult<Option<RHash>> {
|
504
|
-
self.tokenizer
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
516
|
+
self.tokenizer
|
517
|
+
.borrow()
|
518
|
+
.get_padding()
|
519
|
+
.map_or(Ok(None), |params| {
|
520
|
+
let ret_hash = RHash::new();
|
521
|
+
|
522
|
+
ret_hash.aset(
|
523
|
+
"length",
|
524
|
+
match params.strategy {
|
525
|
+
tk::PaddingStrategy::BatchLongest => None,
|
526
|
+
tk::PaddingStrategy::Fixed(size) => Some(size),
|
527
|
+
},
|
528
|
+
)?;
|
529
|
+
ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
|
530
|
+
ret_hash.aset("pad_id", params.pad_id)?;
|
531
|
+
ret_hash.aset("pad_token", &*params.pad_token)?;
|
532
|
+
ret_hash.aset("pad_type_id", params.pad_type_id)?;
|
533
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
534
|
+
|
535
|
+
Ok(Some(ret_hash))
|
536
|
+
})
|
522
537
|
}
|
523
538
|
|
524
539
|
pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
|
@@ -539,7 +554,10 @@ impl RbTokenizer {
|
|
539
554
|
"longest_first" => TruncationStrategy::LongestFirst,
|
540
555
|
"only_first" => TruncationStrategy::OnlyFirst,
|
541
556
|
"only_second" => TruncationStrategy::OnlySecond,
|
542
|
-
_ => return Err(Error::new(
|
557
|
+
_ => return Err(Error::new(
|
558
|
+
exception::arg_error(),
|
559
|
+
"The strategy value must be 'longest_first', 'only_first', or 'only_second'",
|
560
|
+
)),
|
543
561
|
}
|
544
562
|
}
|
545
563
|
|
@@ -549,7 +567,12 @@ impl RbTokenizer {
|
|
549
567
|
params.direction = match dir_str.as_str() {
|
550
568
|
"left" => TruncationDirection::Left,
|
551
569
|
"right" => TruncationDirection::Right,
|
552
|
-
_ =>
|
570
|
+
_ => {
|
571
|
+
return Err(Error::new(
|
572
|
+
exception::arg_error(),
|
573
|
+
"The direction value must be 'left' or 'right'",
|
574
|
+
))
|
575
|
+
}
|
553
576
|
}
|
554
577
|
}
|
555
578
|
|
@@ -559,7 +582,10 @@ impl RbTokenizer {
|
|
559
582
|
}
|
560
583
|
|
561
584
|
if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
562
|
-
return Err(Error::new(
|
585
|
+
return Err(Error::new(
|
586
|
+
exception::arg_error(),
|
587
|
+
error_message.to_string(),
|
588
|
+
));
|
563
589
|
}
|
564
590
|
|
565
591
|
Ok(())
|
@@ -573,16 +599,19 @@ impl RbTokenizer {
|
|
573
599
|
}
|
574
600
|
|
575
601
|
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|
576
|
-
self.tokenizer
|
577
|
-
|
602
|
+
self.tokenizer
|
603
|
+
.borrow()
|
604
|
+
.get_truncation()
|
605
|
+
.map_or(Ok(None), |params| {
|
606
|
+
let ret_hash = RHash::new();
|
578
607
|
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
608
|
+
ret_hash.aset("max_length", params.max_length)?;
|
609
|
+
ret_hash.aset("stride", params.stride)?;
|
610
|
+
ret_hash.aset("strategy", params.strategy.as_ref())?;
|
611
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
583
612
|
|
584
|
-
|
585
|
-
|
613
|
+
Ok(Some(ret_hash))
|
614
|
+
})
|
586
615
|
}
|
587
616
|
|
588
617
|
pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|