tokenizers 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,20 +1,25 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
3
+ use magnus::value::Lazy;
4
4
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
- TypedData,
5
+ data_type_builder, function, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
+ Ruby, TypedData,
7
7
  };
8
8
  use serde::{Deserialize, Serialize};
9
9
  use tk::decoders::bpe::BPEDecoder;
10
+ use tk::decoders::byte_fallback::ByteFallback;
10
11
  use tk::decoders::byte_level::ByteLevel;
11
12
  use tk::decoders::ctc::CTC;
13
+ use tk::decoders::fuse::Fuse;
12
14
  use tk::decoders::metaspace::Metaspace;
15
+ use tk::decoders::strip::Strip;
13
16
  use tk::decoders::wordpiece::WordPiece;
14
17
  use tk::decoders::DecoderWrapper;
15
18
  use tk::Decoder;
19
+ use tk::normalizers::replace::Replace;
16
20
 
17
- use super::RbResult;
21
+ use super::utils::*;
22
+ use super::{DECODERS, RbError, RbResult};
18
23
 
19
24
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
20
25
  pub struct RbDecoder {
@@ -89,6 +94,30 @@ impl RbDecoder {
89
94
  setter!(self, CTC, word_delimiter_token, word_delimiter_token);
90
95
  }
91
96
 
97
+ fn strip_content(&self) -> char {
98
+ getter!(self, Strip, content)
99
+ }
100
+
101
+ fn strip_set_content(&self, content: char) {
102
+ setter!(self, Strip, content, content)
103
+ }
104
+
105
+ fn strip_start(&self) -> usize {
106
+ getter!(self, Strip, start)
107
+ }
108
+
109
+ fn strip_set_start(&self, start: usize) {
110
+ setter!(self, Strip, start, start)
111
+ }
112
+
113
+ fn strip_stop(&self) -> usize {
114
+ getter!(self, Strip, stop)
115
+ }
116
+
117
+ fn strip_set_stop(&self, stop: usize) {
118
+ setter!(self, Strip, stop, stop)
119
+ }
120
+
92
121
  pub fn metaspace_replacement(&self) -> char {
93
122
  getter!(self, Metaspace, get_replacement().clone())
94
123
  }
@@ -130,6 +159,14 @@ impl RbBPEDecoder {
130
159
  }
131
160
  }
132
161
 
162
+ pub struct RbByteFallbackDecoder {}
163
+
164
+ impl RbByteFallbackDecoder {
165
+ pub fn new() -> RbDecoder {
166
+ ByteFallback::default().into()
167
+ }
168
+ }
169
+
133
170
  pub struct RbByteLevelDecoder {}
134
171
 
135
172
  impl RbByteLevelDecoder {
@@ -146,6 +183,14 @@ impl RbCTC {
146
183
  }
147
184
  }
148
185
 
186
+ pub struct RbFuse {}
187
+
188
+ impl RbFuse {
189
+ pub fn new() -> RbDecoder {
190
+ Fuse::default().into()
191
+ }
192
+ }
193
+
149
194
  pub struct RbMetaspaceDecoder {}
150
195
 
151
196
  impl RbMetaspaceDecoder {
@@ -154,6 +199,22 @@ impl RbMetaspaceDecoder {
154
199
  }
155
200
  }
156
201
 
202
+ pub struct RbReplaceDecoder {}
203
+
204
+ impl RbReplaceDecoder {
205
+ pub fn new(pattern: RbPattern, content: String) -> RbResult<RbDecoder> {
206
+ Replace::new(pattern, content).map(|v| v.into()).map_err(RbError::from)
207
+ }
208
+ }
209
+
210
+ pub struct RbStripDecoder {}
211
+
212
+ impl RbStripDecoder {
213
+ pub fn new(content: char, start: usize, stop: usize) -> RbDecoder {
214
+ Strip::new(content, start, stop).into()
215
+ }
216
+ }
217
+
157
218
  pub struct RbWordPieceDecoder {}
158
219
 
