tokenizers 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/Cargo.lock +160 -96
- data/ext/tokenizers/Cargo.toml +6 -6
- data/ext/tokenizers/src/decoders.rs +149 -39
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +71 -50
- data/ext/tokenizers/src/normalizers.rs +113 -74
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/decoders/strip.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/normalizers/prepend.rb +9 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +4 -2
- metadata +6 -4
@@ -1,20 +1,25 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::
|
3
|
+
use magnus::value::Lazy;
|
4
4
|
use magnus::{
|
5
|
-
|
6
|
-
TypedData,
|
5
|
+
data_type_builder, function, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
6
|
+
Ruby, TypedData,
|
7
7
|
};
|
8
8
|
use serde::{Deserialize, Serialize};
|
9
9
|
use tk::decoders::bpe::BPEDecoder;
|
10
|
+
use tk::decoders::byte_fallback::ByteFallback;
|
10
11
|
use tk::decoders::byte_level::ByteLevel;
|
11
12
|
use tk::decoders::ctc::CTC;
|
13
|
+
use tk::decoders::fuse::Fuse;
|
12
14
|
use tk::decoders::metaspace::Metaspace;
|
15
|
+
use tk::decoders::strip::Strip;
|
13
16
|
use tk::decoders::wordpiece::WordPiece;
|
14
17
|
use tk::decoders::DecoderWrapper;
|
15
18
|
use tk::Decoder;
|
19
|
+
use tk::normalizers::replace::Replace;
|
16
20
|
|
17
|
-
use super::
|
21
|
+
use super::utils::*;
|
22
|
+
use super::{DECODERS, RbError, RbResult};
|
18
23
|
|
19
24
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
20
25
|
pub struct RbDecoder {
|
@@ -89,6 +94,30 @@ impl RbDecoder {
|
|
89
94
|
setter!(self, CTC, word_delimiter_token, word_delimiter_token);
|
90
95
|
}
|
91
96
|
|
97
|
+
fn strip_content(&self) -> char {
|
98
|
+
getter!(self, Strip, content)
|
99
|
+
}
|
100
|
+
|
101
|
+
fn strip_set_content(&self, content: char) {
|
102
|
+
setter!(self, Strip, content, content)
|
103
|
+
}
|
104
|
+
|
105
|
+
fn strip_start(&self) -> usize {
|
106
|
+
getter!(self, Strip, start)
|
107
|
+
}
|
108
|
+
|
109
|
+
fn strip_set_start(&self, start: usize) {
|
110
|
+
setter!(self, Strip, start, start)
|
111
|
+
}
|
112
|
+
|
113
|
+
fn strip_stop(&self) -> usize {
|
114
|
+
getter!(self, Strip, stop)
|
115
|
+
}
|
116
|
+
|
117
|
+
fn strip_set_stop(&self, stop: usize) {
|
118
|
+
setter!(self, Strip, stop, stop)
|
119
|
+
}
|
120
|
+
|
92
121
|
pub fn metaspace_replacement(&self) -> char {
|
93
122
|
getter!(self, Metaspace, get_replacement().clone())
|
94
123
|
}
|
@@ -130,6 +159,14 @@ impl RbBPEDecoder {
|
|
130
159
|
}
|
131
160
|
}
|
132
161
|
|
162
|
+
pub struct RbByteFallbackDecoder {}
|
163
|
+
|
164
|
+
impl RbByteFallbackDecoder {
|
165
|
+
pub fn new() -> RbDecoder {
|
166
|
+
ByteFallback::default().into()
|
167
|
+
}
|
168
|
+
}
|
169
|
+
|
133
170
|
pub struct RbByteLevelDecoder {}
|
134
171
|
|
135
172
|
impl RbByteLevelDecoder {
|
@@ -146,6 +183,14 @@ impl RbCTC {
|
|
146
183
|
}
|
147
184
|
}
|
148
185
|
|
186
|
+
pub struct RbFuse {}
|
187
|
+
|
188
|
+
impl RbFuse {
|
189
|
+
pub fn new() -> RbDecoder {
|
190
|
+
Fuse::default().into()
|
191
|
+
}
|
192
|
+
}
|
193
|
+
|
149
194
|
pub struct RbMetaspaceDecoder {}
|
150
195
|
|
151
196
|
impl RbMetaspaceDecoder {
|
@@ -154,6 +199,22 @@ impl RbMetaspaceDecoder {
|
|
154
199
|
}
|
155
200
|
}
|
156
201
|
|
202
|
+
pub struct RbReplaceDecoder {}
|
203
|
+
|
204
|
+
impl RbReplaceDecoder {
|
205
|
+
pub fn new(pattern: RbPattern, content: String) -> RbResult<RbDecoder> {
|
206
|
+
Replace::new(pattern, content).map(|v| v.into()).map_err(RbError::from)
|
207
|
+
}
|
208
|
+
}
|
209
|
+
|
210
|
+
pub struct RbStripDecoder {}
|
211
|
+
|
212
|
+
impl RbStripDecoder {
|
213
|
+
pub fn new(content: char, start: usize, stop: usize) -> RbDecoder {
|
214
|
+
Strip::new(content, start, stop).into()
|
215
|
+
}
|
216
|
+
}
|
217
|
+
|
157
218
|
pub struct RbWordPieceDecoder {}
|
158
219
|
|
159
220
|
impl RbWordPieceDecoder {
|
@@ -199,60 +260,94 @@ impl Decoder for RbDecoderWrapper {
|
|
199
260
|
}
|
200
261
|
|
201
262
|
unsafe impl TypedData for RbDecoder {
|
202
|
-
fn class() -> RClass {
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
})
|
263
|
+
fn class(ruby: &Ruby) -> RClass {
|
264
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
265
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Decoder").unwrap();
|
266
|
+
class.undef_default_alloc_func();
|
267
|
+
class
|
268
|
+
});
|
269
|
+
ruby.get_inner(&CLASS)
|
208
270
|
}
|
209
271
|
|
210
272
|
fn data_type() -> &'static DataType {
|
211
|
-
|
273
|
+
static DATA_TYPE: DataType = data_type_builder!(RbDecoder, "Tokenizers::Decoders::Decoder").build();
|
274
|
+
&DATA_TYPE
|
212
275
|
}
|
213
276
|
|
214
|
-
fn class_for(value: &Self) -> RClass {
|
277
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
278
|
+
static BPE_DECODER: Lazy<RClass> = Lazy::new(|ruby| {
|
279
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("BPEDecoder").unwrap();
|
280
|
+
class.undef_default_alloc_func();
|
281
|
+
class
|
282
|
+
});
|
283
|
+
static BYTE_FALLBACK: Lazy<RClass> = Lazy::new(|ruby| {
|
284
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteFallback").unwrap();
|
285
|
+
class.undef_default_alloc_func();
|
286
|
+
class
|
287
|
+
});
|
288
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
289
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteLevel").unwrap();
|
290
|
+
class.undef_default_alloc_func();
|
291
|
+
class
|
292
|
+
});
|
293
|
+
static CTC: Lazy<RClass> = Lazy::new(|ruby| {
|
294
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("CTC").unwrap();
|
295
|
+
class.undef_default_alloc_func();
|
296
|
+
class
|
297
|
+
});
|
298
|
+
static FUSE: Lazy<RClass> = Lazy::new(|ruby| {
|
299
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Fuse").unwrap();
|
300
|
+
class.undef_default_alloc_func();
|
301
|
+
class
|
302
|
+
});
|
303
|
+
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
304
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Metaspace").unwrap();
|
305
|
+
class.undef_default_alloc_func();
|
306
|
+
class
|
307
|
+
});
|
308
|
+
static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
|
309
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Replace").unwrap();
|
310
|
+
class.undef_default_alloc_func();
|
311
|
+
class
|
312
|
+
});
|
313
|
+
static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
|
314
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Strip").unwrap();
|
315
|
+
class.undef_default_alloc_func();
|
316
|
+
class
|
317
|
+
});
|
318
|
+
static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
|
319
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("WordPiece").unwrap();
|
320
|
+
class.undef_default_alloc_func();
|
321
|
+
class
|
322
|
+
});
|
215
323
|
match &value.decoder {
|
216
324
|
RbDecoderWrapper::Wrapped(inner) => match *inner.read().unwrap() {
|
217
|
-
DecoderWrapper::BPE(_) =>
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
DecoderWrapper::
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
}),
|
227
|
-
DecoderWrapper::CTC(_) => *memoize!(RClass: {
|
228
|
-
let class: RClass = crate::decoders().const_get("CTC").unwrap();
|
229
|
-
class.undef_alloc_func();
|
230
|
-
class
|
231
|
-
}),
|
232
|
-
DecoderWrapper::Metaspace(_) => *memoize!(RClass: {
|
233
|
-
let class: RClass = crate::decoders().const_get("Metaspace").unwrap();
|
234
|
-
class.undef_alloc_func();
|
235
|
-
class
|
236
|
-
}),
|
237
|
-
DecoderWrapper::WordPiece(_) => *memoize!(RClass: {
|
238
|
-
let class: RClass = crate::decoders().const_get("WordPiece").unwrap();
|
239
|
-
class.undef_alloc_func();
|
240
|
-
class
|
241
|
-
}),
|
325
|
+
DecoderWrapper::BPE(_) => ruby.get_inner(&BPE_DECODER),
|
326
|
+
DecoderWrapper::ByteFallback(_) => ruby.get_inner(&BYTE_FALLBACK),
|
327
|
+
DecoderWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
328
|
+
DecoderWrapper::CTC(_) => ruby.get_inner(&CTC),
|
329
|
+
DecoderWrapper::Fuse(_) => ruby.get_inner(&FUSE),
|
330
|
+
DecoderWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
|
331
|
+
DecoderWrapper::Replace(_) => ruby.get_inner(&REPLACE),
|
332
|
+
DecoderWrapper::Strip(_) => ruby.get_inner(&STRIP),
|
333
|
+
DecoderWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
|
242
334
|
_ => todo!(),
|
243
335
|
},
|
244
336
|
}
|
245
337
|
}
|
246
338
|
}
|
247
339
|
|
248
|
-
pub fn
|
249
|
-
let decoder = module.define_class("Decoder",
|
340
|
+
pub fn init_decoders(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
341
|
+
let decoder = module.define_class("Decoder", ruby.class_object())?;
|
250
342
|
|
251
343
|
let class = module.define_class("BPEDecoder", decoder)?;
|
252
344
|
class.define_singleton_method("_new", function!(RbBPEDecoder::new, 1))?;
|
253
345
|
class.define_method("suffix", method!(RbDecoder::bpe_suffix, 0))?;
|
254
346
|
class.define_method("suffix=", method!(RbDecoder::bpe_set_suffix, 1))?;
|
255
347
|
|
348
|
+
let class = module.define_class("ByteFallback", decoder)?;
|
349
|
+
class.define_singleton_method("new", function!(RbByteFallbackDecoder::new, 0))?;
|
350
|
+
|
256
351
|
let class = module.define_class("ByteLevel", decoder)?;
|
257
352
|
class.define_singleton_method("new", function!(RbByteLevelDecoder::new, 0))?;
|
258
353
|
|
@@ -265,6 +360,9 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
|
|
265
360
|
class.define_method("word_delimiter_token", method!(RbDecoder::ctc_word_delimiter_token, 0))?;
|
266
361
|
class.define_method("word_delimiter_token=", method!(RbDecoder::ctc_set_word_delimiter_token, 1))?;
|
267
362
|
|
363
|
+
let class = module.define_class("Fuse", decoder)?;
|
364
|
+
class.define_singleton_method("new", function!(RbFuse::new, 0))?;
|
365
|
+
|
268
366
|
let class = module.define_class("Metaspace", decoder)?;
|
269
367
|
class.define_singleton_method("_new", function!(RbMetaspaceDecoder::new, 2))?;
|
270
368
|
class.define_method("add_prefix_space", method!(RbDecoder::metaspace_add_prefix_space, 0))?;
|
@@ -272,6 +370,18 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
|
|
272
370
|
class.define_method("replacement", method!(RbDecoder::metaspace_replacement, 0))?;
|
273
371
|
class.define_method("replacement=", method!(RbDecoder::metaspace_set_replacement, 1))?;
|
274
372
|
|
373
|
+
let class = module.define_class("Replace", decoder)?;
|
374
|
+
class.define_singleton_method("new", function!(RbReplaceDecoder::new, 2))?;
|
375
|
+
|
376
|
+
let class = module.define_class("Strip", decoder)?;
|
377
|
+
class.define_singleton_method("_new", function!(RbStripDecoder::new, 3))?;
|
378
|
+
class.define_method("content", method!(RbDecoder::strip_content, 0))?;
|
379
|
+
class.define_method("content=", method!(RbDecoder::strip_set_content, 1))?;
|
380
|
+
class.define_method("start", method!(RbDecoder::strip_start, 0))?;
|
381
|
+
class.define_method("start=", method!(RbDecoder::strip_set_start, 1))?;
|
382
|
+
class.define_method("stop", method!(RbDecoder::strip_stop, 0))?;
|
383
|
+
class.define_method("stop=", method!(RbDecoder::strip_set_stop, 1))?;
|
384
|
+
|
275
385
|
let class = module.define_class("WordPiece", decoder)?;
|
276
386
|
class.define_singleton_method("_new", function!(RbWordPieceDecoder::new, 2))?;
|
277
387
|
class.define_method("cleanup", method!(RbDecoder::word_piece_cleanup, 0))?;
|
data/ext/tokenizers/src/error.rs
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
-
use magnus::{
|
1
|
+
use magnus::{prelude::*, value::Lazy, Error, ExceptionClass, Ruby};
|
2
2
|
|
3
|
-
use super::
|
3
|
+
use super::TOKENIZERS;
|
4
4
|
|
5
5
|
pub struct RbError {}
|
6
6
|
|
@@ -11,6 +11,8 @@ impl RbError {
|
|
11
11
|
}
|
12
12
|
}
|
13
13
|
|
14
|
+
static ERROR: Lazy<ExceptionClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Error").unwrap());
|
15
|
+
|
14
16
|
fn error() -> ExceptionClass {
|
15
|
-
|
17
|
+
Ruby::get().unwrap().get_inner(&ERROR)
|
16
18
|
}
|
data/ext/tokenizers/src/lib.rs
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
#![allow(clippy::new_ret_no_self)]
|
2
|
+
|
1
3
|
extern crate tokenizers as tk;
|
2
4
|
|
3
5
|
mod decoders;
|
@@ -16,43 +18,29 @@ use error::RbError;
|
|
16
18
|
use tokenizer::RbTokenizer;
|
17
19
|
use utils::RbRegex;
|
18
20
|
|
19
|
-
use magnus::{
|
21
|
+
use magnus::{function, method, prelude::*, value::Lazy, Error, RModule, Ruby};
|
20
22
|
|
21
23
|
type RbResult<T> = Result<T, Error>;
|
22
24
|
|
23
|
-
|
24
|
-
*memoize!(RModule: define_module("Tokenizers").unwrap())
|
25
|
-
}
|
25
|
+
static TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.class_object().const_get("Tokenizers").unwrap());
|
26
26
|
|
27
|
-
|
28
|
-
*memoize!(RModule: module().const_get("Decoders").unwrap())
|
29
|
-
}
|
27
|
+
static DECODERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Decoders").unwrap());
|
30
28
|
|
31
|
-
|
32
|
-
*memoize!(RModule: module().const_get("Models").unwrap())
|
33
|
-
}
|
29
|
+
static MODELS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Models").unwrap());
|
34
30
|
|
35
|
-
|
36
|
-
*memoize!(RModule: module().const_get("Normalizers").unwrap())
|
37
|
-
}
|
31
|
+
static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Normalizers").unwrap());
|
38
32
|
|
39
|
-
|
40
|
-
*memoize!(RModule: module().const_get("PreTokenizers").unwrap())
|
41
|
-
}
|
33
|
+
static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("PreTokenizers").unwrap());
|
42
34
|
|
43
|
-
|
44
|
-
*memoize!(RModule: module().const_get("Processors").unwrap())
|
45
|
-
}
|
35
|
+
static PROCESSORS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Processors").unwrap());
|
46
36
|
|
47
|
-
|
48
|
-
*memoize!(RModule: module().const_get("Trainers").unwrap())
|
49
|
-
}
|
37
|
+
static TRAINERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Trainers").unwrap());
|
50
38
|
|
51
39
|
#[magnus::init]
|
52
|
-
fn init() -> RbResult<()> {
|
53
|
-
let module =
|
40
|
+
fn init(ruby: &Ruby) -> RbResult<()> {
|
41
|
+
let module = ruby.get_inner(&TOKENIZERS);
|
54
42
|
|
55
|
-
let class = module.define_class("Tokenizer",
|
43
|
+
let class = module.define_class("Tokenizer", ruby.class_object())?;
|
56
44
|
class.define_singleton_method("new", function!(RbTokenizer::from_model, 1))?;
|
57
45
|
class.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
|
58
46
|
class.define_method(
|
@@ -86,7 +74,7 @@ fn init() -> RbResult<()> {
|
|
86
74
|
class.define_method("_vocab_size", method!(RbTokenizer::vocab_size, 1))?;
|
87
75
|
class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
|
88
76
|
|
89
|
-
let class = module.define_class("Encoding",
|
77
|
+
let class = module.define_class("Encoding", ruby.class_object())?;
|
90
78
|
class.define_method("n_sequences", method!(RbEncoding::n_sequences, 0))?;
|
91
79
|
class.define_method("ids", method!(RbEncoding::ids, 0))?;
|
92
80
|
class.define_method("tokens", method!(RbEncoding::tokens, 0))?;
|
@@ -111,7 +99,7 @@ fn init() -> RbResult<()> {
|
|
111
99
|
class.define_method("_char_to_token", method!(RbEncoding::char_to_token, 2))?;
|
112
100
|
class.define_method("_char_to_word", method!(RbEncoding::char_to_word, 2))?;
|
113
101
|
|
114
|
-
let class = module.define_class("Regex",
|
102
|
+
let class = module.define_class("Regex", ruby.class_object())?;
|
115
103
|
class.define_singleton_method("new", function!(RbRegex::new, 1))?;
|
116
104
|
|
117
105
|
let models = module.define_module("Models")?;
|
@@ -121,12 +109,12 @@ fn init() -> RbResult<()> {
|
|
121
109
|
let normalizers = module.define_module("Normalizers")?;
|
122
110
|
let trainers = module.define_module("Trainers")?;
|
123
111
|
|
124
|
-
models::
|
125
|
-
pre_tokenizers::
|
126
|
-
decoders::
|
127
|
-
processors::
|
128
|
-
normalizers::
|
129
|
-
trainers::
|
112
|
+
models::init_models(ruby, &models)?;
|
113
|
+
pre_tokenizers::init_pre_tokenizers(ruby, &pre_tokenizers)?;
|
114
|
+
decoders::init_decoders(ruby, &decoders)?;
|
115
|
+
processors::init_processors(ruby, &processors)?;
|
116
|
+
normalizers::init_normalizers(ruby, &normalizers)?;
|
117
|
+
trainers::init_trainers(ruby, &trainers)?;
|
130
118
|
|
131
119
|
Ok(())
|
132
120
|
}
|
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
|
|
3
3
|
use std::sync::{Arc, RwLock};
|
4
4
|
|
5
5
|
use crate::trainers::RbTrainer;
|
6
|
-
use magnus::
|
6
|
+
use magnus::prelude::*;
|
7
7
|
use magnus::{
|
8
|
-
exception, function,
|
9
|
-
RClass, RHash, RModule, Symbol, TypedData, Value,
|
8
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
|
9
|
+
RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
|
10
10
|
};
|
11
11
|
use serde::{Deserialize, Serialize};
|
12
12
|
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
|
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
|
|
16
16
|
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
|
17
17
|
use tk::{Model, Token};
|
18
18
|
|
19
|
-
use super::{RbError, RbResult};
|
19
|
+
use super::{MODELS, RbError, RbResult};
|
20
20
|
|
21
21
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
22
22
|
pub struct RbModel {
|
@@ -73,32 +73,37 @@ impl RbBPE {
|
|
73
73
|
fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
74
74
|
let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
|
75
75
|
if !value.is_nil() {
|
76
|
-
builder = builder.cache_capacity(
|
76
|
+
builder = builder.cache_capacity(TryConvert::try_convert(value)?);
|
77
77
|
}
|
78
78
|
|
79
79
|
let value: Value = kwargs.delete(Symbol::new("dropout"))?;
|
80
80
|
if !value.is_nil() {
|
81
|
-
builder = builder.dropout(
|
81
|
+
builder = builder.dropout(TryConvert::try_convert(value)?);
|
82
82
|
}
|
83
83
|
|
84
84
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
85
85
|
if !value.is_nil() {
|
86
|
-
builder = builder.unk_token(
|
86
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
87
87
|
}
|
88
88
|
|
89
89
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
90
90
|
if !value.is_nil() {
|
91
|
-
builder = builder.continuing_subword_prefix(
|
91
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
92
92
|
}
|
93
93
|
|
94
94
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
95
95
|
if !value.is_nil() {
|
96
|
-
builder = builder.end_of_word_suffix(
|
96
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
97
97
|
}
|
98
98
|
|
99
99
|
let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
|
100
100
|
if !value.is_nil() {
|
101
|
-
builder = builder.fuse_unk(
|
101
|
+
builder = builder.fuse_unk(TryConvert::try_convert(value)?);
|
102
|
+
}
|
103
|
+
|
104
|
+
let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
|
105
|
+
if !value.is_nil() {
|
106
|
+
builder = builder.byte_fallback(TryConvert::try_convert(value)?);
|
102
107
|
}
|
103
108
|
|
104
109
|
if !kwargs.is_empty() {
|
@@ -169,6 +174,14 @@ impl RbModel {
|
|
169
174
|
setter!(self, BPE, fuse_unk, fuse_unk);
|
170
175
|
}
|
171
176
|
|
177
|
+
pub fn bpe_byte_fallback(&self) -> bool {
|
178
|
+
getter!(self, BPE, byte_fallback)
|
179
|
+
}
|
180
|
+
|
181
|
+
pub fn bpe_set_byte_fallback(&self, byte_fallback: bool) {
|
182
|
+
setter!(self, BPE, byte_fallback, byte_fallback);
|
183
|
+
}
|
184
|
+
|
172
185
|
pub fn bpe_continuing_subword_prefix(&self) -> Option<String> {
|
173
186
|
getter!(self, BPE, continuing_subword_prefix.clone())
|
174
187
|
}
|
@@ -221,13 +234,13 @@ impl RbModel {
|
|
221
234
|
pub struct RbUnigram {}
|
222
235
|
|
223
236
|
impl RbUnigram {
|
224
|
-
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
|
225
|
-
match (vocab, unk_id) {
|
226
|
-
(Some(vocab), unk_id) => {
|
227
|
-
let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
|
237
|
+
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
|
238
|
+
match (vocab, unk_id, byte_fallback) {
|
239
|
+
(Some(vocab), unk_id, byte_fallback) => {
|
240
|
+
let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
|
228
241
|
Ok(model.into())
|
229
242
|
}
|
230
|
-
(None, None) => Ok(Unigram::default().into()),
|
243
|
+
(None, None, _) => Ok(Unigram::default().into()),
|
231
244
|
_ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
|
232
245
|
}
|
233
246
|
}
|
@@ -264,17 +277,17 @@ impl RbWordPiece {
|
|
264
277
|
fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
265
278
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
266
279
|
if !value.is_nil() {
|
267
|
-
builder = builder.unk_token(
|
280
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
268
281
|
}
|
269
282
|
|
270
283
|
let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
|
271
284
|
if !value.is_nil() {
|
272
|
-
builder = builder.max_input_chars_per_word(
|
285
|
+
builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
|
273
286
|
}
|
274
287
|
|
275
288
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
276
289
|
if !value.is_nil() {
|
277
|
-
builder = builder.continuing_subword_prefix(
|
290
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
278
291
|
}
|
279
292
|
|
280
293
|
if !kwargs.is_empty() {
|
@@ -301,46 +314,52 @@ impl RbWordPiece {
|
|
301
314
|
}
|
302
315
|
|
303
316
|
unsafe impl TypedData for RbModel {
|
304
|
-
fn class() -> RClass {
|
305
|
-
|
306
|
-
let class: RClass =
|
307
|
-
class.
|
317
|
+
fn class(ruby: &Ruby) -> RClass {
|
318
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
319
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
|
320
|
+
class.undef_default_alloc_func();
|
308
321
|
class
|
309
|
-
})
|
322
|
+
});
|
323
|
+
ruby.get_inner(&CLASS)
|
310
324
|
}
|
311
325
|
|
312
326
|
fn data_type() -> &'static DataType {
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
327
|
+
static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
|
328
|
+
&DATA_TYPE
|
329
|
+
}
|
330
|
+
|
331
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
332
|
+
static BPE: Lazy<RClass> = Lazy::new(|ruby| {
|
333
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
|
334
|
+
class.undef_default_alloc_func();
|
335
|
+
class
|
336
|
+
});
|
337
|
+
static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
|
338
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
|
339
|
+
class.undef_default_alloc_func();
|
340
|
+
class
|
341
|
+
});
|
342
|
+
static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
343
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
|
344
|
+
class.undef_default_alloc_func();
|
345
|
+
class
|
346
|
+
});
|
347
|
+
static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
|
348
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
|
349
|
+
class.undef_default_alloc_func();
|
350
|
+
class
|
351
|
+
});
|
317
352
|
match *value.model.read().unwrap() {
|
318
|
-
ModelWrapper::BPE(_) =>
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
}),
|
323
|
-
ModelWrapper::Unigram(_) => *memoize!(RClass: {
|
324
|
-
let class: RClass = crate::models().const_get("Unigram").unwrap();
|
325
|
-
class.undef_alloc_func();
|
326
|
-
class
|
327
|
-
}),
|
328
|
-
ModelWrapper::WordLevel(_) => *memoize!(RClass: {
|
329
|
-
let class: RClass = crate::models().const_get("WordLevel").unwrap();
|
330
|
-
class.undef_alloc_func();
|
331
|
-
class
|
332
|
-
}),
|
333
|
-
ModelWrapper::WordPiece(_) => *memoize!(RClass: {
|
334
|
-
let class: RClass = crate::models().const_get("WordPiece").unwrap();
|
335
|
-
class.undef_alloc_func();
|
336
|
-
class
|
337
|
-
}),
|
353
|
+
ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
|
354
|
+
ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
|
355
|
+
ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
|
356
|
+
ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
|
338
357
|
}
|
339
358
|
}
|
340
359
|
}
|
341
360
|
|
342
|
-
pub fn
|
343
|
-
let model = module.define_class("Model",
|
361
|
+
pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
362
|
+
let model = module.define_class("Model", ruby.class_object())?;
|
344
363
|
|
345
364
|
let class = module.define_class("BPE", model)?;
|
346
365
|
class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
|
@@ -355,9 +374,11 @@ pub fn models(module: &RModule) -> RbResult<()> {
|
|
355
374
|
class.define_method("end_of_word_suffix=", method!(RbModel::bpe_set_end_of_word_suffix, 1))?;
|
356
375
|
class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
|
357
376
|
class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
|
377
|
+
class.define_method("byte_fallback", method!(RbModel::bpe_byte_fallback, 0))?;
|
378
|
+
class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
|
358
379
|
|
359
380
|
let class = module.define_class("Unigram", model)?;
|
360
|
-
class.define_singleton_method("_new", function!(RbUnigram::new,
|
381
|
+
class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
|
361
382
|
|
362
383
|
let class = module.define_class("WordLevel", model)?;
|
363
384
|
class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
|