tokenizers 0.6.3 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,8 +20,8 @@ use tk::{Model, Token};
20
20
  use super::{RbError, RbResult, MODELS};
21
21
 
22
22
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
23
+ #[serde(transparent)]
23
24
  pub struct RbModel {
24
- #[serde(flatten)]
25
25
  pub model: Arc<RwLock<ModelWrapper>>,
26
26
  }
27
27
 
@@ -158,7 +158,7 @@ macro_rules! setter {
158
158
  }
159
159
 
160
160
  impl RbModel {
161
- pub fn bpe_dropout(&self) -> Option<f32> {
161
+ pub fn bpe_get_dropout(&self) -> Option<f32> {
162
162
  getter!(self, BPE, dropout)
163
163
  }
164
164
 
@@ -166,7 +166,7 @@ impl RbModel {
166
166
  setter!(self, BPE, dropout, dropout);
167
167
  }
168
168
 
169
- pub fn bpe_unk_token(&self) -> Option<String> {
169
+ pub fn bpe_get_unk_token(&self) -> Option<String> {
170
170
  getter!(self, BPE, unk_token.clone())
171
171
  }
172
172
 
@@ -174,7 +174,7 @@ impl RbModel {
174
174
  setter!(self, BPE, unk_token, unk_token);
175
175
  }
176
176
 
177
- pub fn bpe_fuse_unk(&self) -> bool {
177
+ pub fn bpe_get_fuse_unk(&self) -> bool {
178
178
  getter!(self, BPE, fuse_unk)
179
179
  }
180
180
 
@@ -182,7 +182,7 @@ impl RbModel {
182
182
  setter!(self, BPE, fuse_unk, fuse_unk);
183
183
  }
184
184
 
185
- pub fn bpe_byte_fallback(&self) -> bool {
185
+ pub fn bpe_get_byte_fallback(&self) -> bool {
186
186
  getter!(self, BPE, byte_fallback)
187
187
  }
188
188
 
@@ -190,7 +190,7 @@ impl RbModel {
190
190
  setter!(self, BPE, byte_fallback, byte_fallback);
191
191
  }
192
192
 
193
- pub fn bpe_continuing_subword_prefix(&self) -> Option<String> {
193
+ pub fn bpe_get_continuing_subword_prefix(&self) -> Option<String> {
194
194
  getter!(self, BPE, continuing_subword_prefix.clone())
195
195
  }
196
196
 
@@ -203,7 +203,7 @@ impl RbModel {
203
203
  );
204
204
  }
205
205
 
206
- pub fn bpe_end_of_word_suffix(&self) -> Option<String> {
206
+ pub fn bpe_get_end_of_word_suffix(&self) -> Option<String> {
207
207
  getter!(self, BPE, end_of_word_suffix.clone())
208
208
  }
209
209
 
@@ -211,7 +211,7 @@ impl RbModel {
211
211
  setter!(self, BPE, end_of_word_suffix, end_of_word_suffix);
212
212
  }
213
213
 
214
- pub fn word_level_unk_token(&self) -> String {
214
+ pub fn word_level_get_unk_token(&self) -> String {
215
215
  getter!(self, WordLevel, unk_token.clone())
216
216
  }
217
217
 
@@ -219,7 +219,7 @@ impl RbModel {
219
219
  setter!(self, WordLevel, unk_token, unk_token);
220
220
  }
221
221
 
222
- pub fn word_piece_unk_token(&self) -> String {
222
+ pub fn word_piece_get_unk_token(&self) -> String {
223
223
  getter!(self, WordPiece, unk_token.clone())
224
224
  }
225
225
 
@@ -227,7 +227,7 @@ impl RbModel {
227
227
  setter!(self, WordPiece, unk_token, unk_token);
228
228
  }
229
229
 
230
- pub fn word_piece_continuing_subword_prefix(&self) -> String {
230
+ pub fn word_piece_get_continuing_subword_prefix(&self) -> String {
231
231
  getter!(self, WordPiece, continuing_subword_prefix.clone())
232
232
  }
233
233
 
@@ -240,7 +240,7 @@ impl RbModel {
240
240
  );
241
241
  }
242
242
 
243
- pub fn word_piece_max_input_chars_per_word(&self) -> usize {
243
+ pub fn word_piece_get_max_input_chars_per_word(&self) -> usize {
244
244
  getter!(self, WordPiece, max_input_chars_per_word.clone())
245
245
  }
246
246
 
@@ -405,13 +405,13 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
405
405
  let class = module.define_class("BPE", model)?;
406
406
  class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
407
407
  class.define_singleton_method("_from_file", function!(RbBPE::from_file, 3))?;
408
- class.define_method("dropout", method!(RbModel::bpe_dropout, 0))?;
408
+ class.define_method("dropout", method!(RbModel::bpe_get_dropout, 0))?;
409
409
  class.define_method("dropout=", method!(RbModel::bpe_set_dropout, 1))?;
410
- class.define_method("unk_token", method!(RbModel::bpe_unk_token, 0))?;
410
+ class.define_method("unk_token", method!(RbModel::bpe_get_unk_token, 0))?;
411
411
  class.define_method("unk_token=", method!(RbModel::bpe_set_unk_token, 1))?;
412
412
  class.define_method(
413
413
  "continuing_subword_prefix",
414
- method!(RbModel::bpe_continuing_subword_prefix, 0),
414
+ method!(RbModel::bpe_get_continuing_subword_prefix, 0),
415
415
  )?;
416
416
  class.define_method(
417
417
  "continuing_subword_prefix=",
@@ -419,15 +419,15 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
419
419
  )?;
420
420
  class.define_method(
421
421
  "end_of_word_suffix",
422
- method!(RbModel::bpe_end_of_word_suffix, 0),
422
+ method!(RbModel::bpe_get_end_of_word_suffix, 0),
423
423
  )?;
424
424
  class.define_method(
425
425
  "end_of_word_suffix=",
426
426
  method!(RbModel::bpe_set_end_of_word_suffix, 1),
427
427
  )?;
428
- class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
428
+ class.define_method("fuse_unk", method!(RbModel::bpe_get_fuse_unk, 0))?;
429
429
  class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
430
- class.define_method("byte_fallback", method!(RbModel::bpe_byte_fallback, 0))?;
430
+ class.define_method("byte_fallback", method!(RbModel::bpe_get_byte_fallback, 0))?;
431
431
  class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
432
432
 
433
433
  let class = module.define_class("Unigram", model)?;
@@ -437,17 +437,17 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
437
437
  class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
438
438
  class.define_singleton_method("_from_file", function!(RbWordLevel::from_file, 2))?;
439
439
  class.define_singleton_method("read_file", function!(RbWordLevel::read_file, 1))?;
440
- class.define_method("unk_token", method!(RbModel::word_level_unk_token, 0))?;
440
+ class.define_method("unk_token", method!(RbModel::word_level_get_unk_token, 0))?;
441
441
  class.define_method("unk_token=", method!(RbModel::word_level_set_unk_token, 1))?;
442
442
 
443
443
  let class = module.define_class("WordPiece", model)?;
444
444
  class.define_singleton_method("_new", function!(RbWordPiece::new, 2))?;
445
445
  class.define_singleton_method("_from_file", function!(RbWordPiece::from_file, 2))?;
446
- class.define_method("unk_token", method!(RbModel::word_piece_unk_token, 0))?;
446
+ class.define_method("unk_token", method!(RbModel::word_piece_get_unk_token, 0))?;
447
447
  class.define_method("unk_token=", method!(RbModel::word_piece_set_unk_token, 1))?;
448
448
  class.define_method(
449
449
  "continuing_subword_prefix",
450
- method!(RbModel::word_piece_continuing_subword_prefix, 0),
450
+ method!(RbModel::word_piece_get_continuing_subword_prefix, 0),
451
451
  )?;
452
452
  class.define_method(
453
453
  "continuing_subword_prefix=",
@@ -455,7 +455,7 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
455
455
  )?;
456
456
  class.define_method(
457
457
  "max_input_chars_per_word",
458
- method!(RbModel::word_piece_max_input_chars_per_word, 0),
458
+ method!(RbModel::word_piece_get_max_input_chars_per_word, 0),
459
459
  )?;
460
460
  class.define_method(
461
461
  "max_input_chars_per_word=",
@@ -16,8 +16,8 @@ use super::utils::*;
16
16
  use super::{RbError, RbResult, NORMALIZERS};
17
17
 
18
18
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
19
+ #[serde(transparent)]
19
20
  pub struct RbNormalizer {
20
- #[serde(flatten)]
21
21
  pub(crate) normalizer: RbNormalizerTypeWrapper,
22
22
  }
23
23
 
@@ -69,7 +69,7 @@ macro_rules! setter {
69
69
  }
70
70
 
71
71
  impl RbNormalizer {
72
- fn bert_clean_text(&self) -> bool {
72
+ fn bert_get_clean_text(&self) -> bool {
73
73
  getter!(self, BertNormalizer, clean_text)
74
74
  }
75
75
 
@@ -77,7 +77,7 @@ impl RbNormalizer {
77
77
  setter!(self, BertNormalizer, clean_text, clean_text);
78
78
  }
79
79
 
80
- fn bert_handle_chinese_chars(&self) -> bool {
80
+ fn bert_get_handle_chinese_chars(&self) -> bool {
81
81
  getter!(self, BertNormalizer, handle_chinese_chars)
82
82
  }
83
83
 
@@ -90,7 +90,7 @@ impl RbNormalizer {
90
90
  );
91
91
  }
92
92
 
93
- fn bert_strip_accents(&self) -> Option<bool> {
93
+ fn bert_get_strip_accents(&self) -> Option<bool> {
94
94
  getter!(self, BertNormalizer, strip_accents)
95
95
  }
96
96
 
@@ -98,7 +98,7 @@ impl RbNormalizer {
98
98
  setter!(self, BertNormalizer, strip_accents, strip_accents);
99
99
  }
100
100
 
101
- fn bert_lowercase(&self) -> bool {
101
+ fn bert_get_lowercase(&self) -> bool {
102
102
  getter!(self, BertNormalizer, lowercase)
103
103
  }
104
104
 
@@ -106,7 +106,7 @@ impl RbNormalizer {
106
106
  setter!(self, BertNormalizer, lowercase, lowercase);
107
107
  }
108
108
 
109
- fn prepend_prepend(&self) -> String {
109
+ fn prepend_get_prepend(&self) -> String {
110
110
  getter!(self, Prepend, prepend)
111
111
  }
112
112
 
@@ -114,7 +114,7 @@ impl RbNormalizer {
114
114
  setter!(self, Prepend, prepend, prepend);
115
115
  }
116
116
 
117
- fn strip_left(&self) -> bool {
117
+ fn strip_get_left(&self) -> bool {
118
118
  getter!(self, StripNormalizer, strip_left)
119
119
  }
120
120
 
@@ -122,7 +122,7 @@ impl RbNormalizer {
122
122
  setter!(self, StripNormalizer, strip_left, left);
123
123
  }
124
124
 
125
- fn strip_right(&self) -> bool {
125
+ fn strip_get_right(&self) -> bool {
126
126
  getter!(self, StripNormalizer, strip_right)
127
127
  }
128
128
 
@@ -476,11 +476,11 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
476
476
 
477
477
  let class = module.define_class("BertNormalizer", normalizer)?;
478
478
  class.define_singleton_method("_new", function!(RbBertNormalizer::new, 4))?;
479
- class.define_method("clean_text", method!(RbNormalizer::bert_clean_text, 0))?;
479
+ class.define_method("clean_text", method!(RbNormalizer::bert_get_clean_text, 0))?;
480
480
  class.define_method("clean_text=", method!(RbNormalizer::bert_set_clean_text, 1))?;
481
481
  class.define_method(
482
482
  "handle_chinese_chars",
483
- method!(RbNormalizer::bert_handle_chinese_chars, 0),
483
+ method!(RbNormalizer::bert_get_handle_chinese_chars, 0),
484
484
  )?;
485
485
  class.define_method(
486
486
  "handle_chinese_chars=",
@@ -488,13 +488,13 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
488
488
  )?;
489
489
  class.define_method(
490
490
  "strip_accents",
491
- method!(RbNormalizer::bert_strip_accents, 0),
491
+ method!(RbNormalizer::bert_get_strip_accents, 0),
492
492
  )?;
493
493
  class.define_method(
494
494
  "strip_accents=",
495
495
  method!(RbNormalizer::bert_set_strip_accents, 1),
496
496
  )?;
497
- class.define_method("lowercase", method!(RbNormalizer::bert_lowercase, 0))?;
497
+ class.define_method("lowercase", method!(RbNormalizer::bert_get_lowercase, 0))?;
498
498
  class.define_method("lowercase=", method!(RbNormalizer::bert_set_lowercase, 1))?;
499
499
 
500
500
  let class = module.define_class("Lowercase", normalizer)?;
@@ -523,14 +523,14 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
523
523
 
524
524
  let class = module.define_class("Prepend", normalizer)?;
525
525
  class.define_singleton_method("_new", function!(RbPrepend::new, 1))?;
526
- class.define_method("prepend", method!(RbNormalizer::prepend_prepend, 0))?;
526
+ class.define_method("prepend", method!(RbNormalizer::prepend_get_prepend, 0))?;
527
527
  class.define_method("prepend=", method!(RbNormalizer::prepend_set_prepend, 1))?;
528
528
 
529
529
  let class = module.define_class("Strip", normalizer)?;
530
530
  class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
531
- class.define_method("left", method!(RbNormalizer::strip_left, 0))?;
531
+ class.define_method("left", method!(RbNormalizer::strip_get_left, 0))?;
532
532
  class.define_method("left=", method!(RbNormalizer::strip_set_left, 1))?;
533
- class.define_method("right", method!(RbNormalizer::strip_right, 0))?;
533
+ class.define_method("right", method!(RbNormalizer::strip_get_right, 0))?;
534
534
  class.define_method("right=", method!(RbNormalizer::strip_set_right, 1))?;
535
535
 
536
536
  let class = module.define_class("StripAccents", normalizer)?;
@@ -25,8 +25,8 @@ use super::utils::*;
25
25
  use super::{RbError, RbResult, PRE_TOKENIZERS};
26
26
 
27
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
28
+ #[serde(transparent)]
28
29
  pub struct RbPreTokenizer {
29
- #[serde(flatten)]
30
30
  pub(crate) pretok: RbPreTokenizerTypeWrapper,
31
31
  }
32
32
 
@@ -88,7 +88,7 @@ impl RbPreTokenizer {
88
88
  RbPreTokenizer { pretok }
89
89
  }
90
90
 
91
- fn byte_level_add_prefix_space(&self) -> bool {
91
+ fn byte_level_get_add_prefix_space(&self) -> bool {
92
92
  getter!(self, ByteLevel, add_prefix_space)
93
93
  }
94
94
 
@@ -96,7 +96,7 @@ impl RbPreTokenizer {
96
96
  setter!(self, ByteLevel, add_prefix_space, add_prefix_space);
97
97
  }
98
98
 
99
- fn byte_level_use_regex(&self) -> bool {
99
+ fn byte_level_get_use_regex(&self) -> bool {
100
100
  getter!(self, ByteLevel, use_regex)
101
101
  }
102
102
 
@@ -104,7 +104,7 @@ impl RbPreTokenizer {
104
104
  setter!(self, ByteLevel, use_regex, use_regex);
105
105
  }
106
106
 
107
- fn char_delimiter_split_delimiter(&self) -> String {
107
+ fn char_delimiter_split_get_delimiter(&self) -> String {
108
108
  getter!(self, Delimiter, delimiter.to_string())
109
109
  }
110
110
 
@@ -112,7 +112,7 @@ impl RbPreTokenizer {
112
112
  setter!(self, Delimiter, delimiter, delimiter);
113
113
  }
114
114
 
115
- fn digits_individual_digits(&self) -> bool {
115
+ fn digits_get_individual_digits(&self) -> bool {
116
116
  getter!(self, Digits, individual_digits)
117
117
  }
118
118
 
@@ -120,7 +120,7 @@ impl RbPreTokenizer {
120
120
  setter!(self, Digits, individual_digits, individual_digits);
121
121
  }
122
122
 
123
- fn metaspace_replacement(&self) -> String {
123
+ fn metaspace_get_replacement(&self) -> String {
124
124
  getter!(self, Metaspace, get_replacement().to_string())
125
125
  }
126
126
 
@@ -128,7 +128,7 @@ impl RbPreTokenizer {
128
128
  setter!(self, Metaspace, @set_replacement, replacement);
129
129
  }
130
130
 
131
- fn metaspace_split(&self) -> bool {
131
+ fn metaspace_get_split(&self) -> bool {
132
132
  getter!(self, Metaspace, get_split())
133
133
  }
134
134
 
@@ -136,7 +136,7 @@ impl RbPreTokenizer {
136
136
  setter!(self, Metaspace, @set_split, split);
137
137
  }
138
138
 
139
- fn metaspace_prepend_scheme(&self) -> String {
139
+ fn metaspace_get_prepend_scheme(&self) -> String {
140
140
  // Assuming Metaspace has a method to get the prepend_scheme as a string
141
141
  let scheme: PrependScheme = getter!(self, Metaspace, get_prepend_scheme());
142
142
  match scheme {
@@ -528,7 +528,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
528
528
  class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
529
529
  class.define_method(
530
530
  "add_prefix_space",
531
- method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
531
+ method!(RbPreTokenizer::byte_level_get_add_prefix_space, 0),
532
532
  )?;
533
533
  class.define_method(
534
534
  "add_prefix_space=",
@@ -536,7 +536,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
536
536
  )?;
537
537
  class.define_method(
538
538
  "use_regex",
539
- method!(RbPreTokenizer::byte_level_use_regex, 0),
539
+ method!(RbPreTokenizer::byte_level_get_use_regex, 0),
540
540
  )?;
541
541
  class.define_method(
542
542
  "use_regex=",
@@ -547,7 +547,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
547
547
  class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
548
548
  class.define_method(
549
549
  "delimiter",
550
- method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
550
+ method!(RbPreTokenizer::char_delimiter_split_get_delimiter, 0),
551
551
  )?;
552
552
  class.define_method(
553
553
  "delimiter=",
@@ -558,7 +558,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
558
558
  class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
559
559
  class.define_method(
560
560
  "individual_digits",
561
- method!(RbPreTokenizer::digits_individual_digits, 0),
561
+ method!(RbPreTokenizer::digits_get_individual_digits, 0),
562
562
  )?;
563
563
  class.define_method(
564
564
  "individual_digits=",
@@ -569,7 +569,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
569
569
  class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
570
570
  class.define_method(
571
571
  "prepend_scheme",
572
- method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
572
+ method!(RbPreTokenizer::metaspace_get_prepend_scheme, 0),
573
573
  )?;
574
574
  class.define_method(
575
575
  "prepend_scheme=",
@@ -577,13 +577,13 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
577
577
  )?;
578
578
  class.define_method(
579
579
  "replacement",
580
- method!(RbPreTokenizer::metaspace_replacement, 0),
580
+ method!(RbPreTokenizer::metaspace_get_replacement, 0),
581
581
  )?;
582
582
  class.define_method(
583
583
  "replacement=",
584
584
  method!(RbPreTokenizer::metaspace_set_replacement, 1),
585
585
  )?;
586
- class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
586
+ class.define_method("split", method!(RbPreTokenizer::metaspace_get_split, 0))?;
587
587
  class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
588
588
 
589
589
  let class = module.define_class("Punctuation", pre_tokenizer)?;
@@ -1,10 +1,12 @@
1
1
  use std::sync::Arc;
2
+ use std::sync::RwLock;
2
3
 
3
4
  use magnus::{
4
5
  data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
- RClass, RModule, Ruby, TryConvert, TypedData, Value,
6
+ RArray, RClass, RModule, Ruby, TryConvert, TypedData, Value,
6
7
  };
7
- use serde::{Deserialize, Serialize};
8
+ use serde::ser::SerializeStruct;
9
+ use serde::{Deserialize, Deserializer, Serialize, Serializer};
8
10
  use tk::processors::bert::BertProcessing;
9
11
  use tk::processors::byte_level::ByteLevel;
10
12
  use tk::processors::roberta::RobertaProcessing;
@@ -15,17 +17,28 @@ use tk::{Encoding, PostProcessor};
15
17
  use super::{RbResult, PROCESSORS};
16
18
 
17
19
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
20
+ #[serde(transparent)]
18
21
  pub struct RbPostProcessor {
19
- #[serde(flatten)]
20
- pub processor: Arc<PostProcessorWrapper>,
22
+ pub processor: RbPostProcessorTypeWrapper,
21
23
  }
22
24
 
23
25
  impl RbPostProcessor {
24
- pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
26
+ pub fn new(processor: RbPostProcessorTypeWrapper) -> Self {
25
27
  RbPostProcessor { processor }
26
28
  }
27
29
  }
28
30
 
31
+ impl<I> From<I> for RbPostProcessor
32
+ where
33
+ I: Into<PostProcessorWrapper>,
34
+ {
35
+ fn from(processor: I) -> Self {
36
+ RbPostProcessor {
37
+ processor: processor.into().into(),
38
+ }
39
+ }
40
+ }
41
+
29
42
  impl PostProcessor for RbPostProcessor {
30
43
  fn added_tokens(&self, is_pair: bool) -> usize {
31
44
  self.processor.added_tokens(is_pair)
@@ -41,6 +54,92 @@ impl PostProcessor for RbPostProcessor {
41
54
  }
42
55
  }
43
56
 
57
+ #[derive(Clone)]
58
+ pub(crate) enum RbPostProcessorTypeWrapper {
59
+ Sequence(Vec<Arc<RwLock<PostProcessorWrapper>>>),
60
+ Single(Arc<RwLock<PostProcessorWrapper>>),
61
+ }
62
+
63
+ impl PostProcessor for RbPostProcessorTypeWrapper {
64
+ fn added_tokens(&self, is_pair: bool) -> usize {
65
+ match self {
66
+ RbPostProcessorTypeWrapper::Single(inner) => inner
67
+ .read()
68
+ .expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPostProcessor")
69
+ .added_tokens(is_pair),
70
+ RbPostProcessorTypeWrapper::Sequence(inner) => inner.iter().map(|p| {
71
+ p.read()
72
+ .expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPostProcessor")
73
+ .added_tokens(is_pair)
74
+ }).sum::<usize>(),
75
+ }
76
+ }
77
+
78
+ fn process_encodings(
79
+ &self,
80
+ mut encodings: Vec<Encoding>,
81
+ add_special_tokens: bool,
82
+ ) -> tk::Result<Vec<Encoding>> {
83
+ match self {
84
+ RbPostProcessorTypeWrapper::Single(inner) => inner
85
+ .read()
86
+ .expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPreTokenizer")
87
+ .process_encodings(encodings, add_special_tokens),
88
+ RbPostProcessorTypeWrapper::Sequence(inner) => {
89
+ for processor in inner.iter() {
90
+ encodings = processor
91
+ .read()
92
+ .expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPreTokenizer")
93
+ .process_encodings(encodings, add_special_tokens)?;
94
+ }
95
+ Ok(encodings)
96
+ },
97
+ }
98
+ }
99
+ }
100
+
101
+ impl<'de> Deserialize<'de> for RbPostProcessorTypeWrapper {
102
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
103
+ where
104
+ D: Deserializer<'de>,
105
+ {
106
+ let wrapper = PostProcessorWrapper::deserialize(deserializer)?;
107
+ Ok(wrapper.into())
108
+ }
109
+ }
110
+
111
+ impl Serialize for RbPostProcessorTypeWrapper {
112
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
113
+ where
114
+ S: Serializer,
115
+ {
116
+ match self {
117
+ RbPostProcessorTypeWrapper::Sequence(seq) => {
118
+ let mut ser = serializer.serialize_struct("Sequence", 2)?;
119
+ ser.serialize_field("type", "Sequence")?;
120
+ ser.serialize_field("processors", seq)?;
121
+ ser.end()
122
+ }
123
+ RbPostProcessorTypeWrapper::Single(inner) => inner.serialize(serializer),
124
+ }
125
+ }
126
+ }
127
+
128
+ impl<I> From<I> for RbPostProcessorTypeWrapper
129
+ where
130
+ I: Into<PostProcessorWrapper>,
131
+ {
132
+ fn from(processor: I) -> Self {
133
+ let processor = processor.into();
134
+ match processor {
135
+ PostProcessorWrapper::Sequence(seq) => RbPostProcessorTypeWrapper::Sequence(
136
+ seq.into_iter().map(|p| Arc::new(RwLock::new(p))).collect(),
137
+ ),
138
+ _ => RbPostProcessorTypeWrapper::Single(Arc::new(RwLock::new(processor.clone()))),
139
+ }
140
+ }
141
+ }
142
+
44
143
  #[derive(Clone, Debug)]
45
144
  pub struct RbSpecialToken(SpecialToken);
46
145
 
@@ -91,7 +190,7 @@ pub struct RbBertProcessing {}
91
190
 
92
191
  impl RbBertProcessing {
93
192
  pub fn new(sep: (String, u32), cls: (String, u32)) -> RbPostProcessor {
94
- RbPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into()))
193
+ BertProcessing::new(sep, cls).into()
95
194
  }
96
195
  }
97
196
 
@@ -104,7 +203,7 @@ impl RbByteLevel {
104
203
  if let Some(to) = trim_offsets {
105
204
  byte_level = byte_level.trim_offsets(to);
106
205
  }
107
- RbPostProcessor::new(Arc::new(byte_level.into()))
206
+ byte_level.into()
108
207
  }
109
208
  }
110
209
 
@@ -120,7 +219,7 @@ impl RbRobertaProcessing {
120
219
  let proc = RobertaProcessing::new(sep, cls)
121
220
  .trim_offsets(trim_offsets)
122
221
  .add_prefix_space(add_prefix_space);
123
- RbPostProcessor::new(Arc::new(proc.into()))
222
+ proc.into()
124
223
  }
125
224
  }
126
225
 
@@ -145,7 +244,27 @@ impl RbTemplateProcessing {
145
244
  }
146
245
  let processor = builder.build().unwrap(); //.map_err(RbError::from)?;
147
246
 
148
- Ok(RbPostProcessor::new(Arc::new(processor.into())))
247
+ Ok(processor.into())
248
+ }
249
+ }
250
+
251
+ pub struct RbSequence {}
252
+
253
+ impl RbSequence {
254
+ fn new(processors_rb: RArray) -> RbResult<RbPostProcessor> {
255
+ let mut processors = Vec::with_capacity(processors_rb.len());
256
+ for n in processors_rb {
257
+ let processor = <&RbPostProcessor>::try_convert(n)?;
258
+ match &processor.processor {
259
+ RbPostProcessorTypeWrapper::Sequence(inner) => {
260
+ processors.extend(inner.iter().cloned())
261
+ }
262
+ RbPostProcessorTypeWrapper::Single(inner) => processors.push(inner.clone()),
263
+ }
264
+ }
265
+ Ok(RbPostProcessor::new(RbPostProcessorTypeWrapper::Sequence(
266
+ processors,
267
+ )))
149
268
  }
150
269
  }
151
270
 
@@ -198,12 +317,20 @@ unsafe impl TypedData for RbPostProcessor {
198
317
  class.undef_default_alloc_func();
199
318
  class
200
319
  });
201
- match *value.processor {
202
- PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
203
- PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
204
- PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
205
- PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
206
- _ => todo!(),
320
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
321
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("Sequence").unwrap();
322
+ class.undef_default_alloc_func();
323
+ class
324
+ });
325
+ match &value.processor {
326
+ RbPostProcessorTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
327
+ PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
328
+ PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
329
+ PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
330
+ PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
331
+ _ => todo!(),
332
+ },
333
+ RbPostProcessorTypeWrapper::Sequence(_) => ruby.get_inner(&SEQUENCE),
207
334
  }
208
335
  }
209
336
  }
@@ -223,5 +350,8 @@ pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
223
350
  let class = module.define_class("TemplateProcessing", post_processor)?;
224
351
  class.define_singleton_method("_new", function!(RbTemplateProcessing::new, 3))?;
225
352
 
353
+ let class = module.define_class("Sequence", post_processor)?;
354
+ class.define_singleton_method("_new", function!(RbSequence::new, 1))?;
355
+
226
356
  Ok(())
227
357
  }
@@ -0,0 +1,51 @@
1
+ use std::ffi::c_void;
2
+ use std::ptr::null_mut;
3
+
4
+ use magnus::Ruby;
5
+ use rb_sys::rb_thread_call_without_gvl;
6
+
7
+ pub trait GvlExt {
8
+ fn detach<T, F>(&self, func: F) -> T
9
+ where
10
+ F: Send + FnOnce() -> T,
11
+ T: Send;
12
+ }
13
+
14
+ impl GvlExt for Ruby {
15
+ fn detach<T, F>(&self, func: F) -> T
16
+ where
17
+ F: Send + FnOnce() -> T,
18
+ T: Send,
19
+ {
20
+ let mut data = CallbackData {
21
+ func: Some(func),
22
+ result: None,
23
+ };
24
+
25
+ unsafe {
26
+ rb_thread_call_without_gvl(
27
+ Some(call_without_gvl::<F, T>),
28
+ &mut data as *mut _ as *mut c_void,
29
+ None,
30
+ null_mut(),
31
+ );
32
+ }
33
+
34
+ data.result.unwrap()
35
+ }
36
+ }
37
+
38
+ struct CallbackData<F, T> {
39
+ func: Option<F>,
40
+ result: Option<T>,
41
+ }
42
+
43
+ extern "C" fn call_without_gvl<F, T>(data: *mut c_void) -> *mut c_void
44
+ where
45
+ F: FnOnce() -> T,
46
+ {
47
+ let data = unsafe { &mut *(data as *mut CallbackData<F, T>) };
48
+ let func = data.func.take().unwrap();
49
+ data.result = Some(func());
50
+ null_mut()
51
+ }