tokenizers 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/Cargo.lock +33 -74
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +4 -2
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +78 -3
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +88 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +448 -20
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/encoding.rb +19 -0
  25. data/lib/tokenizers/from_pretrained.rb +1 -1
  26. data/lib/tokenizers/models/bpe.rb +9 -0
  27. data/lib/tokenizers/models/unigram.rb +9 -0
  28. data/lib/tokenizers/models/word_level.rb +13 -0
  29. data/lib/tokenizers/models/word_piece.rb +9 -0
  30. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  31. data/lib/tokenizers/normalizers/strip.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  36. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  37. data/lib/tokenizers/processors/byte_level.rb +9 -0
  38. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  39. data/lib/tokenizers/processors/template_processing.rb +9 -0
  40. data/lib/tokenizers/tokenizer.rb +45 -0
  41. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  42. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  43. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  44. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  45. data/lib/tokenizers/version.rb +1 -1
  46. data/lib/tokenizers.rb +49 -7
  47. metadata +32 -3
@@ -1,14 +1,442 @@
1
- use tk::normalizers::BertNormalizer;
1
+ use std::sync::{Arc, RwLock};
2
2
 
3
- #[magnus::wrap(class = "Tokenizers::BertNormalizer")]
4
- pub struct RbBertNormalizer {
5
- pub normalizer: BertNormalizer,
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
6
+ TypedData,
7
+ };
8
+ use serde::ser::SerializeStruct;
9
+ use serde::{Deserialize, Serialize, Serializer};
10
+ use tk::normalizers::{
11
+ BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Strip, StripAccents,
12
+ NFC, NFD, NFKC, NFKD,
13
+ };
14
+ use tk::{NormalizedString, Normalizer};
15
+
16
+ use super::utils::*;
17
+ use super::{RbError, RbResult};
18
+
19
+ #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
20
+ pub struct RbNormalizer {
21
+ #[serde(flatten)]
22
+ pub(crate) normalizer: RbNormalizerTypeWrapper,
23
+ }
24
+
25
+ impl RbNormalizer {
26
+ pub(crate) fn new(normalizer: RbNormalizerTypeWrapper) -> Self {
27
+ RbNormalizer { normalizer }
28
+ }
29
+
30
+ pub fn normalize_str(&self, sequence: String) -> RbResult<String> {
31
+ let mut normalized = NormalizedString::from(sequence);
32
+ self.normalizer.normalize(&mut normalized).map_err(RbError::from)?;
33
+ Ok(normalized.get().to_owned())
34
+ }
35
+ }
36
+
37
+ impl Normalizer for RbNormalizer {
38
+ fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
39
+ self.normalizer.normalize(normalized)
40
+ }
41
+ }
42
+
43
+ macro_rules! getter {
44
+ ($self: ident, $variant: ident, $name: ident) => {{
45
+ if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
46
+ let wrapper = norm.read().unwrap();
47
+ if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper {
48
+ o.$name
49
+ } else {
50
+ unreachable!()
51
+ }
52
+ } else {
53
+ unreachable!()
54
+ }
55
+ }};
56
+ }
57
+
58
+ macro_rules! setter {
59
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
60
+ if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
61
+ let mut wrapper = norm.write().unwrap();
62
+ if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(ref mut o)) = *wrapper {
63
+ o.$name = $value;
64
+ }
65
+ }
66
+ }};
67
+ }
68
+
69
+ impl RbNormalizer {
70
+
71
+ fn bert_clean_text(&self) -> bool {
72
+ getter!(self, BertNormalizer, clean_text)
73
+ }
74
+
75
+ fn bert_set_clean_text(&self, clean_text: bool) {
76
+ setter!(self, BertNormalizer, clean_text, clean_text);
77
+ }
78
+
79
+ fn bert_handle_chinese_chars(&self) -> bool {
80
+ getter!(self, BertNormalizer, handle_chinese_chars)
81
+ }
82
+
83
+ fn bert_set_handle_chinese_chars(&self, handle_chinese_chars: bool) {
84
+ setter!(
85
+ self,
86
+ BertNormalizer,
87
+ handle_chinese_chars,
88
+ handle_chinese_chars
89
+ );
90
+ }
91
+
92
+ fn bert_strip_accents(&self) -> Option<bool> {
93
+ getter!(self, BertNormalizer, strip_accents)
94
+ }
95
+
96
+ fn bert_set_strip_accents(&self, strip_accents: Option<bool>) {
97
+ setter!(self, BertNormalizer, strip_accents, strip_accents);
98
+ }
99
+
100
+ fn bert_lowercase(&self) -> bool {
101
+ getter!(self, BertNormalizer, lowercase)
102
+ }
103
+
104
+ fn bert_set_lowercase(&self, lowercase: bool) {
105
+ setter!(self, BertNormalizer, lowercase, lowercase)
106
+ }
107
+
108
+ fn strip_left(&self) -> bool {
109
+ getter!(self, StripNormalizer, strip_left)
110
+ }
111
+
112
+ fn strip_set_left(&self, left: bool) {
113
+ setter!(self, StripNormalizer, strip_left, left)
114
+ }
115
+
116
+ fn strip_right(&self) -> bool {
117
+ getter!(self, StripNormalizer, strip_right)
118
+ }
119
+
120
+ fn strip_set_right(&self, right: bool) {
121
+ setter!(self, StripNormalizer, strip_right, right)
122
+ }
6
123
  }
