tokenizers 0.2.3 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/Cargo.lock +33 -74
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +4 -2
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +3 -2
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +64 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +437 -23
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +9 -6
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/from_pretrained.rb +2 -2
  25. data/lib/tokenizers/models/bpe.rb +9 -0
  26. data/lib/tokenizers/models/unigram.rb +9 -0
  27. data/lib/tokenizers/models/word_level.rb +13 -0
  28. data/lib/tokenizers/models/word_piece.rb +9 -0
  29. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  30. data/lib/tokenizers/normalizers/strip.rb +9 -0
  31. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  36. data/lib/tokenizers/processors/byte_level.rb +9 -0
  37. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  38. data/lib/tokenizers/processors/template_processing.rb +9 -0
  39. data/lib/tokenizers/tokenizer.rb +40 -7
  40. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  41. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  42. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  43. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  44. data/lib/tokenizers/version.rb +1 -1
  45. data/lib/tokenizers.rb +42 -2
  46. metadata +30 -3
@@ -1,14 +1,478 @@
1
+ use std::sync::{Arc, RwLock};
2
+
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
6
+ RArray, RClass, RModule, TypedData,
7
+ };
8
+
9
+ use serde::ser::SerializeStruct;
10
+ use serde::{Deserialize, Serialize, Serializer};
11
+
1
12
  use tk::pre_tokenizers::bert::BertPreTokenizer;
13
+ use tk::pre_tokenizers::byte_level::ByteLevel;
14
+ use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
15
+ use tk::pre_tokenizers::digits::Digits;
16
+ use tk::pre_tokenizers::metaspace::Metaspace;
17
+ use tk::pre_tokenizers::punctuation::Punctuation;
18
+ use tk::pre_tokenizers::split::Split;
19
+ use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
20
+ use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
21
+ use tk::pre_tokenizers::PreTokenizerWrapper;
22
+ use tk::tokenizer::Offsets;
23
+ use tk::{PreTokenizedString, PreTokenizer};
24
+
25
+ use super::utils::*;
26
+ use super::{RbError, RbResult};
27
+
28
+ #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
29
+ pub struct RbPreTokenizer {
30
+ #[serde(flatten)]
31
+ pub(crate) pretok: RbPreTokenizerTypeWrapper,
32
+ }
33
+
34
+ impl RbPreTokenizer {
35
+ fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
36
+ let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
37
+
38
+ self.pretok.pre_tokenize(&mut pretokenized).map_err(RbError::from)?;
39
+
40
+ Ok(pretokenized
41
+ .get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
42
+ .into_iter()
43
+ .map(|(s, o, _)| (s.to_owned(), o))
44
+ .collect())
45
+ }
46
+ }
2
47
 
3
- #[magnus::wrap(class = "Tokenizers::BertPreTokenizer")]
4
- pub struct RbBertPreTokenizer {
5
- pub pretok: BertPreTokenizer,
48
+ macro_rules! getter {
49
+ ($self: ident, $variant: ident, $($name: tt)+) => {{
50
+ if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
51
+ if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) =
52
+ *single.read().unwrap() {
53
+ pretok.$($name)+
54
+ } else {
55
+ unreachable!()
56
+ }
57
+ } else {
58
+ unreachable!()
59
+ }
60
+ }};
6
61
  }
7
62
 
