tokenizers 0.5.3 → 0.5.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
3
  use magnus::{
4
- data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
5
- RArray, RClass, RModule, Ruby, TryConvert, TypedData,
4
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType,
5
+ DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
6
6
  };
7
7
 
8
8
  use serde::ser::SerializeStruct;
@@ -22,7 +22,7 @@ use tk::tokenizer::Offsets;
22
22
  use tk::{PreTokenizedString, PreTokenizer};
23
23
 
24
24
  use super::utils::*;
25
- use super::{PRE_TOKENIZERS, RbError, RbResult};
25
+ use super::{RbError, RbResult, PRE_TOKENIZERS};
26
26
 
27
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
28
28
  pub struct RbPreTokenizer {
@@ -34,7 +34,9 @@ impl RbPreTokenizer {
34
34
  fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
35
35
  let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
36
36
 
37
- self.pretok.pre_tokenize(&mut pretokenized).map_err(RbError::from)?;
37
+ self.pretok
38
+ .pre_tokenize(&mut pretokenized)
39
+ .map_err(RbError::from)?;
38
40
 
39
41
  Ok(pretokenized
40
42
  .get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
@@ -195,11 +197,7 @@ impl RbDigits {
195
197
  pub struct RbMetaspace {}
196
198
 
197
199
  impl RbMetaspace {
198
- fn new(
199
- replacement: char,
200
- prepend_scheme: String,
201
- split: bool,
202
- ) -> RbResult<RbPreTokenizer> {
200
+ fn new(replacement: char, prepend_scheme: String, split: bool) -> RbResult<RbPreTokenizer> {
203
201
  let prepend_scheme = from_string(prepend_scheme)?;
204
202
  Ok(Metaspace::new(replacement, prepend_scheme, split).into())
205
203
  }
@@ -216,8 +214,14 @@ impl RbPunctuation {
216
214
  pub struct RbSplit {}
217
215
 
218
216
  impl RbSplit {
219
- pub fn new(pattern: RbPattern, behavior: RbSplitDelimiterBehavior, invert: bool) -> RbResult<RbPreTokenizer> {
220
- Split::new(pattern, behavior.into(), invert).map(|v| v.into()).map_err(RbError::from)
217
+ pub fn new(
218
+ pattern: RbPattern,
219
+ behavior: RbSplitDelimiterBehavior,
220
+ invert: bool,
221
+ ) -> RbResult<RbPreTokenizer> {
222
+ Split::new(pattern, behavior.into(), invert)
223
+ .map(|v| v.into())
224
+ .map_err(RbError::from)
221
225
  }
222
226
  }
223
227
 
@@ -258,16 +262,18 @@ pub struct RbSequence {}
258
262
  impl RbSequence {
259
263
  fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
260
264
  let mut sequence = Vec::with_capacity(pre_tokenizers.len());
261
- for n in pre_tokenizers.into_iter() {
265
+ for n in pre_tokenizers {
262
266
  let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n)?;
263
267
  match &pretokenizer.pretok {
264
268
  RbPreTokenizerTypeWrapper::Sequence(inner) => {
265
- sequence.extend(inner.iter().cloned())
269
+ sequence.extend(inner.iter().cloned());
266
270
  }
267
271
  RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
268
272
  }
269
273
  }
270
- Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(sequence)))
274
+ Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
275
+ sequence,
276
+ )))
271
277
  }
272
278
  }
273
279
 
@@ -277,10 +283,13 @@ pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
277
283
  "never" => PrependScheme::Never,
278
284
  "always" => PrependScheme::Always,
279
285
  _ => {
280
- return Err(Error::new(exception::arg_error(), format!(
281
- "{} is an unknown variant, should be one of ['first', 'never', 'always']",
282
- string
283
- )));
286
+ return Err(Error::new(
287
+ exception::arg_error(),
288
+ format!(
289
+ "{} is an unknown variant, should be one of ['first', 'never', 'always']",
290
+ string
291
+ ),
292
+ ));
284
293
  }
285
294
  };
286
295
  Ok(scheme)
@@ -381,7 +390,10 @@ impl PreTokenizer for RbPreTokenizerWrapper {
381
390
  unsafe impl TypedData for RbPreTokenizer {
382
391
  fn class(ruby: &Ruby) -> RClass {
383
392
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
384
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
393
+ let class: RClass = ruby
394
+ .get_inner(&PRE_TOKENIZERS)
395
+ .const_get("PreTokenizer")
396
+ .unwrap();
385
397
  class.undef_default_alloc_func();
386
398
  class
387
399
  });
@@ -389,28 +401,41 @@ unsafe impl TypedData for RbPreTokenizer {
389
401
  }
390
402
 
391
403
  fn data_type() -> &'static DataType {
392
- static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
404
+ static DATA_TYPE: DataType =
405
+ data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
393
406
  &DATA_TYPE
394
407
  }
395
408
 
396
409
  fn class_for(ruby: &Ruby, value: &Self) -> RClass {
397
410
  static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
398
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
411
+ let class: RClass = ruby
412
+ .get_inner(&PRE_TOKENIZERS)
413
+ .const_get("Sequence")
414
+ .unwrap();
399
415
  class.undef_default_alloc_func();
400
416
  class
401
417
  });