7
124
 
125
+ pub struct RbBertNormalizer {}
126
+
8
127
  impl RbBertNormalizer {
9
- pub fn new() -> Self {
10
- RbBertNormalizer {
11
- normalizer: BertNormalizer::default(),
128
+ pub fn new(clean_text: bool, handle_chinese_chars: bool, strip_accents: Option<bool>, lowercase: bool) -> RbNormalizer {
129
+ BertNormalizer::new(clean_text, handle_chinese_chars, strip_accents, lowercase).into()
130
+ }
131
+ }
132
+
133
+ pub struct RbLowercase {}
134
+
135
+ impl RbLowercase {
136
+ pub fn new() -> RbNormalizer {
137
+ Lowercase.into()
138
+ }
139
+ }
140
+
141
+ pub struct RbNFC {}
142
+
143
+ impl RbNFC {
144
+ pub fn new() -> RbNormalizer {
145
+ NFC.into()
146
+ }
147
+ }
148
+
149
+ pub struct RbNFD {}
150
+
151
+ impl RbNFD {
152
+ pub fn new() -> RbNormalizer {
153
+ NFD.into()
154
+ }
155
+ }
156
+
157
+ pub struct RbNFKC {}
158
+
159
+ impl RbNFKC {
160
+ pub fn new() -> RbNormalizer {
161
+ NFKC.into()
162
+ }
163
+ }
164
+
165
+ pub struct RbNFKD {}
166
+
167
+ impl RbNFKD {
168
+ pub fn new() -> RbNormalizer {
169
+ NFKD.into()
170
+ }
171
+ }
172
+
173
+ pub struct RbNmt {}
174
+
175
+ impl RbNmt {
176
+ pub fn new() -> RbNormalizer {
177
+ Nmt.into()
178
+ }
179
+ }
180
+
181
+ pub struct RbReplace {}
182
+
183
+ impl RbReplace {
184
+ pub fn new(pattern: RbPattern, content: String) -> RbResult<RbNormalizer> {
185
+ Replace::new(pattern, content).map(|v| v.into()).map_err(RbError::from)
186
+ }
187
+ }
188
+
189
+ pub struct RbStrip {}
190
+
191
+ impl RbStrip {
192
+ pub fn new(left: bool, right: bool) -> RbNormalizer {
193
+ Strip::new(left, right).into()
194
+ }
195
+ }
196
+
197
+ pub struct RbStripAccents {}
198
+
199
+ impl RbStripAccents {
200
+ pub fn new() -> RbNormalizer {
201
+ StripAccents.into()
202
+ }
203
+ }
204
+
205
+ pub struct RbSequence {}
206
+
207
+ impl RbSequence {
208
+ fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
209
+ let mut sequence = Vec::with_capacity(normalizers.len());
210
+ for n in normalizers.each() {
211
+ let normalizer: &RbNormalizer = n?.try_convert()?;
212
+ match &normalizer.normalizer {
213
+ RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
214
+ RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
215
+ }
216
+ }
217
+ Ok(RbNormalizer::new(RbNormalizerTypeWrapper::Sequence(sequence)))
218
+ }
219
+ }
220
+
221
+ #[derive(Debug, Clone, Deserialize)]
222
+ #[serde(untagged)]
223
+ pub(crate) enum RbNormalizerWrapper {
224
+ // Custom(CustomNormalizer),
225
+ Wrapped(NormalizerWrapper),
226
+ }
227
+
228
+ impl Serialize for RbNormalizerWrapper {
229
+ fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
230
+ where
231
+ S: Serializer,
232
+ {
233
+ match self {
234
+ RbNormalizerWrapper::Wrapped(inner) => inner.serialize(serializer),
235
+ // RbNormalizerWrapper::Custom(inner) => inner.serialize(serializer),
12
236
  }
13
237
  }
14
238
  }
239
+
240
+ #[derive(Debug, Clone, Deserialize)]
241
+ #[serde(untagged)]
242
+ pub(crate) enum RbNormalizerTypeWrapper {
243
+ Sequence(Vec<Arc<RwLock<RbNormalizerWrapper>>>),
244
+ Single(Arc<RwLock<RbNormalizerWrapper>>),
245
+ }
246
+
247
+ impl Serialize for RbNormalizerTypeWrapper {
248
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
249
+ where
250
+ S: Serializer,
251
+ {
252
+ match self {
253
+ RbNormalizerTypeWrapper::Sequence(seq) => {
254
+ let mut ser = serializer.serialize_struct("Sequence", 2)?;
255
+ ser.serialize_field("type", "Sequence")?;
256
+ ser.serialize_field("normalizers", seq)?;
257
+ ser.end()
258
+ }
259
+ RbNormalizerTypeWrapper::Single(inner) => inner.serialize(serializer),
260
+ }
261
+ }
262
+ }
263
+
264
+ impl<I> From<I> for RbNormalizerWrapper
265
+ where
266
+ I: Into<NormalizerWrapper>,
267
+ {
268
+ fn from(norm: I) -> Self {
269
+ RbNormalizerWrapper::Wrapped(norm.into())
270
+ }
271
+ }
272
+
273
+ impl<I> From<I> for RbNormalizerTypeWrapper
274
+ where
275
+ I: Into<RbNormalizerWrapper>,
276
+ {
277
+ fn from(norm: I) -> Self {
278
+ RbNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
279
+ }
280
+ }
281
+
282
+ impl<I> From<I> for RbNormalizer
283
+ where
284
+ I: Into<NormalizerWrapper>,
285
+ {
286
+ fn from(norm: I) -> Self {
287
+ RbNormalizer {
288
+ normalizer: norm.into().into(),
289
+ }
290
+ }
291
+ }
292
+
293
+ impl Normalizer for RbNormalizerTypeWrapper {
294
+ fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
295
+ match self {
296
+ RbNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
297
+ RbNormalizerTypeWrapper::Sequence(inner) => inner
298
+ .iter()
299
+ .try_for_each(|n| n.read().unwrap().normalize(normalized)),
300
+ }
301
+ }
302
+ }
303
+
304
+ impl Normalizer for RbNormalizerWrapper {
305
+ fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
306
+ match self {
307
+ RbNormalizerWrapper::Wrapped(inner) => inner.normalize(normalized),
308
+ // RbNormalizerWrapper::Custom(inner) => inner.normalize(normalized),
309
+ }
310
+ }
311
+ }
312
+
313
+ unsafe impl TypedData for RbNormalizer {
314
+ fn class() -> RClass {
315
+ *memoize!(RClass: {
316
+ let class: RClass = crate::normalizers().const_get("Normalizer").unwrap();
317
+ class.undef_alloc_func();
318
+ class
319
+ })
320
+ }
321
+
322
+ fn data_type() -> &'static DataType {
323
+ memoize!(DataType: DataTypeBuilder::<RbNormalizer>::new("Tokenizers::Normalizers::Normalizer").build())
324
+ }
325
+
326
+ fn class_for(value: &Self) -> RClass {
327
+ match &value.normalizer {
328
+ RbNormalizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
329
+ let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
330
+ class.undef_alloc_func();
331
+ class
332
+ }),
333
+ RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
334
+ RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
335
+ NormalizerWrapper::BertNormalizer(_) => *memoize!(RClass: {
336
+ let class: RClass = crate::normalizers().const_get("BertNormalizer").unwrap();
337
+ class.undef_alloc_func();
338
+ class
339
+ }),
340
+ NormalizerWrapper::Lowercase(_) => *memoize!(RClass: {
341
+ let class: RClass = crate::normalizers().const_get("Lowercase").unwrap();
342
+ class.undef_alloc_func();
343
+ class
344
+ }),
345
+ NormalizerWrapper::NFD(_) => *memoize!(RClass: {
346
+ let class: RClass = crate::normalizers().const_get("NFD").unwrap();
347
+ class.undef_alloc_func();
348
+ class
349
+ }),
350
+ NormalizerWrapper::NFC(_) => *memoize!(RClass: {
351
+ let class: RClass = crate::normalizers().const_get("NFC").unwrap();
352
+ class.undef_alloc_func();
353
+ class
354
+ }),
355
+ NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
356
+ let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
357
+ class.undef_alloc_func();
358
+ class
359
+ }),
360
+ NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
361
+ let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
362
+ class.undef_alloc_func();
363
+ class
364
+ }),
365
+ NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
366
+ let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
367
+ class.undef_alloc_func();
368
+ class
369
+ }),
370
+ NormalizerWrapper::Replace(_) => *memoize!(RClass: {
371
+ let class: RClass = crate::normalizers().const_get("Replace").unwrap();
372
+ class.undef_alloc_func();
373
+ class
374
+ }),
375
+ NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
376
+ let class: RClass = crate::normalizers().const_get("Strip").unwrap();
377
+ class.undef_alloc_func();
378
+ class
379
+ }),
380
+ NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
381
+ let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
382
+ class.undef_alloc_func();
383
+ class
384
+ }),
385
+ _ => todo!(),
386
+ },
387
+ },
388
+ }
389
+ }
390
+ }
391
+
392
+ pub fn normalizers(module: &RModule) -> RbResult<()> {
393
+ let normalizer = module.define_class("Normalizer", Default::default())?;
394
+ normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
395
+
396
+ let class = module.define_class("Sequence", normalizer)?;
397
+ class.define_singleton_method("new", function!(RbSequence::new, 1))?;
398
+
399
+ let class = module.define_class("BertNormalizer", normalizer)?;
400
+ class.define_singleton_method("_new", function!(RbBertNormalizer::new, 4))?;
401
+ class.define_method("clean_text", method!(RbNormalizer::bert_clean_text, 0))?;
402
+ class.define_method("clean_text=", method!(RbNormalizer::bert_set_clean_text, 1))?;
403
+ class.define_method("handle_chinese_chars", method!(RbNormalizer::bert_handle_chinese_chars, 0))?;
404
+ class.define_method("handle_chinese_chars=", method!(RbNormalizer::bert_set_handle_chinese_chars, 1))?;
405
+ class.define_method("strip_accents", method!(RbNormalizer::bert_strip_accents, 0))?;
406
+ class.define_method("strip_accents=", method!(RbNormalizer::bert_set_strip_accents, 1))?;
407
+ class.define_method("lowercase", method!(RbNormalizer::bert_lowercase, 0))?;
408
+ class.define_method("lowercase=", method!(RbNormalizer::bert_set_lowercase, 1))?;
409
+
410
+ let class = module.define_class("Lowercase", normalizer)?;
411
+ class.define_singleton_method("new", function!(RbLowercase::new, 0))?;
412
+
413
+ let class = module.define_class("NFC", normalizer)?;
414
+ class.define_singleton_method("new", function!(RbNFC::new, 0))?;
415
+
416
+ let class = module.define_class("NFD", normalizer)?;
417
+ class.define_singleton_method("new", function!(RbNFD::new, 0))?;
418
+
419
+ let class = module.define_class("NFKC", normalizer)?;
420
+ class.define_singleton_method("new", function!(RbNFKC::new, 0))?;
421
+
422
+ let class = module.define_class("NFKD", normalizer)?;
423
+ class.define_singleton_method("new", function!(RbNFKD::new, 0))?;
424
+
425
+ let class = module.define_class("Nmt", normalizer)?;
426
+ class.define_singleton_method("new", function!(RbNmt::new, 0))?;
427
+
428
+ let class = module.define_class("Replace", normalizer)?;
429
+ class.define_singleton_method("new", function!(RbReplace::new, 2))?;
430
+
431
+ let class = module.define_class("Strip", normalizer)?;
432
+ class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
433
+ class.define_method("left", method!(RbNormalizer::strip_left, 0))?;
434
+ class.define_method("left=", method!(RbNormalizer::strip_set_left, 1))?;
435
+ class.define_method("right", method!(RbNormalizer::strip_right, 0))?;
436
+ class.define_method("right=", method!(RbNormalizer::strip_set_right, 1))?;
437
+
438
+ let class = module.define_class("StripAccents", normalizer)?;
439
+ class.define_singleton_method("new", function!(RbStripAccents::new, 0))?;
440
+
441
+ Ok(())
442
+ }