tokenizers 0.5.2 → 0.5.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,8 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
3
  use magnus::{
4
- data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
5
- RArray, RClass, RModule, Ruby, TryConvert, TypedData,
4
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType,
5
+ DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
6
6
  };
7
7
 
8
8
  use serde::ser::SerializeStruct;
@@ -22,7 +22,7 @@ use tk::tokenizer::Offsets;
22
22
  use tk::{PreTokenizedString, PreTokenizer};
23
23
 
24
24
  use super::utils::*;
25
- use super::{PRE_TOKENIZERS, RbError, RbResult};
25
+ use super::{RbError, RbResult, PRE_TOKENIZERS};
26
26
 
27
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
28
28
  pub struct RbPreTokenizer {
@@ -34,7 +34,9 @@ impl RbPreTokenizer {
34
34
  fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
35
35
  let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
36
36
 
37
- self.pretok.pre_tokenize(&mut pretokenized).map_err(RbError::from)?;
37
+ self.pretok
38
+ .pre_tokenize(&mut pretokenized)
39
+ .map_err(RbError::from)?;
38
40
 
39
41
  Ok(pretokenized
40
42
  .get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
@@ -195,11 +197,7 @@ impl RbDigits {
195
197
  pub struct RbMetaspace {}
196
198
 
197
199
  impl RbMetaspace {
198
- fn new(
199
- replacement: char,
200
- prepend_scheme: String,
201
- split: bool,
202
- ) -> RbResult<RbPreTokenizer> {
200
+ fn new(replacement: char, prepend_scheme: String, split: bool) -> RbResult<RbPreTokenizer> {
203
201
  let prepend_scheme = from_string(prepend_scheme)?;
204
202
  Ok(Metaspace::new(replacement, prepend_scheme, split).into())
205
203
  }
@@ -216,8 +214,14 @@ impl RbPunctuation {
216
214
  pub struct RbSplit {}
217
215
 
218
216
  impl RbSplit {
219
- pub fn new(pattern: RbPattern, behavior: RbSplitDelimiterBehavior, invert: bool) -> RbResult<RbPreTokenizer> {
220
- Split::new(pattern, behavior.into(), invert).map(|v| v.into()).map_err(RbError::from)
217
+ pub fn new(
218
+ pattern: RbPattern,
219
+ behavior: RbSplitDelimiterBehavior,
220
+ invert: bool,
221
+ ) -> RbResult<RbPreTokenizer> {
222
+ Split::new(pattern, behavior.into(), invert)
223
+ .map(|v| v.into())
224
+ .map_err(RbError::from)
221
225
  }
222
226
  }
223
227
 
@@ -258,16 +262,18 @@ pub struct RbSequence {}
258
262
  impl RbSequence {
259
263
  fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
260
264
  let mut sequence = Vec::with_capacity(pre_tokenizers.len());
261
- for n in pre_tokenizers.into_iter() {
265
+ for n in pre_tokenizers {
262
266
  let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n)?;
263
267
  match &pretokenizer.pretok {
264
268
  RbPreTokenizerTypeWrapper::Sequence(inner) => {
265
- sequence.extend(inner.iter().cloned())
269
+ sequence.extend(inner.iter().cloned());
266
270
  }
267
271
  RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
268
272
  }
269
273
  }
270
- Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(sequence)))
274
+ Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
275
+ sequence,
276
+ )))
271
277
  }
272
278
  }
273
279
 
@@ -277,10 +283,13 @@ pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
277
283
  "never" => PrependScheme::Never,
278
284
  "always" => PrependScheme::Always,
279
285
  _ => {
280
- return Err(Error::new(exception::arg_error(), format!(
281
- "{} is an unknown variant, should be one of ['first', 'never', 'always']",
282
- string
283
- )));
286
+ return Err(Error::new(
287
+ exception::arg_error(),
288
+ format!(
289
+ "{} is an unknown variant, should be one of ['first', 'never', 'always']",
290
+ string
291
+ ),
292
+ ));
284
293
  }
285
294
  };
286
295
  Ok(scheme)
@@ -381,7 +390,10 @@ impl PreTokenizer for RbPreTokenizerWrapper {
381
390
  unsafe impl TypedData for RbPreTokenizer {
382
391
  fn class(ruby: &Ruby) -> RClass {
383
392
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
384
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
393
+ let class: RClass = ruby
394
+ .get_inner(&PRE_TOKENIZERS)
395
+ .const_get("PreTokenizer")
396
+ .unwrap();
385
397
  class.undef_default_alloc_func();
386
398
  class
387
399
  });
@@ -389,28 +401,41 @@ unsafe impl TypedData for RbPreTokenizer {
389
401
  }
390
402
 
391
403
  fn data_type() -> &'static DataType {
392
- static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
404
+ static DATA_TYPE: DataType =
405
+ data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
393
406
  &DATA_TYPE
394
407
  }
395
408
 
396
409
  fn class_for(ruby: &Ruby, value: &Self) -> RClass {
397
410
  static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
398
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
411
+ let class: RClass = ruby
412
+ .get_inner(&PRE_TOKENIZERS)
413
+ .const_get("Sequence")
414
+ .unwrap();
399
415
  class.undef_default_alloc_func();
400
416
  class
401
417
  });
