tokenizers 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,19 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
6
- TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
5
+ Ruby, TryConvert, TypedData,
7
6
  };
8
7
  use serde::ser::SerializeStruct;
9
8
  use serde::{Deserialize, Serialize, Serializer};
10
9
  use tk::normalizers::{
11
- BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Strip, StripAccents,
10
+ BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Prepend, Strip, StripAccents,
12
11
  NFC, NFD, NFKC, NFKD,
13
12
  };
14
13
  use tk::{NormalizedString, Normalizer};
15
14
 
16
15
  use super::utils::*;
17
- use super::{RbError, RbResult};
16
+ use super::{NORMALIZERS, RbError, RbResult};
18
17
 
19
18
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
20
19
  pub struct RbNormalizer {
@@ -44,7 +43,7 @@ macro_rules! getter {
44
43
  ($self: ident, $variant: ident, $name: ident) => {{
45
44
  if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
46
45
  let wrapper = norm.read().unwrap();
47
- if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper {
46
+ if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone() {
48
47
  o.$name
49
48
  } else {
50
49
  unreachable!()
@@ -105,6 +104,14 @@ impl RbNormalizer {
105
104
  setter!(self, BertNormalizer, lowercase, lowercase)
106
105
  }
107
106
 
107
+ fn prepend_prepend(&self) -> String {
108
+ getter!(self, Prepend, prepend)
109
+ }
110
+
111
+ fn prepend_set_prepend(&self, prepend: String) {
112
+ setter!(self, Prepend, prepend, prepend)
113
+ }
114
+
108
115
  fn strip_left(&self) -> bool {
109
116
  getter!(self, StripNormalizer, strip_left)
110
117
  }
@@ -186,6 +193,14 @@ impl RbReplace {
186
193
  }
187
194
  }
188
195
 
196
+ pub struct RbPrepend {}
197
+
198
+ impl RbPrepend {
199
+ pub fn new(prepend: String) -> RbNormalizer {
200
+ Prepend::new(prepend).into()
201
+ }
202
+ }
203
+
189
204
  pub struct RbStrip {}
190
205
 
191
206
  impl RbStrip {
@@ -208,7 +223,7 @@ impl RbSequence {
208
223
  fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
209
224
  let mut sequence = Vec::with_capacity(normalizers.len());
210
225
  for n in normalizers.each() {
211
- let normalizer: &RbNormalizer = n?.try_convert()?;
226
+ let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
212
227
  match &normalizer.normalizer {
213
228
  RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
214
229
  RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
@@ -311,77 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
311
326
  }
312
327
 
313
328
  unsafe impl TypedData for RbNormalizer {
314
- fn class() -> RClass {
315
- *memoize!(RClass: {
316
- let class: RClass = crate::normalizers().const_get("Normalizer").unwrap();
317
- class.undef_alloc_func();
318
- class
319
- })
329
+ fn class(ruby: &Ruby) -> RClass {
330
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
331
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
332
+ class.undef_default_alloc_func();
333
+ class
334
+ });
335
+ ruby.get_inner(&CLASS)
320
336
  }
321
337
 
322
338
  fn data_type() -> &'static DataType {
323
- memoize!(DataType: DataTypeBuilder::<RbNormalizer>::new("Tokenizers::Normalizers::Normalizer").build())
324
- }
325
-
326
- fn class_for(value: &Self) -> RClass {
339
+ static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
340
+ &DATA_TYPE
341
+ }
342
+
343
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
344
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
345
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
346
+ class.undef_default_alloc_func();
347
+ class
348
+ });
349
+ static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
355
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
356
+ class.undef_default_alloc_func();
357
+ class
358
+ });
359
+ static NFD: Lazy<RClass> = Lazy::new(|ruby| {
360
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
361
+ class.undef_default_alloc_func();
362
+ class
363
+ });
364
+ static NFC: Lazy<RClass> = Lazy::new(|ruby| {
365
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
366
+ class.undef_default_alloc_func();
367
+ class
368
+ });
369
+ static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
370
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
371
+ class.undef_default_alloc_func();
372
+ class
373
+ });
374
+ static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
375
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
376
+ class.undef_default_alloc_func();
377
+ class
378
+ });
379
+ static NMT: Lazy<RClass> = Lazy::new(|ruby| {
380
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
381
+ class.undef_default_alloc_func();
382
+ class
383
+ });
384
+ static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
385
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
386
+ class.undef_default_alloc_func();
387
+ class
388
+ });
389
+ static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
390
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
391
+ class.undef_default_alloc_func();
392
+ class
393
+ });
394
+ static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
395
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
396
+ class.undef_default_alloc_func();
397
+ class
398
+ });
399
+ static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
400
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
401
+ class.undef_default_alloc_func();
402
+ class
403
+ });
327
404
  match &value.normalizer {
328
- RbNormalizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
329
- let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
330
- class.undef_alloc_func();
331
- class
332
- }),
405
+ RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
333
406
  RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
