tokenizers 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/Cargo.lock +33 -74
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +4 -2
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +78 -3
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +88 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +448 -20
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/encoding.rb +19 -0
  25. data/lib/tokenizers/from_pretrained.rb +1 -1
  26. data/lib/tokenizers/models/bpe.rb +9 -0
  27. data/lib/tokenizers/models/unigram.rb +9 -0
  28. data/lib/tokenizers/models/word_level.rb +13 -0
  29. data/lib/tokenizers/models/word_piece.rb +9 -0
  30. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  31. data/lib/tokenizers/normalizers/strip.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  36. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  37. data/lib/tokenizers/processors/byte_level.rb +9 -0
  38. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  39. data/lib/tokenizers/processors/template_processing.rb +9 -0
  40. data/lib/tokenizers/tokenizer.rb +45 -0
  41. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  42. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  43. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  44. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  45. data/lib/tokenizers/version.rb +1 -1
  46. data/lib/tokenizers.rb +49 -7
  47. metadata +32 -3
@@ -1,14 +1,478 @@
1
+ use std::sync::{Arc, RwLock};
2
+
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
6
+ RArray, RClass, RModule, TypedData,
7
+ };
8
+
9
+ use serde::ser::SerializeStruct;
10
+ use serde::{Deserialize, Serialize, Serializer};
11
+
1
12
  use tk::pre_tokenizers::bert::BertPreTokenizer;
13
+ use tk::pre_tokenizers::byte_level::ByteLevel;
14
+ use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
15
+ use tk::pre_tokenizers::digits::Digits;
16
+ use tk::pre_tokenizers::metaspace::Metaspace;
17
+ use tk::pre_tokenizers::punctuation::Punctuation;
18
+ use tk::pre_tokenizers::split::Split;
19
+ use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
20
+ use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
21
+ use tk::pre_tokenizers::PreTokenizerWrapper;
22
+ use tk::tokenizer::Offsets;
23
+ use tk::{PreTokenizedString, PreTokenizer};
24
+
25
+ use super::utils::*;
26
+ use super::{RbError, RbResult};
27
+
28
+ #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
29
+ pub struct RbPreTokenizer {
30
+ #[serde(flatten)]
31
+ pub(crate) pretok: RbPreTokenizerTypeWrapper,
32
+ }
33
+
34
+ impl RbPreTokenizer {
35
+ fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
36
+ let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
37
+
38
+ self.pretok.pre_tokenize(&mut pretokenized).map_err(RbError::from)?;
39
+
40
+ Ok(pretokenized
41
+ .get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
42
+ .into_iter()
43
+ .map(|(s, o, _)| (s.to_owned(), o))
44
+ .collect())
45
+ }
46
+ }
2
47
 
3
- #[magnus::wrap(class = "Tokenizers::BertPreTokenizer")]
4
- pub struct RbBertPreTokenizer {
5
- pub pretok: BertPreTokenizer,
48
+ macro_rules! getter {
49
+ ($self: ident, $variant: ident, $($name: tt)+) => {{
50
+ if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
51
+ if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) =
52
+ *single.read().unwrap() {
53
+ pretok.$($name)+
54
+ } else {
55
+ unreachable!()
56
+ }
57
+ } else {
58
+ unreachable!()
59
+ }
60
+ }};
6
61
  }
7
62
 