402
418
  static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
403
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
419
+ let class: RClass = ruby
420
+ .get_inner(&PRE_TOKENIZERS)
421
+ .const_get("BertPreTokenizer")
422
+ .unwrap();
404
423
  class.undef_default_alloc_func();
405
424
  class
406
425
  });
407
426
  static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
408
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
427
+ let class: RClass = ruby
428
+ .get_inner(&PRE_TOKENIZERS)
429
+ .const_get("ByteLevel")
430
+ .unwrap();
409
431
  class.undef_default_alloc_func();
410
432
  class
411
433
  });
412
434
  static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
413
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
435
+ let class: RClass = ruby
436
+ .get_inner(&PRE_TOKENIZERS)
437
+ .const_get("CharDelimiterSplit")
438
+ .unwrap();
414
439
  class.undef_default_alloc_func();
415
440
  class
416
441
  });
@@ -420,12 +445,18 @@ unsafe impl TypedData for RbPreTokenizer {
420
445
  class
421
446
  });
422
447
  static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
423
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
448
+ let class: RClass = ruby
449
+ .get_inner(&PRE_TOKENIZERS)
450
+ .const_get("Metaspace")
451
+ .unwrap();
424
452
  class.undef_default_alloc_func();
425
453
  class
426
454
  });
427
455
  static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
428
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
456
+ let class: RClass = ruby
457
+ .get_inner(&PRE_TOKENIZERS)
458
+ .const_get("Punctuation")
459
+ .unwrap();
429
460
  class.undef_default_alloc_func();
430
461
  class
431
462
  });
@@ -435,17 +466,26 @@ unsafe impl TypedData for RbPreTokenizer {
435
466
  class
436
467
  });
437
468
  static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
438
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
469
+ let class: RClass = ruby
470
+ .get_inner(&PRE_TOKENIZERS)
471
+ .const_get("UnicodeScripts")
472
+ .unwrap();
439
473
  class.undef_default_alloc_func();
440
474
  class
441
475
  });
442
476
  static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
443
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
477
+ let class: RClass = ruby
478
+ .get_inner(&PRE_TOKENIZERS)
479
+ .const_get("Whitespace")
480
+ .unwrap();
444
481
  class.undef_default_alloc_func();
445
482
  class
446
483
  });
447
484
  static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
448
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
485
+ let class: RClass = ruby
486
+ .get_inner(&PRE_TOKENIZERS)
487
+ .const_get("WhitespaceSplit")
488
+ .unwrap();
449
489
  class.undef_default_alloc_func();
450
490
  class
451
491
  });