334
407
  RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
335
- NormalizerWrapper::BertNormalizer(_) => *memoize!(RClass: {
336
- let class: RClass = crate::normalizers().const_get("BertNormalizer").unwrap();
337
- class.undef_alloc_func();
338
- class
339
- }),
340
- NormalizerWrapper::Lowercase(_) => *memoize!(RClass: {
341
- let class: RClass = crate::normalizers().const_get("Lowercase").unwrap();
342
- class.undef_alloc_func();
343
- class
344
- }),
345
- NormalizerWrapper::NFD(_) => *memoize!(RClass: {
346
- let class: RClass = crate::normalizers().const_get("NFD").unwrap();
347
- class.undef_alloc_func();
348
- class
349
- }),
350
- NormalizerWrapper::NFC(_) => *memoize!(RClass: {
351
- let class: RClass = crate::normalizers().const_get("NFC").unwrap();
352
- class.undef_alloc_func();
353
- class
354
- }),
355
- NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
356
- let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
357
- class.undef_alloc_func();
358
- class
359
- }),
360
- NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
361
- let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
362
- class.undef_alloc_func();
363
- class
364
- }),
365
- NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
366
- let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
367
- class.undef_alloc_func();
368
- class
369
- }),
370
- NormalizerWrapper::Replace(_) => *memoize!(RClass: {
371
- let class: RClass = crate::normalizers().const_get("Replace").unwrap();
372
- class.undef_alloc_func();
373
- class
374
- }),
375
- NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
376
- let class: RClass = crate::normalizers().const_get("Strip").unwrap();
377
- class.undef_alloc_func();
378
- class
379
- }),
380
- NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
381
- let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
382
- class.undef_alloc_func();
383
- class
384
- }),
408
+ NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
409
+ NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
410
+ NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
411
+ NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
412
+ NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
413
+ NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
414
+ NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
415
+ NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
416
+ NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
417
+ NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
418
+ NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
385
419
  _ => todo!(),
386
420
  },
387
421
  },
@@ -389,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
389
423
  }
390
424
  }
391
425
 
