tokenizers 0.5.2 → 0.5.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
3
  use magnus::{
4
- data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
5
- RArray, RClass, RModule, Ruby, TryConvert, TypedData,
4
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType,
5
+ DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
6
6
  };
7
7
 
8
8
  use serde::ser::SerializeStruct;
@@ -22,7 +22,7 @@ use tk::tokenizer::Offsets;
22
22
  use tk::{PreTokenizedString, PreTokenizer};
23
23
 
24
24
  use super::utils::*;
25
- use super::{PRE_TOKENIZERS, RbError, RbResult};
25
+ use super::{RbError, RbResult, PRE_TOKENIZERS};
26
26
 
27
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
28
28
  pub struct RbPreTokenizer {
@@ -34,7 +34,9 @@ impl RbPreTokenizer {
34
34
  fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
35
35
  let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
36
36
 
37
- self.pretok.pre_tokenize(&mut pretokenized).map_err(RbError::from)?;
37
+ self.pretok
38
+ .pre_tokenize(&mut pretokenized)
39
+ .map_err(RbError::from)?;
38
40
 
39
41
  Ok(pretokenized
40
42
  .get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
@@ -195,11 +197,7 @@ impl RbDigits {
195
197
  pub struct RbMetaspace {}
196
198
 
197
199
  impl RbMetaspace {
198
- fn new(
199
- replacement: char,
200
- prepend_scheme: String,
201
- split: bool,
202
- ) -> RbResult<RbPreTokenizer> {
200
+ fn new(replacement: char, prepend_scheme: String, split: bool) -> RbResult<RbPreTokenizer> {
203
201
  let prepend_scheme = from_string(prepend_scheme)?;
204
202
  Ok(Metaspace::new(replacement, prepend_scheme, split).into())
205
203
  }
@@ -216,8 +214,14 @@ impl RbPunctuation {
216
214
  pub struct RbSplit {}
217
215
 
218
216
  impl RbSplit {
219
- pub fn new(pattern: RbPattern, behavior: RbSplitDelimiterBehavior, invert: bool) -> RbResult<RbPreTokenizer> {
220
- Split::new(pattern, behavior.into(), invert).map(|v| v.into()).map_err(RbError::from)
217
+ pub fn new(
218
+ pattern: RbPattern,
219
+ behavior: RbSplitDelimiterBehavior,
220
+ invert: bool,
221
+ ) -> RbResult<RbPreTokenizer> {
222
+ Split::new(pattern, behavior.into(), invert)
223
+ .map(|v| v.into())
224
+ .map_err(RbError::from)
221
225
  }
222
226
  }
223
227
 
@@ -258,16 +262,18 @@ pub struct RbSequence {}
258
262
  impl RbSequence {
259
263
  fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
260
264
  let mut sequence = Vec::with_capacity(pre_tokenizers.len());
261
- for n in pre_tokenizers.into_iter() {
265
+ for n in pre_tokenizers {
262
266
  let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n)?;
263
267
  match &pretokenizer.pretok {
264
268
  RbPreTokenizerTypeWrapper::Sequence(inner) => {
265
- sequence.extend(inner.iter().cloned())
269
+ sequence.extend(inner.iter().cloned());
266
270
  }
267
271
  RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
268
272
  }
269
273
  }
270
- Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(sequence)))
274
+ Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
275
+ sequence,
276
+ )))
271
277
  }
272
278
  }
273
279
 
@@ -277,10 +283,13 @@ pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
277
283
  "never" => PrependScheme::Never,
278
284
  "always" => PrependScheme::Always,
279
285
  _ => {
280
- return Err(Error::new(exception::arg_error(), format!(
281
- "{} is an unknown variant, should be one of ['first', 'never', 'always']",
282
- string
283
- )));
286
+ return Err(Error::new(
287
+ exception::arg_error(),
288
+ format!(
289
+ "{} is an unknown variant, should be one of ['first', 'never', 'always']",
290
+ string
291
+ ),
292
+ ));
284
293
  }
285
294
  };
286
295
  Ok(scheme)
@@ -381,7 +390,10 @@ impl PreTokenizer for RbPreTokenizerWrapper {
381
390
  unsafe impl TypedData for RbPreTokenizer {
382
391
  fn class(ruby: &Ruby) -> RClass {
383
392
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
384
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
393
+ let class: RClass = ruby
394
+ .get_inner(&PRE_TOKENIZERS)
395
+ .const_get("PreTokenizer")
396
+ .unwrap();
385
397
  class.undef_default_alloc_func();
386
398
  class
387
399
  });
@@ -389,28 +401,41 @@ unsafe impl TypedData for RbPreTokenizer {
389
401
  }
390
402
 
391
403
  fn data_type() -> &'static DataType {
392
- static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
404
+ static DATA_TYPE: DataType =
405
+ data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
393
406
  &DATA_TYPE
394
407
  }
395
408
 
396
409
  fn class_for(ruby: &Ruby, value: &Self) -> RClass {
397
410
  static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
398
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
411
+ let class: RClass = ruby
412
+ .get_inner(&PRE_TOKENIZERS)
413
+ .const_get("Sequence")
414
+ .unwrap();
399
415
  class.undef_default_alloc_func();
400
416
  class
401
417
  });