159
220
  impl RbWordPieceDecoder {
@@ -199,60 +260,94 @@ impl Decoder for RbDecoderWrapper {
199
260
  }
200
261
 
201
262
  unsafe impl TypedData for RbDecoder {
202
- fn class() -> RClass {
203
- *memoize!(RClass: {
204
- let class: RClass = crate::decoders().const_get("Decoder").unwrap();
205
- class.undef_alloc_func();
206
- class
207
- })
263
+ fn class(ruby: &Ruby) -> RClass {
264
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
265
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Decoder").unwrap();
266
+ class.undef_default_alloc_func();
267
+ class
268
+ });
269
+ ruby.get_inner(&CLASS)
208
270
  }
209
271
 
210
272
  fn data_type() -> &'static DataType {
211
- memoize!(DataType: DataTypeBuilder::<RbDecoder>::new("Tokenizers::Decoders::Decoder").build())
273
+ static DATA_TYPE: DataType = data_type_builder!(RbDecoder, "Tokenizers::Decoders::Decoder").build();
274
+ &DATA_TYPE
212
275
  }
213
276
 
214
- fn class_for(value: &Self) -> RClass {
277
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
278
+ static BPE_DECODER: Lazy<RClass> = Lazy::new(|ruby| {
279
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("BPEDecoder").unwrap();
280
+ class.undef_default_alloc_func();
281
+ class
282
+ });
283
+ static BYTE_FALLBACK: Lazy<RClass> = Lazy::new(|ruby| {
284
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteFallback").unwrap();
285
+ class.undef_default_alloc_func();
286
+ class
287
+ });
288
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
289
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteLevel").unwrap();
290
+ class.undef_default_alloc_func();
291
+ class
292
+ });
293
+ static CTC: Lazy<RClass> = Lazy::new(|ruby| {
294
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("CTC").unwrap();
295
+ class.undef_default_alloc_func();
296
+ class
297
+ });
298
+ static FUSE: Lazy<RClass> = Lazy::new(|ruby| {
299
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Fuse").unwrap();
300
+ class.undef_default_alloc_func();
301
+ class
302
+ });
303
+ static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
304
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Metaspace").unwrap();
305
+ class.undef_default_alloc_func();
306
+ class
307
+ });
308
+ static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
309
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Replace").unwrap();
310
+ class.undef_default_alloc_func();
311
+ class
312
+ });
313
+ static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
314
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Strip").unwrap();
315
+ class.undef_default_alloc_func();
316
+ class
317
+ });
318
+ static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
319
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("WordPiece").unwrap();
320
+ class.undef_default_alloc_func();
321
+ class
322
+ });
215
323
  match &value.decoder {
216
324
  RbDecoderWrapper::Wrapped(inner) => match *inner.read().unwrap() {
217
- DecoderWrapper::BPE(_) => *memoize!(RClass: {
218
- let class: RClass = crate::decoders().const_get("BPEDecoder").unwrap();
219
- class.undef_alloc_func();
220
- class
221
- }),
222
- DecoderWrapper::ByteLevel(_) => *memoize!(RClass: {
223
- let class: RClass = crate::decoders().const_get("ByteLevel").unwrap();
224
- class.undef_alloc_func();
225
- class
226
- }),
227
- DecoderWrapper::CTC(_) => *memoize!(RClass: {
228
- let class: RClass = crate::decoders().const_get("CTC").unwrap();
229
- class.undef_alloc_func();
230
- class
231
- }),
232
- DecoderWrapper::Metaspace(_) => *memoize!(RClass: {
233
- let class: RClass = crate::decoders().const_get("Metaspace").unwrap();
234
- class.undef_alloc_func();
235
- class
236
- }),
237
- DecoderWrapper::WordPiece(_) => *memoize!(RClass: {
238
- let class: RClass = crate::decoders().const_get("WordPiece").unwrap();
239
- class.undef_alloc_func();
240
- class
241
- }),
325
+ DecoderWrapper::BPE(_) => ruby.get_inner(&BPE_DECODER),
326
+ DecoderWrapper::ByteFallback(_) => ruby.get_inner(&BYTE_FALLBACK),
327
+ DecoderWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
328
+ DecoderWrapper::CTC(_) => ruby.get_inner(&CTC),
329
+ DecoderWrapper::Fuse(_) => ruby.get_inner(&FUSE),
330
+ DecoderWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
331
+ DecoderWrapper::Replace(_) => ruby.get_inner(&REPLACE),
332
+ DecoderWrapper::Strip(_) => ruby.get_inner(&STRIP),
333
+ DecoderWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
242
334
  _ => todo!(),
243
335
  },
244
336
  }