392
- pub fn normalizers(module: &RModule) -> RbResult<()> {
393
- let normalizer = module.define_class("Normalizer", Default::default())?;
426
+ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
427
+ let normalizer = module.define_class("Normalizer", ruby.class_object())?;
394
428
  normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
395
429
 
396
430
  let class = module.define_class("Sequence", normalizer)?;
@@ -428,6 +462,11 @@ pub fn normalizers(module: &RModule) -> RbResult<()> {
428
462
  let class = module.define_class("Replace", normalizer)?;
429
463
  class.define_singleton_method("new", function!(RbReplace::new, 2))?;
430
464
 
465
+ let class = module.define_class("Prepend", normalizer)?;
466
+ class.define_singleton_method("_new", function!(RbPrepend::new, 1))?;
467
+ class.define_method("prepend", method!(RbNormalizer::prepend_prepend, 0))?;
468
+ class.define_method("prepend=", method!(RbNormalizer::prepend_set_prepend, 1))?;
469
+
431
470
  let class = module.define_class("Strip", normalizer)?;
432
471
  class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
433
472
  class.define_method("left", method!(RbNormalizer::strip_left, 0))?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
6
- RArray, RClass, RModule, TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
+ RArray, RClass, RModule, Ruby, TryConvert, TypedData,
7
6
  };
8
7
 
9
8
  use serde::ser::SerializeStruct;
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
23
22
  use tk::{PreTokenizedString, PreTokenizer};
24
23
 
25
24
  use super::utils::*;
26
- use super::{RbError, RbResult};
25
+ use super::{PRE_TOKENIZERS, RbError, RbResult};
27
26
 
28
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
29
28
  pub struct RbPreTokenizer {
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
215
214
 
216
215
  impl RbWhitespace {
217
216
  pub fn new() -> RbPreTokenizer {
218
- Whitespace::default().into()
217
+ Whitespace.into()
219
218
  }
220
219
  }
221
220
 
@@ -241,7 +240,7 @@ impl RbSequence {
241
240
  fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
242
241
  let mut sequence = Vec::with_capacity(pre_tokenizers.len());
243
242
  for n in pre_tokenizers.each() {
244
- let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
243
+ let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
245
244
  match &pretokenizer.pretok {
246
245
  RbPreTokenizerTypeWrapper::Sequence(inner) => {
247
246
  sequence.extend(inner.iter().cloned())
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
346
345
  }
347
346
 
348
347
  unsafe impl TypedData for RbPreTokenizer {
349
- fn class() -> RClass {
350
- *memoize!(RClass: {
351
- let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
352
- class.undef_alloc_func();
353
- class
354
- })
348
+ fn class(ruby: &Ruby) -> RClass {
349
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ ruby.get_inner(&CLASS)
355
355
  }
356
356
 
357
357
  fn data_type() -> &'static DataType {
358
- memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
359
- }
360
-
361
- fn class_for(value: &Self) -> RClass {
358
+ static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
359
+ &DATA_TYPE
360
+ }
361
+
362
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
363
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
364
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
365
+ class.undef_default_alloc_func();
366
+ class
367
+ });
368
+ static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
369
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
370
+ class.undef_default_alloc_func();
371
+ class
372
+ });
373
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
374
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
375
+ class.undef_default_alloc_func();
376
+ class
377
+ });
378
+ static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
379
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
380
+ class.undef_default_alloc_func();
381
+ class
382
+ });
383
+ static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
384
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
385
+ class.undef_default_alloc_func();
386
+ class
387
+ });
388
+ static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
389
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
390
+ class.undef_default_alloc_func();
391
+ class
392
+ });
393
+ static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
394
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
395
+ class.undef_default_alloc_func();
396
+ class
397
+ });
398
+ static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
399
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
400
+ class.undef_default_alloc_func();
401
+ class
402
+ });
403
+ static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
404
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
405
+ class.undef_default_alloc_func();
406
+ class
407
+ });
408
+ static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
409
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
410
+ class.undef_default_alloc_func();
411
+ class
412
+ });
413
+ static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
414
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
415
+ class.undef_default_alloc_func();
416
+ class
417
+ });
362
418
  match &value.pretok {
363
- RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
364
- let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
365
- class.undef_alloc_func();
366
- class
367
- }),
419
+ RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
368
420
  RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
369
421
  RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
370
- PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
371
- let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
372
- class.undef_alloc_func();
373
- class
374
- }),
375
- PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
376
- let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
377
- class.undef_alloc_func();
378
- class
379
- }),
380
- PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
381
- let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
382
- class.undef_alloc_func();
383
- class
384
- }),
385
- PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
386
- let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
387
- class.undef_alloc_func();
388
- class
389
- }),
390
- PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
391
- let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
392
- class.undef_alloc_func();
393
- class
394
- }),
395
- PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
396
- let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
397
- class.undef_alloc_func();
398
- class
399
- }),
400
- PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
401
- let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
402
- class.undef_alloc_func();
403
- class
404
- }),
405
- PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
406
- let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
407
- class.undef_alloc_func();
408
- class
409
- }),
410
- PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
411
- let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
412
- class.undef_alloc_func();
413
- class
414
- }),
415
- PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
416
- let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
417
- class.undef_alloc_func();
418
- class
419
- }),
422
+ PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
423
+ PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
424
+ PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
425
+ PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
426
+ PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
427
+ PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
428
+ PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
429
+ PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
430
+ PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
431
+ PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
420
432
  _ => todo!(),
421
433
  },
422
434
  },
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
424
436
  }