63
+ macro_rules! setter {
64
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
65
+ if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
66
+ if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
67
+ *single.write().unwrap()
68
+ {
69
+ pretok.$name = $value;
70
+ }
71
+ }
72
+ }};
73
+ ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
74
+ if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
75
+ if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
76
+ *single.write().unwrap()
77
+ {
78
+ pretok.$name($value);
79
+ }
80
+ }
81
+ }};
82
+ }
83
+
84
+ impl RbPreTokenizer {
85
+ #[allow(dead_code)]
86
+ pub(crate) fn new(pretok: RbPreTokenizerTypeWrapper) -> Self {
87
+ RbPreTokenizer { pretok }
88
+ }
89
+
90
+ fn byte_level_add_prefix_space(&self) -> bool {
91
+ getter!(self, ByteLevel, add_prefix_space)
92
+ }
93
+
94
+ fn byte_level_set_add_prefix_space(&self, add_prefix_space: bool) {
95
+ setter!(self, ByteLevel, add_prefix_space, add_prefix_space);
96
+ }
97
+
98
+ fn byte_level_use_regex(&self) -> bool {
99
+ getter!(self, ByteLevel, use_regex)
100
+ }
101
+
102
+ fn byte_level_set_use_regex(&self, use_regex: bool) {
103
+ setter!(self, ByteLevel, use_regex, use_regex);
104
+ }
105
+
106
+ fn char_delimiter_split_delimiter(&self) -> String {
107
+ getter!(self, Delimiter, delimiter.to_string())
108
+ }
109
+
110
+ fn char_delimiter_split_set_delimiter(&self, delimiter: char) {
111
+ setter!(self, Delimiter, delimiter, delimiter);
112
+ }
113
+
114
+ fn digits_individual_digits(&self) -> bool {
115
+ getter!(self, Digits, individual_digits)
116
+ }
117
+
118
+ fn digits_set_individual_digits(&self, individual_digits: bool) {
119
+ setter!(self, Digits, individual_digits, individual_digits);
120
+ }
121
+
122
+ fn metaspace_add_prefix_space(&self) -> bool {
123
+ getter!(self, Metaspace, add_prefix_space)
124
+ }
125
+
126
+ fn metaspace_set_add_prefix_space(&self, add_prefix_space: bool) {
127
+ setter!(self, Metaspace, add_prefix_space, add_prefix_space);
128
+ }
129
+
130
+ fn metaspace_replacement(&self) -> String {
131
+ getter!(self, Metaspace, get_replacement().to_string())
132
+ }
133
+
134
+ fn metaspace_set_replacement(&self, replacement: char) {
135
+ setter!(self, Metaspace, @set_replacement, replacement);
136
+ }
137
+ }
138
+
139
+ impl PreTokenizer for RbPreTokenizer {
140
+ fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
141
+ self.pretok.pre_tokenize(normalized)
142
+ }
143
+ }
144
+
145
+ pub struct RbByteLevel {}
146
+
147
+ impl RbByteLevel {
148
+ pub fn new(add_prefix_space: bool, use_regex: bool) -> RbPreTokenizer {
149
+ ByteLevel::default()
150
+ .add_prefix_space(add_prefix_space)
151
+ .use_regex(use_regex)
152
+ .into()
153
+ }
154
+
155
+ fn alphabet() -> Vec<String> {
156
+ ByteLevel::alphabet()
157
+ .into_iter()
158
+ .map(|c| c.to_string())
159
+ .collect()
160
+ }
161
+ }
162
+
163
+ pub struct RbCharDelimiterSplit {}
164
+
165
+ impl RbCharDelimiterSplit {
166
+ pub fn new(delimiter: char) -> RbPreTokenizer {
167
+ CharDelimiterSplit::new(delimiter).into()
168
+ }
169
+ }
170
+
171
+ pub struct RbDigits {}
172
+
173
+ impl RbDigits {
174
+ fn new(individual_digits: bool) -> RbPreTokenizer {
175
+ Digits::new(individual_digits).into()
176
+ }
177
+ }
178
+
179
+ pub struct RbMetaspace {}
180
+
181
+ impl RbMetaspace {
182
+ fn new(
183
+ replacement: char,
184
+ add_prefix_space: bool,
185
+ ) -> RbPreTokenizer {
186
+ Metaspace::new(replacement, add_prefix_space).into()
187
+ }
188
+ }
189
+
190
+ pub struct RbPunctuation {}
191
+
192
+ impl RbPunctuation {
193
+ pub fn new(behavior: RbSplitDelimiterBehavior) -> RbResult<RbPreTokenizer> {
194
+ Ok(Punctuation::new(behavior.into()).into())
195
+ }
196
+ }
197
+
198
+ pub struct RbSplit {}
199
+
200
+ impl RbSplit {
201
+ pub fn new(pattern: RbPattern, behavior: RbSplitDelimiterBehavior, invert: bool) -> RbResult<RbPreTokenizer> {
202
+ Split::new(pattern, behavior.into(), invert).map(|v| v.into()).map_err(RbError::from)
203
+ }
204
+ }
205
+
206
+ pub struct RbUnicodeScripts {}
207
+
208
+ impl RbUnicodeScripts {
209
+ pub fn new() -> RbPreTokenizer {
210
+ UnicodeScripts::new().into()
211
+ }
212
+ }
213
+
214
+ pub struct RbWhitespace {}
215
+
216
+ impl RbWhitespace {
217
+ pub fn new() -> RbPreTokenizer {
218
+ Whitespace::default().into()
219
+ }
220
+ }
221
+
222
+ pub struct RbWhitespaceSplit {}
223
+
224
+ impl RbWhitespaceSplit {
225
+ pub fn new() -> RbPreTokenizer {
226
+ WhitespaceSplit.into()
227
+ }
228
+ }
229
+
230
+ pub struct RbBertPreTokenizer {}
231
+
8
232
  impl RbBertPreTokenizer {
9
- pub fn new() -> Self {
10
- RbBertPreTokenizer {
11
- pretok: BertPreTokenizer,
233
+ pub fn new() -> RbPreTokenizer {
234
+ BertPreTokenizer.into()
235
+ }
236
+ }
237
+
238
+ pub struct RbSequence {}
239
+
240
+ impl RbSequence {
241
+ fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
242
+ let mut sequence = Vec::with_capacity(pre_tokenizers.len());
243
+ for n in pre_tokenizers.each() {
244
+ let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
245
+ match &pretokenizer.pretok {
246
+ RbPreTokenizerTypeWrapper::Sequence(inner) => {
247
+ sequence.extend(inner.iter().cloned())
248
+ }
249
+ RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
250
+ }
251
+ }
252
+ Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(sequence)))
253
+ }
254
+ }
255
+
256
+ #[derive(Clone, Deserialize)]
257
+ #[serde(untagged)]
258
+ pub(crate) enum RbPreTokenizerWrapper {
259
+ // Custom(CustomPreTokenizer),
260
+ Wrapped(PreTokenizerWrapper),
261
+ }
262
+
263
+ impl Serialize for RbPreTokenizerWrapper {
264
+ fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
265
+ where
266
+ S: Serializer,
267
+ {
268
+ match self {
269
+ RbPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
270
+ // RbPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
271
+ }
272
+ }
273
+ }
274
+
275
+ #[derive(Clone, Deserialize)]
276
+ #[serde(untagged)]
277
+ pub(crate) enum RbPreTokenizerTypeWrapper {
278
+ Sequence(Vec<Arc<RwLock<RbPreTokenizerWrapper>>>),
279
+ Single(Arc<RwLock<RbPreTokenizerWrapper>>),
280
+ }
281
+
282
+ impl Serialize for RbPreTokenizerTypeWrapper {
283
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
284
+ where
285
+ S: Serializer,
286
+ {
287
+ match self {
288
+ RbPreTokenizerTypeWrapper::Sequence(seq) => {
289
+ let mut ser = serializer.serialize_struct("Sequence", 2)?;
290
+ ser.serialize_field("type", "Sequence")?;
291
+ ser.serialize_field("pretokenizers", seq)?;
292
+ ser.end()
293
+ }
294
+ RbPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer),
295
+ }
296
+ }
297
+ }
298
+
299
+ impl<I> From<I> for RbPreTokenizerWrapper
300
+ where
301
+ I: Into<PreTokenizerWrapper>,
302
+ {
303
+ fn from(pretok: I) -> Self {
304
+ RbPreTokenizerWrapper::Wrapped(pretok.into())
305
+ }
306
+ }
307
+
308
+ impl<I> From<I> for RbPreTokenizerTypeWrapper
309
+ where
310
+ I: Into<RbPreTokenizerWrapper>,
311
+ {
312
+ fn from(pretok: I) -> Self {
313
+ RbPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into())))
314
+ }
315
+ }
316
+
317
+ impl<I> From<I> for RbPreTokenizer
318
+ where
319
+ I: Into<PreTokenizerWrapper>,
320
+ {
321
+ fn from(pretok: I) -> Self {
322
+ RbPreTokenizer {
323
+ pretok: pretok.into().into(),
324
+ }
325
+ }
326
+ }
327
+
328
+ impl PreTokenizer for RbPreTokenizerTypeWrapper {
329
+ fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
330
+ match self {
331
+ RbPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok),
332
+ RbPreTokenizerTypeWrapper::Sequence(inner) => inner
333
+ .iter()
334
+ .try_for_each(|n| n.read().unwrap().pre_tokenize(pretok)),
12
335
  }