402
418
  static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
403
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
419
+ let class: RClass = ruby
420
+ .get_inner(&PRE_TOKENIZERS)
421
+ .const_get("BertPreTokenizer")
422
+ .unwrap();
404
423
  class.undef_default_alloc_func();
405
424
  class
406
425
  });
407
426
  static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
408
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
427
+ let class: RClass = ruby
428
+ .get_inner(&PRE_TOKENIZERS)
429
+ .const_get("ByteLevel")
430
+ .unwrap();
409
431
  class.undef_default_alloc_func();
410
432
  class
411
433
  });
412
434
  static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
413
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
435
+ let class: RClass = ruby
436
+ .get_inner(&PRE_TOKENIZERS)
437
+ .const_get("CharDelimiterSplit")
438
+ .unwrap();
414
439
  class.undef_default_alloc_func();
415
440
  class
416
441
  });
@@ -420,12 +445,18 @@ unsafe impl TypedData for RbPreTokenizer {
420
445
  class
421
446
  });
422
447
  static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
423
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
448
+ let class: RClass = ruby
449
+ .get_inner(&PRE_TOKENIZERS)
450
+ .const_get("Metaspace")
451
+ .unwrap();
424
452
  class.undef_default_alloc_func();
425
453
  class
426
454
  });
427
455
  static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
428
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
456
+ let class: RClass = ruby
457
+ .get_inner(&PRE_TOKENIZERS)
458
+ .const_get("Punctuation")
459
+ .unwrap();
429
460
  class.undef_default_alloc_func();
430
461
  class
431
462
  });
@@ -435,17 +466,26 @@ unsafe impl TypedData for RbPreTokenizer {
435
466
  class
436
467
  });
437
468
  static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
438
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
469
+ let class: RClass = ruby
470
+ .get_inner(&PRE_TOKENIZERS)
471
+ .const_get("UnicodeScripts")
472
+ .unwrap();
439
473
  class.undef_default_alloc_func();
440
474
  class
441
475
  });
442
476
  static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
443
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
477
+ let class: RClass = ruby
478
+ .get_inner(&PRE_TOKENIZERS)
479
+ .const_get("Whitespace")
480
+ .unwrap();
444
481
  class.undef_default_alloc_func();
445
482
  class
446
483
  });
447
484
  static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
448
- let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
485
+ let class: RClass = ruby
486
+ .get_inner(&PRE_TOKENIZERS)
487
+ .const_get("WhitespaceSplit")
488
+ .unwrap();
449
489
  class.undef_default_alloc_func();
450
490
  class
451
491
  });
@@ -472,7 +512,10 @@ unsafe impl TypedData for RbPreTokenizer {
472
512
 
473
513
  pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
474
514
  let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
475
- pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
515
+ pre_tokenizer.define_method(
516
+ "pre_tokenize_str",
517
+ method!(RbPreTokenizer::pre_tokenize_str, 1),
518
+ )?;
476
519
 
477
520
  let class = module.define_class("Sequence", pre_tokenizer)?;
478
521
  class.define_singleton_method("new", function!(RbSequence::new, 1))?;
@@ -483,27 +526,63 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
483
526
  let class = module.define_class("ByteLevel", pre_tokenizer)?;
484
527
  class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
485
528
  class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
486
- class.define_method("add_prefix_space", method!(RbPreTokenizer::byte_level_add_prefix_space, 0))?;
487
- class.define_method("add_prefix_space=", method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1))?;
488
- class.define_method("use_regex", method!(RbPreTokenizer::byte_level_use_regex, 0))?;
489
- class.define_method("use_regex=", method!(RbPreTokenizer::byte_level_set_use_regex, 1))?;
529
+ class.define_method(
530
+ "add_prefix_space",
531
+ method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
532
+ )?;
533
+ class.define_method(
534
+ "add_prefix_space=",
535
+ method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1),
536
+ )?;
537
+ class.define_method(
538
+ "use_regex",
539
+ method!(RbPreTokenizer::byte_level_use_regex, 0),
540
+ )?;
541
+ class.define_method(
542
+ "use_regex=",
543
+ method!(RbPreTokenizer::byte_level_set_use_regex, 1),
544
+ )?;
490
545
 
