tokenizers 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,20 +1,19 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
6
- TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
5
+ Ruby, TryConvert, TypedData,
7
6
  };
8
7
  use serde::ser::SerializeStruct;
9
8
  use serde::{Deserialize, Serialize, Serializer};
10
9
  use tk::normalizers::{
11
- BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Strip, StripAccents,
10
+ BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Prepend, Strip, StripAccents,
12
11
  NFC, NFD, NFKC, NFKD,
13
12
  };
14
13
  use tk::{NormalizedString, Normalizer};
15
14
 
16
15
  use super::utils::*;
17
- use super::{RbError, RbResult};
16
+ use super::{NORMALIZERS, RbError, RbResult};
18
17
 
19
18
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
20
19
  pub struct RbNormalizer {
@@ -44,7 +43,7 @@ macro_rules! getter {
44
43
  ($self: ident, $variant: ident, $name: ident) => {{
45
44
  if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
46
45
  let wrapper = norm.read().unwrap();
47
- if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper {
46
+ if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone() {
48
47
  o.$name
49
48
  } else {
50
49
  unreachable!()
@@ -105,6 +104,14 @@ impl RbNormalizer {
105
104
  setter!(self, BertNormalizer, lowercase, lowercase)
106
105
  }
107
106
 
107
+ fn prepend_prepend(&self) -> String {
108
+ getter!(self, Prepend, prepend)
109
+ }
110
+
111
+ fn prepend_set_prepend(&self, prepend: String) {
112
+ setter!(self, Prepend, prepend, prepend)
113
+ }
114
+
108
115
  fn strip_left(&self) -> bool {
109
116
  getter!(self, StripNormalizer, strip_left)
110
117
  }
@@ -186,6 +193,14 @@ impl RbReplace {
186
193
  }
187
194
  }
188
195
 
196
+ pub struct RbPrepend {}
197
+
198
+ impl RbPrepend {
199
+ pub fn new(prepend: String) -> RbNormalizer {
200
+ Prepend::new(prepend).into()
201
+ }
202
+ }
203
+
189
204
  pub struct RbStrip {}
190
205
 
191
206
  impl RbStrip {
@@ -208,7 +223,7 @@ impl RbSequence {
208
223
  fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
209
224
  let mut sequence = Vec::with_capacity(normalizers.len());
210
225
  for n in normalizers.each() {
211
- let normalizer: &RbNormalizer = n?.try_convert()?;
226
+ let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
212
227
  match &normalizer.normalizer {
213
228
  RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
214
229
  RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
@@ -311,77 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
311
326
  }
312
327
 
313
328
  unsafe impl TypedData for RbNormalizer {
314
- fn class() -> RClass {
315
- *memoize!(RClass: {
316
- let class: RClass = crate::normalizers().const_get("Normalizer").unwrap();
317
- class.undef_alloc_func();
318
- class
319
- })
329
+ fn class(ruby: &Ruby) -> RClass {
330
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
331
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
332
+ class.undef_default_alloc_func();
333
+ class
334
+ });
335
+ ruby.get_inner(&CLASS)
320
336
  }
321
337
 
322
338
  fn data_type() -> &'static DataType {
323
- memoize!(DataType: DataTypeBuilder::<RbNormalizer>::new("Tokenizers::Normalizers::Normalizer").build())
324
- }
325
-
326
- fn class_for(value: &Self) -> RClass {
339
+ static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
340
+ &DATA_TYPE
341
+ }
342
+
343
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
344
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
345
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
346
+ class.undef_default_alloc_func();
347
+ class
348
+ });
349
+ static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
355
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
356
+ class.undef_default_alloc_func();
357
+ class
358
+ });
359
+ static NFD: Lazy<RClass> = Lazy::new(|ruby| {
360
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
361
+ class.undef_default_alloc_func();
362
+ class
363
+ });
364
+ static NFC: Lazy<RClass> = Lazy::new(|ruby| {
365
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
366
+ class.undef_default_alloc_func();
367
+ class
368
+ });
369
+ static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
370
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
371
+ class.undef_default_alloc_func();
372
+ class
373
+ });
374
+ static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
375
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
376
+ class.undef_default_alloc_func();
377
+ class
378
+ });
379
+ static NMT: Lazy<RClass> = Lazy::new(|ruby| {
380
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
381
+ class.undef_default_alloc_func();
382
+ class
383
+ });
384
+ static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
385
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
386
+ class.undef_default_alloc_func();
387
+ class
388
+ });
389
+ static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
390
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
391
+ class.undef_default_alloc_func();
392
+ class
393
+ });
394
+ static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
395
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
396
+ class.undef_default_alloc_func();
397
+ class
398
+ });
399
+ static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
400
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
401
+ class.undef_default_alloc_func();
402
+ class
403
+ });
327
404
  match &value.normalizer {
328
- RbNormalizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
329
- let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
330
- class.undef_alloc_func();
331
- class
332
- }),
405
+ RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
333
406
  RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