@@ -472,7 +512,10 @@ unsafe impl TypedData for RbPreTokenizer {
472
512
 
473
513
  pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
474
514
  let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
475
- pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
515
+ pre_tokenizer.define_method(
516
+ "pre_tokenize_str",
517
+ method!(RbPreTokenizer::pre_tokenize_str, 1),
518
+ )?;
476
519
 
477
520
  let class = module.define_class("Sequence", pre_tokenizer)?;
478
521
  class.define_singleton_method("new", function!(RbSequence::new, 1))?;
@@ -483,27 +526,63 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
483
526
  let class = module.define_class("ByteLevel", pre_tokenizer)?;
484
527
  class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
485
528
  class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
486
- class.define_method("add_prefix_space", method!(RbPreTokenizer::byte_level_add_prefix_space, 0))?;
487
- class.define_method("add_prefix_space=", method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1))?;
488
- class.define_method("use_regex", method!(RbPreTokenizer::byte_level_use_regex, 0))?;
489
- class.define_method("use_regex=", method!(RbPreTokenizer::byte_level_set_use_regex, 1))?;
529
+ class.define_method(
530
+ "add_prefix_space",
531
+ method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
532
+ )?;
533
+ class.define_method(
534
+ "add_prefix_space=",
535
+ method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1),
536
+ )?;
537
+ class.define_method(
538
+ "use_regex",
539
+ method!(RbPreTokenizer::byte_level_use_regex, 0),
540
+ )?;
541
+ class.define_method(
542
+ "use_regex=",
543
+ method!(RbPreTokenizer::byte_level_set_use_regex, 1),
544
+ )?;
490
545
 
491
546
  let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
492
547
  class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
493
- class.define_method("delimiter", method!(RbPreTokenizer::char_delimiter_split_delimiter, 0))?;
494
- class.define_method("delimiter=", method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1))?;
548
+ class.define_method(
549
+ "delimiter",
550
+ method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
551
+ )?;
552
+ class.define_method(
553
+ "delimiter=",
554
+ method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1),
555
+ )?;
495
556
 
496
557
  let class = module.define_class("Digits", pre_tokenizer)?;
497
558
  class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
498
- class.define_method("individual_digits", method!(RbPreTokenizer::digits_individual_digits, 0))?;
499
- class.define_method("individual_digits=", method!(RbPreTokenizer::digits_set_individual_digits, 1))?;
559
+ class.define_method(
560
+ "individual_digits",
561
+ method!(RbPreTokenizer::digits_individual_digits, 0),
562
+ )?;
563
+ class.define_method(
564
+ "individual_digits=",
565
+ method!(RbPreTokenizer::digits_set_individual_digits, 1),
566
+ )?;
500
567
 
501
568
  let class = module.define_class("Metaspace", pre_tokenizer)?;
502
569
  class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
503
- class.define_method("prepend_scheme", method!(RbPreTokenizer::metaspace_prepend_scheme, 0))?;
504
- class.define_method("prepend_scheme=", method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1))?;
505
- class.define_method("replacement", method!(RbPreTokenizer::metaspace_replacement, 0))?;
506
- class.define_method("replacement=", method!(RbPreTokenizer::metaspace_set_replacement, 1))?;
570
+ class.define_method(
571
+ "prepend_scheme",
572
+ method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
573
+ )?;
574
+ class.define_method(
575
+ "prepend_scheme=",
576
+ method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1),
577
+ )?;
578
+ class.define_method(
579
+ "replacement",
580
+ method!(RbPreTokenizer::metaspace_replacement, 0),
581
+ )?;
582
+ class.define_method(
583
+ "replacement=",
584
+ method!(RbPreTokenizer::metaspace_set_replacement, 1),
585
+ )?;
507
586
  class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
508
587
  class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
509
588
 
@@ -1,8 +1,8 @@
1
1
  use std::sync::Arc;
2
2
 
3
3
  use magnus::{
4
- data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
5
- Ruby, TryConvert, TypedData, Value,
4
+ data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
+ RClass, RModule, Ruby, TryConvert, TypedData, Value,
6
6
  };
7
7
  use serde::{Deserialize, Serialize};
8
8
  use tk::processors::bert::BertProcessing;
@@ -12,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
12
12
  use tk::processors::PostProcessorWrapper;
13
13
  use tk::{Encoding, PostProcessor};
14
14
 
15
- use super::{PROCESSORS, RbResult};
15
+ use super::{RbResult, PROCESSORS};
16
16
 
17
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
18
18
  pub struct RbPostProcessor {
@@ -106,7 +106,6 @@ impl RbByteLevel {
106
106
  }
107
107
  RbPostProcessor::new(Arc::new(byte_level.into()))
108
108
  }
109
-
110
109
  }
111
110
 
112
111
  pub struct RbRobertaProcessing {}