491
546
  let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
492
547
  class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
493
- class.define_method("delimiter", method!(RbPreTokenizer::char_delimiter_split_delimiter, 0))?;
494
- class.define_method("delimiter=", method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1))?;
548
+ class.define_method(
549
+ "delimiter",
550
+ method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
551
+ )?;
552
+ class.define_method(
553
+ "delimiter=",
554
+ method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1),
555
+ )?;
495
556
 
496
557
  let class = module.define_class("Digits", pre_tokenizer)?;
497
558
  class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
498
- class.define_method("individual_digits", method!(RbPreTokenizer::digits_individual_digits, 0))?;
499
- class.define_method("individual_digits=", method!(RbPreTokenizer::digits_set_individual_digits, 1))?;
559
+ class.define_method(
560
+ "individual_digits",
561
+ method!(RbPreTokenizer::digits_individual_digits, 0),
562
+ )?;
563
+ class.define_method(
564
+ "individual_digits=",
565
+ method!(RbPreTokenizer::digits_set_individual_digits, 1),
566
+ )?;
500
567
 
501
568
  let class = module.define_class("Metaspace", pre_tokenizer)?;
502
569
  class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
503
- class.define_method("prepend_scheme", method!(RbPreTokenizer::metaspace_prepend_scheme, 0))?;
504
- class.define_method("prepend_scheme=", method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1))?;
505
- class.define_method("replacement", method!(RbPreTokenizer::metaspace_replacement, 0))?;
506
- class.define_method("replacement=", method!(RbPreTokenizer::metaspace_set_replacement, 1))?;
570
+ class.define_method(
571
+ "prepend_scheme",
572
+ method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
573
+ )?;
574
+ class.define_method(
575
+ "prepend_scheme=",
576
+ method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1),
577
+ )?;
578
+ class.define_method(
579
+ "replacement",
580
+ method!(RbPreTokenizer::metaspace_replacement, 0),
581
+ )?;
582
+ class.define_method(
583
+ "replacement=",
584
+ method!(RbPreTokenizer::metaspace_set_replacement, 1),
585
+ )?;
507
586
  class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
508
587
  class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
509
588
 
@@ -1,8 +1,8 @@
1
1
  use std::sync::Arc;
2
2
 
3
3
  use magnus::{
4
- data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
5
- Ruby, TryConvert, TypedData, Value,
4
+ data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
+ RClass, RModule, Ruby, TryConvert, TypedData, Value,
6
6
  };
7
7
  use serde::{Deserialize, Serialize};
8
8
  use tk::processors::bert::BertProcessing;
@@ -12,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
12
12
  use tk::processors::PostProcessorWrapper;
13
13
  use tk::{Encoding, PostProcessor};
14
14
 
15
- use super::{PROCESSORS, RbResult};
15
+ use super::{RbResult, PROCESSORS};
16
16
 
17
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
18
18
  pub struct RbPostProcessor {
@@ -106,7 +106,6 @@ impl RbByteLevel {
106
106
  }
107
107
  RbPostProcessor::new(Arc::new(byte_level.into()))
108
108
  }
109
-
110
109
  }
111
110
 
112
111
  pub struct RbRobertaProcessing {}