425
437
  }
426
438
 
427
- pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
428
- let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
439
+ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
440
+ let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
429
441
  pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
430
442
 
431
443
  let class = module.define_class("Sequence", pre_tokenizer)?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::Arc;
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
- TryConvert, TypedData, Value,
4
+ data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
5
+ Ruby, TryConvert, TypedData, Value,
7
6
  };
8
7
  use serde::{Deserialize, Serialize};
9
8
  use tk::processors::bert::BertProcessing;
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
13
12
  use tk::processors::PostProcessorWrapper;
14
13
  use tk::{Encoding, PostProcessor};
15
14
 
16
- use super::RbResult;
15
+ use super::{PROCESSORS, RbResult};
17
16
 
18
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
19
18
  pub struct RbPostProcessor {
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
53
52
 
54
53
  impl TryConvert for RbSpecialToken {
55
54
  fn try_convert(ob: Value) -> RbResult<Self> {
56
- if let Ok(v) = ob.try_convert::<(String, u32)>() {
55
+ if let Ok(v) = <(String, u32)>::try_convert(ob) {
57
56
  Ok(Self(v.into()))
58
- } else if let Ok(v) = ob.try_convert::<(u32, String)>() {
57
+ } else if let Ok(v) = <(u32, String)>::try_convert(ob) {
59
58
  Ok(Self(v.into()))
60
59
  } else {
61
60
  todo!()
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
74
73
 
75
74
  impl TryConvert for RbTemplate {
76
75
  fn try_convert(ob: Value) -> RbResult<Self> {
77
- if let Ok(s) = ob.try_convert::<String>() {
76
+ if let Ok(s) = String::try_convert(ob) {
78
77
  Ok(Self(
79
78
  s.try_into().unwrap(), //.map_err(RbError::from)?,
80
79
  ))
81
- } else if let Ok(s) = ob.try_convert::<Vec<String>>() {
80
+ } else if let Ok(s) = <Vec<String>>::try_convert(ob) {
82
81
  Ok(Self(
83
82
  s.try_into().unwrap(), //.map_err(RbError::from)?,
84
83
  ))
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
152
151
  }
153
152
 
154
153
  unsafe impl TypedData for RbPostProcessor {
155
- fn class() -> RClass {
156
- *memoize!(RClass: {
157
- let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
158
- class.undef_alloc_func();
159
- class
160
- })
154
+ fn class(ruby: &Ruby) -> RClass {
155
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
156
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
157
+ class.undef_default_alloc_func();
158
+ class
159
+ });
160
+ ruby.get_inner(&CLASS)
161
161
  }
162
162
 
163
163
  fn data_type() -> &'static DataType {
164
- memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
164
+ static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
165
+ &DATA_TYPE
165
166
  }
166
167
 
167
- fn class_for(value: &Self) -> RClass {
168
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
169
+ static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
170
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
171
+ class.undef_default_alloc_func();
172
+ class
173
+ });
174
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
175
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
176
+ class.undef_default_alloc_func();
177
+ class
178
+ });
179
+ static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
180
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
181
+ class.undef_default_alloc_func();
182
+ class
183
+ });
184
+ static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
185
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
186
+ class.undef_default_alloc_func();
187
+ class
188
+ });
168
189
  match *value.processor {
169
- PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
170
- let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
171
- class.undef_alloc_func();
172
- class
173
- }),
174
- PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
175
- let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
176
- class.undef_alloc_func();
177
- class
178
- }),
179
- PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
180
- let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
181
- class.undef_alloc_func();
182
- class
183
- }),
184
- PostProcessorWrapper::Template(_) => *memoize!(RClass: {
185
- let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
186
- class.undef_alloc_func();
187
- class
188
- }),
190
+ PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
191
+ PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
192
+ PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
193
+ PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
189
194
  _ => todo!(),
190
195
  }
191
196
  }
192
197
  }
193
198
 
194
- pub fn processors(module: &RModule) -> RbResult<()> {
195
- let post_processor = module.define_class("PostProcessor", Default::default())?;
199
+ pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
200
+ let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
196
201
 
197
202
  let class = module.define_class("BertProcessing", post_processor)?;
198
203
  class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;