tokenizers 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/Cargo.lock +33 -74
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +4 -2
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +78 -3
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +88 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +448 -20
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/encoding.rb +19 -0
  25. data/lib/tokenizers/from_pretrained.rb +1 -1
  26. data/lib/tokenizers/models/bpe.rb +9 -0
  27. data/lib/tokenizers/models/unigram.rb +9 -0
  28. data/lib/tokenizers/models/word_level.rb +13 -0
  29. data/lib/tokenizers/models/word_piece.rb +9 -0
  30. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  31. data/lib/tokenizers/normalizers/strip.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  36. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  37. data/lib/tokenizers/processors/byte_level.rb +9 -0
  38. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  39. data/lib/tokenizers/processors/template_processing.rb +9 -0
  40. data/lib/tokenizers/tokenizer.rb +45 -0
  41. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  42. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  43. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  44. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  45. data/lib/tokenizers/version.rb +1 -1
  46. data/lib/tokenizers.rb +49 -7
  47. metadata +32 -3
@@ -1,14 +1,442 @@
1
- use tk::normalizers::BertNormalizer;
1
+ use std::sync::{Arc, RwLock};
2
2
 
3
- #[magnus::wrap(class = "Tokenizers::BertNormalizer")]
4
- pub struct RbBertNormalizer {
5
- pub normalizer: BertNormalizer,
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
6
+ TypedData,
7
+ };
8
+ use serde::ser::SerializeStruct;
9
+ use serde::{Deserialize, Serialize, Serializer};
10
+ use tk::normalizers::{
11
+ BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Strip, StripAccents,
12
+ NFC, NFD, NFKC, NFKD,
13
+ };
14
+ use tk::{NormalizedString, Normalizer};
15
+
16
+ use super::utils::*;
17
+ use super::{RbError, RbResult};
18
+
19
+ #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
20
+ pub struct RbNormalizer {
21
+ #[serde(flatten)]
22
+ pub(crate) normalizer: RbNormalizerTypeWrapper,
23
+ }
24
+
25
+ impl RbNormalizer {
26
+ pub(crate) fn new(normalizer: RbNormalizerTypeWrapper) -> Self {
27
+ RbNormalizer { normalizer }
28
+ }
29
+
30
+ pub fn normalize_str(&self, sequence: String) -> RbResult<String> {
31
+ let mut normalized = NormalizedString::from(sequence);
32
+ self.normalizer.normalize(&mut normalized).map_err(RbError::from)?;
33
+ Ok(normalized.get().to_owned())
34
+ }
35
+ }
36
+
37
+ impl Normalizer for RbNormalizer {
38
+ fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
39
+ self.normalizer.normalize(normalized)
40
+ }
41
+ }
42
+
43
+ macro_rules! getter {
44
+ ($self: ident, $variant: ident, $name: ident) => {{
45
+ if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
46
+ let wrapper = norm.read().unwrap();
47
+ if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper {
48
+ o.$name
49
+ } else {
50
+ unreachable!()
51
+ }
52
+ } else {
53
+ unreachable!()
54
+ }
55
+ }};
56
+ }
57
+
58
+ macro_rules! setter {
59
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
60
+ if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
61
+ let mut wrapper = norm.write().unwrap();
62
+ if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(ref mut o)) = *wrapper {
63
+ o.$name = $value;
64
+ }
65
+ }
66
+ }};
67
+ }
68
+
69
+ impl RbNormalizer {
70
+
71
+ fn bert_clean_text(&self) -> bool {
72
+ getter!(self, BertNormalizer, clean_text)
73
+ }
74
+
75
+ fn bert_set_clean_text(&self, clean_text: bool) {
76
+ setter!(self, BertNormalizer, clean_text, clean_text);
77
+ }
78
+
79
+ fn bert_handle_chinese_chars(&self) -> bool {
80
+ getter!(self, BertNormalizer, handle_chinese_chars)
81
+ }
82
+
83
+ fn bert_set_handle_chinese_chars(&self, handle_chinese_chars: bool) {
84
+ setter!(
85
+ self,
86
+ BertNormalizer,
87
+ handle_chinese_chars,
88
+ handle_chinese_chars
89
+ );
90
+ }
91
+
92
+ fn bert_strip_accents(&self) -> Option<bool> {
93
+ getter!(self, BertNormalizer, strip_accents)
94
+ }
95
+
96
+ fn bert_set_strip_accents(&self, strip_accents: Option<bool>) {
97
+ setter!(self, BertNormalizer, strip_accents, strip_accents);
98
+ }
99
+
100
+ fn bert_lowercase(&self) -> bool {
101
+ getter!(self, BertNormalizer, lowercase)
102
+ }
103
+
104
+ fn bert_set_lowercase(&self, lowercase: bool) {
105
+ setter!(self, BertNormalizer, lowercase, lowercase)
106
+ }
107
+
108
+ fn strip_left(&self) -> bool {
109
+ getter!(self, StripNormalizer, strip_left)
110
+ }
111
+
112
+ fn strip_set_left(&self, left: bool) {
113
+ setter!(self, StripNormalizer, strip_left, left)
114
+ }
115
+
116
+ fn strip_right(&self) -> bool {
117
+ getter!(self, StripNormalizer, strip_right)
118
+ }
119
+
120
+ fn strip_set_right(&self, right: bool) {
121
+ setter!(self, StripNormalizer, strip_right, right)
122
+ }
6
123
  }
