tokenizers 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,25 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
3
+ use magnus::value::Lazy;
4
4
  use magnus::{
5
- function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
- TypedData,
5
+ data_type_builder, function, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
+ Ruby, TypedData,
7
7
  };
8
8
  use serde::{Deserialize, Serialize};
9
9
  use tk::decoders::bpe::BPEDecoder;
10
+ use tk::decoders::byte_fallback::ByteFallback;
10
11
  use tk::decoders::byte_level::ByteLevel;
11
12
  use tk::decoders::ctc::CTC;
13
+ use tk::decoders::fuse::Fuse;
12
14
  use tk::decoders::metaspace::Metaspace;
15
+ use tk::decoders::strip::Strip;
13
16
  use tk::decoders::wordpiece::WordPiece;
14
17
  use tk::decoders::DecoderWrapper;
15
18
  use tk::Decoder;
19
+ use tk::normalizers::replace::Replace;
16
20
 
17
- use super::RbResult;
21
+ use super::utils::*;
22
+ use super::{DECODERS, RbError, RbResult};
18
23
 
19
24
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
20
25
  pub struct RbDecoder {
@@ -89,6 +94,30 @@ impl RbDecoder {
89
94
  setter!(self, CTC, word_delimiter_token, word_delimiter_token);
90
95
  }
91
96
 
97
+ fn strip_content(&self) -> char {
98
+ getter!(self, Strip, content)
99
+ }
100
+
101
+ fn strip_set_content(&self, content: char) {
102
+ setter!(self, Strip, content, content)
103
+ }
104
+
105
+ fn strip_start(&self) -> usize {
106
+ getter!(self, Strip, start)
107
+ }
108
+
109
+ fn strip_set_start(&self, start: usize) {
110
+ setter!(self, Strip, start, start)
111
+ }
112
+
113
+ fn strip_stop(&self) -> usize {
114
+ getter!(self, Strip, stop)
115
+ }
116
+
117
+ fn strip_set_stop(&self, stop: usize) {
118
+ setter!(self, Strip, stop, stop)
119
+ }
120
+
92
121
  pub fn metaspace_replacement(&self) -> char {
93
122
  getter!(self, Metaspace, get_replacement().clone())
94
123
  }
@@ -130,6 +159,14 @@ impl RbBPEDecoder {
130
159
  }
131
160
  }
132
161
 
162
+ pub struct RbByteFallbackDecoder {}
163
+
164
+ impl RbByteFallbackDecoder {
165
+ pub fn new() -> RbDecoder {
166
+ ByteFallback::default().into()
167
+ }
168
+ }
169
+
133
170
  pub struct RbByteLevelDecoder {}
134
171
 
135
172
  impl RbByteLevelDecoder {
@@ -146,6 +183,14 @@ impl RbCTC {
146
183
  }
147
184
  }
148
185
 
186
+ pub struct RbFuse {}
187
+
188
+ impl RbFuse {
189
+ pub fn new() -> RbDecoder {
190
+ Fuse::default().into()
191
+ }
192
+ }
193
+
149
194
  pub struct RbMetaspaceDecoder {}
150
195
 
151
196
  impl RbMetaspaceDecoder {
@@ -154,6 +199,22 @@ impl RbMetaspaceDecoder {
154
199
  }
155
200
  }
156
201
 
202
+ pub struct RbReplaceDecoder {}
203
+
204
+ impl RbReplaceDecoder {
205
+ pub fn new(pattern: RbPattern, content: String) -> RbResult<RbDecoder> {
206
+ Replace::new(pattern, content).map(|v| v.into()).map_err(RbError::from)
207
+ }
208
+ }
209
+
210
+ pub struct RbStripDecoder {}
211
+
212
+ impl RbStripDecoder {
213
+ pub fn new(content: char, start: usize, stop: usize) -> RbDecoder {
214
+ Strip::new(content, start, stop).into()
215
+ }
216
+ }
217
+
157
218
  pub struct RbWordPieceDecoder {}
158
219
 