@@ -117,7 +116,7 @@ impl RbRobertaProcessing {
117
116
  cls: (String, u32),
118
117
  trim_offsets: bool,
119
118
  add_prefix_space: bool,
120
- ) -> RbPostProcessor {
119
+ ) -> RbPostProcessor {
121
120
  let proc = RobertaProcessing::new(sep, cls)
122
121
  .trim_offsets(trim_offsets)
123
122
  .add_prefix_space(add_prefix_space);
@@ -153,7 +152,10 @@ impl RbTemplateProcessing {
153
152
  unsafe impl TypedData for RbPostProcessor {
154
153
  fn class(ruby: &Ruby) -> RClass {
155
154
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
156
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
155
+ let class: RClass = ruby
156
+ .get_inner(&PROCESSORS)
157
+ .const_get("PostProcessor")
158
+ .unwrap();
157
159
  class.undef_default_alloc_func();
158
160
  class
159
161
  });
@@ -161,13 +163,17 @@ unsafe impl TypedData for RbPostProcessor {
161
163
  }
162
164
 
163
165
  fn data_type() -> &'static DataType {
164
- static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
166
+ static DATA_TYPE: DataType =
167
+ data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
165
168
  &DATA_TYPE
166
169
  }
167
170
 
168
171
  fn class_for(ruby: &Ruby, value: &Self) -> RClass {
169
172
  static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
170
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
173
+ let class: RClass = ruby
174
+ .get_inner(&PROCESSORS)
175
+ .const_get("BertProcessing")
176
+ .unwrap();
171
177
  class.undef_default_alloc_func();
172
178
  class
173
179
  });
@@ -177,12 +183,18 @@ unsafe impl TypedData for RbPostProcessor {
177
183
  class
178
184
  });
179
185
  static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
180
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
186
+ let class: RClass = ruby
187
+ .get_inner(&PROCESSORS)
188
+ .const_get("RobertaProcessing")
189
+ .unwrap();
181
190
  class.undef_default_alloc_func();
182
191
  class
183
192
  });
184
193
  static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
185
- let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
194
+ let class: RClass = ruby
195
+ .get_inner(&PROCESSORS)
196
+ .const_get("TemplateProcessing")
197
+ .unwrap();
186
198
  class.undef_default_alloc_func();
187
199
  class
188
200
  });
@@ -6,8 +6,8 @@ use std::str::FromStr;
6
6
  use magnus::prelude::*;
7
7
  use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
8
8
  use tk::tokenizer::{
9
- Model, PaddingDirection, PaddingParams, PaddingStrategy,
10
- TruncationDirection, TruncationParams, TruncationStrategy, TokenizerImpl
9
+ Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
10
+ TruncationParams, TruncationStrategy,
11
11
  };
12
12
  use tk::AddedToken;
13
13
 
@@ -22,9 +22,10 @@ use super::processors::RbPostProcessor;
22
22
  use super::trainers::RbTrainer;
23
23
  use super::{RbError, RbResult};
24
24
 