7
124
 
125
+ pub struct RbBertNormalizer {}
126
+
8
127
  impl RbBertNormalizer {
9
- pub fn new() -> Self {
10
- RbBertNormalizer {
11
- normalizer: BertNormalizer::default(),
128
+ pub fn new(clean_text: bool, handle_chinese_chars: bool, strip_accents: Option<bool>, lowercase: bool) -> RbNormalizer {
129
+ BertNormalizer::new(clean_text, handle_chinese_chars, strip_accents, lowercase).into()
130
+ }
131
+ }
132
+
133
+ pub struct RbLowercase {}
134
+
135
+ impl RbLowercase {
136
+ pub fn new() -> RbNormalizer {
137
+ Lowercase.into()
138
+ }
139
+ }
140
+
141
+ pub struct RbNFC {}
142
+
143
+ impl RbNFC {
144
+ pub fn new() -> RbNormalizer {
145
+ NFC.into()
146
+ }
147
+ }
148
+
149
+ pub struct RbNFD {}
150
+
151
+ impl RbNFD {
152
+ pub fn new() -> RbNormalizer {
153
+ NFD.into()
154
+ }
155
+ }
156
+
157
+ pub struct RbNFKC {}
158
+
159
+ impl RbNFKC {
160
+ pub fn new() -> RbNormalizer {
161
+ NFKC.into()
162
+ }
163
+ }
164
+
165
+ pub struct RbNFKD {}
166
+
167
+ impl RbNFKD {
168
+ pub fn new() -> RbNormalizer {
169
+ NFKD.into()
170
+ }
171
+ }
172
+
173
+ pub struct RbNmt {}
174
+
175
+ impl RbNmt {
176
+ pub fn new() -> RbNormalizer {
177
+ Nmt.into()
178
+ }
179
+ }
180
+
181
+ pub struct RbReplace {}
182
+
183
+ impl RbReplace {
184
+ pub fn new(pattern: RbPattern, content: String) -> RbResult<RbNormalizer> {
185
+ Replace::new(pattern, content).map(|v| v.into()).map_err(RbError::from)
186
+ }
187
+ }
188
+
189
+ pub struct RbStrip {}
190
+
191
+ impl RbStrip {
192
+ pub fn new(left: bool, right: bool) -> RbNormalizer {
193
+ Strip::new(left, right).into()
194
+ }
195
+ }
196
+
197
+ pub struct RbStripAccents {}
198
+
199
+ impl RbStripAccents {
200
+ pub fn new() -> RbNormalizer {
201
+ StripAccents.into()
202
+ }
203
+ }
204
+
205
+ pub struct RbSequence {}
206
+
207
+ impl RbSequence {
208
+ fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
209
+ let mut sequence = Vec::with_capacity(normalizers.len());
210
+ for n in normalizers.each() {
211
+ let normalizer: &RbNormalizer = n?.try_convert()?;
212
+ match &normalizer.normalizer {
213
+ RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
214
+ RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
215
+ }
216
+ }
217
+ Ok(RbNormalizer::new(RbNormalizerTypeWrapper::Sequence(sequence)))
218
+ }
219
+ }
220
+
221
+ #[derive(Debug, Clone, Deserialize)]
222
+ #[serde(untagged)]
223
+ pub(crate) enum RbNormalizerWrapper {
224
+ // Custom(CustomNormalizer),
225
+ Wrapped(NormalizerWrapper),
226
+ }
227
+
228
+ impl Serialize for RbNormalizerWrapper {
229
+ fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
230
+ where
231
+ S: Serializer,
232
+ {
233
+ match self {
234
+ RbNormalizerWrapper::Wrapped(inner) => inner.serialize(serializer),
235
+ // RbNormalizerWrapper::Custom(inner) => inner.serialize(serializer),
12
236
  }
13
237
  }
14
238
  }
239
+
240
+ #[derive(Debug, Clone, Deserialize)]
241
+ #[serde(untagged)]
242
+ pub(crate) enum RbNormalizerTypeWrapper {
243
+ Sequence(Vec<Arc<RwLock<RbNormalizerWrapper>>>),
244
+ Single(Arc<RwLock<RbNormalizerWrapper>>),
245
+ }
246
+
247
+ impl Serialize for RbNormalizerTypeWrapper {
248
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
249
+ where
250
+ S: Serializer,
251
+ {
252
+ match self {
253
+ RbNormalizerTypeWrapper::Sequence(seq) => {
254
+ let mut ser = serializer.serialize_struct("Sequence", 2)?;
255
+ ser.serialize_field("type", "Sequence")?;
256
+ ser.serialize_field("normalizers", seq)?;
257
+ ser.end()
258
+ }
259
+ RbNormalizerTypeWrapper::Single(inner) => inner.serialize(serializer),
260
+ }
261
+ }
262
+ }
263
+
264
+ impl<I> From<I> for RbNormalizerWrapper
265
+ where
266
+ I: Into<NormalizerWrapper>,
267
+ {
268
+ fn from(norm: I) -> Self {
269
+ RbNormalizerWrapper::Wrapped(norm.into())
270
+ }
271
+ }
272
+
273
+ impl<I> From<I> for RbNormalizerTypeWrapper
274
+ where
275
+ I: Into<RbNormalizerWrapper>,
276
+ {
277
+ fn from(norm: I) -> Self {
278
+ RbNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
279
+ }
280
+ }
281
+
282
+ impl<I> From<I> for RbNormalizer
283
+ where
284
+ I: Into<NormalizerWrapper>,
285
+ {
286
+ fn from(norm: I) -> Self {
287
+ RbNormalizer {
288
+ normalizer: norm.into().into(),
289
+ }
290
+ }
291
+ }
292
+
293
+ impl Normalizer for RbNormalizerTypeWrapper {
294
+ fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
295
+ match self {
296
+ RbNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
297
+ RbNormalizerTypeWrapper::Sequence(inner) => inner
298
+ .iter()
299
+ .try_for_each(|n| n.read().unwrap().normalize(normalized)),
300
+ }
301
+ }
302
+ }
303
+
304
+ impl Normalizer for RbNormalizerWrapper {
305
+ fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
306
+ match self {
307
+ RbNormalizerWrapper::Wrapped(inner) => inner.normalize(normalized),
308
+ // RbNormalizerWrapper::Custom(inner) => inner.normalize(normalized),
309
+ }
310
+ }
311
+ }
312
+
313
+ unsafe impl TypedData for RbNormalizer {
314
+ fn class() -> RClass {
315
+ *memoize!(RClass: {
316
+ let class: RClass = crate::normalizers().const_get("Normalizer").unwrap();
317
+ class.undef_alloc_func();
318
+ class
319
+ })
320
+ }
321
+
322
+ fn data_type() -> &'static DataType {
323
+ memoize!(DataType: DataTypeBuilder::<RbNormalizer>::new("Tokenizers::Normalizers::Normalizer").build())
324
+ }
325
+
326
+ fn class_for(value: &Self) -> RClass {
327
+ match &value.normalizer {
328
+ RbNormalizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
329
+ let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
330
+ class.undef_alloc_func();
331
+ class
332
+ }),
333
+ RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
334
+ RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
335
+ NormalizerWrapper::BertNormalizer(_) => *memoize!(RClass: {
336
+ let class: RClass = crate::normalizers().const_get("BertNormalizer").unwrap();
337
+ class.undef_alloc_func();
338
+ class
339
+ }),
340
+ NormalizerWrapper::Lowercase(_) => *memoize!(RClass: {
341
+ let class: RClass = crate::normalizers().const_get("Lowercase").unwrap();
342
+ class.undef_alloc_func();
343
+ class
344
+ }),
345
+ NormalizerWrapper::NFD(_) => *memoize!(RClass: {
346
+ let class: RClass = crate::normalizers().const_get("NFD").unwrap();
347
+ class.undef_alloc_func();
348
+ class
349
+ }),
350
+ NormalizerWrapper::NFC(_) => *memoize!(RClass: {
351
+ let class: RClass = crate::normalizers().const_get("NFC").unwrap();
352
+ class.undef_alloc_func();
353
+ class
354
+ }),
355
+ NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
356
+ let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
357
+ class.undef_alloc_func();
358
+ class
359
+ }),
360
+ NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
361
+ let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
362
+ class.undef_alloc_func();
363
+ class
364
+ }),
365
+ NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
366
+ let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
367
+ class.undef_alloc_func();
368
+ class
369
+ }),
370
+ NormalizerWrapper::Replace(_) => *memoize!(RClass: {
371
+ let class: RClass = crate::normalizers().const_get("Replace").unwrap();
372
+ class.undef_alloc_func();
373
+ class
374
+ }),
375
+ NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
376
+ let class: RClass = crate::normalizers().const_get("Strip").unwrap();
377
+ class.undef_alloc_func();
378
+ class
379
+ }),
380
+ NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
381
+ let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
382
+ class.undef_alloc_func();
383
+ class
384
+ }),
385
+ _ => todo!(),
386
+ },
387
+ },
388
+ }
389
+ }
390
+ }
391
+
392
+ pub fn normalizers(module: &RModule) -> RbResult<()> {
393
+ let normalizer = module.define_class("Normalizer", Default::default())?;
394
+ normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
395
+
396
+ let class = module.define_class("Sequence", normalizer)?;
397
+ class.define_singleton_method("new", function!(RbSequence::new, 1))?;
398
+
399
+ let class = module.define_class("BertNormalizer", normalizer)?;
400
+ class.define_singleton_method("_new", function!(RbBertNormalizer::new, 4))?;
401
+ class.define_method("clean_text", method!(RbNormalizer::bert_clean_text, 0))?;
402
+ class.define_method("clean_text=", method!(RbNormalizer::bert_set_clean_text, 1))?;
403
+ class.define_method("handle_chinese_chars", method!(RbNormalizer::bert_handle_chinese_chars, 0))?;
404
+ class.define_method("handle_chinese_chars=", method!(RbNormalizer::bert_set_handle_chinese_chars, 1))?;
405
+ class.define_method("strip_accents", method!(RbNormalizer::bert_strip_accents, 0))?;
406
+ class.define_method("strip_accents=", method!(RbNormalizer::bert_set_strip_accents, 1))?;
407
+ class.define_method("lowercase", method!(RbNormalizer::bert_lowercase, 0))?;
408
+ class.define_method("lowercase=", method!(RbNormalizer::bert_set_lowercase, 1))?;
409
+
410
+ let class = module.define_class("Lowercase", normalizer)?;
411
+ class.define_singleton_method("new", function!(RbLowercase::new, 0))?;
412
+
413
+ let class = module.define_class("NFC", normalizer)?;
414
+ class.define_singleton_method("new", function!(RbNFC::new, 0))?;
415
+
416
+ let class = module.define_class("NFD", normalizer)?;
417
+ class.define_singleton_method("new", function!(RbNFD::new, 0))?;
418
+
419
+ let class = module.define_class("NFKC", normalizer)?;
420
+ class.define_singleton_method("new", function!(RbNFKC::new, 0))?;
421
+
422
+ let class = module.define_class("NFKD", normalizer)?;
423
+ class.define_singleton_method("new", function!(RbNFKD::new, 0))?;
424
+
425
+ let class = module.define_class("Nmt", normalizer)?;
426
+ class.define_singleton_method("new", function!(RbNmt::new, 0))?;
427
+
428
+ let class = module.define_class("Replace", normalizer)?;
429
+ class.define_singleton_method("new", function!(RbReplace::new, 2))?;
430
+
431
+ let class = module.define_class("Strip", normalizer)?;
432
+ class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
433
+ class.define_method("left", method!(RbNormalizer::strip_left, 0))?;
434
+ class.define_method("left=", method!(RbNormalizer::strip_set_left, 1))?;
435
+ class.define_method("right", method!(RbNormalizer::strip_right, 0))?;
436
+ class.define_method("right=", method!(RbNormalizer::strip_set_right, 1))?;
437
+
438
+ let class = module.define_class("StripAccents", normalizer)?;
439
+ class.define_singleton_method("new", function!(RbStripAccents::new, 0))?;
440
+
441
+ Ok(())
442
+ }