402
418
  static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
403
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
419
+ let class: RClass = ruby
420
+ .get_inner(&PRE_TOKENIZERS)
421
+ .const_get("BertPreTokenizer")
422
+ .unwrap();
404
423
  class.undef_default_alloc_func();
405
424
  class
406
425
  });
407
426
  static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
408
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
427
+ let class: RClass = ruby
428
+ .get_inner(&PRE_TOKENIZERS)
429
+ .const_get("ByteLevel")
430
+ .unwrap();
409
431
  class.undef_default_alloc_func();
410
432
  class
411
433
  });
412
434
  static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
413
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
435
+ let class: RClass = ruby
436
+ .get_inner(&PRE_TOKENIZERS)
437
+ .const_get("CharDelimiterSplit")
438
+ .unwrap();
414
439
  class.undef_default_alloc_func();
415
440
  class
416
441
  });
@@ -420,12 +445,18 @@ unsafe impl TypedData for RbPreTokenizer {
420
445
  class
421
446
  });
422
447
  static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
423
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
448
+ let class: RClass = ruby
449
+ .get_inner(&PRE_TOKENIZERS)
450
+ .const_get("Metaspace")
451
+ .unwrap();
424
452
  class.undef_default_alloc_func();
425
453
  class
426
454
  });
427
455
  static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
428
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
456
+ let class: RClass = ruby
457
+ .get_inner(&PRE_TOKENIZERS)
458
+ .const_get("Punctuation")
459
+ .unwrap();
429
460
  class.undef_default_alloc_func();
430
461
  class
431
462
  });
@@ -435,17 +466,26 @@ unsafe impl TypedData for RbPreTokenizer {
435
466
  class
436
467
  });
437
468
  static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
438
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
469
+ let class: RClass = ruby
470
+ .get_inner(&PRE_TOKENIZERS)
471
+ .const_get("UnicodeScripts")
472
+ .unwrap();
439
473
  class.undef_default_alloc_func();
440
474
  class
441
475
  });
442
476
  static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
443
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
477
+ let class: RClass = ruby
478
+ .get_inner(&PRE_TOKENIZERS)
479
+ .const_get("Whitespace")
480
+ .unwrap();
444
481
  class.undef_default_alloc_func();
445
482
  class
446
483
  });
447
484
  static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
448
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
485
+ let class: RClass = ruby
486
+ .get_inner(&PRE_TOKENIZERS)
487
+ .const_get("WhitespaceSplit")
488
+ .unwrap();
449
489
  class.undef_default_alloc_func();
450
490
  class
451
491
  });
@@ -472,7 +512,10 @@ unsafe impl TypedData for RbPreTokenizer {
472
512
 
473
513
  pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
474
514
  let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
475
- pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
515
+ pre_tokenizer.define_method(
516
+ "pre_tokenize_str",
517
+ method!(RbPreTokenizer::pre_tokenize_str, 1),
518
+ )?;
476
519
 
477
520
  let class = module.define_class("Sequence", pre_tokenizer)?;
478
521
  class.define_singleton_method("new", function!(RbSequence::new, 1))?;
@@ -483,27 +526,63 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
483
526
  let class = module.define_class("ByteLevel", pre_tokenizer)?;
484
527
  class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
485
528
  class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
486
- class.define_method("add_prefix_space", method!(RbPreTokenizer::byte_level_add_prefix_space, 0))?;
487
- class.define_method("add_prefix_space=", method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1))?;
488
- class.define_method("use_regex", method!(RbPreTokenizer::byte_level_use_regex, 0))?;
489
- class.define_method("use_regex=", method!(RbPreTokenizer::byte_level_set_use_regex, 1))?;
529
+ class.define_method(
530
+ "add_prefix_space",
531
+ method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
532
+ )?;
533
+ class.define_method(
534
+ "add_prefix_space=",
535
+ method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1),
536
+ )?;
537
+ class.define_method(
538
+ "use_regex",
539
+ method!(RbPreTokenizer::byte_level_use_regex, 0),
540
+ )?;
541
+ class.define_method(
542
+ "use_regex=",
543
+ method!(RbPreTokenizer::byte_level_set_use_regex, 1),
544
+ )?;
490
545
 