25
+ #[magnus::wrap(class = "Tokenizers::AddedToken")]
25
26
  pub struct RbAddedToken {
26
27
  pub content: String,
27
- pub is_special_token: bool,
28
+ pub special: bool,
28
29
  pub single_word: Option<bool>,
29
30
  pub lstrip: Option<bool>,
30
31
  pub rstrip: Option<bool>,
@@ -32,10 +33,10 @@ pub struct RbAddedToken {
32
33
  }
33
34
 
34
35
  impl RbAddedToken {
35
- pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
36
+ pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
36
37
  Self {
37
38
  content: content.into(),
38
- is_special_token: is_special_token.unwrap_or(false),
39
+ special: special.unwrap_or(false),
39
40
  single_word: None,
40
41
  lstrip: None,
41
42
  rstrip: None,
@@ -44,7 +45,7 @@ impl RbAddedToken {
44
45
  }
45
46
 
46
47
  pub fn get_token(&self) -> tk::tokenizer::AddedToken {
47
- let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
48
+ let mut token = tk::AddedToken::from(&self.content, self.special);
48
49
 
49
50
  if let Some(sw) = self.single_word {
50
51
  token = token.single_word(sw);
@@ -71,11 +72,73 @@ impl From<tk::AddedToken> for RbAddedToken {
71
72
  lstrip: Some(token.lstrip),
72
73
  rstrip: Some(token.rstrip),
73
74
  normalized: Some(token.normalized),
74
- is_special_token: !token.normalized,
75
+ special: !token.normalized,
75
76
  }
76
77
  }
77
78
  }
78
79
 
80
+ impl RbAddedToken {
81
+ pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
82
+ let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
83
+
84
+ let value: Value = kwargs.delete(Symbol::new("single_word"))?;
85
+ if !value.is_nil() {
86
+ token.single_word = TryConvert::try_convert(value)?;
87
+ }
88
+
89
+ let value: Value = kwargs.delete(Symbol::new("lstrip"))?;
90
+ if !value.is_nil() {
91
+ token.lstrip = TryConvert::try_convert(value)?;
92
+ }
93
+
94
+ let value: Value = kwargs.delete(Symbol::new("rstrip"))?;
95
+ if !value.is_nil() {
96
+ token.rstrip = TryConvert::try_convert(value)?;
97
+ }
98
+
99
+ let value: Value = kwargs.delete(Symbol::new("normalized"))?;
100
+ if !value.is_nil() {
101
+ token.normalized = TryConvert::try_convert(value)?;
102
+ }
103
+
104
+ let value: Value = kwargs.delete(Symbol::new("special"))?;
105
+ if !value.is_nil() {
106
+ token.special = TryConvert::try_convert(value)?;
107
+ }
108
+
109
+ if !kwargs.is_empty() {
110
+ // TODO improve message
111
+ return Err(Error::new(exception::arg_error(), "unknown keyword"));
112
+ }
113
+
114
+ Ok(token)
115
+ }
116
+
117
+ pub fn get_content(&self) -> String {
118
+ self.content.to_string()
119
+ }
120
+
121
+ pub fn get_rstrip(&self) -> bool {
122
+ self.get_token().rstrip
123
+ }
124
+
125
+ pub fn get_lstrip(&self) -> bool {
126
+ self.get_token().lstrip
127
+ }
128
+
129
+ pub fn get_single_word(&self) -> bool {
130
+ self.get_token().single_word
131
+ }
132
+
133
+ pub fn get_normalized(&self) -> bool {
134
+ self.get_token().normalized
135
+ }
136
+
137
+ pub fn get_special(&self) -> bool {
138
+ self.get_token().special
139
+ }
140
+ }
141
+
79
142
  struct TextInputSequence<'s>(tk::InputSequence<'s>);
80
143
 
81
144
  impl<'s> TryConvert for TextInputSequence<'s> {
@@ -221,7 +284,10 @@ impl RbTokenizer {
221
284
  }
222
285
 
223
286
  pub fn to_str(&self, pretty: bool) -> RbResult<String> {
224
- self.tokenizer.borrow().to_string(pretty).map_err(RbError::from)
287
+ self.tokenizer
288
+ .borrow()
289
+ .to_string(pretty)
290
+ .map_err(RbError::from)
225
291
  }
226
292
 
227
293
  pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
@@ -320,7 +386,11 @@ impl RbTokenizer {
320
386
  .map_err(RbError::from)
321
387
  }
322
388
 
323
- pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
389
+ pub fn decode_batch(
390
+ &self,
391
+ sequences: Vec<Vec<u32>>,
392
+ skip_special_tokens: bool,
393
+ ) -> RbResult<Vec<String>> {
324
394
  let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
325
395
  self.tokenizer
326
396
  .borrow()
@@ -392,7 +462,12 @@ impl RbTokenizer {
392
462
  params.direction = match dir_str.as_str() {
393
463
  "left" => PaddingDirection::Left,
394
464
  "right" => PaddingDirection::Right,
395
- _ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
465
+ _ => {
466
+ return Err(Error::new(
467
+ exception::arg_error(),
468
+ "The direction value must be 'left' or 'right'",
469
+ ))
470
+ }
396
471
  }
397
472
  }
398
473
 
@@ -438,24 +513,27 @@ impl RbTokenizer {
438
513
  }
439
514
 
440
515
  pub fn padding(&self) -> RbResult<Option<RHash>> {
441
- self.tokenizer.borrow().get_padding().map_or(Ok(None), |params| {
442
- let ret_hash = RHash::new();
443
-
444
- ret_hash.aset(
445
- "length",
446
- match params.strategy {
447
- tk::PaddingStrategy::BatchLongest => None,
448
- tk::PaddingStrategy::Fixed(size) => Some(size),
449
- },
450
- )?;
451
- ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
452
- ret_hash.aset("pad_id", params.pad_id)?;
453
- ret_hash.aset("pad_token", &*params.pad_token)?;
454
- ret_hash.aset("pad_type_id", params.pad_type_id)?;
455
- ret_hash.aset("direction", params.direction.as_ref())?;
456
-
457
- Ok(Some(ret_hash))
458
- })
516
+ self.tokenizer
517
+ .borrow()
518
+ .get_padding()
519
+ .map_or(Ok(None), |params| {
520
+ let ret_hash = RHash::new();
521
+
522
+ ret_hash.aset(
523
+ "length",
524
+ match params.strategy {
525
+ tk::PaddingStrategy::BatchLongest => None,
526
+ tk::PaddingStrategy::Fixed(size) => Some(size),
527
+ },
528
+ )?;
529
+ ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
530
+ ret_hash.aset("pad_id", params.pad_id)?;
531
+ ret_hash.aset("pad_token", &*params.pad_token)?;
532
+ ret_hash.aset("pad_type_id", params.pad_type_id)?;
533
+ ret_hash.aset("direction", params.direction.as_ref())?;
534
+
535
+ Ok(Some(ret_hash))
536
+ })
459
537
  }
460
538
 
461
539
  pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
@@ -476,7 +554,10 @@ impl RbTokenizer {
476
554
  "longest_first" => TruncationStrategy::LongestFirst,
477
555
  "only_first" => TruncationStrategy::OnlyFirst,
478
556
  "only_second" => TruncationStrategy::OnlySecond,
479
- _ => return Err(Error::new(exception::arg_error(), "The strategy value must be 'longest_first', 'only_first', or 'only_second'")),
557
+ _ => return Err(Error::new(
558
+ exception::arg_error(),
559
+ "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
560
+ )),
480
561
  }
481
562
  }
482
563
 
@@ -486,7 +567,12 @@ impl RbTokenizer {
486
567
  params.direction = match dir_str.as_str() {
487
568
  "left" => TruncationDirection::Left,
488
569
  "right" => TruncationDirection::Right,
489
- _ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
570
+ _ => {
571
+ return Err(Error::new(
572
+ exception::arg_error(),
573
+ "The direction value must be 'left' or 'right'",
574
+ ))
575
+ }
490
576
  }
491
577
  }
492
578
 
@@ -496,7 +582,10 @@ impl RbTokenizer {
496
582
  }
497
583
 
498
584
  if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
499
- return Err(Error::new(exception::arg_error(), error_message.to_string()));
585
+ return Err(Error::new(
586
+ exception::arg_error(),
587
+ error_message.to_string(),
588
+ ));
500
589
  }
501
590
 
502
591
  Ok(())
@@ -510,16 +599,19 @@ impl RbTokenizer {
510
599
  }
511
600
 
512
601
  pub fn truncation(&self) -> RbResult<Option<RHash>> {
513
- self.tokenizer.borrow().get_truncation().map_or(Ok(None), |params| {
514
- let ret_hash = RHash::new();
602
+ self.tokenizer
603
+ .borrow()
604
+ .get_truncation()
605
+ .map_or(Ok(None), |params| {
606
+ let ret_hash = RHash::new();
515
607
 
516
- ret_hash.aset("max_length", params.max_length)?;
517
- ret_hash.aset("stride", params.stride)?;
518
- ret_hash.aset("strategy", params.strategy.as_ref())?;
519
- ret_hash.aset("direction", params.direction.as_ref())?;
608
+ ret_hash.aset("max_length", params.max_length)?;
609
+ ret_hash.aset("stride", params.stride)?;
610
+ ret_hash.aset("strategy", params.strategy.as_ref())?;
611
+ ret_hash.aset("direction", params.direction.as_ref())?;
520
612
 
521
- Ok(Some(ret_hash))
522
- })
613
+ Ok(Some(ret_hash))
614
+ })
523
615
  }
524
616
 
525
617
  pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
@@ -536,4 +628,14 @@ impl RbTokenizer {
536
628
  pub fn vocab_size(&self, with_added_tokens: bool) -> usize {
537
629
  self.tokenizer.borrow().get_vocab_size(with_added_tokens)
538
630
  }
631
+
632
+ pub fn get_added_tokens_decoder(&self) -> RbResult<RHash> {
633
+ let sorted_map = RHash::new();
634
+
635
+ for (key, value) in self.tokenizer.borrow().get_added_tokens_decoder() {
636
+ sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
637
+ }
638
+
639
+ Ok(sorted_map)
640
+ }
539
641
  }