63
+ macro_rules! setter {
64
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
65
+ if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
66
+ if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
67
+ *single.write().unwrap()
68
+ {
69
+ pretok.$name = $value;
70
+ }
71
+ }
72
+ }};
73
+ ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
74
+ if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
75
+ if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
76
+ *single.write().unwrap()
77
+ {
78
+ pretok.$name($value);
79
+ }
80
+ }
81
+ }};
82
+ }
83
+
84
+ impl RbPreTokenizer {
85
+ #[allow(dead_code)]
86
+ pub(crate) fn new(pretok: RbPreTokenizerTypeWrapper) -> Self {
87
+ RbPreTokenizer { pretok }
88
+ }
89
+
90
+ fn byte_level_add_prefix_space(&self) -> bool {
91
+ getter!(self, ByteLevel, add_prefix_space)
92
+ }
93
+
94
+ fn byte_level_set_add_prefix_space(&self, add_prefix_space: bool) {
95
+ setter!(self, ByteLevel, add_prefix_space, add_prefix_space);
96
+ }
97
+
98
+ fn byte_level_use_regex(&self) -> bool {
99
+ getter!(self, ByteLevel, use_regex)
100
+ }
101
+
102
+ fn byte_level_set_use_regex(&self, use_regex: bool) {
103
+ setter!(self, ByteLevel, use_regex, use_regex);
104
+ }
105
+
106
+ fn char_delimiter_split_delimiter(&self) -> String {
107
+ getter!(self, Delimiter, delimiter.to_string())
108
+ }
109
+
110
+ fn char_delimiter_split_set_delimiter(&self, delimiter: char) {
111
+ setter!(self, Delimiter, delimiter, delimiter);
112
+ }
113
+
114
+ fn digits_individual_digits(&self) -> bool {
115
+ getter!(self, Digits, individual_digits)
116
+ }
117
+
118
+ fn digits_set_individual_digits(&self, individual_digits: bool) {
119
+ setter!(self, Digits, individual_digits, individual_digits);
120
+ }
121
+
122
+ fn metaspace_add_prefix_space(&self) -> bool {
123
+ getter!(self, Metaspace, add_prefix_space)
124
+ }
125
+
126
+ fn metaspace_set_add_prefix_space(&self, add_prefix_space: bool) {
127
+ setter!(self, Metaspace, add_prefix_space, add_prefix_space);
128
+ }
129
+
130
+ fn metaspace_replacement(&self) -> String {
131
+ getter!(self, Metaspace, get_replacement().to_string())
132
+ }
133
+
134
+ fn metaspace_set_replacement(&self, replacement: char) {
135
+ setter!(self, Metaspace, @set_replacement, replacement);
136
+ }
137
+ }
138
+
139
+ impl PreTokenizer for RbPreTokenizer {
140
+ fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
141
+ self.pretok.pre_tokenize(normalized)
142
+ }
143
+ }
144
+
145
+ pub struct RbByteLevel {}
146
+
147
+ impl RbByteLevel {
148
+ pub fn new(add_prefix_space: bool, use_regex: bool) -> RbPreTokenizer {
149
+ ByteLevel::default()
150
+ .add_prefix_space(add_prefix_space)
151
+ .use_regex(use_regex)
152
+ .into()
153
+ }
154
+
155
+ fn alphabet() -> Vec<String> {
156
+ ByteLevel::alphabet()
157
+ .into_iter()
158
+ .map(|c| c.to_string())
159
+ .collect()
160
+ }
161
+ }
162
+
163
+ pub struct RbCharDelimiterSplit {}
164
+
165
+ impl RbCharDelimiterSplit {
166
+ pub fn new(delimiter: char) -> RbPreTokenizer {
167
+ CharDelimiterSplit::new(delimiter).into()
168
+ }
169
+ }
170
+
171
+ pub struct RbDigits {}
172
+
173
+ impl RbDigits {
174
+ fn new(individual_digits: bool) -> RbPreTokenizer {
175
+ Digits::new(individual_digits).into()
176
+ }
177
+ }
178
+
179
+ pub struct RbMetaspace {}
180
+
181
+ impl RbMetaspace {
182
+ fn new(
183
+ replacement: char,
184
+ add_prefix_space: bool,
185
+ ) -> RbPreTokenizer {
186
+ Metaspace::new(replacement, add_prefix_space).into()
187
+ }
188
+ }
189
+
190
+ pub struct RbPunctuation {}
191
+
192
+ impl RbPunctuation {
193
+ pub fn new(behavior: RbSplitDelimiterBehavior) -> RbResult<RbPreTokenizer> {
194
+ Ok(Punctuation::new(behavior.into()).into())
195
+ }
196
+ }
197
+
198
+ pub struct RbSplit {}
199
+
200
+ impl RbSplit {
201
+ pub fn new(pattern: RbPattern, behavior: RbSplitDelimiterBehavior, invert: bool) -> RbResult<RbPreTokenizer> {
202
+ Split::new(pattern, behavior.into(), invert).map(|v| v.into()).map_err(RbError::from)
203
+ }
204
+ }
205
+
206
+ pub struct RbUnicodeScripts {}
207
+
208
+ impl RbUnicodeScripts {
209
+ pub fn new() -> RbPreTokenizer {
210
+ UnicodeScripts::new().into()
211
+ }
212
+ }
213
+
214
+ pub struct RbWhitespace {}
215
+
216
+ impl RbWhitespace {
217
+ pub fn new() -> RbPreTokenizer {
218
+ Whitespace::default().into()
219
+ }
220
+ }
221
+
222
+ pub struct RbWhitespaceSplit {}
223
+
224
+ impl RbWhitespaceSplit {
225
+ pub fn new() -> RbPreTokenizer {
226
+ WhitespaceSplit.into()
227
+ }
228
+ }
229
+
230
+ pub struct RbBertPreTokenizer {}
231
+
8
232
  impl RbBertPreTokenizer {
9
- pub fn new() -> Self {
10
- RbBertPreTokenizer {
11
- pretok: BertPreTokenizer,
233
+ pub fn new() -> RbPreTokenizer {
234
+ BertPreTokenizer.into()
235
+ }
236
+ }
237
+
238
+ pub struct RbSequence {}
239
+
240
+ impl RbSequence {
241
+ fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
242
+ let mut sequence = Vec::with_capacity(pre_tokenizers.len());
243
+ for n in pre_tokenizers.each() {
244
+ let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
245
+ match &pretokenizer.pretok {
246
+ RbPreTokenizerTypeWrapper::Sequence(inner) => {
247
+ sequence.extend(inner.iter().cloned())
248
+ }
249
+ RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
250
+ }
251
+ }
252
+ Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(sequence)))
253
+ }
254
+ }
255
+
256
+ #[derive(Clone, Deserialize)]
257
+ #[serde(untagged)]
258
+ pub(crate) enum RbPreTokenizerWrapper {
259
+ // Custom(CustomPreTokenizer),
260
+ Wrapped(PreTokenizerWrapper),
261
+ }
262
+
263
+ impl Serialize for RbPreTokenizerWrapper {
264
+ fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
265
+ where
266
+ S: Serializer,
267
+ {
268
+ match self {
269
+ RbPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
270
+ // RbPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
271
+ }
272
+ }
273
+ }
274
+
275
+ #[derive(Clone, Deserialize)]
276
+ #[serde(untagged)]
277
+ pub(crate) enum RbPreTokenizerTypeWrapper {
278
+ Sequence(Vec<Arc<RwLock<RbPreTokenizerWrapper>>>),
279
+ Single(Arc<RwLock<RbPreTokenizerWrapper>>),
280
+ }
281
+
282
+ impl Serialize for RbPreTokenizerTypeWrapper {
283
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
284
+ where
285
+ S: Serializer,
286
+ {
287
+ match self {
288
+ RbPreTokenizerTypeWrapper::Sequence(seq) => {
289
+ let mut ser = serializer.serialize_struct("Sequence", 2)?;
290
+ ser.serialize_field("type", "Sequence")?;
291
+ ser.serialize_field("pretokenizers", seq)?;
292
+ ser.end()
293
+ }
294
+ RbPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer),
295
+ }
296
+ }
297
+ }
298
+
299
+ impl<I> From<I> for RbPreTokenizerWrapper
300
+ where
301
+ I: Into<PreTokenizerWrapper>,
302
+ {
303
+ fn from(pretok: I) -> Self {
304
+ RbPreTokenizerWrapper::Wrapped(pretok.into())
305
+ }
306
+ }
307
+
308
+ impl<I> From<I> for RbPreTokenizerTypeWrapper
309
+ where
310
+ I: Into<RbPreTokenizerWrapper>,
311
+ {
312
+ fn from(pretok: I) -> Self {
313
+ RbPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into())))
314
+ }
315
+ }
316
+
317
+ impl<I> From<I> for RbPreTokenizer
318
+ where
319
+ I: Into<PreTokenizerWrapper>,
320
+ {
321
+ fn from(pretok: I) -> Self {
322
+ RbPreTokenizer {
323
+ pretok: pretok.into().into(),
324
+ }
325
+ }
326
+ }
327
+
328
+ impl PreTokenizer for RbPreTokenizerTypeWrapper {
329
+ fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
330
+ match self {
331
+ RbPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok),
332
+ RbPreTokenizerTypeWrapper::Sequence(inner) => inner
333
+ .iter()
334
+ .try_for_each(|n| n.read().unwrap().pre_tokenize(pretok)),
12
335
  }