334
407
  RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
335
- NormalizerWrapper::BertNormalizer(_) => *memoize!(RClass: {
336
- let class: RClass = crate::normalizers().const_get("BertNormalizer").unwrap();
337
- class.undef_alloc_func();
338
- class
339
- }),
340
- NormalizerWrapper::Lowercase(_) => *memoize!(RClass: {
341
- let class: RClass = crate::normalizers().const_get("Lowercase").unwrap();
342
- class.undef_alloc_func();
343
- class
344
- }),
345
- NormalizerWrapper::NFD(_) => *memoize!(RClass: {
346
- let class: RClass = crate::normalizers().const_get("NFD").unwrap();
347
- class.undef_alloc_func();
348
- class
349
- }),
350
- NormalizerWrapper::NFC(_) => *memoize!(RClass: {
351
- let class: RClass = crate::normalizers().const_get("NFC").unwrap();
352
- class.undef_alloc_func();
353
- class
354
- }),
355
- NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
356
- let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
357
- class.undef_alloc_func();
358
- class
359
- }),
360
- NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
361
- let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
362
- class.undef_alloc_func();
363
- class
364
- }),
365
- NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
366
- let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
367
- class.undef_alloc_func();
368
- class
369
- }),
370
- NormalizerWrapper::Replace(_) => *memoize!(RClass: {
371
- let class: RClass = crate::normalizers().const_get("Replace").unwrap();
372
- class.undef_alloc_func();
373
- class
374
- }),
375
- NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
376
- let class: RClass = crate::normalizers().const_get("Strip").unwrap();
377
- class.undef_alloc_func();
378
- class
379
- }),
380
- NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
381
- let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
382
- class.undef_alloc_func();
383
- class
384
- }),
408
+ NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
409
+ NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
410
+ NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
411
+ NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
412
+ NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
413
+ NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
414
+ NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
415
+ NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
416
+ NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
417
+ NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
418
+ NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
385
419
  _ => todo!(),
386
420
  },
387
421
  },
@@ -389,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
389
423
  }
390
424
  }
391
425
 
