tokenizers 0.3.3 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
3
3
  use std::sync::{Arc, RwLock};
4
4
 
5
5
  use crate::trainers::RbTrainer;
6
- use magnus::typed_data::DataTypeBuilder;
6
+ use magnus::prelude::*;
7
7
  use magnus::{
8
- exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
- RClass, RHash, RModule, Symbol, TypedData, Value,
8
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
10
10
  };
11
11
  use serde::{Deserialize, Serialize};
12
12
  use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
16
16
  use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
17
17
  use tk::{Model, Token};
18
18
 
19
- use super::{RbError, RbResult};
19
+ use super::{MODELS, RbError, RbResult};
20
20
 
21
21
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
22
22
  pub struct RbModel {
@@ -73,37 +73,37 @@ impl RbBPE {
73
73
  fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
74
74
  let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
75
75
  if !value.is_nil() {
76
- builder = builder.cache_capacity(value.try_convert()?);
76
+ builder = builder.cache_capacity(TryConvert::try_convert(value)?);
77
77
  }
78
78
 
79
79
  let value: Value = kwargs.delete(Symbol::new("dropout"))?;
80
80
  if !value.is_nil() {
81
- builder = builder.dropout(value.try_convert()?);
81
+ builder = builder.dropout(TryConvert::try_convert(value)?);
82
82
  }
83
83
 
84
84
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
85
85
  if !value.is_nil() {
86
- builder = builder.unk_token(value.try_convert()?);
86
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
87
87
  }
88
88
 
89
89
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
90
90
  if !value.is_nil() {
91
- builder = builder.continuing_subword_prefix(value.try_convert()?);
91
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
92
92
  }
93
93
 
94
94
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
95
95
  if !value.is_nil() {
96
- builder = builder.end_of_word_suffix(value.try_convert()?);
96
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
97
97
  }
98
98
 
99
99
  let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
100
100
  if !value.is_nil() {
101
- builder = builder.fuse_unk(value.try_convert()?);
101
+ builder = builder.fuse_unk(TryConvert::try_convert(value)?);
102
102
  }
103
103
 
104
104
  let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
105
105
  if !value.is_nil() {
106
- builder = builder.byte_fallback(value.try_convert()?);
106
+ builder = builder.byte_fallback(TryConvert::try_convert(value)?);
107
107
  }
108
108
 
109
109
  if !kwargs.is_empty() {
@@ -234,13 +234,13 @@ impl RbModel {
234
234
  pub struct RbUnigram {}
235
235
 
236
236
  impl RbUnigram {
237
- fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
238
- match (vocab, unk_id) {
239
- (Some(vocab), unk_id) => {
240
- let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
237
+ fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
238
+ match (vocab, unk_id, byte_fallback) {
239
+ (Some(vocab), unk_id, byte_fallback) => {
240
+ let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
241
241
  Ok(model.into())
242
242
  }
243
- (None, None) => Ok(Unigram::default().into()),
243
+ (None, None, _) => Ok(Unigram::default().into()),
244
244
  _ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
245
245
  }
246
246
  }
@@ -277,17 +277,17 @@ impl RbWordPiece {
277
277
  fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
278
278
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
279
279
  if !value.is_nil() {
280
- builder = builder.unk_token(value.try_convert()?);
280
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
281
281
  }
282
282
 
283
283
  let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
284
284
  if !value.is_nil() {
285
- builder = builder.max_input_chars_per_word(value.try_convert()?);
285
+ builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
286
286
  }
287
287
 
288
288
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
289
289
  if !value.is_nil() {
290
- builder = builder.continuing_subword_prefix(value.try_convert()?);
290
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
291
291
  }
292
292
 
293
293
  if !kwargs.is_empty() {
@@ -314,46 +314,52 @@ impl RbWordPiece {
314
314
  }
315
315
 
316
316
  unsafe impl TypedData for RbModel {
317
- fn class() -> RClass {
318
- *memoize!(RClass: {
319
- let class: RClass = crate::models().const_get("Model").unwrap();
320
- class.undef_alloc_func();
317
+ fn class(ruby: &Ruby) -> RClass {
318
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
319
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
320
+ class.undef_default_alloc_func();
321
321
  class
322
- })
322
+ });
323
+ ruby.get_inner(&CLASS)
323
324
  }
324
325
 
325
326
  fn data_type() -> &'static DataType {
326
- memoize!(DataType: DataTypeBuilder::<RbModel>::new("Tokenizers::Models::Model").build())
327
- }
328
-
329
- fn class_for(value: &Self) -> RClass {
327
+ static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
328
+ &DATA_TYPE
329
+ }
330
+
331
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
332
+ static BPE: Lazy<RClass> = Lazy::new(|ruby| {
333
+ let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
334
+ class.undef_default_alloc_func();
335
+ class
336
+ });
337
+ static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
338
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
339
+ class.undef_default_alloc_func();
340
+ class
341
+ });
342
+ static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
343
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
344
+ class.undef_default_alloc_func();
345
+ class
346
+ });
347
+ static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
348
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
349
+ class.undef_default_alloc_func();
350
+ class
351
+ });
330
352
  match *value.model.read().unwrap() {
331
- ModelWrapper::BPE(_) => *memoize!(RClass: {
332
- let class: RClass = crate::models().const_get("BPE").unwrap();
333
- class.undef_alloc_func();
334
- class
335
- }),
336
- ModelWrapper::Unigram(_) => *memoize!(RClass: {
337
- let class: RClass = crate::models().const_get("Unigram").unwrap();
338
- class.undef_alloc_func();
339
- class
340
- }),
341
- ModelWrapper::WordLevel(_) => *memoize!(RClass: {
342
- let class: RClass = crate::models().const_get("WordLevel").unwrap();
343
- class.undef_alloc_func();
344
- class
345
- }),
346
- ModelWrapper::WordPiece(_) => *memoize!(RClass: {
347
- let class: RClass = crate::models().const_get("WordPiece").unwrap();
348
- class.undef_alloc_func();
349
- class
350
- }),
353
+ ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
354
+ ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
355
+ ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
356
+ ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
351
357
  }
352
358
  }