491
546
  let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
492
547
  class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
493
- class.define_method("delimiter", method!(RbPreTokenizer::char_delimiter_split_delimiter, 0))?;
494
- class.define_method("delimiter=", method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1))?;
548
+ class.define_method(
549
+ "delimiter",
550
+ method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
551
+ )?;
552
+ class.define_method(
553
+ "delimiter=",
554
+ method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1),
555
+ )?;
495
556
 
496
557
  let class = module.define_class("Digits", pre_tokenizer)?;
497
558
  class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
498
- class.define_method("individual_digits", method!(RbPreTokenizer::digits_individual_digits, 0))?;
499
- class.define_method("individual_digits=", method!(RbPreTokenizer::digits_set_individual_digits, 1))?;
559
+ class.define_method(
560
+ "individual_digits",
561
+ method!(RbPreTokenizer::digits_individual_digits, 0),
562
+ )?;
563
+ class.define_method(
564
+ "individual_digits=",
565
+ method!(RbPreTokenizer::digits_set_individual_digits, 1),
566
+ )?;
500
567
 
501
568
  let class = module.define_class("Metaspace", pre_tokenizer)?;
502
569
  class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
503
- class.define_method("prepend_scheme", method!(RbPreTokenizer::metaspace_prepend_scheme, 0))?;
504
- class.define_method("prepend_scheme=", method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1))?;
505
- class.define_method("replacement", method!(RbPreTokenizer::metaspace_replacement, 0))?;
506
- class.define_method("replacement=", method!(RbPreTokenizer::metaspace_set_replacement, 1))?;
570
+ class.define_method(
571
+ "prepend_scheme",
572
+ method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
573
+ )?;
574
+ class.define_method(
575
+ "prepend_scheme=",
576
+ method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1),
577
+ )?;
578
+ class.define_method(
579
+ "replacement",
580
+ method!(RbPreTokenizer::metaspace_replacement, 0),
581
+ )?;
582
+ class.define_method(
583
+ "replacement=",
584
+ method!(RbPreTokenizer::metaspace_set_replacement, 1),
585
+ )?;
507
586
  class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
508
587
  class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
509
588
 
@@ -1,8 +1,8 @@
1
1
  use std::sync::Arc;
2
2
 
3
3
  use magnus::{
4
- data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
5
- Ruby, TryConvert, TypedData, Value,
4
+ data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
+ RClass, RModule, Ruby, TryConvert, TypedData, Value,
6
6
  };
7
7
  use serde::{Deserialize, Serialize};
8
8
  use tk::processors::bert::BertProcessing;
@@ -12,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
12
12
  use tk::processors::PostProcessorWrapper;
13
13
  use tk::{Encoding, PostProcessor};
14
14
 
15
- use super::{PROCESSORS, RbResult};
15
+ use super::{RbResult, PROCESSORS};
16
16
 
17
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
18
18
  pub struct RbPostProcessor {
@@ -106,7 +106,6 @@ impl RbByteLevel {
106
106
  }
107
107
  RbPostProcessor::new(Arc::new(byte_level.into()))
108
108
  }
109
-
110
109
  }
111
110
 
112
111
  pub struct RbRobertaProcessing {}
