tokenizers 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/Cargo.lock +160 -96
- data/ext/tokenizers/Cargo.toml +6 -6
- data/ext/tokenizers/src/decoders.rs +149 -39
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +71 -50
- data/ext/tokenizers/src/normalizers.rs +113 -74
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/decoders/strip.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/normalizers/prepend.rb +9 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +4 -2
- metadata +6 -4
@@ -1,20 +1,25 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::
|
3
|
+
use magnus::value::Lazy;
|
4
4
|
use magnus::{
|
5
|
-
|
6
|
-
TypedData,
|
5
|
+
data_type_builder, function, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
6
|
+
Ruby, TypedData,
|
7
7
|
};
|
8
8
|
use serde::{Deserialize, Serialize};
|
9
9
|
use tk::decoders::bpe::BPEDecoder;
|
10
|
+
use tk::decoders::byte_fallback::ByteFallback;
|
10
11
|
use tk::decoders::byte_level::ByteLevel;
|
11
12
|
use tk::decoders::ctc::CTC;
|
13
|
+
use tk::decoders::fuse::Fuse;
|
12
14
|
use tk::decoders::metaspace::Metaspace;
|
15
|
+
use tk::decoders::strip::Strip;
|
13
16
|
use tk::decoders::wordpiece::WordPiece;
|
14
17
|
use tk::decoders::DecoderWrapper;
|
15
18
|
use tk::Decoder;
|
19
|
+
use tk::normalizers::replace::Replace;
|
16
20
|
|
17
|
-
use super::
|
21
|
+
use super::utils::*;
|
22
|
+
use super::{DECODERS, RbError, RbResult};
|
18
23
|
|
19
24
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
20
25
|
pub struct RbDecoder {
|
@@ -89,6 +94,30 @@ impl RbDecoder {
|
|
89
94
|
setter!(self, CTC, word_delimiter_token, word_delimiter_token);
|
90
95
|
}
|
91
96
|
|
97
|
+
fn strip_content(&self) -> char {
|
98
|
+
getter!(self, Strip, content)
|
99
|
+
}
|
100
|
+
|
101
|
+
fn strip_set_content(&self, content: char) {
|
102
|
+
setter!(self, Strip, content, content)
|
103
|
+
}
|
104
|
+
|
105
|
+
fn strip_start(&self) -> usize {
|
106
|
+
getter!(self, Strip, start)
|
107
|
+
}
|
108
|
+
|
109
|
+
fn strip_set_start(&self, start: usize) {
|
110
|
+
setter!(self, Strip, start, start)
|
111
|
+
}
|
112
|
+
|
113
|
+
fn strip_stop(&self) -> usize {
|
114
|
+
getter!(self, Strip, stop)
|
115
|
+
}
|
116
|
+
|
117
|
+
fn strip_set_stop(&self, stop: usize) {
|
118
|
+
setter!(self, Strip, stop, stop)
|
119
|
+
}
|
120
|
+
|
92
121
|
pub fn metaspace_replacement(&self) -> char {
|
93
122
|
getter!(self, Metaspace, get_replacement().clone())
|
94
123
|
}
|
@@ -130,6 +159,14 @@ impl RbBPEDecoder {
|
|
130
159
|
}
|
131
160
|
}
|
132
161
|
|
162
|
+
pub struct RbByteFallbackDecoder {}
|
163
|
+
|
164
|
+
impl RbByteFallbackDecoder {
|
165
|
+
pub fn new() -> RbDecoder {
|
166
|
+
ByteFallback::default().into()
|
167
|
+
}
|
168
|
+
}
|
169
|
+
|
133
170
|
pub struct RbByteLevelDecoder {}
|
134
171
|
|
135
172
|
impl RbByteLevelDecoder {
|
@@ -146,6 +183,14 @@ impl RbCTC {
|
|
146
183
|
}
|
147
184
|
}
|
148
185
|
|
186
|
+
pub struct RbFuse {}
|
187
|
+
|
188
|
+
impl RbFuse {
|
189
|
+
pub fn new() -> RbDecoder {
|
190
|
+
Fuse::default().into()
|
191
|
+
}
|
192
|
+
}
|
193
|
+
|
149
194
|
pub struct RbMetaspaceDecoder {}
|
150
195
|
|
151
196
|
impl RbMetaspaceDecoder {
|
@@ -154,6 +199,22 @@ impl RbMetaspaceDecoder {
|
|
154
199
|
}
|
155
200
|
}
|
156
201
|
|
202
|
+
pub struct RbReplaceDecoder {}
|
203
|
+
|
204
|
+
impl RbReplaceDecoder {
|
205
|
+
pub fn new(pattern: RbPattern, content: String) -> RbResult<RbDecoder> {
|
206
|
+
Replace::new(pattern, content).map(|v| v.into()).map_err(RbError::from)
|
207
|
+
}
|
208
|
+
}
|
209
|
+
|
210
|
+
pub struct RbStripDecoder {}
|
211
|
+
|
212
|
+
impl RbStripDecoder {
|
213
|
+
pub fn new(content: char, start: usize, stop: usize) -> RbDecoder {
|
214
|
+
Strip::new(content, start, stop).into()
|
215
|
+
}
|
216
|
+
}
|
217
|
+
|
157
218
|
pub struct RbWordPieceDecoder {}
|
158
219
|
|
159
220
|
impl RbWordPieceDecoder {
|
@@ -199,60 +260,94 @@ impl Decoder for RbDecoderWrapper {
|
|
199
260
|
}
|
200
261
|
|
201
262
|
unsafe impl TypedData for RbDecoder {
|
202
|
-
fn class() -> RClass {
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
})
|
263
|
+
fn class(ruby: &Ruby) -> RClass {
|
264
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
265
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Decoder").unwrap();
|
266
|
+
class.undef_default_alloc_func();
|
267
|
+
class
|
268
|
+
});
|
269
|
+
ruby.get_inner(&CLASS)
|
208
270
|
}
|
209
271
|
|
210
272
|
fn data_type() -> &'static DataType {
|
211
|
-
|
273
|
+
static DATA_TYPE: DataType = data_type_builder!(RbDecoder, "Tokenizers::Decoders::Decoder").build();
|
274
|
+
&DATA_TYPE
|
212
275
|
}
|
213
276
|
|
214
|
-
fn class_for(value: &Self) -> RClass {
|
277
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
278
|
+
static BPE_DECODER: Lazy<RClass> = Lazy::new(|ruby| {
|
279
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("BPEDecoder").unwrap();
|
280
|
+
class.undef_default_alloc_func();
|
281
|
+
class
|
282
|
+
});
|
283
|
+
static BYTE_FALLBACK: Lazy<RClass> = Lazy::new(|ruby| {
|
284
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteFallback").unwrap();
|
285
|
+
class.undef_default_alloc_func();
|
286
|
+
class
|
287
|
+
});
|
288
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
289
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("ByteLevel").unwrap();
|
290
|
+
class.undef_default_alloc_func();
|
291
|
+
class
|
292
|
+
});
|
293
|
+
static CTC: Lazy<RClass> = Lazy::new(|ruby| {
|
294
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("CTC").unwrap();
|
295
|
+
class.undef_default_alloc_func();
|
296
|
+
class
|
297
|
+
});
|
298
|
+
static FUSE: Lazy<RClass> = Lazy::new(|ruby| {
|
299
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Fuse").unwrap();
|
300
|
+
class.undef_default_alloc_func();
|
301
|
+
class
|
302
|
+
});
|
303
|
+
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
304
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Metaspace").unwrap();
|
305
|
+
class.undef_default_alloc_func();
|
306
|
+
class
|
307
|
+
});
|
308
|
+
static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
|
309
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Replace").unwrap();
|
310
|
+
class.undef_default_alloc_func();
|
311
|
+
class
|
312
|
+
});
|
313
|
+
static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
|
314
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("Strip").unwrap();
|
315
|
+
class.undef_default_alloc_func();
|
316
|
+
class
|
317
|
+
});
|
318
|
+
static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
|
319
|
+
let class: RClass = ruby.get_inner(&DECODERS).const_get("WordPiece").unwrap();
|
320
|
+
class.undef_default_alloc_func();
|
321
|
+
class
|
322
|
+
});
|
215
323
|
match &value.decoder {
|
216
324
|
RbDecoderWrapper::Wrapped(inner) => match *inner.read().unwrap() {
|
217
|
-
DecoderWrapper::BPE(_) =>
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
DecoderWrapper::
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
}),
|
227
|
-
DecoderWrapper::CTC(_) => *memoize!(RClass: {
|
228
|
-
let class: RClass = crate::decoders().const_get("CTC").unwrap();
|
229
|
-
class.undef_alloc_func();
|
230
|
-
class
|
231
|
-
}),
|
232
|
-
DecoderWrapper::Metaspace(_) => *memoize!(RClass: {
|
233
|
-
let class: RClass = crate::decoders().const_get("Metaspace").unwrap();
|
234
|
-
class.undef_alloc_func();
|
235
|
-
class
|
236
|
-
}),
|
237
|
-
DecoderWrapper::WordPiece(_) => *memoize!(RClass: {
|
238
|
-
let class: RClass = crate::decoders().const_get("WordPiece").unwrap();
|
239
|
-
class.undef_alloc_func();
|
240
|
-
class
|
241
|
-
}),
|
325
|
+
DecoderWrapper::BPE(_) => ruby.get_inner(&BPE_DECODER),
|
326
|
+
DecoderWrapper::ByteFallback(_) => ruby.get_inner(&BYTE_FALLBACK),
|
327
|
+
DecoderWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
328
|
+
DecoderWrapper::CTC(_) => ruby.get_inner(&CTC),
|
329
|
+
DecoderWrapper::Fuse(_) => ruby.get_inner(&FUSE),
|
330
|
+
DecoderWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
|
331
|
+
DecoderWrapper::Replace(_) => ruby.get_inner(&REPLACE),
|
332
|
+
DecoderWrapper::Strip(_) => ruby.get_inner(&STRIP),
|
333
|
+
DecoderWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
|
242
334
|
_ => todo!(),
|
243
335
|
},
|
244
336
|
}
|
245
337
|
}
|
246
338
|
}
|
247
339
|
|
248
|
-
pub fn
|
249
|
-
let decoder = module.define_class("Decoder",
|
340
|
+
pub fn init_decoders(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
341
|
+
let decoder = module.define_class("Decoder", ruby.class_object())?;
|
250
342
|
|
251
343
|
let class = module.define_class("BPEDecoder", decoder)?;
|
252
344
|
class.define_singleton_method("_new", function!(RbBPEDecoder::new, 1))?;
|
253
345
|
class.define_method("suffix", method!(RbDecoder::bpe_suffix, 0))?;
|
254
346
|
class.define_method("suffix=", method!(RbDecoder::bpe_set_suffix, 1))?;
|
255
347
|
|
348
|
+
let class = module.define_class("ByteFallback", decoder)?;
|
349
|
+
class.define_singleton_method("new", function!(RbByteFallbackDecoder::new, 0))?;
|
350
|
+
|
256
351
|
let class = module.define_class("ByteLevel", decoder)?;
|
257
352
|
class.define_singleton_method("new", function!(RbByteLevelDecoder::new, 0))?;
|
258
353
|
|
@@ -265,6 +360,9 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
|
|
265
360
|
class.define_method("word_delimiter_token", method!(RbDecoder::ctc_word_delimiter_token, 0))?;
|
266
361
|
class.define_method("word_delimiter_token=", method!(RbDecoder::ctc_set_word_delimiter_token, 1))?;
|
267
362
|
|
363
|
+
let class = module.define_class("Fuse", decoder)?;
|
364
|
+
class.define_singleton_method("new", function!(RbFuse::new, 0))?;
|
365
|
+
|
268
366
|
let class = module.define_class("Metaspace", decoder)?;
|
269
367
|
class.define_singleton_method("_new", function!(RbMetaspaceDecoder::new, 2))?;
|
270
368
|
class.define_method("add_prefix_space", method!(RbDecoder::metaspace_add_prefix_space, 0))?;
|
@@ -272,6 +370,18 @@ pub fn decoders(module: &RModule) -> RbResult<()> {
|
|
272
370
|
class.define_method("replacement", method!(RbDecoder::metaspace_replacement, 0))?;
|
273
371
|
class.define_method("replacement=", method!(RbDecoder::metaspace_set_replacement, 1))?;
|
274
372
|
|
373
|
+
let class = module.define_class("Replace", decoder)?;
|
374
|
+
class.define_singleton_method("new", function!(RbReplaceDecoder::new, 2))?;
|
375
|
+
|
376
|
+
let class = module.define_class("Strip", decoder)?;
|
377
|
+
class.define_singleton_method("_new", function!(RbStripDecoder::new, 3))?;
|
378
|
+
class.define_method("content", method!(RbDecoder::strip_content, 0))?;
|
379
|
+
class.define_method("content=", method!(RbDecoder::strip_set_content, 1))?;
|
380
|
+
class.define_method("start", method!(RbDecoder::strip_start, 0))?;
|
381
|
+
class.define_method("start=", method!(RbDecoder::strip_set_start, 1))?;
|
382
|
+
class.define_method("stop", method!(RbDecoder::strip_stop, 0))?;
|
383
|
+
class.define_method("stop=", method!(RbDecoder::strip_set_stop, 1))?;
|
384
|
+
|
275
385
|
let class = module.define_class("WordPiece", decoder)?;
|
276
386
|
class.define_singleton_method("_new", function!(RbWordPieceDecoder::new, 2))?;
|
277
387
|
class.define_method("cleanup", method!(RbDecoder::word_piece_cleanup, 0))?;
|
data/ext/tokenizers/src/error.rs
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
-
use magnus::{
|
1
|
+
use magnus::{prelude::*, value::Lazy, Error, ExceptionClass, Ruby};
|
2
2
|
|
3
|
-
use super::
|
3
|
+
use super::TOKENIZERS;
|
4
4
|
|
5
5
|
pub struct RbError {}
|
6
6
|
|
@@ -11,6 +11,8 @@ impl RbError {
|
|
11
11
|
}
|
12
12
|
}
|
13
13
|
|
14
|
+
static ERROR: Lazy<ExceptionClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Error").unwrap());
|
15
|
+
|
14
16
|
fn error() -> ExceptionClass {
|
15
|
-
|
17
|
+
Ruby::get().unwrap().get_inner(&ERROR)
|
16
18
|
}
|
data/ext/tokenizers/src/lib.rs
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
#![allow(clippy::new_ret_no_self)]
|
2
|
+
|
1
3
|
extern crate tokenizers as tk;
|
2
4
|
|
3
5
|
mod decoders;
|
@@ -16,43 +18,29 @@ use error::RbError;
|
|
16
18
|
use tokenizer::RbTokenizer;
|
17
19
|
use utils::RbRegex;
|
18
20
|
|
19
|
-
use magnus::{
|
21
|
+
use magnus::{function, method, prelude::*, value::Lazy, Error, RModule, Ruby};
|
20
22
|
|
21
23
|
type RbResult<T> = Result<T, Error>;
|
22
24
|
|
23
|
-
|
24
|
-
*memoize!(RModule: define_module("Tokenizers").unwrap())
|
25
|
-
}
|
25
|
+
static TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.class_object().const_get("Tokenizers").unwrap());
|
26
26
|
|
27
|
-
|
28
|
-
*memoize!(RModule: module().const_get("Decoders").unwrap())
|
29
|
-
}
|
27
|
+
static DECODERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Decoders").unwrap());
|
30
28
|
|
31
|
-
|
32
|
-
*memoize!(RModule: module().const_get("Models").unwrap())
|
33
|
-
}
|
29
|
+
static MODELS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Models").unwrap());
|
34
30
|
|
35
|
-
|
36
|
-
*memoize!(RModule: module().const_get("Normalizers").unwrap())
|
37
|
-
}
|
31
|
+
static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Normalizers").unwrap());
|
38
32
|
|
39
|
-
|
40
|
-
*memoize!(RModule: module().const_get("PreTokenizers").unwrap())
|
41
|
-
}
|
33
|
+
static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("PreTokenizers").unwrap());
|
42
34
|
|
43
|
-
|
44
|
-
*memoize!(RModule: module().const_get("Processors").unwrap())
|
45
|
-
}
|
35
|
+
static PROCESSORS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Processors").unwrap());
|
46
36
|
|
47
|
-
|
48
|
-
*memoize!(RModule: module().const_get("Trainers").unwrap())
|
49
|
-
}
|
37
|
+
static TRAINERS: Lazy<RModule> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Trainers").unwrap());
|
50
38
|
|
51
39
|
#[magnus::init]
|
52
|
-
fn init() -> RbResult<()> {
|
53
|
-
let module =
|
40
|
+
fn init(ruby: &Ruby) -> RbResult<()> {
|
41
|
+
let module = ruby.get_inner(&TOKENIZERS);
|
54
42
|
|
55
|
-
let class = module.define_class("Tokenizer",
|
43
|
+
let class = module.define_class("Tokenizer", ruby.class_object())?;
|
56
44
|
class.define_singleton_method("new", function!(RbTokenizer::from_model, 1))?;
|
57
45
|
class.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
|
58
46
|
class.define_method(
|
@@ -86,7 +74,7 @@ fn init() -> RbResult<()> {
|
|
86
74
|
class.define_method("_vocab_size", method!(RbTokenizer::vocab_size, 1))?;
|
87
75
|
class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
|
88
76
|
|
89
|
-
let class = module.define_class("Encoding",
|
77
|
+
let class = module.define_class("Encoding", ruby.class_object())?;
|
90
78
|
class.define_method("n_sequences", method!(RbEncoding::n_sequences, 0))?;
|
91
79
|
class.define_method("ids", method!(RbEncoding::ids, 0))?;
|
92
80
|
class.define_method("tokens", method!(RbEncoding::tokens, 0))?;
|
@@ -111,7 +99,7 @@ fn init() -> RbResult<()> {
|
|
111
99
|
class.define_method("_char_to_token", method!(RbEncoding::char_to_token, 2))?;
|
112
100
|
class.define_method("_char_to_word", method!(RbEncoding::char_to_word, 2))?;
|
113
101
|
|
114
|
-
let class = module.define_class("Regex",
|
102
|
+
let class = module.define_class("Regex", ruby.class_object())?;
|
115
103
|
class.define_singleton_method("new", function!(RbRegex::new, 1))?;
|
116
104
|
|
117
105
|
let models = module.define_module("Models")?;
|
@@ -121,12 +109,12 @@ fn init() -> RbResult<()> {
|
|
121
109
|
let normalizers = module.define_module("Normalizers")?;
|
122
110
|
let trainers = module.define_module("Trainers")?;
|
123
111
|
|
124
|
-
models::
|
125
|
-
pre_tokenizers::
|
126
|
-
decoders::
|
127
|
-
processors::
|
128
|
-
normalizers::
|
129
|
-
trainers::
|
112
|
+
models::init_models(ruby, &models)?;
|
113
|
+
pre_tokenizers::init_pre_tokenizers(ruby, &pre_tokenizers)?;
|
114
|
+
decoders::init_decoders(ruby, &decoders)?;
|
115
|
+
processors::init_processors(ruby, &processors)?;
|
116
|
+
normalizers::init_normalizers(ruby, &normalizers)?;
|
117
|
+
trainers::init_trainers(ruby, &trainers)?;
|
130
118
|
|
131
119
|
Ok(())
|
132
120
|
}
|
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
|
|
3
3
|
use std::sync::{Arc, RwLock};
|
4
4
|
|
5
5
|
use crate::trainers::RbTrainer;
|
6
|
-
use magnus::
|
6
|
+
use magnus::prelude::*;
|
7
7
|
use magnus::{
|
8
|
-
exception, function,
|
9
|
-
RClass, RHash, RModule, Symbol, TypedData, Value,
|
8
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
|
9
|
+
RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
|
10
10
|
};
|
11
11
|
use serde::{Deserialize, Serialize};
|
12
12
|
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
|
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
|
|
16
16
|
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
|
17
17
|
use tk::{Model, Token};
|
18
18
|
|
19
|
-
use super::{RbError, RbResult};
|
19
|
+
use super::{MODELS, RbError, RbResult};
|
20
20
|
|
21
21
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
22
22
|
pub struct RbModel {
|
@@ -73,32 +73,37 @@ impl RbBPE {
|
|
73
73
|
fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
74
74
|
let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
|
75
75
|
if !value.is_nil() {
|
76
|
-
builder = builder.cache_capacity(
|
76
|
+
builder = builder.cache_capacity(TryConvert::try_convert(value)?);
|
77
77
|
}
|
78
78
|
|
79
79
|
let value: Value = kwargs.delete(Symbol::new("dropout"))?;
|
80
80
|
if !value.is_nil() {
|
81
|
-
builder = builder.dropout(
|
81
|
+
builder = builder.dropout(TryConvert::try_convert(value)?);
|
82
82
|
}
|
83
83
|
|
84
84
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
85
85
|
if !value.is_nil() {
|
86
|
-
builder = builder.unk_token(
|
86
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
87
87
|
}
|
88
88
|
|
89
89
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
90
90
|
if !value.is_nil() {
|
91
|
-
builder = builder.continuing_subword_prefix(
|
91
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
92
92
|
}
|
93
93
|
|
94
94
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
95
95
|
if !value.is_nil() {
|
96
|
-
builder = builder.end_of_word_suffix(
|
96
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
97
97
|
}
|
98
98
|
|
99
99
|
let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
|
100
100
|
if !value.is_nil() {
|
101
|
-
builder = builder.fuse_unk(
|
101
|
+
builder = builder.fuse_unk(TryConvert::try_convert(value)?);
|
102
|
+
}
|
103
|
+
|
104
|
+
let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
|
105
|
+
if !value.is_nil() {
|
106
|
+
builder = builder.byte_fallback(TryConvert::try_convert(value)?);
|
102
107
|
}
|
103
108
|
|
104
109
|
if !kwargs.is_empty() {
|
@@ -169,6 +174,14 @@ impl RbModel {
|
|
169
174
|
setter!(self, BPE, fuse_unk, fuse_unk);
|
170
175
|
}
|
171
176
|
|
177
|
+
pub fn bpe_byte_fallback(&self) -> bool {
|
178
|
+
getter!(self, BPE, byte_fallback)
|
179
|
+
}
|
180
|
+
|
181
|
+
pub fn bpe_set_byte_fallback(&self, byte_fallback: bool) {
|
182
|
+
setter!(self, BPE, byte_fallback, byte_fallback);
|
183
|
+
}
|
184
|
+
|
172
185
|
pub fn bpe_continuing_subword_prefix(&self) -> Option<String> {
|
173
186
|
getter!(self, BPE, continuing_subword_prefix.clone())
|
174
187
|
}
|
@@ -221,13 +234,13 @@ impl RbModel {
|
|
221
234
|
pub struct RbUnigram {}
|
222
235
|
|
223
236
|
impl RbUnigram {
|
224
|
-
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
|
225
|
-
match (vocab, unk_id) {
|
226
|
-
(Some(vocab), unk_id) => {
|
227
|
-
let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
|
237
|
+
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
|
238
|
+
match (vocab, unk_id, byte_fallback) {
|
239
|
+
(Some(vocab), unk_id, byte_fallback) => {
|
240
|
+
let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
|
228
241
|
Ok(model.into())
|
229
242
|
}
|
230
|
-
(None, None) => Ok(Unigram::default().into()),
|
243
|
+
(None, None, _) => Ok(Unigram::default().into()),
|
231
244
|
_ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
|
232
245
|
}
|
233
246
|
}
|
@@ -264,17 +277,17 @@ impl RbWordPiece {
|
|
264
277
|
fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
265
278
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
266
279
|
if !value.is_nil() {
|
267
|
-
builder = builder.unk_token(
|
280
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
268
281
|
}
|
269
282
|
|
270
283
|
let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
|
271
284
|
if !value.is_nil() {
|
272
|
-
builder = builder.max_input_chars_per_word(
|
285
|
+
builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
|
273
286
|
}
|
274
287
|
|
275
288
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
276
289
|
if !value.is_nil() {
|
277
|
-
builder = builder.continuing_subword_prefix(
|
290
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
278
291
|
}
|
279
292
|
|
280
293
|
if !kwargs.is_empty() {
|
@@ -301,46 +314,52 @@ impl RbWordPiece {
|
|
301
314
|
}
|
302
315
|
|
303
316
|
unsafe impl TypedData for RbModel {
|
304
|
-
fn class() -> RClass {
|
305
|
-
|
306
|
-
let class: RClass =
|
307
|
-
class.
|
317
|
+
fn class(ruby: &Ruby) -> RClass {
|
318
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
319
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
|
320
|
+
class.undef_default_alloc_func();
|
308
321
|
class
|
309
|
-
})
|
322
|
+
});
|
323
|
+
ruby.get_inner(&CLASS)
|
310
324
|
}
|
311
325
|
|
312
326
|
fn data_type() -> &'static DataType {
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
327
|
+
static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
|
328
|
+
&DATA_TYPE
|
329
|
+
}
|
330
|
+
|
331
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
332
|
+
static BPE: Lazy<RClass> = Lazy::new(|ruby| {
|
333
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
|
334
|
+
class.undef_default_alloc_func();
|
335
|
+
class
|
336
|
+
});
|
337
|
+
static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
|
338
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
|
339
|
+
class.undef_default_alloc_func();
|
340
|
+
class
|
341
|
+
});
|
342
|
+
static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
343
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
|
344
|
+
class.undef_default_alloc_func();
|
345
|
+
class
|
346
|
+
});
|
347
|
+
static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
|
348
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
|
349
|
+
class.undef_default_alloc_func();
|
350
|
+
class
|
351
|
+
});
|
317
352
|
match *value.model.read().unwrap() {
|
318
|
-
ModelWrapper::BPE(_) =>
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
}),
|
323
|
-
ModelWrapper::Unigram(_) => *memoize!(RClass: {
|
324
|
-
let class: RClass = crate::models().const_get("Unigram").unwrap();
|
325
|
-
class.undef_alloc_func();
|
326
|
-
class
|
327
|
-
}),
|
328
|
-
ModelWrapper::WordLevel(_) => *memoize!(RClass: {
|
329
|
-
let class: RClass = crate::models().const_get("WordLevel").unwrap();
|
330
|
-
class.undef_alloc_func();
|
331
|
-
class
|
332
|
-
}),
|
333
|
-
ModelWrapper::WordPiece(_) => *memoize!(RClass: {
|
334
|
-
let class: RClass = crate::models().const_get("WordPiece").unwrap();
|
335
|
-
class.undef_alloc_func();
|
336
|
-
class
|
337
|
-
}),
|
353
|
+
ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
|
354
|
+
ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
|
355
|
+
ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
|
356
|
+
ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
|
338
357
|
}
|
339
358
|
}
|
340
359
|
}
|
341
360
|
|
342
|
-
pub fn
|
343
|
-
let model = module.define_class("Model",
|
361
|
+
pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
362
|
+
let model = module.define_class("Model", ruby.class_object())?;
|
344
363
|
|
345
364
|
let class = module.define_class("BPE", model)?;
|
346
365
|
class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
|
@@ -355,9 +374,11 @@ pub fn models(module: &RModule) -> RbResult<()> {
|
|
355
374
|
class.define_method("end_of_word_suffix=", method!(RbModel::bpe_set_end_of_word_suffix, 1))?;
|
356
375
|
class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
|
357
376
|
class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
|
377
|
+
class.define_method("byte_fallback", method!(RbModel::bpe_byte_fallback, 0))?;
|
378
|
+
class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
|
358
379
|
|
359
380
|
let class = module.define_class("Unigram", model)?;
|
360
|
-
class.define_singleton_method("_new", function!(RbUnigram::new,
|
381
|
+
class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
|
361
382
|
|
362
383
|
let class = module.define_class("WordLevel", model)?;
|
363
384
|
class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
|