@@ -117,7 +116,7 @@ impl RbRobertaProcessing {
117
116
  cls: (String, u32),
118
117
  trim_offsets: bool,
119
118
  add_prefix_space: bool,
120
- ) -> RbPostProcessor {
119
+ ) -> RbPostProcessor {
121
120
  let proc = RobertaProcessing::new(sep, cls)
122
121
  .trim_offsets(trim_offsets)
123
122
  .add_prefix_space(add_prefix_space);
@@ -153,7 +152,10 @@ impl RbTemplateProcessing {
153
152
  unsafe impl TypedData for RbPostProcessor {
154
153
  fn class(ruby: &Ruby) -> RClass {
155
154
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
156
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
155
+ let class: RClass = ruby
156
+ .get_inner(&PROCESSORS)
157
+ .const_get("PostProcessor")
158
+ .unwrap();
157
159
  class.undef_default_alloc_func();
158
160
  class
159
161
  });
@@ -161,13 +163,17 @@ unsafe impl TypedData for RbPostProcessor {
161
163
  }
162
164
 
163
165
  fn data_type() -> &'static DataType {
164
- static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
166
+ static DATA_TYPE: DataType =
167
+ data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
165
168
  &DATA_TYPE
166
169
  }
167
170
 
168
171
  fn class_for(ruby: &Ruby, value: &Self) -> RClass {
169
172
  static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
170
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
173
+ let class: RClass = ruby
174
+ .get_inner(&PROCESSORS)
175
+ .const_get("BertProcessing")
176
+ .unwrap();
171
177
  class.undef_default_alloc_func();
172
178
  class
173
179
  });
@@ -177,12 +183,18 @@ unsafe impl TypedData for RbPostProcessor {
177
183
  class
178
184
  });
179
185
  static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
180
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
186
+ let class: RClass = ruby
187
+ .get_inner(&PROCESSORS)
188
+ .const_get("RobertaProcessing")
189
+ .unwrap();
181
190
  class.undef_default_alloc_func();
182
191
  class
183
192
  });
184
193
  static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
185
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
194
+ let class: RClass = ruby
195
+ .get_inner(&PROCESSORS)
196
+ .const_get("TemplateProcessing")
197
+ .unwrap();
186
198
  class.undef_default_alloc_func();
187
199
  class
188
200
  });
@@ -6,8 +6,8 @@ use std::str::FromStr;
6
6
  use magnus::prelude::*;
7
7
  use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
8
8
  use tk::tokenizer::{
9
- Model, PaddingDirection, PaddingParams, PaddingStrategy,
10
- TruncationDirection, TruncationParams, TruncationStrategy, TokenizerImpl
9
+ Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
10
+ TruncationParams, TruncationStrategy,
11
11
  };
12
12
  use tk::AddedToken;
13
13
 
@@ -284,7 +284,10 @@ impl RbTokenizer {
284
284
  }
285
285
 
286
286
  pub fn to_str(&self, pretty: bool) -> RbResult<String> {
287
- self.tokenizer.borrow().to_string(pretty).map_err(RbError::from)
287
+ self.tokenizer
288
+ .borrow()
289
+ .to_string(pretty)
290
+ .map_err(RbError::from)
288
291
  }
289
292
 
290
293
  pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
@@ -383,7 +386,11 @@ impl RbTokenizer {
383
386
  .map_err(RbError::from)
384
387
  }
385
388
 