159
220
  impl RbWordPieceDecoder {
@@ -199,60 +260,94 @@ impl Decoder for RbDecoderWrapper {
199
260
  }
200
261
 
201
262
  unsafe impl TypedData for RbDecoder {
202
- fn class() -> RClass {
203
- *memoize!(RClass: {
204
- let class: RClass = crate::decoders().const_get("Decoder").unwrap();
205
- class.undef_alloc_func();
206
- class
207
- })
263
+ fn class(ruby: &Ruby) -> RClass {
264
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
265
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Decoder").unwrap();
266
+ class.undef_default_alloc_func();
267
+ class
268
+ });
269
+ ruby.get_inner(&CLASS)
208
270
  }
209
271
 
210
272
  fn data_type() -> &'static DataType {
211
- memoize!(DataType: DataTypeBuilder::<RbDecoder>::new("Tokenizers::Decoders::Decoder").build())
273
+ static DATA_TYPE: DataType = data_type_builder!(RbDecoder, "Tokenizers::Decoders::Decoder").build();
274
+ &DATA_TYPE
212
275
  }
213
276
 
214
- fn class_for(value: &Self) -> RClass {
277
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
278
+ static BPE_DECODER: Lazy<RClass> = Lazy::new(|ruby| {
279
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("BPEDecoder").unwrap();
280
+ class.undef_default_alloc_func();
281
+ class
282
+ });
283
+ static BYTE_FALLBACK: Lazy<RClass> = Lazy::new(|ruby| {
284
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteFallback").unwrap();
285
+ class.undef_default_alloc_func();
286
+ class
287
+ });
288
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
289
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteLevel").unwrap();
290
+ class.undef_default_alloc_func();
291
+ class
292
+ });
293
+ static CTC: Lazy<RClass> = Lazy::new(|ruby| {
294
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("CTC").unwrap();
295
+ class.undef_default_alloc_func();
296
+ class
297
+ });
298
+ static FUSE: Lazy<RClass> = Lazy::new(|ruby| {
299
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Fuse").unwrap();
300
+ class.undef_default_alloc_func();
301
+ class
302
+ });
303
+ static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
304
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Metaspace").unwrap();
305
+ class.undef_default_alloc_func();
306
+ class
307
+ });
308
+ static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
309
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Replace").unwrap();
310
+ class.undef_default_alloc_func();
311
+ class
312
+ });
313
+ static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
314
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("Strip").unwrap();
315
+ class.undef_default_alloc_func();
316
+ class
317
+ });
318
+ static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
319
+ let class: RClass = ruby.get_inner(&DECODERS).const_get("WordPiece").unwrap();
320
+ class.undef_default_alloc_func();
321
+ class
322
+ });
215
323
  match &value.decoder {
216
324
  RbDecoderWrapper::Wrapped(inner) => match *inner.read().unwrap() {
217
- DecoderWrapper::BPE(_) => *memoize!(RClass: {
218
- let class: RClass = crate::decoders().const_get("BPEDecoder").unwrap();
219
- class.undef_alloc_func();
220
- class
221
- }),
222
- DecoderWrapper::ByteLevel(_) => *memoize!(RClass: {
223
- let class: RClass = crate::decoders().const_get("ByteLevel").unwrap();
224
- class.undef_alloc_func();
225
- class
226
- }),
227
- DecoderWrapper::CTC(_) => *memoize!(RClass: {
228
- let class: RClass = crate::decoders().const_get("CTC").unwrap();
229
- class.undef_alloc_func();
230
- class
231
- }),
232
- DecoderWrapper::Metaspace(_) => *memoize!(RClass: {
233
- let class: RClass = crate::decoders().const_get("Metaspace").unwrap();
234
- class.undef_alloc_func();
235
- class
236
- }),
237
- DecoderWrapper::WordPiece(_) => *memoize!(RClass: {
238
- let class: RClass = crate::decoders().const_get("WordPiece").unwrap();
239
- class.undef_alloc_func();
240
- class
241
- }),
325
+ DecoderWrapper::BPE(_) => ruby.get_inner(&BPE_DECODER),
326
+ DecoderWrapper::ByteFallback(_) => ruby.get_inner(&BYTE_FALLBACK),
327
+ DecoderWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
328
+ DecoderWrapper::CTC(_) => ruby.get_inner(&CTC),
329
+ DecoderWrapper::Fuse(_) => ruby.get_inner(&FUSE),
330
+ DecoderWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
331
+ DecoderWrapper::Replace(_) => ruby.get_inner(&REPLACE),
332
+ DecoderWrapper::Strip(_) => ruby.get_inner(&STRIP),
333
+ DecoderWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
242
334
  _ => todo!(),
243
335
  },
244
336
  }