13
336
  }
14
337
  }
338
+
339
+ impl PreTokenizer for RbPreTokenizerWrapper {
340
+ fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
341
+ match self {
342
+ RbPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok),
343
+ // RbPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok),
344
+ }
345
+ }
346
+ }
347
+
348
+ unsafe impl TypedData for RbPreTokenizer {
349
+ fn class() -> RClass {
350
+ *memoize!(RClass: {
351
+ let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
352
+ class.undef_alloc_func();
353
+ class
354
+ })
355
+ }
356
+
357
+ fn data_type() -> &'static DataType {
358
+ memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
359
+ }
360
+
361
+ fn class_for(value: &Self) -> RClass {
362
+ match &value.pretok {
363
+ RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
364
+ let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
365
+ class.undef_alloc_func();
366
+ class
367
+ }),
368
+ RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
369
+ RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
370
+ PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
371
+ let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
372
+ class.undef_alloc_func();
373
+ class
374
+ }),
375
+ PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
376
+ let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
377
+ class.undef_alloc_func();
378
+ class
379
+ }),
380
+ PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
381
+ let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
382
+ class.undef_alloc_func();
383
+ class
384
+ }),
385
+ PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
386
+ let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
387
+ class.undef_alloc_func();
388
+ class
389
+ }),
390
+ PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
391
+ let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
392
+ class.undef_alloc_func();
393
+ class
394
+ }),
395
+ PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
396
+ let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
397
+ class.undef_alloc_func();
398
+ class
399
+ }),
400
+ PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
401
+ let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
402
+ class.undef_alloc_func();
403
+ class
404
+ }),
405
+ PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
406
+ let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
407
+ class.undef_alloc_func();
408
+ class
409
+ }),
410
+ PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
411
+ let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
412
+ class.undef_alloc_func();
413
+ class
414
+ }),
415
+ PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
416
+ let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
417
+ class.undef_alloc_func();
418
+ class
419
+ }),
420
+ _ => todo!(),
421
+ },
422
+ },
423
+ }
424
+ }
425
+ }
426
+
427
+ pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
428
+ let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
429
+ pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
430
+
431
+ let class = module.define_class("Sequence", pre_tokenizer)?;
432
+ class.define_singleton_method("new", function!(RbSequence::new, 1))?;
433
+
434
+ let class = module.define_class("BertPreTokenizer", pre_tokenizer)?;
435
+ class.define_singleton_method("new", function!(RbBertPreTokenizer::new, 0))?;
436
+
437
+ let class = module.define_class("ByteLevel", pre_tokenizer)?;
438
+ class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
439
+ class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
440
+ class.define_method("add_prefix_space", method!(RbPreTokenizer::byte_level_add_prefix_space, 0))?;
441
+ class.define_method("add_prefix_space=", method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1))?;
442
+ class.define_method("use_regex", method!(RbPreTokenizer::byte_level_use_regex, 0))?;
443
+ class.define_method("use_regex=", method!(RbPreTokenizer::byte_level_set_use_regex, 1))?;
444
+
445
+ let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
446
+ class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
447
+ class.define_method("delimiter", method!(RbPreTokenizer::char_delimiter_split_delimiter, 0))?;
448
+ class.define_method("delimiter=", method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1))?;
449
+
450
+ let class = module.define_class("Digits", pre_tokenizer)?;
451
+ class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
452
+ class.define_method("individual_digits", method!(RbPreTokenizer::digits_individual_digits, 0))?;
453
+ class.define_method("individual_digits=", method!(RbPreTokenizer::digits_set_individual_digits, 1))?;
454
+
455
+ let class = module.define_class("Metaspace", pre_tokenizer)?;
456
+ class.define_singleton_method("_new", function!(RbMetaspace::new, 2))?;
457
+ class.define_method("add_prefix_space", method!(RbPreTokenizer::metaspace_add_prefix_space, 0))?;
458
+ class.define_method("add_prefix_space=", method!(RbPreTokenizer::metaspace_set_add_prefix_space, 1))?;
459
+ class.define_method("replacement", method!(RbPreTokenizer::metaspace_replacement, 0))?;
460
+ class.define_method("replacement=", method!(RbPreTokenizer::metaspace_set_replacement, 1))?;
461
+
462
+ let class = module.define_class("Punctuation", pre_tokenizer)?;
463
+ class.define_singleton_method("_new", function!(RbPunctuation::new, 1))?;
464
+
465
+ let class = module.define_class("Split", pre_tokenizer)?;
466
+ class.define_singleton_method("_new", function!(RbSplit::new, 3))?;
467
+
468
+ let class = module.define_class("UnicodeScripts", pre_tokenizer)?;
469
+ class.define_singleton_method("new", function!(RbUnicodeScripts::new, 0))?;
470
+
471
+ let class = module.define_class("Whitespace", pre_tokenizer)?;
472
+ class.define_singleton_method("new", function!(RbWhitespace::new, 0))?;
473
+
474
+ let class = module.define_class("WhitespaceSplit", pre_tokenizer)?;
475
+ class.define_singleton_method("new", function!(RbWhitespaceSplit::new, 0))?;
476
+
477
+ Ok(())
478
+ }
@@ -0,0 +1,210 @@
1
+ use std::sync::Arc;
2
+
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
+ TryConvert, TypedData, Value,
7
+ };
8
+ use serde::{Deserialize, Serialize};
9
+ use tk::processors::bert::BertProcessing;
10
+ use tk::processors::byte_level::ByteLevel;
11
+ use tk::processors::roberta::RobertaProcessing;
12
+ use tk::processors::template::{SpecialToken, Template};
13
+ use tk::processors::PostProcessorWrapper;
14
+ use tk::{Encoding, PostProcessor};
15
+
16
+ use super::RbResult;
17
+
18
+ #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
19
+ pub struct RbPostProcessor {
20
+ #[serde(flatten)]
21
+ pub processor: Arc<PostProcessorWrapper>,
22
+ }
23
+
24
+ impl RbPostProcessor {
25
+ pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
26
+ RbPostProcessor { processor }
27
+ }
28
+ }
29
+
30
+ impl PostProcessor for RbPostProcessor {
31
+ fn added_tokens(&self, is_pair: bool) -> usize {
32
+ self.processor.added_tokens(is_pair)
33
+ }
34
+
35
+ fn process_encodings(
36
+ &self,
37
+ encodings: Vec<Encoding>,
38
+ add_special_tokens: bool,
39
+ ) -> tk::Result<Vec<Encoding>> {
40
+ self.processor
41
+ .process_encodings(encodings, add_special_tokens)
42
+ }
43
+ }
44
+
45
+ #[derive(Clone, Debug)]
46
+ pub struct RbSpecialToken(SpecialToken);
47
+
48
+ impl From<RbSpecialToken> for SpecialToken {
49
+ fn from(v: RbSpecialToken) -> Self {
50
+ v.0
51
+ }
52
+ }
53
+
54
+ impl TryConvert for RbSpecialToken {
55
+ fn try_convert(ob: Value) -> RbResult<Self> {
56
+ if let Ok(v) = ob.try_convert::<(String, u32)>() {
57
+ Ok(Self(v.into()))
58
+ } else if let Ok(v) = ob.try_convert::<(u32, String)>() {
59
+ Ok(Self(v.into()))
60
+ } else {
61
+ todo!()
62
+ }
63
+ }
64
+ }
65
+
66
+ #[derive(Clone, Debug)]
67
+ pub struct RbTemplate(Template);
68
+
69
+ impl From<RbTemplate> for Template {
70
+ fn from(v: RbTemplate) -> Self {
71
+ v.0
72
+ }
73
+ }
74
+
75
+ impl TryConvert for RbTemplate {
76
+ fn try_convert(ob: Value) -> RbResult<Self> {
77
+ if let Ok(s) = ob.try_convert::<String>() {
78
+ Ok(Self(
79
+ s.try_into().unwrap(), //.map_err(RbError::from)?,
80
+ ))
81
+ } else if let Ok(s) = ob.try_convert::<Vec<String>>() {
82
+ Ok(Self(
83
+ s.try_into().unwrap(), //.map_err(RbError::from)?,
84
+ ))
85
+ } else {
86
+ todo!()
87
+ }
88
+ }
89
+ }
90
+
91
+ pub struct RbBertProcessing {}
92
+
93
+ impl RbBertProcessing {
94
+ pub fn new(sep: (String, u32), cls: (String, u32)) -> RbPostProcessor {
95
+ RbPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into()))
96
+ }
97
+ }
98
+
99
+ pub struct RbByteLevel {}
100
+
101
+ impl RbByteLevel {
102
+ pub fn new(trim_offsets: Option<bool>) -> RbPostProcessor {
103
+ let mut byte_level = ByteLevel::default();
104
+
105
+ if let Some(to) = trim_offsets {
106
+ byte_level = byte_level.trim_offsets(to);
107
+ }
108
+ RbPostProcessor::new(Arc::new(byte_level.into()))
109
+ }
110
+
111
+ }
112
+
113
+ pub struct RbRobertaProcessing {}
114
+
115
+ impl RbRobertaProcessing {
116
+ fn new(
117
+ sep: (String, u32),
118
+ cls: (String, u32),
119
+ trim_offsets: bool,
120
+ add_prefix_space: bool,
121
+ ) -> RbPostProcessor {
122
+ let proc = RobertaProcessing::new(sep, cls)
123
+ .trim_offsets(trim_offsets)
124
+ .add_prefix_space(add_prefix_space);
125
+ RbPostProcessor::new(Arc::new(proc.into()))
126
+ }
127
+ }
128
+
129
+ pub struct RbTemplateProcessing {}
130
+
131
+ impl RbTemplateProcessing {
132
+ pub fn new(
133
+ single: Option<RbTemplate>,
134
+ pair: Option<RbTemplate>,
135
+ special_tokens: Option<Vec<(String, u32)>>,
136
+ ) -> RbResult<RbPostProcessor> {
137
+ let mut builder = tk::processors::template::TemplateProcessing::builder();
138
+
139
+ if let Some(seq) = single {
140
+ builder.single(seq.into());
141
+ }
142
+ if let Some(seq) = pair {
143
+ builder.pair(seq.into());
144
+ }
145
+ if let Some(sp) = special_tokens {
146
+ builder.special_tokens(sp);
147
+ }
148
+ let processor = builder.build().unwrap(); //.map_err(RbError::from)?;
149
+
150
+ Ok(RbPostProcessor::new(Arc::new(processor.into())))
151
+ }
152
+ }
153
+
154
+ unsafe impl TypedData for RbPostProcessor {
155
+ fn class() -> RClass {
156
+ *memoize!(RClass: {
157
+ let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
158
+ class.undef_alloc_func();
159
+ class
160
+ })
161
+ }
162
+
163
+ fn data_type() -> &'static DataType {
164
+ memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
165
+ }
166
+
167
+ fn class_for(value: &Self) -> RClass {
168
+ match *value.processor {
169
+ PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
170
+ let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
171
+ class.undef_alloc_func();
172
+ class
173
+ }),
174
+ PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
175
+ let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
176
+ class.undef_alloc_func();
177
+ class
178
+ }),
179
+ PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
180
+ let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
181
+ class.undef_alloc_func();
182
+ class
183
+ }),
184
+ PostProcessorWrapper::Template(_) => *memoize!(RClass: {
185
+ let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
186
+ class.undef_alloc_func();
187
+ class
188
+ }),
189
+ _ => todo!(),
190
+ }
191
+ }
192
+ }
193
+
194
+ pub fn processors(module: &RModule) -> RbResult<()> {
195
+ let post_processor = module.define_class("PostProcessor", Default::default())?;
196
+
197
+ let class = module.define_class("BertProcessing", post_processor)?;
198
+ class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
199
+
200
+ let class = module.define_class("ByteLevel", post_processor)?;
201
+ class.define_singleton_method("_new", function!(RbByteLevel::new, 1))?;
202
+
203
+ let class = module.define_class("RobertaProcessing", post_processor)?;
204
+ class.define_singleton_method("_new", function!(RbRobertaProcessing::new, 4))?;
205
+
206
+ let class = module.define_class("TemplateProcessing", post_processor)?;
207
+ class.define_singleton_method("_new", function!(RbTemplateProcessing::new, 3))?;
208
+
209
+ Ok(())
210
+ }