353
359
  }
354
360
 
355
- pub fn models(module: &RModule) -> RbResult<()> {
356
- let model = module.define_class("Model", Default::default())?;
361
+ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
362
+ let model = module.define_class("Model", ruby.class_object())?;
357
363
 
358
364
  let class = module.define_class("BPE", model)?;
359
365
  class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
@@ -372,7 +378,7 @@ pub fn models(module: &RModule) -> RbResult<()> {
372
378
  class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
373
379
 
374
380
  let class = module.define_class("Unigram", model)?;
375
- class.define_singleton_method("_new", function!(RbUnigram::new, 2))?;
381
+ class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
376
382
 
377
383
  let class = module.define_class("WordLevel", model)?;
378
384
  class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
6
- TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
5
+ Ruby, TryConvert, TypedData,
7
6
  };
8
7
  use serde::ser::SerializeStruct;
9
8
  use serde::{Deserialize, Serialize, Serializer};
@@ -14,7 +13,7 @@ use tk::normalizers::{
14
13
  use tk::{NormalizedString, Normalizer};
15
14
 
16
15
  use super::utils::*;
17
- use super::{RbError, RbResult};
16
+ use super::{NORMALIZERS, RbError, RbResult};
18
17
 
19
18
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
20
19
  pub struct RbNormalizer {
@@ -224,7 +223,7 @@ impl RbSequence {
224
223
  fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
225
224
  let mut sequence = Vec::with_capacity(normalizers.len());
226
225
  for n in normalizers.each() {
227
- let normalizer: &RbNormalizer = n?.try_convert()?;
226
+ let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
228
227
  match &normalizer.normalizer {
229
228
  RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
230
229
  RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
@@ -327,82 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
327
326
  }
328
327
 
329
328
  unsafe impl TypedData for RbNormalizer {
330
- fn class() -> RClass {
331
- *memoize!(RClass: {
332
- let class: RClass = crate::normalizers().const_get("Normalizer").unwrap();
333
- class.undef_alloc_func();
334
- class
335
- })
329
+ fn class(ruby: &Ruby) -> RClass {
330
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
331
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
332
+ class.undef_default_alloc_func();
333
+ class
334
+ });
335
+ ruby.get_inner(&CLASS)
336
336
  }
337
337
 
338
338
  fn data_type() -> &'static DataType {
339
- memoize!(DataType: DataTypeBuilder::<RbNormalizer>::new("Tokenizers::Normalizers::Normalizer").build())
340
- }
341
-
342
- fn class_for(value: &Self) -> RClass {
339
+ static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
340
+ &DATA_TYPE
341
+ }
342
+
343
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
344
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
345
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
346
+ class.undef_default_alloc_func();
347
+ class
348
+ });
349
+ static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
355
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
356
+ class.undef_default_alloc_func();
357
+ class
358
+ });
359
+ static NFD: Lazy<RClass> = Lazy::new(|ruby| {
360
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
361
+ class.undef_default_alloc_func();
362
+ class
363
+ });
364
+ static NFC: Lazy<RClass> = Lazy::new(|ruby| {
365
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
366
+ class.undef_default_alloc_func();
367
+ class
368
+ });
369
+ static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
370
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
371
+ class.undef_default_alloc_func();
372
+ class
373
+ });
374
+ static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
375
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
376
+ class.undef_default_alloc_func();
377
+ class
378
+ });
379
+ static NMT: Lazy<RClass> = Lazy::new(|ruby| {
380
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
381
+ class.undef_default_alloc_func();
382
+ class
383
+ });
384
+ static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
385
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
386
+ class.undef_default_alloc_func();
387
+ class
388
+ });
389
+ static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
390
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
391
+ class.undef_default_alloc_func();
392
+ class
393
+ });
394
+ static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
395
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
396
+ class.undef_default_alloc_func();
397
+ class
398
+ });
399
+ static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
400
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
401
+ class.undef_default_alloc_func();
402
+ class
403
+ });
343
404
  match &value.normalizer {
344
- RbNormalizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
345
- let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
346
- class.undef_alloc_func();
347
- class
348
- }),
405
+ RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
349
406
  RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
350
407
  RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
351
- NormalizerWrapper::BertNormalizer(_) => *memoize!(RClass: {
352
- let class: RClass = crate::normalizers().const_get("BertNormalizer").unwrap();
353
- class.undef_alloc_func();
354
- class
355
- }),
356
- NormalizerWrapper::Lowercase(_) => *memoize!(RClass: {
357
- let class: RClass = crate::normalizers().const_get("Lowercase").unwrap();
358
- class.undef_alloc_func();
359
- class
360
- }),
361
- NormalizerWrapper::NFD(_) => *memoize!(RClass: {
362
- let class: RClass = crate::normalizers().const_get("NFD").unwrap();
363
- class.undef_alloc_func();
364
- class
365
- }),
366
- NormalizerWrapper::NFC(_) => *memoize!(RClass: {
367
- let class: RClass = crate::normalizers().const_get("NFC").unwrap();
368
- class.undef_alloc_func();
369
- class
370
- }),
371
- NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
372
- let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
373
- class.undef_alloc_func();
374
- class
375
- }),
376
- NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
377
- let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
378
- class.undef_alloc_func();
379
- class
380
- }),
381
- NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
382
- let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
383
- class.undef_alloc_func();
384
- class
385
- }),
386
- NormalizerWrapper::Replace(_) => *memoize!(RClass: {
387
- let class: RClass = crate::normalizers().const_get("Replace").unwrap();
388
- class.undef_alloc_func();
389
- class
390
- }),
391
- NormalizerWrapper::Prepend(_) => *memoize!(RClass: {
392
- let class: RClass = crate::normalizers().const_get("Prepend").unwrap();
393
- class.undef_alloc_func();
394
- class
395
- }),
396
- NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
397
- let class: RClass = crate::normalizers().const_get("Strip").unwrap();
398
- class.undef_alloc_func();
399
- class
400
- }),
401
- NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
402
- let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
403
- class.undef_alloc_func();
404
- class
405
- }),
408
+ NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
409
+ NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
410
+ NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
411
+ NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
412
+ NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
413
+ NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
414
+ NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
415
+ NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
416
+ NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
417
+ NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
418
+ NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
406
419
  _ => todo!(),
407
420
  },
408
421
  },
@@ -410,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
410
423
  }
411
424
  }
412
425
 
413
- pub fn normalizers(module: &RModule) -> RbResult<()> {
414
- let normalizer = module.define_class("Normalizer", Default::default())?;
426
+ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
427
+ let normalizer = module.define_class("Normalizer", ruby.class_object())?;
415
428
  normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
416
429
 
417
430
  let class = module.define_class("Sequence", normalizer)?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
6
- RArray, RClass, RModule, TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
+ RArray, RClass, RModule, Ruby, TryConvert, TypedData,
7
6
  };
8
7
 
9
8
  use serde::ser::SerializeStruct;
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
23
22
  use tk::{PreTokenizedString, PreTokenizer};
24
23
 
25
24
  use super::utils::*;
26
- use super::{RbError, RbResult};
25
+ use super::{PRE_TOKENIZERS, RbError, RbResult};
27
26
 
28
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
29
28
  pub struct RbPreTokenizer {
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
215
214
 
216
215
  impl RbWhitespace {
217
216
  pub fn new() -> RbPreTokenizer {
218
- Whitespace::default().into()
217
+ Whitespace.into()
219
218
  }
220
219
  }
221
220
 
@@ -241,7 +240,7 @@ impl RbSequence {
241
240
  fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
242
241
  let mut sequence = Vec::with_capacity(pre_tokenizers.len());
243
242
  for n in pre_tokenizers.each() {
244
- let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
243
+ let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
245
244
  match &pretokenizer.pretok {
246
245
  RbPreTokenizerTypeWrapper::Sequence(inner) => {
247
246
  sequence.extend(inner.iter().cloned())
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
346
345
  }
347
346
 
348
347
  unsafe impl TypedData for RbPreTokenizer {
349
- fn class() -> RClass {
350
- *memoize!(RClass: {
351
- let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
352
- class.undef_alloc_func();
353
- class
354
- })
348
+ fn class(ruby: &Ruby) -> RClass {
349
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ ruby.get_inner(&CLASS)
355
355
  }
356
356
 
357
357
  fn data_type() -> &'static DataType {
358
- memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
359
- }
360
-
361
- fn class_for(value: &Self) -> RClass {
358
+ static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
359
+ &DATA_TYPE
360
+ }
361
+
362
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
363
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
364
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
365
+ class.undef_default_alloc_func();
366
+ class
367
+ });
368
+ static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
369
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
370
+ class.undef_default_alloc_func();
371
+ class
372
+ });
373
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
374
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
375
+ class.undef_default_alloc_func();
376
+ class
377
+ });
378
+ static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
379
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
380
+ class.undef_default_alloc_func();
381
+ class
382
+ });
383
+ static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
384
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
385
+ class.undef_default_alloc_func();
386
+ class
387
+ });
388
+ static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
389
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
390
+ class.undef_default_alloc_func();
391
+ class
392
+ });
393
+ static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
394
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
395
+ class.undef_default_alloc_func();
396
+ class
397
+ });
398
+ static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
399
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
400
+ class.undef_default_alloc_func();
401
+ class
402
+ });
403
+ static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
404
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
405
+ class.undef_default_alloc_func();
406
+ class
407
+ });
408
+ static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
409
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
410
+ class.undef_default_alloc_func();
411
+ class
412
+ });
413
+ static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
414
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
415
+ class.undef_default_alloc_func();
416
+ class
417
+ });
362
418
  match &value.pretok {
363
- RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
364
- let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
365
- class.undef_alloc_func();
366
- class
367
- }),
419
+ RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
368
420
  RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
369
421
  RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
370
- PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
371
- let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
372
- class.undef_alloc_func();
373
- class
374
- }),
375
- PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
376
- let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
377
- class.undef_alloc_func();
378
- class
379
- }),
380
- PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
381
- let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
382
- class.undef_alloc_func();
383
- class
384
- }),
385
- PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
386
- let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
387
- class.undef_alloc_func();
388
- class
389
- }),
390
- PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
391
- let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
392
- class.undef_alloc_func();
393
- class
394
- }),
395
- PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
396
- let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
397
- class.undef_alloc_func();
398
- class
399
- }),
400
- PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
401
- let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
402
- class.undef_alloc_func();
403
- class
404
- }),
405
- PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
406
- let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
407
- class.undef_alloc_func();
408
- class
409
- }),
410
- PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
411
- let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
412
- class.undef_alloc_func();
413
- class
414
- }),
415
- PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
416
- let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
417
- class.undef_alloc_func();
418
- class
419
- }),
422
+ PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
423
+ PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
424
+ PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
425
+ PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
426
+ PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
427
+ PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
428
+ PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
429
+ PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
430
+ PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
431
+ PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
420
432
  _ => todo!(),
421
433
  },
422
434
  },
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
424
436
  }
425
437
  }
426
438
 
427
- pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
428
- let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
439
+ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
440
+ let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
429
441
  pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
430
442
 
431
443
  let class = module.define_class("Sequence", pre_tokenizer)?;