13
336
  }
14
337
  }
338
+
339
+ impl PreTokenizer for RbPreTokenizerWrapper {
340
+ fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
341
+ match self {
342
+ RbPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok),
343
+ // RbPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok),
344
+ }
345
+ }
346
+ }
347
+
348
+ unsafe impl TypedData for RbPreTokenizer {
349
+ fn class() -> RClass {
350
+ *memoize!(RClass: {
351
+ let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
352
+ class.undef_alloc_func();
353
+ class
354
+ })
355
+ }
356
+
357
+ fn data_type() -> &'static DataType {
358
+ memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
359
+ }
360
+
361
+ fn class_for(value: &Self) -> RClass {
362
+ match &value.pretok {
363
+ RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
364
+ let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
365
+ class.undef_alloc_func();
366
+ class
367
+ }),
368
+ RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
369
+ RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
370
+ PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
371
+ let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
372
+ class.undef_alloc_func();
373
+ class
374
+ }),
375
+ PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
376
+ let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
377
+ class.undef_alloc_func();
378
+ class
379
+ }),
380
+ PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
381
+ let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
382
+ class.undef_alloc_func();
383
+ class
384
+ }),
385
+ PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
386
+ let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
387
+ class.undef_alloc_func();
388
+ class
389
+ }),
390
+ PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
391
+ let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
392
+ class.undef_alloc_func();
393
+ class
394
+ }),
395
+ PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
396
+ let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
397
+ class.undef_alloc_func();
398
+ class
399
+ }),
400
+ PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
401
+ let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
402
+ class.undef_alloc_func();
403
+ class
404
+ }),
405
+ PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
406
+ let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
407
+ class.undef_alloc_func();
408
+ class
409
+ }),
410
+ PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
411
+ let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
412
+ class.undef_alloc_func();
413
+ class
414
+ }),
415
+ PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
416
+ let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
417
+ class.undef_alloc_func();
418
+ class
419
+ }),
420
+ _ => todo!(),
421
+ },
422
+ },
423
+ }
424
+ }
425
+ }
426
+
427
+ pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
428
+ let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
429
+ pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
430
+
431
+ let class = module.define_class("Sequence", pre_tokenizer)?;
432
+ class.define_singleton_method("new", function!(RbSequence::new, 1))?;
433
+
434
+ let class = module.define_class("BertPreTokenizer", pre_tokenizer)?;
435
+ class.define_singleton_method("new", function!(RbBertPreTokenizer::new, 0))?;
436
+
437
+ let class = module.define_class("ByteLevel", pre_tokenizer)?;
438
+ class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
439
+ class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
440
+ class.define_method("add_prefix_space", method!(RbPreTokenizer::byte_level_add_prefix_space, 0))?;
441
+ class.define_method("add_prefix_space=", method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1))?;
442
+ class.define_method("use_regex", method!(RbPreTokenizer::byte_level_use_regex, 0))?;
443
+ class.define_method("use_regex=", method!(RbPreTokenizer::byte_level_set_use_regex, 1))?;
444
+
445
+ let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
446
+ class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
447
+ class.define_method("delimiter", method!(RbPreTokenizer::char_delimiter_split_delimiter, 0))?;
448
+ class.define_method("delimiter=", method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1))?;
449
+
450
+ let class = module.define_class("Digits", pre_tokenizer)?;
451
+ class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
452
+ class.define_method("individual_digits", method!(RbPreTokenizer::digits_individual_digits, 0))?;
453
+ class.define_method("individual_digits=", method!(RbPreTokenizer::digits_set_individual_digits, 1))?;
454
+
455
+ let class = module.define_class("Metaspace", pre_tokenizer)?;
456
+ class.define_singleton_method("_new", function!(RbMetaspace::new, 2))?;
457
+ class.define_method("add_prefix_space", method!(RbPreTokenizer::metaspace_add_prefix_space, 0))?;
458
+ class.define_method("add_prefix_space=", method!(RbPreTokenizer::metaspace_set_add_prefix_space, 1))?;
459
+ class.define_method("replacement", method!(RbPreTokenizer::metaspace_replacement, 0))?;
460
+ class.define_method("replacement=", method!(RbPreTokenizer::metaspace_set_replacement, 1))?;
461
+
462
+ let class = module.define_class("Punctuation", pre_tokenizer)?;
463
+ class.define_singleton_method("_new", function!(RbPunctuation::new, 1))?;
464
+
465
+ let class = module.define_class("Split", pre_tokenizer)?;
466
+ class.define_singleton_method("_new", function!(RbSplit::new, 3))?;
467
+
468
+ let class = module.define_class("UnicodeScripts", pre_tokenizer)?;
469
+ class.define_singleton_method("new", function!(RbUnicodeScripts::new, 0))?;
470
+
471
+ let class = module.define_class("Whitespace", pre_tokenizer)?;
472
+ class.define_singleton_method("new", function!(RbWhitespace::new, 0))?;
473
+
474
+ let class = module.define_class("WhitespaceSplit", pre_tokenizer)?;
475
+ class.define_singleton_method("new", function!(RbWhitespaceSplit::new, 0))?;
476
+
477
+ Ok(())
478
+ }
@@ -0,0 +1,210 @@
1
+ use std::sync::Arc;
2
+
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
+ TryConvert, TypedData, Value,
7
+ };
8
+ use serde::{Deserialize, Serialize};
9
+ use tk::processors::bert::BertProcessing;
10
+ use tk::processors::byte_level::ByteLevel;
11
+ use tk::processors::roberta::RobertaProcessing;
12
+ use tk::processors::template::{SpecialToken, Template};
13
+ use tk::processors::PostProcessorWrapper;
14
+ use tk::{Encoding, PostProcessor};
15
+
16
+ use super::RbResult;
17
+
18
+ #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
19
+ pub struct RbPostProcessor {
20
+ #[serde(flatten)]
21
+ pub processor: Arc<PostProcessorWrapper>,
22
+ }
23
+
24
+ impl RbPostProcessor {
25
+ pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
26
+ RbPostProcessor { processor }
27
+ }
28
+ }
29
+
30
+ impl PostProcessor for RbPostProcessor {
31
+ fn added_tokens(&self, is_pair: bool) -> usize {
32
+ self.processor.added_tokens(is_pair)
33
+ }
34
+
35
+ fn process_encodings(
36
+ &self,
37
+ encodings: Vec<Encoding>,
38
+ add_special_tokens: bool,
39
+ ) -> tk::Result<Vec<Encoding>> {
40
+ self.processor
41
+ .process_encodings(encodings, add_special_tokens)
42
+ }
43
+ }
44
+
45
+ #[derive(Clone, Debug)]
46
+ pub struct RbSpecialToken(SpecialToken);
47
+
48
+ impl From<RbSpecialToken> for SpecialToken {
49
+ fn from(v: RbSpecialToken) -> Self {
50
+ v.0
51
+ }
52
+ }
53
+
54
+ impl TryConvert for RbSpecialToken {
55
+ fn try_convert(ob: Value) -> RbResult<Self> {
56
+ if let Ok(v) = ob.try_convert::<(String, u32)>() {
57
+ Ok(Self(v.into()))
58
+ } else if let Ok(v) = ob.try_convert::<(u32, String)>() {
59
+ Ok(Self(v.into()))
60
+ } else {
61
+ todo!()
62
+ }
63
+ }
64
+ }
65
+
66
+ #[derive(Clone, Debug)]
67
+ pub struct RbTemplate(Template);
68
+
69
+ impl From<RbTemplate> for Template {
70
+ fn from(v: RbTemplate) -> Self {
71
+ v.0
72
+ }
73
+ }
74
+
75
+ impl TryConvert for RbTemplate {
76
+ fn try_convert(ob: Value) -> RbResult<Self> {
77
+ if let Ok(s) = ob.try_convert::<String>() {
78
+ Ok(Self(
79
+ s.try_into().unwrap(), //.map_err(RbError::from)?,
80
+ ))
81
+ } else if let Ok(s) = ob.try_convert::<Vec<String>>() {
82
+ Ok(Self(
83
+ s.try_into().unwrap(), //.map_err(RbError::from)?,
84
+ ))
85
+ } else {
86
+ todo!()
87
+ }
88
+ }
89
+ }
90
+
91
+ pub struct RbBertProcessing {}
92
+
93
+ impl RbBertProcessing {
94
+ pub fn new(sep: (String, u32), cls: (String, u32)) -> RbPostProcessor {
95
+ RbPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into()))
96
+ }
97
+ }
98
+
99
+ pub struct RbByteLevel {}
100
+
101
+ impl RbByteLevel {
102
+ pub fn new(trim_offsets: Option<bool>) -> RbPostProcessor {
103
+ let mut byte_level = ByteLevel::default();
104
+
105
+ if let Some(to) = trim_offsets {
106
+ byte_level = byte_level.trim_offsets(to);
107
+ }
108
+ RbPostProcessor::new(Arc::new(byte_level.into()))
109
+ }
110
+
111
+ }
112
+
113
+ pub struct RbRobertaProcessing {}
114
+
115
+ impl RbRobertaProcessing {
116
+ fn new(
117
+ sep: (String, u32),
118
+ cls: (String, u32),
119
+ trim_offsets: bool,
120
+ add_prefix_space: bool,
121
+ ) -> RbPostProcessor {
122
+ let proc = RobertaProcessing::new(sep, cls)
123
+ .trim_offsets(trim_offsets)
124
+ .add_prefix_space(add_prefix_space);
125
+ RbPostProcessor::new(Arc::new(proc.into()))
126
+ }
127
+ }
128
+
129
+ pub struct RbTemplateProcessing {}
130
+
131
+ impl RbTemplateProcessing {
132
+ pub fn new(
133
+ single: Option<RbTemplate>,
134
+ pair: Option<RbTemplate>,
135
+ special_tokens: Option<Vec<(String, u32)>>,
136
+ ) -> RbResult<RbPostProcessor> {
137
+ let mut builder = tk::processors::template::TemplateProcessing::builder();
138
+
139
+ if let Some(seq) = single {
140
+ builder.single(seq.into());
141
+ }
142
+ if let Some(seq) = pair {
143
+ builder.pair(seq.into());
144
+ }
145
+ if let Some(sp) = special_tokens {
146
+ builder.special_tokens(sp);
147
+ }
148
+ let processor = builder.build().unwrap(); //.map_err(RbError::from)?;
149
+
150
+ Ok(RbPostProcessor::new(Arc::new(processor.into())))
151
+ }
152
+ }
153
+
154
+ unsafe impl TypedData for RbPostProcessor {
155
+ fn class() -> RClass {
156
+ *memoize!(RClass: {
157
+ let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
158
+ class.undef_alloc_func();
159
+ class
160
+ })
161
+ }
162
+
163
+ fn data_type() -> &'static DataType {
164
+ memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
165
+ }
166
+
167
+ fn class_for(value: &Self) -> RClass {
168
+ match *value.processor {
169
+ PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
170
+ let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
171
+ class.undef_alloc_func();
172
+ class
173
+ }),
174
+ PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
175
+ let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
176
+ class.undef_alloc_func();
177
+ class
178
+ }),
179
+ PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
180
+ let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
181
+ class.undef_alloc_func();
182
+ class
183
+ }),
184
+ PostProcessorWrapper::Template(_) => *memoize!(RClass: {
185
+ let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
186
+ class.undef_alloc_func();
187
+ class
188
+ }),
189
+ _ => todo!(),
190
+ }
191
+ }
192
+ }
193
+
194
+ pub fn processors(module: &RModule) -> RbResult<()> {
195
+ let post_processor = module.define_class("PostProcessor", Default::default())?;
196
+
197
+ let class = module.define_class("BertProcessing", post_processor)?;
198
+ class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
199
+
200
+ let class = module.define_class("ByteLevel", post_processor)?;
201
+ class.define_singleton_method("_new", function!(RbByteLevel::new, 1))?;
202
+
203
+ let class = module.define_class("RobertaProcessing", post_processor)?;
204
+ class.define_singleton_method("_new", function!(RbRobertaProcessing::new, 4))?;
205
+
206
+ let class = module.define_class("TemplateProcessing", post_processor)?;
207
+ class.define_singleton_method("_new", function!(RbTemplateProcessing::new, 3))?;
208
+
209
+ Ok(())
210
+ }