tokenizers 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/Cargo.lock +33 -74
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +4 -2
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +78 -3
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +88 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +448 -20
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/encoding.rb +19 -0
  25. data/lib/tokenizers/from_pretrained.rb +1 -1
  26. data/lib/tokenizers/models/bpe.rb +9 -0
  27. data/lib/tokenizers/models/unigram.rb +9 -0
  28. data/lib/tokenizers/models/word_level.rb +13 -0
  29. data/lib/tokenizers/models/word_piece.rb +9 -0
  30. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  31. data/lib/tokenizers/normalizers/strip.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  36. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  37. data/lib/tokenizers/processors/byte_level.rb +9 -0
  38. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  39. data/lib/tokenizers/processors/template_processing.rb +9 -0
  40. data/lib/tokenizers/tokenizer.rb +45 -0
  41. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  42. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  43. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  44. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  45. data/lib/tokenizers/version.rb +1 -1
  46. data/lib/tokenizers.rb +49 -7
  47. metadata +32 -3
@@ -1,4 +1,4 @@
1
- use magnus::{exception, memoize, Error, ExceptionClass, Module};
1
+ use magnus::{memoize, Error, ExceptionClass, Module};
2
2
 
3
3
  use super::module;
4
4
 
@@ -12,5 +12,5 @@ impl RbError {
12
12
  }
13
13
 
14
14
  fn error() -> ExceptionClass {
15
- *memoize!(ExceptionClass: module().define_error("Error", exception::standard_error()).unwrap())
15
+ *memoize!(ExceptionClass: module().const_get("Error").unwrap())
16
16
  }
@@ -6,15 +6,15 @@ mod error;
6
6
  mod models;
7
7
  mod normalizers;
8
8
  mod pre_tokenizers;
9
+ mod processors;
9
10
  mod tokenizer;
11
+ mod trainers;
12
+ mod utils;
10
13
 
11
- use decoders::RbBPEDecoder;
12
14
  use encoding::RbEncoding;
13
15
  use error::RbError;
14
- use models::RbBPE;
15
- use normalizers::RbBertNormalizer;
16
- use pre_tokenizers::RbBertPreTokenizer;
17
16
  use tokenizer::RbTokenizer;
17
+ use utils::RbRegex;
18
18
 
19
19
  use magnus::{define_module, function, memoize, method, prelude::*, Error, RModule};
20
20
 
@@ -24,38 +24,109 @@ fn module() -> RModule {
24
24
  *memoize!(RModule: define_module("Tokenizers").unwrap())
25
25
  }
26
26
 
27
+ fn decoders() -> RModule {
28
+ *memoize!(RModule: module().const_get("Decoders").unwrap())
29
+ }
30
+
31
+ fn models() -> RModule {
32
+ *memoize!(RModule: module().const_get("Models").unwrap())
33
+ }
34
+
35
+ fn normalizers() -> RModule {
36
+ *memoize!(RModule: module().const_get("Normalizers").unwrap())
37
+ }
38
+
39
+ fn pre_tokenizers() -> RModule {
40
+ *memoize!(RModule: module().const_get("PreTokenizers").unwrap())
41
+ }
42
+
43
+ fn processors() -> RModule {
44
+ *memoize!(RModule: module().const_get("Processors").unwrap())
45
+ }
46
+
47
+ fn trainers() -> RModule {
48
+ *memoize!(RModule: module().const_get("Trainers").unwrap())
49
+ }
50
+
27
51
  #[magnus::init]