@@ -117,7 +116,7 @@ impl RbRobertaProcessing {
117
116
  cls: (String, u32),
118
117
  trim_offsets: bool,
119
118
  add_prefix_space: bool,
120
- ) -> RbPostProcessor {
119
+ ) -> RbPostProcessor {
121
120
  let proc = RobertaProcessing::new(sep, cls)
122
121
  .trim_offsets(trim_offsets)
123
122
  .add_prefix_space(add_prefix_space);
@@ -153,7 +152,10 @@ impl RbTemplateProcessing {
153
152
  unsafe impl TypedData for RbPostProcessor {
154
153
  fn class(ruby: &Ruby) -> RClass {
155
154
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
156
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
155
+ let class: RClass = ruby
156
+ .get_inner(&PROCESSORS)
157
+ .const_get("PostProcessor")
158
+ .unwrap();
157
159
  class.undef_default_alloc_func();
158
160
  class
159
161
  });
@@ -161,13 +163,17 @@ unsafe impl TypedData for RbPostProcessor {
161
163
  }
162
164
 
163
165
  fn data_type() -> &'static DataType {
164
- static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
166
+ static DATA_TYPE: DataType =
167
+ data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
165
168
  &DATA_TYPE
166
169
  }
167
170
 
168
171
  fn class_for(ruby: &Ruby, value: &Self) -> RClass {
169
172
  static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
170
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
173
+ let class: RClass = ruby
174
+ .get_inner(&PROCESSORS)
175
+ .const_get("BertProcessing")
176
+ .unwrap();
171
177
  class.undef_default_alloc_func();
172
178
  class
173
179
  });
@@ -177,12 +183,18 @@ unsafe impl TypedData for RbPostProcessor {
177
183
  class
178
184
  });
179
185
  static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
180
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
186
+ let class: RClass = ruby
187
+ .get_inner(&PROCESSORS)
188
+ .const_get("RobertaProcessing")
189
+ .unwrap();
181
190
  class.undef_default_alloc_func();
182
191
  class
183
192
  });
184
193
  static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
185
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
194
+ let class: RClass = ruby
195
+ .get_inner(&PROCESSORS)
196
+ .const_get("TemplateProcessing")
197
+ .unwrap();
186
198
  class.undef_default_alloc_func();
187
199
  class
188
200
  });
@@ -6,8 +6,8 @@ use std::str::FromStr;
6
6
  use magnus::prelude::*;
7
7
  use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
8
8
  use tk::tokenizer::{
9
- Model, PaddingDirection, PaddingParams, PaddingStrategy,
10
- TruncationDirection, TruncationParams, TruncationStrategy, TokenizerImpl
9
+ Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
10
+ TruncationParams, TruncationStrategy,
11
11
  };
12
12
  use tk::AddedToken;
13
13
 
@@ -22,9 +22,10 @@ use super::processors::RbPostProcessor;
22
22
  use super::trainers::RbTrainer;
23
23
  use super::{RbError, RbResult};
24
24
 