245
337
  }
246
338
  }
247
339
 
248
- pub fn decoders(module: &RModule) -> RbResult<()> {
249
- let decoder = module.define_class("Decoder", Default::default())?;
340
+ pub fn init_decoders(ruby: &Ruby, module: &RModule) -> RbResult<()> {
341
+ let decoder = module.define_class("Decoder", ruby.class_object())?;
250
342
 
251
343
  let class = module.define_class("BPEDecoder", decoder)?;
252
344
  class.define_singleton_method("_new", function!(RbBPEDecoder::new, 1))?;
253
345
  class.define_method("suffix", method!(RbDecoder::bpe_suffix, 0))?;
254
346
  class.define_method("suffix=", method!(RbDecoder::bpe_set_suffix, 1))?;
255
347
 
348
+ let class = module.define_class("ByteFallback", decoder)?;
349
+ class.define_singleton_method("new", function!(RbByteFallbackDecoder::new, 0))?;
350
+
256
351
  let class = module.define_class("ByteLevel", decoder)?;
257
352
  class.define_singleton_method("new", function!(RbByteLevelDecoder::new, 0))?;
258
353
 
@@ -265,6 +360,9 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
265
360
  class.define_method("word_delimiter_token", method!(RbDecoder::ctc_word_delimiter_token, 0))?;
266
361
  class.define_method("word_delimiter_token=", method!(RbDecoder::ctc_set_word_delimiter_token, 1))?;
267
362
 
363
+ let class = module.define_class("Fuse", decoder)?;
364
+ class.define_singleton_method("new", function!(RbFuse::new, 0))?;
365
+
268
366
  let class = module.define_class("Metaspace", decoder)?;
269
367
  class.define_singleton_method("_new", function!(RbMetaspaceDecoder::new, 2))?;
270
368
  class.define_method("add_prefix_space", method!(RbDecoder::metaspace_add_prefix_space, 0))?;
@@ -272,6 +370,18 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
272
370
  class.define_method("replacement", method!(RbDecoder::metaspace_replacement, 0))?;
273
371
  class.define_method("replacement=", method!(RbDecoder::metaspace_set_replacement, 1))?;
274
372
 
373
+ let class = module.define_class("Replace", decoder)?;
374
+ class.define_singleton_method("new", function!(RbReplaceDecoder::new, 2))?;
375
+
376
+ let class = module.define_class("Strip", decoder)?;
377
+ class.define_singleton_method("_new", function!(RbStripDecoder::new, 3))?;
378
+ class.define_method("content", method!(RbDecoder::strip_content, 0))?;
379
+ class.define_method("content=", method!(RbDecoder::strip_set_content, 1))?;
380
+ class.define_method("start", method!(RbDecoder::strip_start, 0))?;
381
+ class.define_method("start=", method!(RbDecoder::strip_set_start, 1))?;
382
+ class.define_method("stop", method!(RbDecoder::strip_stop, 0))?;
383
+ class.define_method("stop=", method!(RbDecoder::strip_set_stop, 1))?;
384
+
275
385
  let class = module.define_class("WordPiece", decoder)?;
276
386
  class.define_singleton_method("_new", function!(RbWordPieceDecoder::new, 2))?;
277
387
  class.define_method("cleanup", method!(RbDecoder::word_piece_cleanup, 0))?;
@@ -1,6 +1,6 @@
1
- use magnus::{memoize, Error, ExceptionClass, Module};
1
+ use magnus::{prelude::*, value::Lazy, Error, ExceptionClass, Ruby};
2
2
 
3
- use super::module;
3
+ use super::TOKENIZERS;
4
4
 
5
5
  pub struct RbError {}
6
6
 
@@ -11,6 +11,8 @@ impl RbError {
11
11
  }
12
12
  }
13
13
 
14
+ static ERROR: Lazy<ExceptionClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Error").unwrap());
15
+
14
16
  fn error() -> ExceptionClass {
15
- *memoize!(ExceptionClass: module().const_get("Error").unwrap())
17
+ Ruby::get().unwrap().get_inner(&ERROR)
16
18
  }
@@ -1,3 +1,5 @@
1
+ #![allow(clippy::new_ret_no_self)]
2
+
1
3
  extern crate tokenizers as tk;
2
4
 
3
5
  mod decoders;
@@ -16,43 +18,29 @@ use error::RbError;
16
18
  use tokenizer::RbTokenizer;
17
19
  use utils::RbRegex;
18
20
 
19
- use magnus::{define_module, function, memoize, method, prelude::*, Error, RModule};
21
+ use magnus::{function, method, prelude::*, value::Lazy, Error, RModule, Ruby};
20
22
 
21
23
  type RbResult<T> = Result<T, Error>;
22
24
 
23
- fn module() -> RModule {
24
- *memoize!(RModule: define_module("Tokenizers").unwrap())
25
- }
25
+ static TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.class_object().const_get("Tokenizers").unwrap());
26
26
 