392
- pub fn normalizers(module: &RModule) -> RbResult<()> {
393
- let normalizer = module.define_class("Normalizer", Default::default())?;
426
+ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
427
+ let normalizer = module.define_class("Normalizer", ruby.class_object())?;
394
428
  normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
395
429
 
396
430
  let class = module.define_class("Sequence", normalizer)?;
@@ -428,6 +462,11 @@ pub fn normalizers(module: &RModule) -> RbResult<()> {
428
462
  let class = module.define_class("Replace", normalizer)?;
429
463
  class.define_singleton_method("new", function!(RbReplace::new, 2))?;
430
464
 
465
+ let class = module.define_class("Prepend", normalizer)?;
466
+ class.define_singleton_method("_new", function!(RbPrepend::new, 1))?;
467
+ class.define_method("prepend", method!(RbNormalizer::prepend_prepend, 0))?;
468
+ class.define_method("prepend=", method!(RbNormalizer::prepend_set_prepend, 1))?;
469
+
431
470
  let class = module.define_class("Strip", normalizer)?;
432
471
  class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
433
472
  class.define_method("left", method!(RbNormalizer::strip_left, 0))?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
6
- RArray, RClass, RModule, TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
+ RArray, RClass, RModule, Ruby, TryConvert, TypedData,
7
6
  };
8
7
 
9
8
  use serde::ser::SerializeStruct;
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
23
22
  use tk::{PreTokenizedString, PreTokenizer};
24
23
 
25
24
  use super::utils::*;
26
- use super::{RbError, RbResult};
25
+ use super::{PRE_TOKENIZERS, RbError, RbResult};
27
26
 
28
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
29
28
  pub struct RbPreTokenizer {
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
215
214
 
216
215
  impl RbWhitespace {
217
216
  pub fn new() -> RbPreTokenizer {
218
- Whitespace::default().into()
217
+ Whitespace.into()
219
218
  }
220
219
  }
221
220
 
@@ -241,7 +240,7 @@ impl RbSequence {
241
240
  fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
242
241
  let mut sequence = Vec::with_capacity(pre_tokenizers.len());
243
242
  for n in pre_tokenizers.each() {
244
- let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
243
+ let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
245
244
  match &pretokenizer.pretok {
246
245
  RbPreTokenizerTypeWrapper::Sequence(inner) => {
247
246
  sequence.extend(inner.iter().cloned())
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
346
345
  }
347
346
 
348
347
  unsafe impl TypedData for RbPreTokenizer {
349
- fn class() -> RClass {
350
- *memoize!(RClass: {
351
- let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
352
- class.undef_alloc_func();
353
- class
354
- })
348
+ fn class(ruby: &Ruby) -> RClass {
349
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ ruby.get_inner(&CLASS)
355
355
  }
356
356
 
357
357
  fn data_type() -> &'static DataType {
358
- memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
359
- }
360
-
361
- fn class_for(value: &Self) -> RClass {
358
+ static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
359
+ &DATA_TYPE
360
+ }
361
+
362
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
363
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
364
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
365
+ class.undef_default_alloc_func();
366
+ class
367
+ });
368
+ static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
369
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
370
+ class.undef_default_alloc_func();
371
+ class
372
+ });
373
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
374
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
375
+ class.undef_default_alloc_func();
376
+ class
377
+ });
378
+ static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
379
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
380
+ class.undef_default_alloc_func();
381
+ class
382
+ });
383
+ static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
384
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
385
+ class.undef_default_alloc_func();
386
+ class
387
+ });
388
+ static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
389
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
390
+ class.undef_default_alloc_func();
391
+ class
392
+ });
393
+ static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
394
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
395
+ class.undef_default_alloc_func();
396
+ class
397
+ });
398
+ static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
399
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
400
+ class.undef_default_alloc_func();
401
+ class
402
+ });
403
+ static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
404
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
405
+ class.undef_default_alloc_func();
406
+ class
407
+ });
408
+ static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
409
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
410
+ class.undef_default_alloc_func();
411
+ class
412
+ });
413
+ static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
414
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
415
+ class.undef_default_alloc_func();
416
+ class
417
+ });
362
418
  match &value.pretok {
363
- RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
364
- let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
365
- class.undef_alloc_func();
366
- class
367
- }),
419
+ RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
368
420
  RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
369
421
  RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
370
- PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
371
- let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
372
- class.undef_alloc_func();
373
- class
374
- }),
375
- PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
376
- let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
377
- class.undef_alloc_func();
378
- class
379
- }),
380
- PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
381
- let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
382
- class.undef_alloc_func();
383
- class
384
- }),
385
- PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
386
- let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
387
- class.undef_alloc_func();
388
- class
389
- }),
390
- PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
391
- let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
392
- class.undef_alloc_func();
393
- class
394
- }),
395
- PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
396
- let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
397
- class.undef_alloc_func();
398
- class
399
- }),
400
- PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
401
- let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
402
- class.undef_alloc_func();
403
- class
404
- }),
405
- PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
406
- let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
407
- class.undef_alloc_func();
408
- class
409
- }),
410
- PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
411
- let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
412
- class.undef_alloc_func();
413
- class
414
- }),
415
- PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
416
- let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
417
- class.undef_alloc_func();
418
- class
419
- }),
422
+ PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
423
+ PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
424
+ PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
425
+ PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
426
+ PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
427
+ PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
428
+ PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
429
+ PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
430
+ PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
431
+ PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
420
432
  _ => todo!(),
421
433
  },