245
337
  }
246
338
  }
247
339
 
248
- pub fn decoders(module: &RModule) -> RbResult<()> {
249
- let decoder = module.define_class("Decoder", Default::default())?;
340
+ pub fn init_decoders(ruby: &Ruby, module: &RModule) -> RbResult<()> {
341
+ let decoder = module.define_class("Decoder", ruby.class_object())?;
250
342
 
251
343
  let class = module.define_class("BPEDecoder", decoder)?;
252
344
  class.define_singleton_method("_new", function!(RbBPEDecoder::new, 1))?;
253
345
  class.define_method("suffix", method!(RbDecoder::bpe_suffix, 0))?;
254
346
  class.define_method("suffix=", method!(RbDecoder::bpe_set_suffix, 1))?;
255
347
 
348
+ let class = module.define_class("ByteFallback", decoder)?;
349
+ class.define_singleton_method("new", function!(RbByteFallbackDecoder::new, 0))?;
350
+
256
351
  let class = module.define_class("ByteLevel", decoder)?;
257
352
  class.define_singleton_method("new", function!(RbByteLevelDecoder::new, 0))?;
258
353
 
@@ -265,6 +360,9 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
265
360
  class.define_method("word_delimiter_token", method!(RbDecoder::ctc_word_delimiter_token, 0))?;
266
361
  class.define_method("word_delimiter_token=", method!(RbDecoder::ctc_set_word_delimiter_token, 1))?;
267
362
 
363
+ let class = module.define_class("Fuse", decoder)?;
364
+ class.define_singleton_method("new", function!(RbFuse::new, 0))?;
365
+
268
366
  let class = module.define_class("Metaspace", decoder)?;
269
367
  class.define_singleton_method("_new", function!(RbMetaspaceDecoder::new, 2))?;
270
368
  class.define_method("add_prefix_space", method!(RbDecoder::metaspace_add_prefix_space, 0))?;
@@ -272,6 +370,18 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
272
370
  class.define_method("replacement", method!(RbDecoder::metaspace_replacement, 0))?;
273
371
  class.define_method("replacement=", method!(RbDecoder::metaspace_set_replacement, 1))?;
274
372
 
373
+ let class = module.define_class("Replace", decoder)?;
374
+ class.define_singleton_method("new", function!(RbReplaceDecoder::new, 2))?;
375
+
376
+ let class = module.define_class("Strip", decoder)?;
377
+ class.define_singleton_method("_new", function!(RbStripDecoder::new, 3))?;
378
+ class.define_method("content", method!(RbDecoder::strip_content, 0))?;
379
+ class.define_method("content=", method!(RbDecoder::strip_set_content, 1))?;
380
+ class.define_method("start", method!(RbDecoder::strip_start, 0))?;
381
+ class.define_method("start=", method!(RbDecoder::strip_set_start, 1))?;
382
+ class.define_method("stop", method!(RbDecoder::strip_stop, 0))?;
383
+ class.define_method("stop=", method!(RbDecoder::strip_set_stop, 1))?;
384
+
275
385
  let class = module.define_class("WordPiece", decoder)?;
276
386
  class.define_singleton_method("_new", function!(RbWordPieceDecoder::new, 2))?;
277
387
  class.define_method("cleanup", method!(RbDecoder::word_piece_cleanup, 0))?;
@@ -1,6 +1,6 @@
1
- use magnus::{memoize, Error, ExceptionClass, Module};
1
+ use magnus::{prelude::*, value::Lazy, Error, ExceptionClass, Ruby};
2
2
 
3
- use super::module;
3
+ use super::TOKENIZERS;
4
4
 
5
5
  pub struct RbError {}
6
6
 
@@ -11,6 +11,8 @@ impl RbError {
11
11
  }
12
12
  }
13
13
 
14
+ static ERROR: Lazy<ExceptionClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Error").unwrap());
15
+
14
16
  fn error() -> ExceptionClass {
15
- *memoize!(ExceptionClass: module().const_get("Error").unwrap())
17
+ Ruby::get().unwrap().get_inner(&ERROR)
16
18
  }
@@ -1,3 +1,5 @@
1
+ #![allow(clippy::new_ret_no_self)]
2
+
1
3
  extern crate tokenizers as tk;
2
4
 
3
5
  mod decoders;
@@ -16,43 +18,29 @@ use error::RbError;
16
18
  use tokenizer::RbTokenizer;
17
19
  use utils::RbRegex;
18
20
 
19
- use magnus::{define_module, function, memoize, method, prelude::*, Error, RModule};
21
+ use magnus::{function, method, prelude::*, value::Lazy, Error, RModule, Ruby};
20
22
 
21
23
  type RbResult<T> = Result<T, Error>;
22
24
 
23
- fn module() -> RModule {
24
- *memoize!(RModule: define_module("Tokenizers").unwrap())
25
- }
25
+ static TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.class_object().const_get("Tokenizers").unwrap());
26
26
 
