tokenizers 0.6.3 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/Cargo.lock +21 -22
- data/ext/tokenizers/Cargo.toml +3 -2
- data/ext/tokenizers/src/decoders.rs +31 -28
- data/ext/tokenizers/src/encoding.rs +42 -11
- data/ext/tokenizers/src/error.rs +10 -5
- data/ext/tokenizers/src/lib.rs +4 -91
- data/ext/tokenizers/src/models.rs +21 -21
- data/ext/tokenizers/src/normalizers.rs +15 -15
- data/ext/tokenizers/src/pre_tokenizers.rs +15 -15
- data/ext/tokenizers/src/processors.rs +145 -15
- data/ext/tokenizers/src/ruby.rs +51 -0
- data/ext/tokenizers/src/tokenizer.rs +381 -244
- data/ext/tokenizers/src/trainers.rs +55 -49
- data/ext/tokenizers/src/utils/normalization.rs +2 -1
- data/ext/tokenizers/src/utils/regex.rs +2 -2
- data/lib/tokenizers/from_pretrained.rb +6 -2
- data/lib/tokenizers/processors/sequence.rb +9 -0
- data/lib/tokenizers/tokenizer.rb +4 -0
- data/lib/tokenizers/version.rb +1 -1
- metadata +5 -3
|
@@ -20,8 +20,8 @@ use tk::{Model, Token};
|
|
|
20
20
|
use super::{RbError, RbResult, MODELS};
|
|
21
21
|
|
|
22
22
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
|
23
|
+
#[serde(transparent)]
|
|
23
24
|
pub struct RbModel {
|
|
24
|
-
#[serde(flatten)]
|
|
25
25
|
pub model: Arc<RwLock<ModelWrapper>>,
|
|
26
26
|
}
|
|
27
27
|
|
|
@@ -158,7 +158,7 @@ macro_rules! setter {
|
|
|
158
158
|
}
|
|
159
159
|
|
|
160
160
|
impl RbModel {
|
|
161
|
-
pub fn
|
|
161
|
+
pub fn bpe_get_dropout(&self) -> Option<f32> {
|
|
162
162
|
getter!(self, BPE, dropout)
|
|
163
163
|
}
|
|
164
164
|
|
|
@@ -166,7 +166,7 @@ impl RbModel {
|
|
|
166
166
|
setter!(self, BPE, dropout, dropout);
|
|
167
167
|
}
|
|
168
168
|
|
|
169
|
-
pub fn
|
|
169
|
+
pub fn bpe_get_unk_token(&self) -> Option<String> {
|
|
170
170
|
getter!(self, BPE, unk_token.clone())
|
|
171
171
|
}
|
|
172
172
|
|
|
@@ -174,7 +174,7 @@ impl RbModel {
|
|
|
174
174
|
setter!(self, BPE, unk_token, unk_token);
|
|
175
175
|
}
|
|
176
176
|
|
|
177
|
-
pub fn
|
|
177
|
+
pub fn bpe_get_fuse_unk(&self) -> bool {
|
|
178
178
|
getter!(self, BPE, fuse_unk)
|
|
179
179
|
}
|
|
180
180
|
|
|
@@ -182,7 +182,7 @@ impl RbModel {
|
|
|
182
182
|
setter!(self, BPE, fuse_unk, fuse_unk);
|
|
183
183
|
}
|
|
184
184
|
|
|
185
|
-
pub fn
|
|
185
|
+
pub fn bpe_get_byte_fallback(&self) -> bool {
|
|
186
186
|
getter!(self, BPE, byte_fallback)
|
|
187
187
|
}
|
|
188
188
|
|
|
@@ -190,7 +190,7 @@ impl RbModel {
|
|
|
190
190
|
setter!(self, BPE, byte_fallback, byte_fallback);
|
|
191
191
|
}
|
|
192
192
|
|
|
193
|
-
pub fn
|
|
193
|
+
pub fn bpe_get_continuing_subword_prefix(&self) -> Option<String> {
|
|
194
194
|
getter!(self, BPE, continuing_subword_prefix.clone())
|
|
195
195
|
}
|
|
196
196
|
|
|
@@ -203,7 +203,7 @@ impl RbModel {
|
|
|
203
203
|
);
|
|
204
204
|
}
|
|
205
205
|
|
|
206
|
-
pub fn
|
|
206
|
+
pub fn bpe_get_end_of_word_suffix(&self) -> Option<String> {
|
|
207
207
|
getter!(self, BPE, end_of_word_suffix.clone())
|
|
208
208
|
}
|
|
209
209
|
|
|
@@ -211,7 +211,7 @@ impl RbModel {
|
|
|
211
211
|
setter!(self, BPE, end_of_word_suffix, end_of_word_suffix);
|
|
212
212
|
}
|
|
213
213
|
|
|
214
|
-
pub fn
|
|
214
|
+
pub fn word_level_get_unk_token(&self) -> String {
|
|
215
215
|
getter!(self, WordLevel, unk_token.clone())
|
|
216
216
|
}
|
|
217
217
|
|
|
@@ -219,7 +219,7 @@ impl RbModel {
|
|
|
219
219
|
setter!(self, WordLevel, unk_token, unk_token);
|
|
220
220
|
}
|
|
221
221
|
|
|
222
|
-
pub fn
|
|
222
|
+
pub fn word_piece_get_unk_token(&self) -> String {
|
|
223
223
|
getter!(self, WordPiece, unk_token.clone())
|
|
224
224
|
}
|
|
225
225
|
|
|
@@ -227,7 +227,7 @@ impl RbModel {
|
|
|
227
227
|
setter!(self, WordPiece, unk_token, unk_token);
|
|
228
228
|
}
|
|
229
229
|
|
|
230
|
-
pub fn
|
|
230
|
+
pub fn word_piece_get_continuing_subword_prefix(&self) -> String {
|
|
231
231
|
getter!(self, WordPiece, continuing_subword_prefix.clone())
|
|
232
232
|
}
|
|
233
233
|
|
|
@@ -240,7 +240,7 @@ impl RbModel {
|
|
|
240
240
|
);
|
|
241
241
|
}
|
|
242
242
|
|
|
243
|
-
pub fn
|
|
243
|
+
pub fn word_piece_get_max_input_chars_per_word(&self) -> usize {
|
|
244
244
|
getter!(self, WordPiece, max_input_chars_per_word.clone())
|
|
245
245
|
}
|
|
246
246
|
|
|
@@ -405,13 +405,13 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
405
405
|
let class = module.define_class("BPE", model)?;
|
|
406
406
|
class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
|
|
407
407
|
class.define_singleton_method("_from_file", function!(RbBPE::from_file, 3))?;
|
|
408
|
-
class.define_method("dropout", method!(RbModel::
|
|
408
|
+
class.define_method("dropout", method!(RbModel::bpe_get_dropout, 0))?;
|
|
409
409
|
class.define_method("dropout=", method!(RbModel::bpe_set_dropout, 1))?;
|
|
410
|
-
class.define_method("unk_token", method!(RbModel::
|
|
410
|
+
class.define_method("unk_token", method!(RbModel::bpe_get_unk_token, 0))?;
|
|
411
411
|
class.define_method("unk_token=", method!(RbModel::bpe_set_unk_token, 1))?;
|
|
412
412
|
class.define_method(
|
|
413
413
|
"continuing_subword_prefix",
|
|
414
|
-
method!(RbModel::
|
|
414
|
+
method!(RbModel::bpe_get_continuing_subword_prefix, 0),
|
|
415
415
|
)?;
|
|
416
416
|
class.define_method(
|
|
417
417
|
"continuing_subword_prefix=",
|
|
@@ -419,15 +419,15 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
419
419
|
)?;
|
|
420
420
|
class.define_method(
|
|
421
421
|
"end_of_word_suffix",
|
|
422
|
-
method!(RbModel::
|
|
422
|
+
method!(RbModel::bpe_get_end_of_word_suffix, 0),
|
|
423
423
|
)?;
|
|
424
424
|
class.define_method(
|
|
425
425
|
"end_of_word_suffix=",
|
|
426
426
|
method!(RbModel::bpe_set_end_of_word_suffix, 1),
|
|
427
427
|
)?;
|
|
428
|
-
class.define_method("fuse_unk", method!(RbModel::
|
|
428
|
+
class.define_method("fuse_unk", method!(RbModel::bpe_get_fuse_unk, 0))?;
|
|
429
429
|
class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
|
|
430
|
-
class.define_method("byte_fallback", method!(RbModel::
|
|
430
|
+
class.define_method("byte_fallback", method!(RbModel::bpe_get_byte_fallback, 0))?;
|
|
431
431
|
class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
|
|
432
432
|
|
|
433
433
|
let class = module.define_class("Unigram", model)?;
|
|
@@ -437,17 +437,17 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
437
437
|
class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
|
|
438
438
|
class.define_singleton_method("_from_file", function!(RbWordLevel::from_file, 2))?;
|
|
439
439
|
class.define_singleton_method("read_file", function!(RbWordLevel::read_file, 1))?;
|
|
440
|
-
class.define_method("unk_token", method!(RbModel::
|
|
440
|
+
class.define_method("unk_token", method!(RbModel::word_level_get_unk_token, 0))?;
|
|
441
441
|
class.define_method("unk_token=", method!(RbModel::word_level_set_unk_token, 1))?;
|
|
442
442
|
|
|
443
443
|
let class = module.define_class("WordPiece", model)?;
|
|
444
444
|
class.define_singleton_method("_new", function!(RbWordPiece::new, 2))?;
|
|
445
445
|
class.define_singleton_method("_from_file", function!(RbWordPiece::from_file, 2))?;
|
|
446
|
-
class.define_method("unk_token", method!(RbModel::
|
|
446
|
+
class.define_method("unk_token", method!(RbModel::word_piece_get_unk_token, 0))?;
|
|
447
447
|
class.define_method("unk_token=", method!(RbModel::word_piece_set_unk_token, 1))?;
|
|
448
448
|
class.define_method(
|
|
449
449
|
"continuing_subword_prefix",
|
|
450
|
-
method!(RbModel::
|
|
450
|
+
method!(RbModel::word_piece_get_continuing_subword_prefix, 0),
|
|
451
451
|
)?;
|
|
452
452
|
class.define_method(
|
|
453
453
|
"continuing_subword_prefix=",
|
|
@@ -455,7 +455,7 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
455
455
|
)?;
|
|
456
456
|
class.define_method(
|
|
457
457
|
"max_input_chars_per_word",
|
|
458
|
-
method!(RbModel::
|
|
458
|
+
method!(RbModel::word_piece_get_max_input_chars_per_word, 0),
|
|
459
459
|
)?;
|
|
460
460
|
class.define_method(
|
|
461
461
|
"max_input_chars_per_word=",
|
|
@@ -16,8 +16,8 @@ use super::utils::*;
|
|
|
16
16
|
use super::{RbError, RbResult, NORMALIZERS};
|
|
17
17
|
|
|
18
18
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
|
19
|
+
#[serde(transparent)]
|
|
19
20
|
pub struct RbNormalizer {
|
|
20
|
-
#[serde(flatten)]
|
|
21
21
|
pub(crate) normalizer: RbNormalizerTypeWrapper,
|
|
22
22
|
}
|
|
23
23
|
|
|
@@ -69,7 +69,7 @@ macro_rules! setter {
|
|
|
69
69
|
}
|
|
70
70
|
|
|
71
71
|
impl RbNormalizer {
|
|
72
|
-
fn
|
|
72
|
+
fn bert_get_clean_text(&self) -> bool {
|
|
73
73
|
getter!(self, BertNormalizer, clean_text)
|
|
74
74
|
}
|
|
75
75
|
|
|
@@ -77,7 +77,7 @@ impl RbNormalizer {
|
|
|
77
77
|
setter!(self, BertNormalizer, clean_text, clean_text);
|
|
78
78
|
}
|
|
79
79
|
|
|
80
|
-
fn
|
|
80
|
+
fn bert_get_handle_chinese_chars(&self) -> bool {
|
|
81
81
|
getter!(self, BertNormalizer, handle_chinese_chars)
|
|
82
82
|
}
|
|
83
83
|
|
|
@@ -90,7 +90,7 @@ impl RbNormalizer {
|
|
|
90
90
|
);
|
|
91
91
|
}
|
|
92
92
|
|
|
93
|
-
fn
|
|
93
|
+
fn bert_get_strip_accents(&self) -> Option<bool> {
|
|
94
94
|
getter!(self, BertNormalizer, strip_accents)
|
|
95
95
|
}
|
|
96
96
|
|
|
@@ -98,7 +98,7 @@ impl RbNormalizer {
|
|
|
98
98
|
setter!(self, BertNormalizer, strip_accents, strip_accents);
|
|
99
99
|
}
|
|
100
100
|
|
|
101
|
-
fn
|
|
101
|
+
fn bert_get_lowercase(&self) -> bool {
|
|
102
102
|
getter!(self, BertNormalizer, lowercase)
|
|
103
103
|
}
|
|
104
104
|
|
|
@@ -106,7 +106,7 @@ impl RbNormalizer {
|
|
|
106
106
|
setter!(self, BertNormalizer, lowercase, lowercase);
|
|
107
107
|
}
|
|
108
108
|
|
|
109
|
-
fn
|
|
109
|
+
fn prepend_get_prepend(&self) -> String {
|
|
110
110
|
getter!(self, Prepend, prepend)
|
|
111
111
|
}
|
|
112
112
|
|
|
@@ -114,7 +114,7 @@ impl RbNormalizer {
|
|
|
114
114
|
setter!(self, Prepend, prepend, prepend);
|
|
115
115
|
}
|
|
116
116
|
|
|
117
|
-
fn
|
|
117
|
+
fn strip_get_left(&self) -> bool {
|
|
118
118
|
getter!(self, StripNormalizer, strip_left)
|
|
119
119
|
}
|
|
120
120
|
|
|
@@ -122,7 +122,7 @@ impl RbNormalizer {
|
|
|
122
122
|
setter!(self, StripNormalizer, strip_left, left);
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
-
fn
|
|
125
|
+
fn strip_get_right(&self) -> bool {
|
|
126
126
|
getter!(self, StripNormalizer, strip_right)
|
|
127
127
|
}
|
|
128
128
|
|
|
@@ -476,11 +476,11 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
476
476
|
|
|
477
477
|
let class = module.define_class("BertNormalizer", normalizer)?;
|
|
478
478
|
class.define_singleton_method("_new", function!(RbBertNormalizer::new, 4))?;
|
|
479
|
-
class.define_method("clean_text", method!(RbNormalizer::
|
|
479
|
+
class.define_method("clean_text", method!(RbNormalizer::bert_get_clean_text, 0))?;
|
|
480
480
|
class.define_method("clean_text=", method!(RbNormalizer::bert_set_clean_text, 1))?;
|
|
481
481
|
class.define_method(
|
|
482
482
|
"handle_chinese_chars",
|
|
483
|
-
method!(RbNormalizer::
|
|
483
|
+
method!(RbNormalizer::bert_get_handle_chinese_chars, 0),
|
|
484
484
|
)?;
|
|
485
485
|
class.define_method(
|
|
486
486
|
"handle_chinese_chars=",
|
|
@@ -488,13 +488,13 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
488
488
|
)?;
|
|
489
489
|
class.define_method(
|
|
490
490
|
"strip_accents",
|
|
491
|
-
method!(RbNormalizer::
|
|
491
|
+
method!(RbNormalizer::bert_get_strip_accents, 0),
|
|
492
492
|
)?;
|
|
493
493
|
class.define_method(
|
|
494
494
|
"strip_accents=",
|
|
495
495
|
method!(RbNormalizer::bert_set_strip_accents, 1),
|
|
496
496
|
)?;
|
|
497
|
-
class.define_method("lowercase", method!(RbNormalizer::
|
|
497
|
+
class.define_method("lowercase", method!(RbNormalizer::bert_get_lowercase, 0))?;
|
|
498
498
|
class.define_method("lowercase=", method!(RbNormalizer::bert_set_lowercase, 1))?;
|
|
499
499
|
|
|
500
500
|
let class = module.define_class("Lowercase", normalizer)?;
|
|
@@ -523,14 +523,14 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
523
523
|
|
|
524
524
|
let class = module.define_class("Prepend", normalizer)?;
|
|
525
525
|
class.define_singleton_method("_new", function!(RbPrepend::new, 1))?;
|
|
526
|
-
class.define_method("prepend", method!(RbNormalizer::
|
|
526
|
+
class.define_method("prepend", method!(RbNormalizer::prepend_get_prepend, 0))?;
|
|
527
527
|
class.define_method("prepend=", method!(RbNormalizer::prepend_set_prepend, 1))?;
|
|
528
528
|
|
|
529
529
|
let class = module.define_class("Strip", normalizer)?;
|
|
530
530
|
class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
|
|
531
|
-
class.define_method("left", method!(RbNormalizer::
|
|
531
|
+
class.define_method("left", method!(RbNormalizer::strip_get_left, 0))?;
|
|
532
532
|
class.define_method("left=", method!(RbNormalizer::strip_set_left, 1))?;
|
|
533
|
-
class.define_method("right", method!(RbNormalizer::
|
|
533
|
+
class.define_method("right", method!(RbNormalizer::strip_get_right, 0))?;
|
|
534
534
|
class.define_method("right=", method!(RbNormalizer::strip_set_right, 1))?;
|
|
535
535
|
|
|
536
536
|
let class = module.define_class("StripAccents", normalizer)?;
|
|
@@ -25,8 +25,8 @@ use super::utils::*;
|
|
|
25
25
|
use super::{RbError, RbResult, PRE_TOKENIZERS};
|
|
26
26
|
|
|
27
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
|
28
|
+
#[serde(transparent)]
|
|
28
29
|
pub struct RbPreTokenizer {
|
|
29
|
-
#[serde(flatten)]
|
|
30
30
|
pub(crate) pretok: RbPreTokenizerTypeWrapper,
|
|
31
31
|
}
|
|
32
32
|
|
|
@@ -88,7 +88,7 @@ impl RbPreTokenizer {
|
|
|
88
88
|
RbPreTokenizer { pretok }
|
|
89
89
|
}
|
|
90
90
|
|
|
91
|
-
fn
|
|
91
|
+
fn byte_level_get_add_prefix_space(&self) -> bool {
|
|
92
92
|
getter!(self, ByteLevel, add_prefix_space)
|
|
93
93
|
}
|
|
94
94
|
|
|
@@ -96,7 +96,7 @@ impl RbPreTokenizer {
|
|
|
96
96
|
setter!(self, ByteLevel, add_prefix_space, add_prefix_space);
|
|
97
97
|
}
|
|
98
98
|
|
|
99
|
-
fn
|
|
99
|
+
fn byte_level_get_use_regex(&self) -> bool {
|
|
100
100
|
getter!(self, ByteLevel, use_regex)
|
|
101
101
|
}
|
|
102
102
|
|
|
@@ -104,7 +104,7 @@ impl RbPreTokenizer {
|
|
|
104
104
|
setter!(self, ByteLevel, use_regex, use_regex);
|
|
105
105
|
}
|
|
106
106
|
|
|
107
|
-
fn
|
|
107
|
+
fn char_delimiter_split_get_delimiter(&self) -> String {
|
|
108
108
|
getter!(self, Delimiter, delimiter.to_string())
|
|
109
109
|
}
|
|
110
110
|
|
|
@@ -112,7 +112,7 @@ impl RbPreTokenizer {
|
|
|
112
112
|
setter!(self, Delimiter, delimiter, delimiter);
|
|
113
113
|
}
|
|
114
114
|
|
|
115
|
-
fn
|
|
115
|
+
fn digits_get_individual_digits(&self) -> bool {
|
|
116
116
|
getter!(self, Digits, individual_digits)
|
|
117
117
|
}
|
|
118
118
|
|
|
@@ -120,7 +120,7 @@ impl RbPreTokenizer {
|
|
|
120
120
|
setter!(self, Digits, individual_digits, individual_digits);
|
|
121
121
|
}
|
|
122
122
|
|
|
123
|
-
fn
|
|
123
|
+
fn metaspace_get_replacement(&self) -> String {
|
|
124
124
|
getter!(self, Metaspace, get_replacement().to_string())
|
|
125
125
|
}
|
|
126
126
|
|
|
@@ -128,7 +128,7 @@ impl RbPreTokenizer {
|
|
|
128
128
|
setter!(self, Metaspace, @set_replacement, replacement);
|
|
129
129
|
}
|
|
130
130
|
|
|
131
|
-
fn
|
|
131
|
+
fn metaspace_get_split(&self) -> bool {
|
|
132
132
|
getter!(self, Metaspace, get_split())
|
|
133
133
|
}
|
|
134
134
|
|
|
@@ -136,7 +136,7 @@ impl RbPreTokenizer {
|
|
|
136
136
|
setter!(self, Metaspace, @set_split, split);
|
|
137
137
|
}
|
|
138
138
|
|
|
139
|
-
fn
|
|
139
|
+
fn metaspace_get_prepend_scheme(&self) -> String {
|
|
140
140
|
// Assuming Metaspace has a method to get the prepend_scheme as a string
|
|
141
141
|
let scheme: PrependScheme = getter!(self, Metaspace, get_prepend_scheme());
|
|
142
142
|
match scheme {
|
|
@@ -528,7 +528,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
528
528
|
class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
|
|
529
529
|
class.define_method(
|
|
530
530
|
"add_prefix_space",
|
|
531
|
-
method!(RbPreTokenizer::
|
|
531
|
+
method!(RbPreTokenizer::byte_level_get_add_prefix_space, 0),
|
|
532
532
|
)?;
|
|
533
533
|
class.define_method(
|
|
534
534
|
"add_prefix_space=",
|
|
@@ -536,7 +536,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
536
536
|
)?;
|
|
537
537
|
class.define_method(
|
|
538
538
|
"use_regex",
|
|
539
|
-
method!(RbPreTokenizer::
|
|
539
|
+
method!(RbPreTokenizer::byte_level_get_use_regex, 0),
|
|
540
540
|
)?;
|
|
541
541
|
class.define_method(
|
|
542
542
|
"use_regex=",
|
|
@@ -547,7 +547,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
547
547
|
class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
|
|
548
548
|
class.define_method(
|
|
549
549
|
"delimiter",
|
|
550
|
-
method!(RbPreTokenizer::
|
|
550
|
+
method!(RbPreTokenizer::char_delimiter_split_get_delimiter, 0),
|
|
551
551
|
)?;
|
|
552
552
|
class.define_method(
|
|
553
553
|
"delimiter=",
|
|
@@ -558,7 +558,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
558
558
|
class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
|
|
559
559
|
class.define_method(
|
|
560
560
|
"individual_digits",
|
|
561
|
-
method!(RbPreTokenizer::
|
|
561
|
+
method!(RbPreTokenizer::digits_get_individual_digits, 0),
|
|
562
562
|
)?;
|
|
563
563
|
class.define_method(
|
|
564
564
|
"individual_digits=",
|
|
@@ -569,7 +569,7 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
569
569
|
class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
|
|
570
570
|
class.define_method(
|
|
571
571
|
"prepend_scheme",
|
|
572
|
-
method!(RbPreTokenizer::
|
|
572
|
+
method!(RbPreTokenizer::metaspace_get_prepend_scheme, 0),
|
|
573
573
|
)?;
|
|
574
574
|
class.define_method(
|
|
575
575
|
"prepend_scheme=",
|
|
@@ -577,13 +577,13 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
577
577
|
)?;
|
|
578
578
|
class.define_method(
|
|
579
579
|
"replacement",
|
|
580
|
-
method!(RbPreTokenizer::
|
|
580
|
+
method!(RbPreTokenizer::metaspace_get_replacement, 0),
|
|
581
581
|
)?;
|
|
582
582
|
class.define_method(
|
|
583
583
|
"replacement=",
|
|
584
584
|
method!(RbPreTokenizer::metaspace_set_replacement, 1),
|
|
585
585
|
)?;
|
|
586
|
-
class.define_method("split", method!(RbPreTokenizer::
|
|
586
|
+
class.define_method("split", method!(RbPreTokenizer::metaspace_get_split, 0))?;
|
|
587
587
|
class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
|
|
588
588
|
|
|
589
589
|
let class = module.define_class("Punctuation", pre_tokenizer)?;
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
use std::sync::Arc;
|
|
2
|
+
use std::sync::RwLock;
|
|
2
3
|
|
|
3
4
|
use magnus::{
|
|
4
5
|
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
|
5
|
-
RClass, RModule, Ruby, TryConvert, TypedData, Value,
|
|
6
|
+
RArray, RClass, RModule, Ruby, TryConvert, TypedData, Value,
|
|
6
7
|
};
|
|
7
|
-
use serde::
|
|
8
|
+
use serde::ser::SerializeStruct;
|
|
9
|
+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|
8
10
|
use tk::processors::bert::BertProcessing;
|
|
9
11
|
use tk::processors::byte_level::ByteLevel;
|
|
10
12
|
use tk::processors::roberta::RobertaProcessing;
|
|
@@ -15,17 +17,28 @@ use tk::{Encoding, PostProcessor};
|
|
|
15
17
|
use super::{RbResult, PROCESSORS};
|
|
16
18
|
|
|
17
19
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
|
20
|
+
#[serde(transparent)]
|
|
18
21
|
pub struct RbPostProcessor {
|
|
19
|
-
|
|
20
|
-
pub processor: Arc<PostProcessorWrapper>,
|
|
22
|
+
pub processor: RbPostProcessorTypeWrapper,
|
|
21
23
|
}
|
|
22
24
|
|
|
23
25
|
impl RbPostProcessor {
|
|
24
|
-
pub fn new(processor:
|
|
26
|
+
pub fn new(processor: RbPostProcessorTypeWrapper) -> Self {
|
|
25
27
|
RbPostProcessor { processor }
|
|
26
28
|
}
|
|
27
29
|
}
|
|
28
30
|
|
|
31
|
+
impl<I> From<I> for RbPostProcessor
|
|
32
|
+
where
|
|
33
|
+
I: Into<PostProcessorWrapper>,
|
|
34
|
+
{
|
|
35
|
+
fn from(processor: I) -> Self {
|
|
36
|
+
RbPostProcessor {
|
|
37
|
+
processor: processor.into().into(),
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
29
42
|
impl PostProcessor for RbPostProcessor {
|
|
30
43
|
fn added_tokens(&self, is_pair: bool) -> usize {
|
|
31
44
|
self.processor.added_tokens(is_pair)
|
|
@@ -41,6 +54,92 @@ impl PostProcessor for RbPostProcessor {
|
|
|
41
54
|
}
|
|
42
55
|
}
|
|
43
56
|
|
|
57
|
+
#[derive(Clone)]
|
|
58
|
+
pub(crate) enum RbPostProcessorTypeWrapper {
|
|
59
|
+
Sequence(Vec<Arc<RwLock<PostProcessorWrapper>>>),
|
|
60
|
+
Single(Arc<RwLock<PostProcessorWrapper>>),
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
impl PostProcessor for RbPostProcessorTypeWrapper {
|
|
64
|
+
fn added_tokens(&self, is_pair: bool) -> usize {
|
|
65
|
+
match self {
|
|
66
|
+
RbPostProcessorTypeWrapper::Single(inner) => inner
|
|
67
|
+
.read()
|
|
68
|
+
.expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPostProcessor")
|
|
69
|
+
.added_tokens(is_pair),
|
|
70
|
+
RbPostProcessorTypeWrapper::Sequence(inner) => inner.iter().map(|p| {
|
|
71
|
+
p.read()
|
|
72
|
+
.expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPostProcessor")
|
|
73
|
+
.added_tokens(is_pair)
|
|
74
|
+
}).sum::<usize>(),
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
fn process_encodings(
|
|
79
|
+
&self,
|
|
80
|
+
mut encodings: Vec<Encoding>,
|
|
81
|
+
add_special_tokens: bool,
|
|
82
|
+
) -> tk::Result<Vec<Encoding>> {
|
|
83
|
+
match self {
|
|
84
|
+
RbPostProcessorTypeWrapper::Single(inner) => inner
|
|
85
|
+
.read()
|
|
86
|
+
.expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPreTokenizer")
|
|
87
|
+
.process_encodings(encodings, add_special_tokens),
|
|
88
|
+
RbPostProcessorTypeWrapper::Sequence(inner) => {
|
|
89
|
+
for processor in inner.iter() {
|
|
90
|
+
encodings = processor
|
|
91
|
+
.read()
|
|
92
|
+
.expect("RwLock synchronisation primitive is poisoned, cannot get subtype of RbPreTokenizer")
|
|
93
|
+
.process_encodings(encodings, add_special_tokens)?;
|
|
94
|
+
}
|
|
95
|
+
Ok(encodings)
|
|
96
|
+
},
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
impl<'de> Deserialize<'de> for RbPostProcessorTypeWrapper {
|
|
102
|
+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
103
|
+
where
|
|
104
|
+
D: Deserializer<'de>,
|
|
105
|
+
{
|
|
106
|
+
let wrapper = PostProcessorWrapper::deserialize(deserializer)?;
|
|
107
|
+
Ok(wrapper.into())
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
impl Serialize for RbPostProcessorTypeWrapper {
|
|
112
|
+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
113
|
+
where
|
|
114
|
+
S: Serializer,
|
|
115
|
+
{
|
|
116
|
+
match self {
|
|
117
|
+
RbPostProcessorTypeWrapper::Sequence(seq) => {
|
|
118
|
+
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
|
119
|
+
ser.serialize_field("type", "Sequence")?;
|
|
120
|
+
ser.serialize_field("processors", seq)?;
|
|
121
|
+
ser.end()
|
|
122
|
+
}
|
|
123
|
+
RbPostProcessorTypeWrapper::Single(inner) => inner.serialize(serializer),
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
impl<I> From<I> for RbPostProcessorTypeWrapper
|
|
129
|
+
where
|
|
130
|
+
I: Into<PostProcessorWrapper>,
|
|
131
|
+
{
|
|
132
|
+
fn from(processor: I) -> Self {
|
|
133
|
+
let processor = processor.into();
|
|
134
|
+
match processor {
|
|
135
|
+
PostProcessorWrapper::Sequence(seq) => RbPostProcessorTypeWrapper::Sequence(
|
|
136
|
+
seq.into_iter().map(|p| Arc::new(RwLock::new(p))).collect(),
|
|
137
|
+
),
|
|
138
|
+
_ => RbPostProcessorTypeWrapper::Single(Arc::new(RwLock::new(processor.clone()))),
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
44
143
|
#[derive(Clone, Debug)]
|
|
45
144
|
pub struct RbSpecialToken(SpecialToken);
|
|
46
145
|
|
|
@@ -91,7 +190,7 @@ pub struct RbBertProcessing {}
|
|
|
91
190
|
|
|
92
191
|
impl RbBertProcessing {
|
|
93
192
|
pub fn new(sep: (String, u32), cls: (String, u32)) -> RbPostProcessor {
|
|
94
|
-
|
|
193
|
+
BertProcessing::new(sep, cls).into()
|
|
95
194
|
}
|
|
96
195
|
}
|
|
97
196
|
|
|
@@ -104,7 +203,7 @@ impl RbByteLevel {
|
|
|
104
203
|
if let Some(to) = trim_offsets {
|
|
105
204
|
byte_level = byte_level.trim_offsets(to);
|
|
106
205
|
}
|
|
107
|
-
|
|
206
|
+
byte_level.into()
|
|
108
207
|
}
|
|
109
208
|
}
|
|
110
209
|
|
|
@@ -120,7 +219,7 @@ impl RbRobertaProcessing {
|
|
|
120
219
|
let proc = RobertaProcessing::new(sep, cls)
|
|
121
220
|
.trim_offsets(trim_offsets)
|
|
122
221
|
.add_prefix_space(add_prefix_space);
|
|
123
|
-
|
|
222
|
+
proc.into()
|
|
124
223
|
}
|
|
125
224
|
}
|
|
126
225
|
|
|
@@ -145,7 +244,27 @@ impl RbTemplateProcessing {
|
|
|
145
244
|
}
|
|
146
245
|
let processor = builder.build().unwrap(); //.map_err(RbError::from)?;
|
|
147
246
|
|
|
148
|
-
Ok(
|
|
247
|
+
Ok(processor.into())
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
pub struct RbSequence {}
|
|
252
|
+
|
|
253
|
+
impl RbSequence {
|
|
254
|
+
fn new(processors_rb: RArray) -> RbResult<RbPostProcessor> {
|
|
255
|
+
let mut processors = Vec::with_capacity(processors_rb.len());
|
|
256
|
+
for n in processors_rb {
|
|
257
|
+
let processor = <&RbPostProcessor>::try_convert(n)?;
|
|
258
|
+
match &processor.processor {
|
|
259
|
+
RbPostProcessorTypeWrapper::Sequence(inner) => {
|
|
260
|
+
processors.extend(inner.iter().cloned())
|
|
261
|
+
}
|
|
262
|
+
RbPostProcessorTypeWrapper::Single(inner) => processors.push(inner.clone()),
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
Ok(RbPostProcessor::new(RbPostProcessorTypeWrapper::Sequence(
|
|
266
|
+
processors,
|
|
267
|
+
)))
|
|
149
268
|
}
|
|
150
269
|
}
|
|
151
270
|
|
|
@@ -198,12 +317,20 @@ unsafe impl TypedData for RbPostProcessor {
|
|
|
198
317
|
class.undef_default_alloc_func();
|
|
199
318
|
class
|
|
200
319
|
});
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
320
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
|
321
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("Sequence").unwrap();
|
|
322
|
+
class.undef_default_alloc_func();
|
|
323
|
+
class
|
|
324
|
+
});
|
|
325
|
+
match &value.processor {
|
|
326
|
+
RbPostProcessorTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
|
327
|
+
PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
|
|
328
|
+
PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
|
329
|
+
PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
|
|
330
|
+
PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
|
|
331
|
+
_ => todo!(),
|
|
332
|
+
},
|
|
333
|
+
RbPostProcessorTypeWrapper::Sequence(_) => ruby.get_inner(&SEQUENCE),
|
|
207
334
|
}
|
|
208
335
|
}
|
|
209
336
|
}
|
|
@@ -223,5 +350,8 @@ pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
|
223
350
|
let class = module.define_class("TemplateProcessing", post_processor)?;
|
|
224
351
|
class.define_singleton_method("_new", function!(RbTemplateProcessing::new, 3))?;
|
|
225
352
|
|
|
353
|
+
let class = module.define_class("Sequence", post_processor)?;
|
|
354
|
+
class.define_singleton_method("_new", function!(RbSequence::new, 1))?;
|
|
355
|
+
|
|
226
356
|
Ok(())
|
|
227
357
|
}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
use std::ffi::c_void;
|
|
2
|
+
use std::ptr::null_mut;
|
|
3
|
+
|
|
4
|
+
use magnus::Ruby;
|
|
5
|
+
use rb_sys::rb_thread_call_without_gvl;
|
|
6
|
+
|
|
7
|
+
pub trait GvlExt {
|
|
8
|
+
fn detach<T, F>(&self, func: F) -> T
|
|
9
|
+
where
|
|
10
|
+
F: Send + FnOnce() -> T,
|
|
11
|
+
T: Send;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
impl GvlExt for Ruby {
|
|
15
|
+
fn detach<T, F>(&self, func: F) -> T
|
|
16
|
+
where
|
|
17
|
+
F: Send + FnOnce() -> T,
|
|
18
|
+
T: Send,
|
|
19
|
+
{
|
|
20
|
+
let mut data = CallbackData {
|
|
21
|
+
func: Some(func),
|
|
22
|
+
result: None,
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
unsafe {
|
|
26
|
+
rb_thread_call_without_gvl(
|
|
27
|
+
Some(call_without_gvl::<F, T>),
|
|
28
|
+
&mut data as *mut _ as *mut c_void,
|
|
29
|
+
None,
|
|
30
|
+
null_mut(),
|
|
31
|
+
);
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
data.result.unwrap()
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
struct CallbackData<F, T> {
|
|
39
|
+
func: Option<F>,
|
|
40
|
+
result: Option<T>,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
extern "C" fn call_without_gvl<F, T>(data: *mut c_void) -> *mut c_void
|
|
44
|
+
where
|
|
45
|
+
F: FnOnce() -> T,
|
|
46
|
+
{
|
|
47
|
+
let data = unsafe { &mut *(data as *mut CallbackData<F, T>) };
|
|
48
|
+
let func = data.func.take().unwrap();
|
|
49
|
+
data.result = Some(func());
|
|
50
|
+
null_mut()
|
|
51
|
+
}
|