422
434
  },
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
424
436
  }
425
437
  }
426
438
 
427
- pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
428
- let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
439
+ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
440
+ let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
429
441
  pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
430
442
 
431
443
  let class = module.define_class("Sequence", pre_tokenizer)?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::Arc;
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
- TryConvert, TypedData, Value,
4
+ data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
5
+ Ruby, TryConvert, TypedData, Value,
7
6
  };
8
7
  use serde::{Deserialize, Serialize};
9
8
  use tk::processors::bert::BertProcessing;
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
13
12
  use tk::processors::PostProcessorWrapper;
14
13
  use tk::{Encoding, PostProcessor};
15
14
 
16
- use super::RbResult;
15
+ use super::{PROCESSORS, RbResult};
17
16
 
18
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
19
18
  pub struct RbPostProcessor {
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
53
52
 
54
53
  impl TryConvert for RbSpecialToken {
55
54
  fn try_convert(ob: Value) -> RbResult<Self> {
56
- if let Ok(v) = ob.try_convert::<(String, u32)>() {
55
+ if let Ok(v) = <(String, u32)>::try_convert(ob) {
57
56
  Ok(Self(v.into()))
58
- } else if let Ok(v) = ob.try_convert::<(u32, String)>() {
57
+ } else if let Ok(v) = <(u32, String)>::try_convert(ob) {
59
58
  Ok(Self(v.into()))
60
59
  } else {
61
60
  todo!()
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
74
73
 
75
74
  impl TryConvert for RbTemplate {
76
75
  fn try_convert(ob: Value) -> RbResult<Self> {
77
- if let Ok(s) = ob.try_convert::<String>() {
76
+ if let Ok(s) = String::try_convert(ob) {
78
77
  Ok(Self(
79
78
  s.try_into().unwrap(), //.map_err(RbError::from)?,
80
79
  ))
81
- } else if let Ok(s) = ob.try_convert::<Vec<String>>() {
80
+ } else if let Ok(s) = <Vec<String>>::try_convert(ob) {
82
81
  Ok(Self(
83
82
  s.try_into().unwrap(), //.map_err(RbError::from)?,
84
83
  ))
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
152
151
  }
153
152
 
154
153
  unsafe impl TypedData for RbPostProcessor {
155
- fn class() -> RClass {
156
- *memoize!(RClass: {
157
- let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
158
- class.undef_alloc_func();
159
- class
160
- })
154
+ fn class(ruby: &Ruby) -> RClass {
155
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
156
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
157
+ class.undef_default_alloc_func();
158
+ class
159
+ });
160
+ ruby.get_inner(&CLASS)
161
161
  }
162
162
 
163
163
  fn data_type() -> &'static DataType {
164
- memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
164
+ static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
165
+ &DATA_TYPE
165
166
  }
166
167
 
167
- fn class_for(value: &Self) -> RClass {
168
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
169
+ static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
170
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
171
+ class.undef_default_alloc_func();
172
+ class
173
+ });
174
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
175
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
176
+ class.undef_default_alloc_func();
177
+ class
178
+ });
179
+ static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
180
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
181
+ class.undef_default_alloc_func();
182
+ class
183
+ });
184
+ static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
185
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
186
+ class.undef_default_alloc_func();
187
+ class
188
+ });
168
189
  match *value.processor {
169
- PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
170
- let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
171
- class.undef_alloc_func();
172
- class
173
- }),
174
- PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
175
- let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
176
- class.undef_alloc_func();
177
- class
178
- }),
179
- PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
180
- let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
181
- class.undef_alloc_func();
182
- class
183
- }),
184
- PostProcessorWrapper::Template(_) => *memoize!(RClass: {
185
- let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
186
- class.undef_alloc_func();
187
- class
188
- }),
190
+ PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
191
+ PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
192
+ PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
193
+ PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
189
194
  _ => todo!(),
190
195
  }
191
196
  }
192
197
  }
193
198
 
194
- pub fn processors(module: &RModule) -> RbResult<()> {
195
- let post_processor = module.define_class("PostProcessor", Default::default())?;
199
+ pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
200
+ let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
196
201
 
197
202
  let class = module.define_class("BertProcessing", post_processor)?;
198
203
  class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;