27
- fn decoders() -> RModule {
28
- *memoize!(RModule: module().const_get("Decoders").unwrap())
29
- }
27
+ static DECODERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Decoders").unwrap());
30
28
 
31
- fn models() -> RModule {
32
- *memoize!(RModule: module().const_get("Models").unwrap())
33
- }
29
+ static MODELS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Models").unwrap());
34
30
 
35
- fn normalizers() -> RModule {
36
- *memoize!(RModule: module().const_get("Normalizers").unwrap())
37
- }
31
+ static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Normalizers").unwrap());
38
32
 
39
- fn pre_tokenizers() -> RModule {
40
- *memoize!(RModule: module().const_get("PreTokenizers").unwrap())
41
- }
33
+ static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("PreTokenizers").unwrap());
42
34
 
43
- fn processors() -> RModule {
44
- *memoize!(RModule: module().const_get("Processors").unwrap())
45
- }
35
+ static PROCESSORS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Processors").unwrap());
46
36
 
47
- fn trainers() -> RModule {
48
- *memoize!(RModule: module().const_get("Trainers").unwrap())
49
- }
37
+ static TRAINERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Trainers").unwrap());
50
38
 
51
39
  #[magnus::init]
52
- fn init() -> RbResult<()> {
53
- let module = module();
40
+ fn init(ruby: &Ruby) -> RbResult<()> {
41
+ let module = ruby.get_inner(&TOKENIZERS);
54
42
 
55
- let class = module.define_class("Tokenizer", Default::default())?;
43
+ let class = module.define_class("Tokenizer", ruby.class_object())?;
56
44
  class.define_singleton_method("new", function!(RbTokenizer::from_model, 1))?;
57
45
  class.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
58
46
  class.define_method(
@@ -86,7 +74,7 @@ fn init() -> RbResult<()> {
86
74
  class.define_method("_vocab_size", method!(RbTokenizer::vocab_size, 1))?;
87
75
  class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
88
76
 
89
- let class = module.define_class("Encoding", Default::default())?;
77
+ let class = module.define_class("Encoding", ruby.class_object())?;
90
78
  class.define_method("n_sequences", method!(RbEncoding::n_sequences, 0))?;
91
79
  class.define_method("ids", method!(RbEncoding::ids, 0))?;
92
80
  class.define_method("tokens", method!(RbEncoding::tokens, 0))?;
@@ -111,7 +99,7 @@ fn init() -> RbResult<()> {
111
99
  class.define_method("_char_to_token", method!(RbEncoding::char_to_token, 2))?;
112
100
  class.define_method("_char_to_word", method!(RbEncoding::char_to_word, 2))?;
113
101
 
114
- let class = module.define_class("Regex", Default::default())?;
102
+ let class = module.define_class("Regex", ruby.class_object())?;
115
103
  class.define_singleton_method("new", function!(RbRegex::new, 1))?;
116
104
 
117
105
  let models = module.define_module("Models")?;
@@ -121,12 +109,12 @@ fn init() -> RbResult<()> {
121
109
  let normalizers = module.define_module("Normalizers")?;
122
110
  let trainers = module.define_module("Trainers")?;
123
111
 
124
- models::models(&models)?;
125
- pre_tokenizers::pre_tokenizers(&pre_tokenizers)?;
126
- decoders::decoders(&decoders)?;
127
- processors::processors(&processors)?;
128
- normalizers::normalizers(&normalizers)?;
129
- trainers::trainers(&trainers)?;
112
+ models::init_models(ruby, &models)?;
113
+ pre_tokenizers::init_pre_tokenizers(ruby, &pre_tokenizers)?;
114
+ decoders::init_decoders(ruby, &decoders)?;
115
+ processors::init_processors(ruby, &processors)?;
116
+ normalizers::init_normalizers(ruby, &normalizers)?;
117
+ trainers::init_trainers(ruby, &trainers)?;
130
118
 
131
119
  Ok(())
132
120
  }
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
3
3
  use std::sync::{Arc, RwLock};
4
4
 
5
5
  use crate::trainers::RbTrainer;
6
- use magnus::typed_data::DataTypeBuilder;
6
+ use magnus::prelude::*;
7
7
  use magnus::{
8
- exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
- RClass, RHash, RModule, Symbol, TypedData, Value,
8
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
10
10
  };
11
11
  use serde::{Deserialize, Serialize};
12
12
  use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
16
16
  use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
17
17
  use tk::{Model, Token};
18
18
 
19
- use super::{RbError, RbResult};
19
+ use super::{MODELS, RbError, RbResult};
20
20
 
21
21
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
22
22
  pub struct RbModel {
@@ -73,32 +73,37 @@ impl RbBPE {
73
73
  fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
74
74
  let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
75
75
  if !value.is_nil() {
76
- builder = builder.cache_capacity(value.try_convert()?);
76
+ builder = builder.cache_capacity(TryConvert::try_convert(value)?);
77
77
  }
78
78
 
79
79
  let value: Value = kwargs.delete(Symbol::new("dropout"))?;
80
80
  if !value.is_nil() {
81
- builder = builder.dropout(value.try_convert()?);
81
+ builder = builder.dropout(TryConvert::try_convert(value)?);
82
82
  }
83
83
 
84
84
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
85
85
  if !value.is_nil() {
86
- builder = builder.unk_token(value.try_convert()?);
86
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
87
87
  }
88
88
 
89
89
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
90
90
  if !value.is_nil() {
91
- builder = builder.continuing_subword_prefix(value.try_convert()?);
91
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
92
92
  }
93
93
 
94
94
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
95
95
  if !value.is_nil() {
96
- builder = builder.end_of_word_suffix(value.try_convert()?);
96
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
97
97
  }
98
98
 
99
99
  let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
100
100
  if !value.is_nil() {
101
- builder = builder.fuse_unk(value.try_convert()?);
101
+ builder = builder.fuse_unk(TryConvert::try_convert(value)?);
102
+ }
103
+
104
+ let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
105
+ if !value.is_nil() {
106
+ builder = builder.byte_fallback(TryConvert::try_convert(value)?);
102
107
  }
103
108
 
104
109
  if !kwargs.is_empty() {
@@ -169,6 +174,14 @@ impl RbModel {
169
174
  setter!(self, BPE, fuse_unk, fuse_unk);
170
175
  }
171
176
 
177
+ pub fn bpe_byte_fallback(&self) -> bool {
178
+ getter!(self, BPE, byte_fallback)
179
+ }
180
+
181
+ pub fn bpe_set_byte_fallback(&self, byte_fallback: bool) {
182
+ setter!(self, BPE, byte_fallback, byte_fallback);
183
+ }
184
+
172
185
  pub fn bpe_continuing_subword_prefix(&self) -> Option<String> {
173
186
  getter!(self, BPE, continuing_subword_prefix.clone())
174
187
  }
@@ -221,13 +234,13 @@ impl RbModel {
221
234
  pub struct RbUnigram {}
222
235
 
223
236
  impl RbUnigram {
224
- fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
225
- match (vocab, unk_id) {
226
- (Some(vocab), unk_id) => {
227
- let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
237
+ fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
238
+ match (vocab, unk_id, byte_fallback) {
239
+ (Some(vocab), unk_id, byte_fallback) => {
240
+ let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
228
241
  Ok(model.into())
229
242
  }
230
- (None, None) => Ok(Unigram::default().into()),
243
+ (None, None, _) => Ok(Unigram::default().into()),
231
244
  _ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
232
245
  }
233
246
  }
@@ -264,17 +277,17 @@ impl RbWordPiece {
264
277
  fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
265
278
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
266
279
  if !value.is_nil() {
267
- builder = builder.unk_token(value.try_convert()?);
280
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
268
281
  }
269
282
 
270
283
  let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
271
284
  if !value.is_nil() {
272
- builder = builder.max_input_chars_per_word(value.try_convert()?);
285
+ builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
273
286
  }
274
287
 
275
288
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
276
289
  if !value.is_nil() {
277
- builder = builder.continuing_subword_prefix(value.try_convert()?);
290
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
278
291
  }
279
292
 
280
293
  if !kwargs.is_empty() {
@@ -301,46 +314,52 @@ impl RbWordPiece {
301
314
  }
302
315
 
303
316
  unsafe impl TypedData for RbModel {
304
- fn class() -> RClass {
305
- *memoize!(RClass: {
306
- let class: RClass = crate::models().const_get("Model").unwrap();
307
- class.undef_alloc_func();
317
+ fn class(ruby: &Ruby) -> RClass {
318
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
319
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
320
+ class.undef_default_alloc_func();
308
321
  class
309
- })
322
+ });
323
+ ruby.get_inner(&CLASS)
310
324
  }
311
325
 
312
326
  fn data_type() -> &'static DataType {
313
- memoize!(DataType: DataTypeBuilder::<RbModel>::new("Tokenizers::Models::Model").build())
314
- }
315
-
316
- fn class_for(value: &Self) -> RClass {
327
+ static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
328
+ &DATA_TYPE
329
+ }
330
+
331
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
332
+ static BPE: Lazy<RClass> = Lazy::new(|ruby| {
333
+ let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
334
+ class.undef_default_alloc_func();
335
+ class
336
+ });
337
+ static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
338
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
339
+ class.undef_default_alloc_func();
340
+ class
341
+ });
342
+ static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
343
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
344
+ class.undef_default_alloc_func();
345
+ class
346
+ });
347
+ static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
348
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
349
+ class.undef_default_alloc_func();
350
+ class
351
+ });
317
352
  match *value.model.read().unwrap() {
318
- ModelWrapper::BPE(_) => *memoize!(RClass: {
319
- let class: RClass = crate::models().const_get("BPE").unwrap();
320
- class.undef_alloc_func();
321
- class
322
- }),
323
- ModelWrapper::Unigram(_) => *memoize!(RClass: {
324
- let class: RClass = crate::models().const_get("Unigram").unwrap();
325
- class.undef_alloc_func();
326
- class
327
- }),
328
- ModelWrapper::WordLevel(_) => *memoize!(RClass: {
329
- let class: RClass = crate::models().const_get("WordLevel").unwrap();
330
- class.undef_alloc_func();
331
- class
332
- }),
333
- ModelWrapper::WordPiece(_) => *memoize!(RClass: {
334
- let class: RClass = crate::models().const_get("WordPiece").unwrap();
335
- class.undef_alloc_func();
336
- class
337
- }),
353
+ ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
354
+ ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
355
+ ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
356
+ ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
338
357
  }
339
358
  }
340
359
  }
341
360
 
342
- pub fn models(module: &RModule) -> RbResult<()> {
343
- let model = module.define_class("Model", Default::default())?;
361
+ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
362
+ let model = module.define_class("Model", ruby.class_object())?;
344
363
 
345
364
  let class = module.define_class("BPE", model)?;
346
365
  class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
@@ -355,9 +374,11 @@ pub fn models(module: &RModule) -> RbResult<()> {
355
374
  class.define_method("end_of_word_suffix=", method!(RbModel::bpe_set_end_of_word_suffix, 1))?;
356
375
  class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
357
376
  class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
377
+ class.define_method("byte_fallback", method!(RbModel::bpe_byte_fallback, 0))?;
378
+ class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
358
379
 
359
380
  let class = module.define_class("Unigram", model)?;
360
- class.define_singleton_method("_new", function!(RbUnigram::new, 2))?;
381
+ class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
361
382
 
362
383
  let class = module.define_class("WordLevel", model)?;
363
384
  class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;