28
52
  fn init() -> RbResult<()> {
29
53
  let module = module();
30
- module.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
31
-
32
- let class = module.define_class("BPE", Default::default())?;
33
- class.define_singleton_method("new", function!(RbBPE::new, 2))?;
34
54
 
35
55
  let class = module.define_class("Tokenizer", Default::default())?;
36
- class.define_singleton_method("new", function!(RbTokenizer::new, 1))?;
56
+ class.define_singleton_method("new", function!(RbTokenizer::from_model, 1))?;
57
+ class.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
37
58
  class.define_method(
38
59
  "add_special_tokens",
39
60
  method!(RbTokenizer::add_special_tokens, 1),
40
61
  )?;
41
- class.define_method("encode", method!(RbTokenizer::encode, 1))?;
42
- class.define_method("decode", method!(RbTokenizer::decode, 1))?;
62
+ class.define_method("train", method!(RbTokenizer::train, 2))?;
63
+ class.define_method("_save", method!(RbTokenizer::save, 2))?;
64
+ class.define_method("add_tokens", method!(RbTokenizer::add_tokens, 1))?;
65
+ class.define_method("_encode", method!(RbTokenizer::encode, 4))?;
66
+ class.define_method("_encode_batch", method!(RbTokenizer::encode_batch, 3))?;
67
+ class.define_method("_decode", method!(RbTokenizer::decode, 2))?;
68
+ class.define_method("_decode_batch", method!(RbTokenizer::decode_batch, 2))?;
43
69
  class.define_method("decoder=", method!(RbTokenizer::set_decoder, 1))?;
44
70
  class.define_method("pre_tokenizer=", method!(RbTokenizer::set_pre_tokenizer, 1))?;
71
+ class.define_method(
72
+ "post_processor=",
73
+ method!(RbTokenizer::set_post_processor, 1),
74
+ )?;
45
75
  class.define_method("normalizer=", method!(RbTokenizer::set_normalizer, 1))?;
76
+ class.define_method("token_to_id", method!(RbTokenizer::token_to_id, 1))?;
77
+ class.define_method("id_to_token", method!(RbTokenizer::id_to_token, 1))?;
78
+ class.define_method("_enable_padding", method!(RbTokenizer::enable_padding, 1))?;
79
+ class.define_method("padding", method!(RbTokenizer::padding, 0))?;
80
+ class.define_method("no_padding", method!(RbTokenizer::no_padding, 0))?;
81
+ class.define_method("_enable_truncation", method!(RbTokenizer::enable_truncation, 2))?;
82
+ class.define_method("truncation", method!(RbTokenizer::truncation, 0))?;
83
+ class.define_method("no_truncation", method!(RbTokenizer::no_truncation, 0))?;
84
+ class.define_method("num_special_tokens_to_add", method!(RbTokenizer::num_special_tokens_to_add, 1))?;
85
+ class.define_method("_vocab", method!(RbTokenizer::vocab, 1))?;
86
+ class.define_method("_vocab_size", method!(RbTokenizer::vocab_size, 1))?;
87
+ class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
46
88
 
47
89
  let class = module.define_class("Encoding", Default::default())?;
90
+ class.define_method("n_sequences", method!(RbEncoding::n_sequences, 0))?;
48
91
  class.define_method("ids", method!(RbEncoding::ids, 0))?;
49
92
  class.define_method("tokens", method!(RbEncoding::tokens, 0))?;
93
+ class.define_method("word_ids", method!(RbEncoding::word_ids, 0))?;
94
+ class.define_method("sequence_ids", method!(RbEncoding::sequence_ids, 0))?;
95
+ class.define_method("type_ids", method!(RbEncoding::type_ids, 0))?;
96
+ class.define_method("offsets", method!(RbEncoding::offsets, 0))?;
97
+ class.define_method(
98
+ "special_tokens_mask",
99
+ method!(RbEncoding::special_tokens_mask, 0),
100
+ )?;
101
+ class.define_method("attention_mask", method!(RbEncoding::attention_mask, 0))?;
102
+ class.define_method("overflowing", method!(RbEncoding::overflowing, 0))?;
103
+ class.define_method("_word_to_tokens", method!(RbEncoding::word_to_tokens, 2))?;
104
+ class.define_method("_word_to_chars", method!(RbEncoding::word_to_chars, 2))?;
105
+ class.define_method(
106
+ "token_to_sequence",
107
+ method!(RbEncoding::token_to_sequence, 1),
108
+ )?;
109
+ class.define_method("token_to_chars", method!(RbEncoding::token_to_chars, 1))?;
110
+ class.define_method("token_to_word", method!(RbEncoding::token_to_word, 1))?;
111
+ class.define_method("_char_to_token", method!(RbEncoding::char_to_token, 2))?;
112
+ class.define_method("_char_to_word", method!(RbEncoding::char_to_word, 2))?;
50
113
 
51
- let class = module.define_class("BPEDecoder", Default::default())?;
52
- class.define_singleton_method("new", function!(RbBPEDecoder::new, 0))?;
114
+ let class = module.define_class("Regex", Default::default())?;
115
+ class.define_singleton_method("new", function!(RbRegex::new, 1))?;
53
116
 
54
- let class = module.define_class("BertPreTokenizer", Default::default())?;
55
- class.define_singleton_method("new", function!(RbBertPreTokenizer::new, 0))?;
117
+ let models = module.define_module("Models")?;
118
+ let pre_tokenizers = module.define_module("PreTokenizers")?;
119
+ let decoders = module.define_module("Decoders")?;
120
+ let processors = module.define_module("Processors")?;
121
+ let normalizers = module.define_module("Normalizers")?;
122
+ let trainers = module.define_module("Trainers")?;
56
123
 
57
- let class = module.define_class("BertNormalizer", Default::default())?;
58
- class.define_singleton_method("new", function!(RbBertNormalizer::new, 0))?;
124
+ models::models(&models)?;
125
+ pre_tokenizers::pre_tokenizers(&pre_tokenizers)?;
126
+ decoders::decoders(&decoders)?;
127
+ processors::processors(&processors)?;
128
+ normalizers::normalizers(&normalizers)?;
129
+ trainers::trainers(&trainers)?;
59
130
 
60
131
  Ok(())
61
132
  }