27
- fn decoders() -> RModule {
28
- *memoize!(RModule: module().const_get("Decoders").unwrap())
29
- }
27
+ static DECODERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Decoders").unwrap());
30
28
 
31
- fn models() -> RModule {
32
- *memoize!(RModule: module().const_get("Models").unwrap())
33
- }
29
+ static MODELS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Models").unwrap());
34
30
 
35
- fn normalizers() -> RModule {
36
- *memoize!(RModule: module().const_get("Normalizers").unwrap())
37
- }
31
+ static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Normalizers").unwrap());
38
32
 
39
- fn pre_tokenizers() -> RModule {
40
- *memoize!(RModule: module().const_get("PreTokenizers").unwrap())
41
- }
33
+ static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("PreTokenizers").unwrap());
42
34
 
43
- fn processors() -> RModule {
44
- *memoize!(RModule: module().const_get("Processors").unwrap())
45
- }
35
+ static PROCESSORS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Processors").unwrap());
46
36
 
47
- fn trainers() -> RModule {
48
- *memoize!(RModule: module().const_get("Trainers").unwrap())
49
- }
37
+ static TRAINERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Trainers").unwrap());
50
38
 
51
39
  #[magnus::init]
52
- fn init() -> RbResult<()> {
53
- let module = module();
40
+ fn init(ruby: &Ruby) -> RbResult<()> {
41
+ let module = ruby.get_inner(&TOKENIZERS);
54
42
 
55
- let class = module.define_class("Tokenizer", Default::default())?;
43
+ let class = module.define_class("Tokenizer", ruby.class_object())?;
56
44
  class.define_singleton_method("new", function!(RbTokenizer::from_model, 1))?;
57
45
  class.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
58
46
  class.define_method(
@@ -86,7 +74,7 @@ fn init() -> RbResult<()> {
86
74
  class.define_method("_vocab_size", method!(RbTokenizer::vocab_size, 1))?;
87
75
  class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
88
76
 
89
- let class = module.define_class("Encoding", Default::default())?;
77
+ let class = module.define_class("Encoding", ruby.class_object())?;
90
78
  class.define_method("n_sequences", method!(RbEncoding::n_sequences, 0))?;
91
79
  class.define_method("ids", method!(RbEncoding::ids, 0))?;
92
80
  class.define_method("tokens", method!(RbEncoding::tokens, 0))?;
@@ -111,7 +99,7 @@ fn init() -> RbResult<()> {
111
99
  class.define_method("_char_to_token", method!(RbEncoding::char_to_token, 2))?;
112
100
  class.define_method("_char_to_word", method!(RbEncoding::char_to_word, 2))?;
113
101
 
114
- let class = module.define_class("Regex", Default::default())?;
102
+ let class = module.define_class("Regex", ruby.class_object())?;
115
103
  class.define_singleton_method("new", function!(RbRegex::new, 1))?;
116
104
 
117
105
  let models = module.define_module("Models")?;
@@ -121,12 +109,12 @@ fn init() -> RbResult<()> {
121
109
  let normalizers = module.define_module("Normalizers")?;
122
110
  let trainers = module.define_module("Trainers")?;
123
111
 
124
- models::models(&models)?;
125
- pre_tokenizers::pre_tokenizers(&pre_tokenizers)?;
126
- decoders::decoders(&decoders)?;
127
- processors::processors(&processors)?;
128
- normalizers::normalizers(&normalizers)?;
129
- trainers::trainers(&trainers)?;
112
+ models::init_models(ruby, &models)?;
113
+ pre_tokenizers::init_pre_tokenizers(ruby, &pre_tokenizers)?;
114
+ decoders::init_decoders(ruby, &decoders)?;
115
+ processors::init_processors(ruby, &processors)?;
116
+ normalizers::init_normalizers(ruby, &normalizers)?;
117
+ trainers::init_trainers(ruby, &trainers)?;
130
118
 
131
119
  Ok(())
132
120
  }
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
3
3
  use std::sync::{Arc, RwLock};
4
4
 
5
5
  use crate::trainers::RbTrainer;
6
- use magnus::typed_data::DataTypeBuilder;
6
+ use magnus::prelude::*;
7
7
  use magnus::{
8
- exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
- RClass, RHash, RModule, Symbol, TypedData, Value,
8
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
10
10
  };
11
11
  use serde::{Deserialize, Serialize};
12
12
  use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
16
16
  use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
17
17
  use tk::{Model, Token};
18
18
 
19
- use super::{RbError, RbResult};
19
+ use super::{MODELS, RbError, RbResult};
20
20
 
21
21
  #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
22
22
  pub struct RbModel {
@@ -73,32 +73,37 @@ impl RbBPE {
73
73
  fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
74
74
  let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
75
75
  if !value.is_nil() {
76
- builder = builder.cache_capacity(value.try_convert()?);
76
+ builder = builder.cache_capacity(TryConvert::try_convert(value)?);
77
77
  }
78
78
 
79
79
  let value: Value = kwargs.delete(Symbol::new("dropout"))?;
80
80
  if !value.is_nil() {
81
- builder = builder.dropout(value.try_convert()?);
81
+ builder = builder.dropout(TryConvert::try_convert(value)?);
82
82
  }
83
83
 
84
84
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
85
85
  if !value.is_nil() {
86
- builder = builder.unk_token(value.try_convert()?);
86
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
87
87
  }
88
88
 
89
89
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
90
90
  if !value.is_nil() {
91
- builder = builder.continuing_subword_prefix(value.try_convert()?);
91
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
92
92
  }
93
93
 
94
94
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
95
95
  if !value.is_nil() {
96
- builder = builder.end_of_word_suffix(value.try_convert()?);
96
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
97
97
  }
98
98
 
99
99
  let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
100
100
  if !value.is_nil() {
101
- builder = builder.fuse_unk(value.try_convert()?);
101
+ builder = builder.fuse_unk(TryConvert::try_convert(value)?);
102
+ }
103
+
104
+ let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
105
+ if !value.is_nil() {
106
+ builder = builder.byte_fallback(TryConvert::try_convert(value)?);
102
107
  }
103
108
 
104
109
  if !kwargs.is_empty() {
@@ -169,6 +174,14 @@ impl RbModel {
169
174
  setter!(self, BPE, fuse_unk, fuse_unk);
170
175
  }
171
176
 
177
+ pub fn bpe_byte_fallback(&self) -> bool {
178
+ getter!(self, BPE, byte_fallback)
179
+ }
180
+
181
+ pub fn bpe_set_byte_fallback(&self, byte_fallback: bool) {
182
+ setter!(self, BPE, byte_fallback, byte_fallback);
183
+ }
184
+
172
185
  pub fn bpe_continuing_subword_prefix(&self) -> Option<String> {
173
186
  getter!(self, BPE, continuing_subword_prefix.clone())
174
187
  }
@@ -221,13 +234,13 @@ impl RbModel {
221
234
  pub struct RbUnigram {}
222
235
 
223
236
  impl RbUnigram {
224
- fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
225
- match (vocab, unk_id) {
226
- (Some(vocab), unk_id) => {
227
- let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
237
+ fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
238
+ match (vocab, unk_id, byte_fallback) {
239
+ (Some(vocab), unk_id, byte_fallback) => {
240
+ let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
228
241
  Ok(model.into())
229
242
  }
230
- (None, None) => Ok(Unigram::default().into()),
243
+ (None, None, _) => Ok(Unigram::default().into()),
231
244
  _ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
232
245
  }
233
246
  }
@@ -264,17 +277,17 @@ impl RbWordPiece {
264
277
  fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
265
278
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
266
279
  if !value.is_nil() {
267
- builder = builder.unk_token(value.try_convert()?);
280
+ builder = builder.unk_token(TryConvert::try_convert(value)?);
268
281
  }
269
282
 
270
283
  let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
271
284
  if !value.is_nil() {
272
- builder = builder.max_input_chars_per_word(value.try_convert()?);
285
+ builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
273
286
  }
274
287
 
275
288
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
276
289
  if !value.is_nil() {
277
- builder = builder.continuing_subword_prefix(value.try_convert()?);
290
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
278
291
  }
279
292
 
280
293
  if !kwargs.is_empty() {
@@ -301,46 +314,52 @@ impl RbWordPiece {
301
314
  }
302
315
 
303
316
  unsafe impl TypedData for RbModel {
304
- fn class() -> RClass {
305
- *memoize!(RClass: {
306
- let class: RClass = crate::models().const_get("Model").unwrap();
307
- class.undef_alloc_func();
317
+ fn class(ruby: &Ruby) -> RClass {
318
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
319
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
320
+ class.undef_default_alloc_func();
308
321
  class
309
- })
322
+ });
323
+ ruby.get_inner(&CLASS)
310
324
  }
311
325
 
312
326
  fn data_type() -> &'static DataType {
313
- memoize!(DataType: DataTypeBuilder::<RbModel>::new("Tokenizers::Models::Model").build())
314
- }
315
-
316
- fn class_for(value: &Self) -> RClass {
327
+ static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
328
+ &DATA_TYPE
329
+ }
330
+
331
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
332
+ static BPE: Lazy<RClass> = Lazy::new(|ruby| {
333
+ let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
334
+ class.undef_default_alloc_func();
335
+ class
336
+ });
337
+ static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
338
+ let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
339
+ class.undef_default_alloc_func();
340
+ class
341
+ });
342
+ static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
343
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
344
+ class.undef_default_alloc_func();
345
+ class
346
+ });
347
+ static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
348
+ let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
349
+ class.undef_default_alloc_func();
350
+ class
351
+ });
317
352
  match *value.model.read().unwrap() {
318
- ModelWrapper::BPE(_) => *memoize!(RClass: {
319
- let class: RClass = crate::models().const_get("BPE").unwrap();
320
- class.undef_alloc_func();
321
- class
322
- }),
323
- ModelWrapper::Unigram(_) => *memoize!(RClass: {
324
- let class: RClass = crate::models().const_get("Unigram").unwrap();
325
- class.undef_alloc_func();
326
- class
327
- }),
328
- ModelWrapper::WordLevel(_) => *memoize!(RClass: {
329
- let class: RClass = crate::models().const_get("WordLevel").unwrap();
330
- class.undef_alloc_func();
331
- class
332
- }),
333
- ModelWrapper::WordPiece(_) => *memoize!(RClass: {
334
- let class: RClass = crate::models().const_get("WordPiece").unwrap();
335
- class.undef_alloc_func();
336
- class
337
- }),
353
+ ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
354
+ ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
355
+ ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
356
+ ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
338
357
  }
339
358
  }
340
359
  }
341
360
 
342
- pub fn models(module: &RModule) -> RbResult<()> {
343
- let model = module.define_class("Model", Default::default())?;
361
+ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
362
+ let model = module.define_class("Model", ruby.class_object())?;
344
363
 
345
364
  let class = module.define_class("BPE", model)?;
346
365
  class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
@@ -355,9 +374,11 @@ pub fn models(module: &RModule) -> RbResult<()> {
355
374
  class.define_method("end_of_word_suffix=", method!(RbModel::bpe_set_end_of_word_suffix, 1))?;
356
375
  class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
357
376
  class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
377
+ class.define_method("byte_fallback", method!(RbModel::bpe_byte_fallback, 0))?;
378
+ class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
358
379
 
359
380
  let class = module.define_class("Unigram", model)?;
360
- class.define_singleton_method("_new", function!(RbUnigram::new, 2))?;
381
+ class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
361
382
 
362
383
  let class = module.define_class("WordLevel", model)?;
363
384
  class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;