386
- pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
389
+ pub fn decode_batch(
390
+ &self,
391
+ sequences: Vec<Vec<u32>>,
392
+ skip_special_tokens: bool,
393
+ ) -> RbResult<Vec<String>> {
387
394
  let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
388
395
  self.tokenizer
389
396
  .borrow()
@@ -455,7 +462,12 @@ impl RbTokenizer {
455
462
  params.direction = match dir_str.as_str() {
456
463
  "left" => PaddingDirection::Left,
457
464
  "right" => PaddingDirection::Right,
458
- _ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
465
+ _ => {
466
+ return Err(Error::new(
467
+ exception::arg_error(),
468
+ "The direction value must be 'left' or 'right'",
469
+ ))
470
+ }
459
471
  }
460
472
  }
461
473
 
@@ -501,24 +513,27 @@ impl RbTokenizer {
501
513
  }
502
514
 
503
515
  pub fn padding(&self) -> RbResult<Option<RHash>> {
504
- self.tokenizer.borrow().get_padding().map_or(Ok(None), |params| {
505
- let ret_hash = RHash::new();
506
-
507
- ret_hash.aset(
508
- "length",
509
- match params.strategy {
510
- tk::PaddingStrategy::BatchLongest => None,
511
- tk::PaddingStrategy::Fixed(size) => Some(size),
512
- },
513
- )?;
514
- ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
515
- ret_hash.aset("pad_id", params.pad_id)?;
516
- ret_hash.aset("pad_token", &*params.pad_token)?;
517
- ret_hash.aset("pad_type_id", params.pad_type_id)?;
518
- ret_hash.aset("direction", params.direction.as_ref())?;
519
-
520
- Ok(Some(ret_hash))
521
- })
516
+ self.tokenizer
517
+ .borrow()
518
+ .get_padding()
519
+ .map_or(Ok(None), |params| {
520
+ let ret_hash = RHash::new();
521
+
522
+ ret_hash.aset(
523
+ "length",
524
+ match params.strategy {
525
+ tk::PaddingStrategy::BatchLongest => None,
526
+ tk::PaddingStrategy::Fixed(size) => Some(size),
527
+ },
528
+ )?;
529
+ ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
530
+ ret_hash.aset("pad_id", params.pad_id)?;
531
+ ret_hash.aset("pad_token", &*params.pad_token)?;
532
+ ret_hash.aset("pad_type_id", params.pad_type_id)?;
533
+ ret_hash.aset("direction", params.direction.as_ref())?;
534
+
535
+ Ok(Some(ret_hash))
536
+ })
522
537
  }
523
538
 
524
539
  pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
@@ -539,7 +554,10 @@ impl RbTokenizer {
539
554
  "longest_first" => TruncationStrategy::LongestFirst,
540
555
  "only_first" => TruncationStrategy::OnlyFirst,
541
556
  "only_second" => TruncationStrategy::OnlySecond,
542
- _ => return Err(Error::new(exception::arg_error(), "The strategy value must be 'longest_first', 'only_first', or 'only_second'")),
557
+ _ => return Err(Error::new(
558
+ exception::arg_error(),
559
+ "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
560
+ )),
543
561
  }
544
562
  }
545
563
 
@@ -549,7 +567,12 @@ impl RbTokenizer {
549
567
  params.direction = match dir_str.as_str() {
550
568
  "left" => TruncationDirection::Left,
551
569
  "right" => TruncationDirection::Right,
552
- _ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
570
+ _ => {
571
+ return Err(Error::new(
572
+ exception::arg_error(),
573
+ "The direction value must be 'left' or 'right'",
574
+ ))
575
+ }
553
576
  }
554
577
  }
555
578
 
@@ -559,7 +582,10 @@ impl RbTokenizer {
559
582
  }
560
583
 
561
584
  if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
562
- return Err(Error::new(exception::arg_error(), error_message.to_string()));
585
+ return Err(Error::new(
586
+ exception::arg_error(),
587
+ error_message.to_string(),
588
+ ));
563
589
  }
564
590
 
565
591
  Ok(())
@@ -573,16 +599,19 @@ impl RbTokenizer {
573
599
  }
574
600
 
575
601
  pub fn truncation(&self) -> RbResult<Option<RHash>> {
576
- self.tokenizer.borrow().get_truncation().map_or(Ok(None), |params| {
577
- let ret_hash = RHash::new();
602
+ self.tokenizer
603
+ .borrow()
604
+ .get_truncation()
605
+ .map_or(Ok(None), |params| {
606
+ let ret_hash = RHash::new();
578
607
 
579
- ret_hash.aset("max_length", params.max_length)?;
580
- ret_hash.aset("stride", params.stride)?;
581
- ret_hash.aset("strategy", params.strategy.as_ref())?;
582
- ret_hash.aset("direction", params.direction.as_ref())?;
608
+ ret_hash.aset("max_length", params.max_length)?;
609
+ ret_hash.aset("stride", params.stride)?;
610
+ ret_hash.aset("strategy", params.strategy.as_ref())?;
611
+ ret_hash.aset("direction", params.direction.as_ref())?;
583
612
 
584
- Ok(Some(ret_hash))
585
- })
613
+ Ok(Some(ret_hash))
614
+ })
586
615
  }
587
616
 
588
617
  pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {