tokenizers 0.3.3 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
3
3
  use std::sync::{Arc, RwLock};
4
4
 
5
5
  use crate::trainers::RbTrainer;
6
- use magnus::typed_data::DataTypeBuilder;
6
+ use magnus::prelude::*;
7
7
  use magnus::{
8
- exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
- RClass, RHash, RModule, Symbol, TypedData, Value,
8
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
10
10
  };
11
11
  use serde::{Deserialize, Serialize};
12
12
  use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
16
16
  use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
17
17
  use tk::{Model, Token};
18
18
 
19
- use super::{RbError, RbResult};
19
+ use super::{MODELS, RbError, RbResult};
20
20
 
21
21
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
22
22
  pub struct RbModel {
@@ -73,37 +73,37 @@ impl RbBPE {
73
73
  fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
74
74
  let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
75
75
  if !value.is_nil() {
76
- builder = builder.cache_capacity(value.try_convert()?);
76
+ builder = builder.cache_capacity(TryConvert::try_convert(value)?);
77
77
  }
78
78
 
79
79
  let value: Value = kwargs.delete(Symbol::new("dropout"))?;
80
80
  if !value.is_nil() {
81
- builder = builder.dropout(value.try_convert()?);
81
+ builder = builder.dropout(TryConvert::try_convert(value)?);
82
82
  }
83
83
 
84
84
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
85
85
  if !value.is_nil() {
86
- builder = builder.unk_token(value.try_convert()?);
86
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
87
87
  }
88
88
 
89
89
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
90
90
  if !value.is_nil() {
91
- builder = builder.continuing_subword_prefix(value.try_convert()?);
91
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
92
92
  }
93
93
 
94
94
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
95
95
  if !value.is_nil() {
96
- builder = builder.end_of_word_suffix(value.try_convert()?);
96
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
97
97
  }
98
98
 
99
99
  let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
100
100
  if !value.is_nil() {
101
- builder = builder.fuse_unk(value.try_convert()?);
101
+ builder = builder.fuse_unk(TryConvert::try_convert(value)?);
102
102
  }
103
103
 
104
104
  let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
105
105
  if !value.is_nil() {
106
- builder = builder.byte_fallback(value.try_convert()?);
106
+ builder = builder.byte_fallback(TryConvert::try_convert(value)?);
107
107
  }
108
108
 
109
109
  if !kwargs.is_empty() {
@@ -234,13 +234,13 @@ impl RbModel {
234
234
  pub struct RbUnigram {}
235
235
 
236
236
  impl RbUnigram {
237
- fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
238
- match (vocab, unk_id) {
239
- (Some(vocab), unk_id) => {
240
- let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
237
+ fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
238
+ match (vocab, unk_id, byte_fallback) {
239
+ (Some(vocab), unk_id, byte_fallback) => {
240
+ let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
241
241
  Ok(model.into())
242
242
  }
243
- (None, None) => Ok(Unigram::default().into()),
243
+ (None, None, _) => Ok(Unigram::default().into()),
244
244
  _ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
245
245
  }
246
246
  }
@@ -277,17 +277,17 @@ impl RbWordPiece {
277
277
  fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
278
278
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
279
279
  if !value.is_nil() {
280
- builder = builder.unk_token(value.try_convert()?);
280
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
281
281
  }
282
282
 
283
283
  let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
284
284
  if !value.is_nil() {
285
- builder = builder.max_input_chars_per_word(value.try_convert()?);
285
+ builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
286
286
  }
287
287
 
288
288
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
289
289
  if !value.is_nil() {
290
- builder = builder.continuing_subword_prefix(value.try_convert()?);
290
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
291
291
  }
292
292
 
293
293
  if !kwargs.is_empty() {
@@ -314,46 +314,52 @@ impl RbWordPiece {
314
314
  }
315
315
 
316
316
  unsafe impl TypedData for RbModel {
317
- fn class() -> RClass {
318
- *memoize!(RClass: {
319
- let class: RClass = crate::models().const_get("Model").unwrap();
320
- class.undef_alloc_func();
317
+ fn class(ruby: &Ruby) -> RClass {
318
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
319
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
320
+ class.undef_default_alloc_func();
321
321
  class
322
- })
322
+ });
323
+ ruby.get_inner(&CLASS)
323
324
  }
324
325
 
325
326
  fn data_type() -> &'static DataType {
326
- memoize!(DataType: DataTypeBuilder::<RbModel>::new("Tokenizers::Models::Model").build())
327
- }
328
-
329
- fn class_for(value: &Self) -> RClass {
327
+ static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
328
+ &DATA_TYPE
329
+ }
330
+
331
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
332
+ static BPE: Lazy<RClass> = Lazy::new(|ruby| {
333
+ let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
334
+ class.undef_default_alloc_func();
335
+ class
336
+ });
337
+ static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
338
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
339
+ class.undef_default_alloc_func();
340
+ class
341
+ });
342
+ static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
343
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
344
+ class.undef_default_alloc_func();
345
+ class
346
+ });
347
+ static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
348
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
349
+ class.undef_default_alloc_func();
350
+ class
351
+ });
330
352
  match *value.model.read().unwrap() {
331
- ModelWrapper::BPE(_) => *memoize!(RClass: {
332
- let class: RClass = crate::models().const_get("BPE").unwrap();
333
- class.undef_alloc_func();
334
- class
335
- }),
336
- ModelWrapper::Unigram(_) => *memoize!(RClass: {
337
- let class: RClass = crate::models().const_get("Unigram").unwrap();
338
- class.undef_alloc_func();
339
- class
340
- }),
341
- ModelWrapper::WordLevel(_) => *memoize!(RClass: {
342
- let class: RClass = crate::models().const_get("WordLevel").unwrap();
343
- class.undef_alloc_func();
344
- class
345
- }),
346
- ModelWrapper::WordPiece(_) => *memoize!(RClass: {
347
- let class: RClass = crate::models().const_get("WordPiece").unwrap();
348
- class.undef_alloc_func();
349
- class
350
- }),
353
+ ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
354
+ ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
355
+ ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
356
+ ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
351
357
  }
352
358
  }
353
359
  }
354
360
 
355
- pub fn models(module: &RModule) -> RbResult<()> {
356
- let model = module.define_class("Model", Default::default())?;
361
+ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
362
+ let model = module.define_class("Model", ruby.class_object())?;
357
363
 
358
364
  let class = module.define_class("BPE", model)?;
359
365
  class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
@@ -372,7 +378,7 @@ pub fn models(module: &RModule) -> RbResult<()> {
372
378
  class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
373
379
 
374
380
  let class = module.define_class("Unigram", model)?;
375
- class.define_singleton_method("_new", function!(RbUnigram::new, 2))?;
381
+ class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
376
382
 
377
383
  let class = module.define_class("WordLevel", model)?;
378
384
  class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
6
- TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
5
+ Ruby, TryConvert, TypedData,
7
6
  };
8
7
  use serde::ser::SerializeStruct;
9
8
  use serde::{Deserialize, Serialize, Serializer};
@@ -14,7 +13,7 @@ use tk::normalizers::{
14
13
  use tk::{NormalizedString, Normalizer};
15
14
 
16
15
  use super::utils::*;
17
- use super::{RbError, RbResult};
16
+ use super::{NORMALIZERS, RbError, RbResult};
18
17
 
19
18
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
20
19
  pub struct RbNormalizer {
@@ -224,7 +223,7 @@ impl RbSequence {
224
223
  fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
225
224
  let mut sequence = Vec::with_capacity(normalizers.len());
226
225
  for n in normalizers.each() {
227
- let normalizer: &RbNormalizer = n?.try_convert()?;
226
+ let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
228
227
  match &normalizer.normalizer {
229
228
  RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
230
229
  RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
@@ -327,82 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
327
326
  }
328
327
 
329
328
  unsafe impl TypedData for RbNormalizer {
330
- fn class() -> RClass {
331
- *memoize!(RClass: {
332
- let class: RClass = crate::normalizers().const_get("Normalizer").unwrap();
333
- class.undef_alloc_func();
334
- class
335
- })
329
+ fn class(ruby: &Ruby) -> RClass {
330
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
331
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
332
+ class.undef_default_alloc_func();
333
+ class
334
+ });
335
+ ruby.get_inner(&CLASS)
336
336
  }
337
337
 
338
338
  fn data_type() -> &'static DataType {
339
- memoize!(DataType: DataTypeBuilder::<RbNormalizer>::new("Tokenizers::Normalizers::Normalizer").build())
340
- }
341
-
342
- fn class_for(value: &Self) -> RClass {
339
+ static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
340
+ &DATA_TYPE
341
+ }
342
+
343
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
344
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
345
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
346
+ class.undef_default_alloc_func();
347
+ class
348
+ });
349
+ static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
355
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
356
+ class.undef_default_alloc_func();
357
+ class
358
+ });
359
+ static NFD: Lazy<RClass> = Lazy::new(|ruby| {
360
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
361
+ class.undef_default_alloc_func();
362
+ class
363
+ });
364
+ static NFC: Lazy<RClass> = Lazy::new(|ruby| {
365
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
366
+ class.undef_default_alloc_func();
367
+ class
368
+ });
369
+ static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
370
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
371
+ class.undef_default_alloc_func();
372
+ class
373
+ });
374
+ static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
375
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
376
+ class.undef_default_alloc_func();
377
+ class
378
+ });
379
+ static NMT: Lazy<RClass> = Lazy::new(|ruby| {
380
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
381
+ class.undef_default_alloc_func();
382
+ class
383
+ });
384
+ static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
385
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
386
+ class.undef_default_alloc_func();
387
+ class
388
+ });
389
+ static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
390
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
391
+ class.undef_default_alloc_func();
392
+ class
393
+ });
394
+ static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
395
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
396
+ class.undef_default_alloc_func();
397
+ class
398
+ });
399
+ static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
400
+ let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
401
+ class.undef_default_alloc_func();
402
+ class
403
+ });
343
404
  match &value.normalizer {
344
- RbNormalizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
345
- let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
346
- class.undef_alloc_func();
347
- class
348
- }),
405
+ RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
349
406
  RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
350
407
  RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
351
- NormalizerWrapper::BertNormalizer(_) => *memoize!(RClass: {
352
- let class: RClass = crate::normalizers().const_get("BertNormalizer").unwrap();
353
- class.undef_alloc_func();
354
- class
355
- }),
356
- NormalizerWrapper::Lowercase(_) => *memoize!(RClass: {
357
- let class: RClass = crate::normalizers().const_get("Lowercase").unwrap();
358
- class.undef_alloc_func();
359
- class
360
- }),
361
- NormalizerWrapper::NFD(_) => *memoize!(RClass: {
362
- let class: RClass = crate::normalizers().const_get("NFD").unwrap();
363
- class.undef_alloc_func();
364
- class
365
- }),
366
- NormalizerWrapper::NFC(_) => *memoize!(RClass: {
367
- let class: RClass = crate::normalizers().const_get("NFC").unwrap();
368
- class.undef_alloc_func();
369
- class
370
- }),
371
- NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
372
- let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
373
- class.undef_alloc_func();
374
- class
375
- }),
376
- NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
377
- let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
378
- class.undef_alloc_func();
379
- class
380
- }),
381
- NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
382
- let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
383
- class.undef_alloc_func();
384
- class
385
- }),
386
- NormalizerWrapper::Replace(_) => *memoize!(RClass: {
387
- let class: RClass = crate::normalizers().const_get("Replace").unwrap();
388
- class.undef_alloc_func();
389
- class
390
- }),
391
- NormalizerWrapper::Prepend(_) => *memoize!(RClass: {
392
- let class: RClass = crate::normalizers().const_get("Prepend").unwrap();
393
- class.undef_alloc_func();
394
- class
395
- }),
396
- NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
397
- let class: RClass = crate::normalizers().const_get("Strip").unwrap();
398
- class.undef_alloc_func();
399
- class
400
- }),
401
- NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
402
- let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
403
- class.undef_alloc_func();
404
- class
405
- }),
408
+ NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
409
+ NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
410
+ NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
411
+ NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
412
+ NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
413
+ NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
414
+ NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
415
+ NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
416
+ NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
417
+ NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
418
+ NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
406
419
  _ => todo!(),
407
420
  },
408
421
  },
@@ -410,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
410
423
  }
411
424
  }
412
425
 
413
- pub fn normalizers(module: &RModule) -> RbResult<()> {
414
- let normalizer = module.define_class("Normalizer", Default::default())?;
426
+ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
427
+ let normalizer = module.define_class("Normalizer", ruby.class_object())?;
415
428
  normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
416
429
 
417
430
  let class = module.define_class("Sequence", normalizer)?;
@@ -1,9 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
6
- RArray, RClass, RModule, TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
5
+ RArray, RClass, RModule, Ruby, TryConvert, TypedData,
7
6
  };
8
7
 
9
8
  use serde::ser::SerializeStruct;
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
23
22
  use tk::{PreTokenizedString, PreTokenizer};
24
23
 
25
24
  use super::utils::*;
26
- use super::{RbError, RbResult};
25
+ use super::{PRE_TOKENIZERS, RbError, RbResult};
27
26
 
28
27
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
29
28
  pub struct RbPreTokenizer {
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
215
214
 
216
215
  impl RbWhitespace {
217
216
  pub fn new() -> RbPreTokenizer {
218
- Whitespace::default().into()
217
+ Whitespace.into()
219
218
  }
220
219
  }
221
220
 
@@ -241,7 +240,7 @@ impl RbSequence {
241
240
  fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
242
241
  let mut sequence = Vec::with_capacity(pre_tokenizers.len());
243
242
  for n in pre_tokenizers.each() {
244
- let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
243
+ let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
245
244
  match &pretokenizer.pretok {
246
245
  RbPreTokenizerTypeWrapper::Sequence(inner) => {
247
246
  sequence.extend(inner.iter().cloned())
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
346
345
  }
347
346
 
348
347
  unsafe impl TypedData for RbPreTokenizer {
349
- fn class() -> RClass {
350
- *memoize!(RClass: {
351
- let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
352
- class.undef_alloc_func();
353
- class
354
- })
348
+ fn class(ruby: &Ruby) -> RClass {
349
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
350
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
351
+ class.undef_default_alloc_func();
352
+ class
353
+ });
354
+ ruby.get_inner(&CLASS)
355
355
  }
356
356
 
357
357
  fn data_type() -> &'static DataType {
358
- memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
359
- }
360
-
361
- fn class_for(value: &Self) -> RClass {
358
+ static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
359
+ &DATA_TYPE
360
+ }
361
+
362
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
363
+ static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
364
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
365
+ class.undef_default_alloc_func();
366
+ class
367
+ });
368
+ static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
369
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
370
+ class.undef_default_alloc_func();
371
+ class
372
+ });
373
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
374
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
375
+ class.undef_default_alloc_func();
376
+ class
377
+ });
378
+ static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
379
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
380
+ class.undef_default_alloc_func();
381
+ class
382
+ });
383
+ static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
384
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
385
+ class.undef_default_alloc_func();
386
+ class
387
+ });
388
+ static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
389
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
390
+ class.undef_default_alloc_func();
391
+ class
392
+ });
393
+ static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
394
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
395
+ class.undef_default_alloc_func();
396
+ class
397
+ });
398
+ static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
399
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
400
+ class.undef_default_alloc_func();
401
+ class
402
+ });
403
+ static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
404
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
405
+ class.undef_default_alloc_func();
406
+ class
407
+ });
408
+ static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
409
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
410
+ class.undef_default_alloc_func();
411
+ class
412
+ });
413
+ static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
414
+ let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
415
+ class.undef_default_alloc_func();
416
+ class
417
+ });
362
418
  match &value.pretok {
363
- RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
364
- let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
365
- class.undef_alloc_func();
366
- class
367
- }),
419
+ RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
368
420
  RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
369
421
  RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
370
- PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
371
- let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
372
- class.undef_alloc_func();
373
- class
374
- }),
375
- PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
376
- let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
377
- class.undef_alloc_func();
378
- class
379
- }),
380
- PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
381
- let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
382
- class.undef_alloc_func();
383
- class
384
- }),
385
- PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
386
- let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
387
- class.undef_alloc_func();
388
- class
389
- }),
390
- PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
391
- let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
392
- class.undef_alloc_func();
393
- class
394
- }),
395
- PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
396
- let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
397
- class.undef_alloc_func();
398
- class
399
- }),
400
- PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
401
- let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
402
- class.undef_alloc_func();
403
- class
404
- }),
405
- PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
406
- let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
407
- class.undef_alloc_func();
408
- class
409
- }),
410
- PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
411
- let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
412
- class.undef_alloc_func();
413
- class
414
- }),
415
- PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
416
- let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
417
- class.undef_alloc_func();
418
- class
419
- }),
422
+ PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
423
+ PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
424
+ PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
425
+ PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
426
+ PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
427
+ PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
428
+ PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
429
+ PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
430
+ PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
431
+ PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
420
432
  _ => todo!(),
421
433
  },
422
434
  },
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
424
436
  }
425
437
  }
426
438
 
427
- pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
428
- let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
439
+ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
440
+ let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
429
441
  pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
430
442
 
431
443
  let class = module.define_class("Sequence", pre_tokenizer)?;