@@ -1,19 +1,380 @@
1
- use tk::models::bpe::BPE;
1
+ use std::collections::HashMap;
2
+ use std::path::{Path, PathBuf};
3
+ use std::sync::{Arc, RwLock};
4
+
5
+ use crate::trainers::RbTrainer;
6
+ use magnus::typed_data::DataTypeBuilder;
7
+ use magnus::{
8
+ exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RClass, RHash, RModule, Symbol, TypedData, Value,
10
+ };
11
+ use serde::{Deserialize, Serialize};
12
+ use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
13
+ use tk::models::ModelWrapper;
14
+ use tk::models::unigram::Unigram;
15
+ use tk::models::wordlevel::WordLevel;
16
+ use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
17
+ use tk::{Model, Token};
2
18
 
3
19
  use super::{RbError, RbResult};
4
20
 
5
- #[magnus::wrap(class = "Tokenizers::BPE")]
6
- pub struct RbBPE {
7
- pub model: BPE,
21
+ #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
22
+ pub struct RbModel {
23
+ #[serde(flatten)]
24
+ pub model: Arc<RwLock<ModelWrapper>>,
8
25
  }
9
26
 
27
+ impl Model for RbModel {
28
+ type Trainer = RbTrainer;
29
+
30
+ fn tokenize(&self, tokens: &str) -> tk::Result<Vec<Token>> {
31
+ self.model.read().unwrap().tokenize(tokens)
32
+ }
33
+
34
+ fn token_to_id(&self, token: &str) -> Option<u32> {
35
+ self.model.read().unwrap().token_to_id(token)
36
+ }
37
+
38
+ fn id_to_token(&self, id: u32) -> Option<String> {
39
+ self.model.read().unwrap().id_to_token(id)
40
+ }
41
+
42
+ fn get_vocab(&self) -> HashMap<String, u32> {
43
+ self.model.read().unwrap().get_vocab()
44
+ }
45
+
46
+ fn get_vocab_size(&self) -> usize {
47
+ self.model.read().unwrap().get_vocab_size()
48
+ }
49
+
50
+ fn save(&self, folder: &Path, name: Option<&str>) -> tk::Result<Vec<PathBuf>> {
51
+ self.model.read().unwrap().save(folder, name)
52
+ }
53
+
54
+ fn get_trainer(&self) -> Self::Trainer {
55
+ self.model.read().unwrap().get_trainer().into()
56
+ }
57
+ }
58
+
59
+ impl<I> From<I> for RbModel
60
+ where
61
+ I: Into<ModelWrapper>,
62
+ {
63
+ fn from(model: I) -> Self {
64
+ Self {
65
+ model: Arc::new(RwLock::new(model.into())),
66
+ }
67
+ }
68
+ }
69
+
70
+ pub struct RbBPE {}
71
+
10
72
  impl RbBPE {
11
- pub fn new(vocab: String, merges: String) -> RbResult<Self> {
12
- BPE::from_file(&vocab, &merges)
13
- .unk_token("<unk>".into())
14
- .end_of_word_suffix("</w>".into())
15
- .build()
16
- .map(|v| RbBPE { model: v })
17
- .map_err(RbError::from)
73
+ fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
74
+ let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
75
+ if !value.is_nil() {
76
+ builder = builder.cache_capacity(value.try_convert()?);
77
+ }
78
+
79
+ let value: Value = kwargs.delete(Symbol::new("dropout"))?;
80
+ if !value.is_nil() {
81
+ builder = builder.dropout(value.try_convert()?);
82
+ }
83
+
84
+ let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
85
+ if !value.is_nil() {
86
+ builder = builder.unk_token(value.try_convert()?);
87
+ }
88
+
89
+ let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
90
+ if !value.is_nil() {
91
+ builder = builder.continuing_subword_prefix(value.try_convert()?);
92
+ }
93
+
94
+ let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
95
+ if !value.is_nil() {
96
+ builder = builder.end_of_word_suffix(value.try_convert()?);
97
+ }
98
+
99
+ let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
100
+ if !value.is_nil() {
101
+ builder = builder.fuse_unk(value.try_convert()?);
102
+ }
103
+
104
+ if !kwargs.is_empty() {
105
+ // TODO improve message
106
+ return Err(Error::new(exception::arg_error(), "unknown keyword"));
107
+ }
108
+
109
+ builder.build().map(|v| v.into()).map_err(RbError::from)
110
+ }
111
+
112
+ pub fn new(vocab: Option<Vocab>, merges: Option<Merges>, kwargs: RHash) -> RbResult<RbModel> {
113
+ let mut builder = BPE::builder();
114
+ if let (Some(vocab), Some(merges)) = (vocab, merges) {
115
+ builder = builder.vocab_and_merges(vocab, merges);
116
+ }
117
+ RbBPE::with_builder(builder, kwargs)
118
+ }
119
+
120
+ pub fn from_file(vocab: String, merges: String, kwargs: RHash) -> RbResult<RbModel> {
121
+ let (vocab, merges) = BPE::read_file(&vocab, &merges).map_err(RbError::from)?;
122
+
123
+ RbBPE::new(Some(vocab), Some(merges), kwargs)
124
+ }
125
+ }
126
+
127
+ macro_rules! getter {
128
+ ($self: ident, $variant: ident, $($name: tt)+) => {{
129
+ let model = $self.model.write().unwrap();
130
+ if let ModelWrapper::$variant(ref mo) = *model {
131
+ mo.$($name)+
132
+ } else {
133
+ unreachable!()
134
+ }
135
+ }};
136
+ }
137
+
138
+ macro_rules! setter {
139
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
140
+ let mut model = $self.model.write().unwrap();
141
+ if let ModelWrapper::$variant(ref mut mo) = *model {
142
+ mo.$name = $value;
143
+ }
144
+ }};
145
+ }
146
+
147
+ impl RbModel {
148
+ pub fn bpe_dropout(&self) -> Option<f32> {
149
+ getter!(self, BPE, dropout)
150
+ }
151
+
152
+ pub fn bpe_set_dropout(&self, dropout: Option<f32>) {
153
+ setter!(self, BPE, dropout, dropout);
154
+ }
155
+
156
+ pub fn bpe_unk_token(&self) -> Option<String> {
157
+ getter!(self, BPE, unk_token.clone())
158
+ }
159
+
160
+ pub fn bpe_set_unk_token(&self, unk_token: Option<String>) {
161
+ setter!(self, BPE, unk_token, unk_token);
162
+ }
163
+
164
+ pub fn bpe_fuse_unk(&self) -> bool {
165
+ getter!(self, BPE, fuse_unk)
166
+ }
167
+
168
+ pub fn bpe_set_fuse_unk(&self, fuse_unk: bool) {
169
+ setter!(self, BPE, fuse_unk, fuse_unk);
170
+ }
171
+
172
+ pub fn bpe_continuing_subword_prefix(&self) -> Option<String> {
173
+ getter!(self, BPE, continuing_subword_prefix.clone())
174
+ }
175
+
176
+ pub fn bpe_set_continuing_subword_prefix(&self, continuing_subword_prefix: Option<String>) {
177
+ setter!(self, BPE, continuing_subword_prefix, continuing_subword_prefix);
178
+ }
179
+
180
+ pub fn bpe_end_of_word_suffix(&self) -> Option<String> {
181
+ getter!(self, BPE, end_of_word_suffix.clone())
182
+ }
183
+
184
+ pub fn bpe_set_end_of_word_suffix(&self, end_of_word_suffix: Option<String>) {
185
+ setter!(self, BPE, end_of_word_suffix, end_of_word_suffix);
186
+ }
187
+
188
+ pub fn word_level_unk_token(&self) -> String {
189
+ getter!(self, WordLevel, unk_token.clone())
190
+ }
191
+
192
+ pub fn word_level_set_unk_token(&self, unk_token: String) {
193
+ setter!(self, WordLevel, unk_token, unk_token);
194
+ }
195
+
196
+ pub fn word_piece_unk_token(&self) -> String {
197
+ getter!(self, WordPiece, unk_token.clone())
198
+ }
199
+
200
+ pub fn word_piece_set_unk_token(&self, unk_token: String) {
201
+ setter!(self, WordPiece, unk_token, unk_token);
202
+ }
203
+
204
+ pub fn word_piece_continuing_subword_prefix(&self) -> String {
205
+ getter!(self, WordPiece, continuing_subword_prefix.clone())
206
+ }
207
+
208
+ pub fn word_piece_set_continuing_subword_prefix(&self, continuing_subword_prefix: String) {
209
+ setter!(self, WordPiece, continuing_subword_prefix, continuing_subword_prefix);
210
+ }
211
+
212
+ pub fn word_piece_max_input_chars_per_word(&self) -> usize {
213
+ getter!(self, WordPiece, max_input_chars_per_word.clone())
18
214
  }
215
+
216
+ pub fn word_piece_set_max_input_chars_per_word(&self, max_input_chars_per_word: usize) {
217
+ setter!(self, WordPiece, max_input_chars_per_word, max_input_chars_per_word);
218
+ }
219
+ }
220
+
221
+ pub struct RbUnigram {}
222
+
223
+ impl RbUnigram {
224
+ fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
225
+ match (vocab, unk_id) {
226
+ (Some(vocab), unk_id) => {
227
+ let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
228
+ Ok(model.into())
229
+ }
230
+ (None, None) => Ok(Unigram::default().into()),
231
+ _ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
232
+ }
233
+ }
234
+ }
235
+
236
+ pub struct RbWordLevel {}
237
+
238
+ impl RbWordLevel {
239
+ pub fn new(vocab: Option<HashMap<String, u32>>, unk_token: Option<String>) -> RbResult<RbModel> {
240
+ let mut builder = WordLevel::builder();
241
+ if let Some(vocab) = vocab {
242
+ builder = builder.vocab(vocab);
243
+ }
244
+ if let Some(unk_token) = unk_token {
245
+ builder = builder.unk_token(unk_token);
246
+ }
247
+ builder.build().map(|v| v.into()).map_err(RbError::from)
248
+ }
249
+
250
+ pub fn read_file(vocab: String) -> RbResult<Vocab> {
251
+ WordLevel::read_file(&vocab).map_err(RbError::from)
252
+ }
253
+
254
+ pub fn from_file(vocab: String, unk_token: Option<String>) -> RbResult<RbModel> {
255
+ let vocab = WordLevel::read_file(&vocab).map_err(RbError::from)?;
256
+
257
+ RbWordLevel::new(Some(vocab), unk_token)
258
+ }
259
+ }
260
+
261
+ pub struct RbWordPiece {}
262
+
263
+ impl RbWordPiece {
264
+ fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
265
+ let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
266
+ if !value.is_nil() {
267
+ builder = builder.unk_token(value.try_convert()?);
268
+ }
269
+
270
+ let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
271
+ if !value.is_nil() {
272
+ builder = builder.max_input_chars_per_word(value.try_convert()?);
273
+ }
274
+
275
+ let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
276
+ if !value.is_nil() {
277
+ builder = builder.continuing_subword_prefix(value.try_convert()?);
278
+ }
279
+
280
+ if !kwargs.is_empty() {
281
+ // TODO improve message
282
+ return Err(Error::new(exception::arg_error(), "unknown keyword"));
283
+ }
284
+
285
+ builder.build().map(|v| v.into()).map_err(RbError::from)
286
+ }
287
+
288
+ pub fn new(vocab: Option<HashMap<String, u32>>, kwargs: RHash) -> RbResult<RbModel> {
289
+ let mut builder = WordPiece::builder();
290
+ if let Some(vocab) = vocab {
291
+ builder = builder.vocab(vocab);
292
+ }
293
+ RbWordPiece::with_builder(builder, kwargs)
294
+ }
295
+
296
+ pub fn from_file(vocab: String, kwargs: RHash) -> RbResult<RbModel> {
297
+ let vocab = WordPiece::read_file(&vocab).map_err(RbError::from)?;
298
+
299
+ RbWordPiece::new(Some(vocab), kwargs)
300
+ }
301
+ }
302
+
303
+ unsafe impl TypedData for RbModel {
304
+ fn class() -> RClass {
305
+ *memoize!(RClass: {
306
+ let class: RClass = crate::models().const_get("Model").unwrap();
307
+ class.undef_alloc_func();
308
+ class
309
+ })
310
+ }
311
+
312
+ fn data_type() -> &'static DataType {
313
+ memoize!(DataType: DataTypeBuilder::<RbModel>::new("Tokenizers::Models::Model").build())
314
+ }
315
+
316
+ fn class_for(value: &Self) -> RClass {
317
+ match *value.model.read().unwrap() {
318
+ ModelWrapper::BPE(_) => *memoize!(RClass: {
319
+ let class: RClass = crate::models().const_get("BPE").unwrap();
320
+ class.undef_alloc_func();
321
+ class
322
+ }),
323
+ ModelWrapper::Unigram(_) => *memoize!(RClass: {
324
+ let class: RClass = crate::models().const_get("Unigram").unwrap();
325
+ class.undef_alloc_func();
326
+ class
327
+ }),
328
+ ModelWrapper::WordLevel(_) => *memoize!(RClass: {
329
+ let class: RClass = crate::models().const_get("WordLevel").unwrap();
330
+ class.undef_alloc_func();
331
+ class
332
+ }),
333
+ ModelWrapper::WordPiece(_) => *memoize!(RClass: {
334
+ let class: RClass = crate::models().const_get("WordPiece").unwrap();
335
+ class.undef_alloc_func();
336
+ class
337
+ }),
338
+ }
339
+ }
340
+ }
341
+
342
+ pub fn models(module: &RModule) -> RbResult<()> {
343
+ let model = module.define_class("Model", Default::default())?;
344
+
345
+ let class = module.define_class("BPE", model)?;
346
+ class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
347
+ class.define_singleton_method("_from_file", function!(RbBPE::from_file, 3))?;
348
+ class.define_method("dropout", method!(RbModel::bpe_dropout, 0))?;
349
+ class.define_method("dropout=", method!(RbModel::bpe_set_dropout, 1))?;
350
+ class.define_method("unk_token", method!(RbModel::bpe_unk_token, 0))?;
351
+ class.define_method("unk_token=", method!(RbModel::bpe_set_unk_token, 1))?;
352
+ class.define_method("continuing_subword_prefix", method!(RbModel::bpe_continuing_subword_prefix, 0))?;
353
+ class.define_method("continuing_subword_prefix=", method!(RbModel::bpe_set_continuing_subword_prefix, 1))?;
354
+ class.define_method("end_of_word_suffix", method!(RbModel::bpe_end_of_word_suffix, 0))?;
355
+ class.define_method("end_of_word_suffix=", method!(RbModel::bpe_set_end_of_word_suffix, 1))?;
356
+ class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
357
+ class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
358
+
359
+ let class = module.define_class("Unigram", model)?;
360
+ class.define_singleton_method("_new", function!(RbUnigram::new, 2))?;
361
+
362
+ let class = module.define_class("WordLevel", model)?;
363
+ class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
364
+ class.define_singleton_method("_from_file", function!(RbWordLevel::from_file, 2))?;
365
+ class.define_singleton_method("read_file", function!(RbWordLevel::read_file, 1))?;
366
+ class.define_method("unk_token", method!(RbModel::word_level_unk_token, 0))?;
367
+ class.define_method("unk_token=", method!(RbModel::word_level_set_unk_token, 1))?;
368
+
369
+ let class = module.define_class("WordPiece", model)?;
370
+ class.define_singleton_method("_new", function!(RbWordPiece::new, 2))?;
371
+ class.define_singleton_method("_from_file", function!(RbWordPiece::from_file, 2))?;
372
+ class.define_method("unk_token", method!(RbModel::word_piece_unk_token, 0))?;
373
+ class.define_method("unk_token=", method!(RbModel::word_piece_set_unk_token, 1))?;
374
+ class.define_method("continuing_subword_prefix", method!(RbModel::word_piece_continuing_subword_prefix, 0))?;
375
+ class.define_method("continuing_subword_prefix=", method!(RbModel::word_piece_set_continuing_subword_prefix, 1))?;
376
+ class.define_method("max_input_chars_per_word", method!(RbModel::word_piece_max_input_chars_per_word, 0))?;
377
+ class.define_method("max_input_chars_per_word=", method!(RbModel::word_piece_set_max_input_chars_per_word, 1))?;
378
+
379
+ Ok(())
19
380
  }