25
+ #[magnus::wrap(class = "Tokenizers::AddedToken")]
25
26
  pub struct RbAddedToken {
26
27
  pub content: String,
27
- pub is_special_token: bool,
28
+ pub special: bool,
28
29
  pub single_word: Option<bool>,
29
30
  pub lstrip: Option<bool>,
30
31
  pub rstrip: Option<bool>,
@@ -32,10 +33,10 @@ pub struct RbAddedToken {
32
33
  }
33
34
 
34
35
  impl RbAddedToken {
35
- pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
36
+ pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
36
37
  Self {
37
38
  content: content.into(),
38
- is_special_token: is_special_token.unwrap_or(false),
39
+ special: special.unwrap_or(false),
39
40
  single_word: None,
40
41
  lstrip: None,
41
42
  rstrip: None,
@@ -44,7 +45,7 @@ impl RbAddedToken {
44
45
  }
45
46
 
46
47
  pub fn get_token(&self) -> tk::tokenizer::AddedToken {
47
- let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
48
+ let mut token = tk::AddedToken::from(&self.content, self.special);
48
49
 
49
50
  if let Some(sw) = self.single_word {
50
51
  token = token.single_word(sw);
@@ -71,11 +72,73 @@ impl From<tk::AddedToken> for RbAddedToken {
71
72
  lstrip: Some(token.lstrip),
72
73
  rstrip: Some(token.rstrip),
73
74
  normalized: Some(token.normalized),
74
- is_special_token: !token.normalized,
75
+ special: !token.normalized,
75
76
  }
76
77
  }
77
78
  }
78
79
 
80
+ impl RbAddedToken {
81
+ pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
82
+ let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
83
+
84
+ let value: Value = kwargs.delete(Symbol::new("single_word"))?;
85
+ if !value.is_nil() {
86
+ token.single_word = TryConvert::try_convert(value)?;
87
+ }
88
+
89
+ let value: Value = kwargs.delete(Symbol::new("lstrip"))?;
90
+ if !value.is_nil() {
91
+ token.lstrip = TryConvert::try_convert(value)?;
92
+ }
93
+
94
+ let value: Value = kwargs.delete(Symbol::new("rstrip"))?;
95
+ if !value.is_nil() {
96
+ token.rstrip = TryConvert::try_convert(value)?;
97
+ }
98
+
99
+ let value: Value = kwargs.delete(Symbol::new("normalized"))?;
100
+ if !value.is_nil() {
101
+ token.normalized = TryConvert::try_convert(value)?;
102
+ }
103
+
104
+ let value: Value = kwargs.delete(Symbol::new("special"))?;
105
+ if !value.is_nil() {
106
+ token.special = TryConvert::try_convert(value)?;
107
+ }
108
+
109
+ if !kwargs.is_empty() {
110
+ // TODO improve message
111
+ return Err(Error::new(exception::arg_error(), "unknown keyword"));
112
+ }
113
+
114
+ Ok(token)
115
+ }
116
+
117
+ pub fn get_content(&self) -> String {
118
+ self.content.to_string()
119
+ }
120
+
121
+ pub fn get_rstrip(&self) -> bool {
122
+ self.get_token().rstrip
123
+ }
124
+
125
+ pub fn get_lstrip(&self) -> bool {
126
+ self.get_token().lstrip
127
+ }
128
+
129
+ pub fn get_single_word(&self) -> bool {
130
+ self.get_token().single_word
131
+ }
132
+
133
+ pub fn get_normalized(&self) -> bool {
134
+ self.get_token().normalized
135
+ }
136
+
137
+ pub fn get_special(&self) -> bool {
138
+ self.get_token().special
139
+ }
140
+ }
141
+
79
142
  struct TextInputSequence<'s>(tk::InputSequence<'s>);
80
143
 
81
144
  impl<'s> TryConvert for TextInputSequence<'s> {
@@ -221,7 +284,10 @@ impl RbTokenizer {
221
284
  }
222
285
 
223
286
  pub fn to_str(&self, pretty: bool) -> RbResult<String> {
224
- self.tokenizer.borrow().to_string(pretty).map_err(RbError::from)
287
+ self.tokenizer
288
+ .borrow()
289
+ .to_string(pretty)
290
+ .map_err(RbError::from)
225
291
  }
226
292
 
227
293
  pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
@@ -320,7 +386,11 @@ impl RbTokenizer {
320
386
  .map_err(RbError::from)
321
387
  }
322
388
 
323
- pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
389
+ pub fn decode_batch(
390
+ &self,
391
+ sequences: Vec<Vec<u32>>,
392
+ skip_special_tokens: bool,
393
+ ) -> RbResult<Vec<String>> {
324
394
  let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
325
395
  self.tokenizer
326
396
  .borrow()
@@ -392,7 +462,12 @@ impl RbTokenizer {
392
462
  params.direction = match dir_str.as_str() {
393
463
  "left" => PaddingDirection::Left,
394
464
  "right" => PaddingDirection::Right,
395
- _ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
465
+ _ => {
466
+ return Err(Error::new(
467
+ exception::arg_error(),
468
+ "The direction value must be 'left' or 'right'",
469
+ ))
470
+ }
396
471
  }
397
472
  }
398
473
 
@@ -438,24 +513,27 @@ impl RbTokenizer {
438
513
  }
439
514
 
440
515
  pub fn padding(&self) -> RbResult<Option<RHash>> {
441
- self.tokenizer.borrow().get_padding().map_or(Ok(None), |params| {
442
- let ret_hash = RHash::new();
443
-
444
- ret_hash.aset(
445
- "length",
446
- match params.strategy {
447
- tk::PaddingStrategy::BatchLongest => None,
448
- tk::PaddingStrategy::Fixed(size) => Some(size),
449
- },
450
- )?;
451
- ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
452
- ret_hash.aset("pad_id", params.pad_id)?;
453
- ret_hash.aset("pad_token", &*params.pad_token)?;
454
- ret_hash.aset("pad_type_id", params.pad_type_id)?;
455
- ret_hash.aset("direction", params.direction.as_ref())?;
456
-
457
- Ok(Some(ret_hash))
458
- })
516
+ self.tokenizer
517
+ .borrow()
518
+ .get_padding()
519
+ .map_or(Ok(None), |params| {
520
+ let ret_hash = RHash::new();
521
+
522
+ ret_hash.aset(
523
+ "length",
524
+ match params.strategy {
525
+ tk::PaddingStrategy::BatchLongest => None,
526
+ tk::PaddingStrategy::Fixed(size) => Some(size),
527
+ },
528
+ )?;
529
+ ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
530
+ ret_hash.aset("pad_id", params.pad_id)?;
531
+ ret_hash.aset("pad_token", &*params.pad_token)?;
532
+ ret_hash.aset("pad_type_id", params.pad_type_id)?;
533
+ ret_hash.aset("direction", params.direction.as_ref())?;
534
+
535
+ Ok(Some(ret_hash))
536
+ })
459
537
  }
460
538
 
461
539
  pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
@@ -476,7 +554,10 @@ impl RbTokenizer {
476
554
  "longest_first" => TruncationStrategy::LongestFirst,
477
555
  "only_first" => TruncationStrategy::OnlyFirst,
478
556
  "only_second" => TruncationStrategy::OnlySecond,
479
- _ => return Err(Error::new(exception::arg_error(), "The strategy value must be 'longest_first', 'only_first', or 'only_second'")),
557
+ _ => return Err(Error::new(
558
+ exception::arg_error(),
559
+ "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
560
+ )),
480
561
  }
481
562
  }
482
563
 
@@ -486,7 +567,12 @@ impl RbTokenizer {
486
567
  params.direction = match dir_str.as_str() {
487
568
  "left" => TruncationDirection::Left,
488
569
  "right" => TruncationDirection::Right,
489
- _ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
570
+ _ => {
571
+ return Err(Error::new(
572
+ exception::arg_error(),
573
+ "The direction value must be 'left' or 'right'",
574
+ ))
575
+ }
490
576
  }
491
577
  }
492
578
 
@@ -496,7 +582,10 @@ impl RbTokenizer {
496
582
  }
497
583
 
498
584
  if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
499
- return Err(Error::new(exception::arg_error(), error_message.to_string()));
585
+ return Err(Error::new(
586
+ exception::arg_error(),
587
+ error_message.to_string(),
588
+ ));
500
589
  }
501
590
 
502
591
  Ok(())
@@ -510,16 +599,19 @@ impl RbTokenizer {
510
599
  }
511
600
 
512
601
  pub fn truncation(&self) -> RbResult<Option<RHash>> {
513
- self.tokenizer.borrow().get_truncation().map_or(Ok(None), |params| {
514
- let ret_hash = RHash::new();
602
+ self.tokenizer
603
+ .borrow()
604
+ .get_truncation()
605
+ .map_or(Ok(None), |params| {
606
+ let ret_hash = RHash::new();
515
607
 
516
- ret_hash.aset("max_length", params.max_length)?;
517
- ret_hash.aset("stride", params.stride)?;
518
- ret_hash.aset("strategy", params.strategy.as_ref())?;
519
- ret_hash.aset("direction", params.direction.as_ref())?;
608
+ ret_hash.aset("max_length", params.max_length)?;
609
+ ret_hash.aset("stride", params.stride)?;
610
+ ret_hash.aset("strategy", params.strategy.as_ref())?;
611
+ ret_hash.aset("direction", params.direction.as_ref())?;
520
612
 
521
- Ok(Some(ret_hash))
522
- })
613
+ Ok(Some(ret_hash))
614
+ })
523
615
  }
524
616
 
525
617
  pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
@@ -536,4 +628,14 @@ impl RbTokenizer {
536
628
  pub fn vocab_size(&self, with_added_tokens: bool) -> usize {
537
629
  self.tokenizer.borrow().get_vocab_size(with_added_tokens)
538
630
  }
631
+
632
+ pub fn get_added_tokens_decoder(&self) -> RbResult<RHash> {
633
+ let sorted_map = RHash::new();
634
+
635
+ for (key, value) in self.tokenizer.borrow().get_added_tokens_decoder() {
636
+ sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
637
+ }
638
+
639
+ Ok(sorted_map)
640
+ }
539
641
  }