tokenizers 0.2.3 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/Cargo.lock +32 -73
- data/README.md +4 -0
- data/ext/tokenizers/Cargo.toml +3 -1
- data/ext/tokenizers/src/decoders.rs +275 -6
- data/ext/tokenizers/src/encoding.rs +3 -2
- data/ext/tokenizers/src/error.rs +2 -2
- data/ext/tokenizers/src/lib.rs +64 -17
- data/ext/tokenizers/src/models.rs +372 -11
- data/ext/tokenizers/src/normalizers.rs +435 -7
- data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
- data/ext/tokenizers/src/processors.rs +210 -0
- data/ext/tokenizers/src/tokenizer.rs +437 -23
- data/ext/tokenizers/src/trainers.rs +749 -0
- data/ext/tokenizers/src/utils/mod.rs +5 -0
- data/ext/tokenizers/src/utils/normalization.rs +85 -0
- data/ext/tokenizers/src/utils/regex.rs +22 -0
- data/lib/tokenizers/char_bpe_tokenizer.rb +9 -6
- data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
- data/lib/tokenizers/decoders/ctc.rb +9 -0
- data/lib/tokenizers/decoders/metaspace.rb +9 -0
- data/lib/tokenizers/decoders/word_piece.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/bpe.rb +9 -0
- data/lib/tokenizers/models/unigram.rb +9 -0
- data/lib/tokenizers/models/word_level.rb +13 -0
- data/lib/tokenizers/models/word_piece.rb +9 -0
- data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
- data/lib/tokenizers/normalizers/strip.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
- data/lib/tokenizers/processors/byte_level.rb +9 -0
- data/lib/tokenizers/processors/roberta_processing.rb +9 -0
- data/lib/tokenizers/processors/template_processing.rb +9 -0
- data/lib/tokenizers/tokenizer.rb +40 -7
- data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
- data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
- data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
- data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +42 -2
- metadata +30 -3
@@ -0,0 +1,749 @@
|
|
1
|
+
use std::collections::HashSet;
|
2
|
+
use std::sync::{Arc, RwLock};
|
3
|
+
|
4
|
+
use crate::models::RbModel;
|
5
|
+
use crate::tokenizer::RbAddedToken;
|
6
|
+
use magnus::typed_data::DataTypeBuilder;
|
7
|
+
use magnus::{
|
8
|
+
exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
|
9
|
+
RArray, RClass, RHash, RModule, Symbol, TypedData, Value,
|
10
|
+
};
|
11
|
+
use serde::{Deserialize, Serialize};
|
12
|
+
use tk::models::TrainerWrapper;
|
13
|
+
use tk::Trainer;
|
14
|
+
|
15
|
+
use super::RbResult;
|
16
|
+
|
17
|
+
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
18
|
+
pub struct RbTrainer {
|
19
|
+
#[serde(flatten)]
|
20
|
+
pub trainer: Arc<RwLock<TrainerWrapper>>,
|
21
|
+
}
|
22
|
+
|
23
|
+
impl Trainer for RbTrainer {
|
24
|
+
type Model = RbModel;
|
25
|
+
|
26
|
+
fn should_show_progress(&self) -> bool {
|
27
|
+
self.trainer.read().unwrap().should_show_progress()
|
28
|
+
}
|
29
|
+
|
30
|
+
fn train(&self, model: &mut RbModel) -> tk::Result<Vec<tk::AddedToken>> {
|
31
|
+
self.trainer
|
32
|
+
.read()
|
33
|
+
.unwrap()
|
34
|
+
.train(&mut model.model.write().unwrap())
|
35
|
+
}
|
36
|
+
|
37
|
+
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
|
38
|
+
where
|
39
|
+
I: Iterator<Item = S> + Send,
|
40
|
+
S: AsRef<str> + Send,
|
41
|
+
F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
|
42
|
+
{
|
43
|
+
self.trainer.write().unwrap().feed(iterator, process)
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
47
|
+
macro_rules! getter {
|
48
|
+
($self: ident, $variant: ident, $($name: tt)+) => {{
|
49
|
+
if let TrainerWrapper::$variant(ref trainer) = *$self.trainer.read().unwrap() {
|
50
|
+
trainer.$($name)+
|
51
|
+
} else {
|
52
|
+
unreachable!()
|
53
|
+
}
|
54
|
+
}};
|
55
|
+
}
|
56
|
+
|
57
|
+
macro_rules! setter {
|
58
|
+
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
59
|
+
if let TrainerWrapper::$variant(ref mut trainer) = *$self.trainer.write().unwrap() {
|
60
|
+
trainer.$name = $value;
|
61
|
+
}
|
62
|
+
}};
|
63
|
+
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
|
64
|
+
if let TrainerWrapper::$variant(ref mut trainer) = *$self.trainer.write().unwrap() {
|
65
|
+
trainer.$name($value);
|
66
|
+
}
|
67
|
+
}};
|
68
|
+
}
|
69
|
+
|
70
|
+
impl RbTrainer {
|
71
|
+
|
72
|
+
fn bpe_trainer_vocab_size(&self) -> usize {
|
73
|
+
getter!(self, BpeTrainer, vocab_size)
|
74
|
+
}
|
75
|
+
|
76
|
+
fn bpe_trainer_set_vocab_size(&self, vocab_size: usize) {
|
77
|
+
setter!(self, BpeTrainer, vocab_size, vocab_size);
|
78
|
+
}
|
79
|
+
|
80
|
+
fn bpe_trainer_min_frequency(&self) -> u32 {
|
81
|
+
getter!(self, BpeTrainer, min_frequency)
|
82
|
+
}
|
83
|
+
|
84
|
+
fn bpe_trainer_set_min_frequency(&self, freq: u32) {
|
85
|
+
setter!(self, BpeTrainer, min_frequency, freq);
|
86
|
+
}
|
87
|
+
|
88
|
+
fn bpe_trainer_show_progress(&self) -> bool {
|
89
|
+
getter!(self, BpeTrainer, show_progress)
|
90
|
+
}
|
91
|
+
|
92
|
+
fn bpe_trainer_set_show_progress(&self, show_progress: bool) {
|
93
|
+
setter!(self, BpeTrainer, show_progress, show_progress);
|
94
|
+
}
|
95
|
+
|
96
|
+
fn bpe_trainer_special_tokens(&self) -> Vec<String> {
|
97
|
+
getter!(
|
98
|
+
self,
|
99
|
+
BpeTrainer,
|
100
|
+
special_tokens
|
101
|
+
.iter()
|
102
|
+
.map(|tok| tok.content.clone())
|
103
|
+
.collect()
|
104
|
+
)
|
105
|
+
}
|
106
|
+
|
107
|
+
fn bpe_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
|
108
|
+
setter!(
|
109
|
+
self,
|
110
|
+
BpeTrainer,
|
111
|
+
special_tokens,
|
112
|
+
special_tokens
|
113
|
+
.each()
|
114
|
+
.map(|token| {
|
115
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
116
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
117
|
+
} else {
|
118
|
+
todo!()
|
119
|
+
}
|
120
|
+
})
|
121
|
+
.collect::<RbResult<Vec<_>>>()?
|
122
|
+
);
|
123
|
+
Ok(())
|
124
|
+
}
|
125
|
+
|
126
|
+
fn bpe_trainer_limit_alphabet(&self) -> Option<usize> {
|
127
|
+
getter!(self, BpeTrainer, limit_alphabet)
|
128
|
+
}
|
129
|
+
|
130
|
+
fn bpe_trainer_set_limit_alphabet(&self, limit: Option<usize>) {
|
131
|
+
setter!(self, BpeTrainer, limit_alphabet, limit);
|
132
|
+
}
|
133
|
+
|
134
|
+
fn bpe_trainer_initial_alphabet(&self) -> Vec<String> {
|
135
|
+
getter!(
|
136
|
+
self,
|
137
|
+
BpeTrainer,
|
138
|
+
initial_alphabet.iter().map(|c| c.to_string()).collect()
|
139
|
+
)
|
140
|
+
}
|
141
|
+
|
142
|
+
fn bpe_trainer_set_initial_alphabet(&self, alphabet: Vec<char>) {
|
143
|
+
setter!(
|
144
|
+
self,
|
145
|
+
BpeTrainer,
|
146
|
+
initial_alphabet,
|
147
|
+
alphabet.into_iter().map(|c| c).collect()
|
148
|
+
);
|
149
|
+
}
|
150
|
+
|
151
|
+
fn bpe_trainer_continuing_subword_prefix(&self) -> Option<String> {
|
152
|
+
getter!(self, BpeTrainer, continuing_subword_prefix.clone())
|
153
|
+
}
|
154
|
+
|
155
|
+
fn bpe_trainer_set_continuing_subword_prefix(&self, prefix: Option<String>) {
|
156
|
+
setter!(self, BpeTrainer, continuing_subword_prefix, prefix);
|
157
|
+
}
|
158
|
+
|
159
|
+
fn bpe_trainer_end_of_word_suffix(&self) -> Option<String> {
|
160
|
+
getter!(self, BpeTrainer, end_of_word_suffix.clone())
|
161
|
+
}
|
162
|
+
|
163
|
+
fn bpe_trainer_set_end_of_word_suffix(&self, suffix: Option<String>) {
|
164
|
+
setter!(self, BpeTrainer, end_of_word_suffix, suffix);
|
165
|
+
}
|
166
|
+
|
167
|
+
fn unigram_trainer_vocab_size(&self) -> u32 {
|
168
|
+
getter!(self, UnigramTrainer, vocab_size)
|
169
|
+
}
|
170
|
+
|
171
|
+
fn unigram_trainer_set_vocab_size(&self, vocab_size: u32) {
|
172
|
+
setter!(self, UnigramTrainer, vocab_size, vocab_size);
|
173
|
+
}
|
174
|
+
|
175
|
+
fn unigram_trainer_show_progress(&self) -> bool {
|
176
|
+
getter!(self, UnigramTrainer, show_progress)
|
177
|
+
}
|
178
|
+
|
179
|
+
fn unigram_trainer_set_show_progress(&self, show_progress: bool) {
|
180
|
+
setter!(self, UnigramTrainer, show_progress, show_progress);
|
181
|
+
}
|
182
|
+
|
183
|
+
fn unigram_trainer_special_tokens(&self) -> Vec<String> {
|
184
|
+
getter!(
|
185
|
+
self,
|
186
|
+
UnigramTrainer,
|
187
|
+
special_tokens
|
188
|
+
.iter()
|
189
|
+
.map(|tok| tok.content.clone())
|
190
|
+
.collect()
|
191
|
+
)
|
192
|
+
}
|
193
|
+
|
194
|
+
fn unigram_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
|
195
|
+
setter!(
|
196
|
+
self,
|
197
|
+
UnigramTrainer,
|
198
|
+
special_tokens,
|
199
|
+
special_tokens
|
200
|
+
.each()
|
201
|
+
.map(|token| {
|
202
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
203
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
204
|
+
} else {
|
205
|
+
todo!()
|
206
|
+
}
|
207
|
+
})
|
208
|
+
.collect::<RbResult<Vec<_>>>()?
|
209
|
+
);
|
210
|
+
Ok(())
|
211
|
+
}
|
212
|
+
|
213
|
+
fn unigram_trainer_initial_alphabet(&self) -> Vec<String> {
|
214
|
+
getter!(
|
215
|
+
self,
|
216
|
+
UnigramTrainer,
|
217
|
+
initial_alphabet.iter().map(|c| c.to_string()).collect()
|
218
|
+
)
|
219
|
+
}
|
220
|
+
|
221
|
+
fn unigram_trainer_set_initial_alphabet(&self, alphabet: Vec<char>) {
|
222
|
+
setter!(
|
223
|
+
self,
|
224
|
+
UnigramTrainer,
|
225
|
+
initial_alphabet,
|
226
|
+
alphabet.into_iter().map(|c| c).collect()
|
227
|
+
);
|
228
|
+
}
|
229
|
+
|
230
|
+
fn word_level_trainer_vocab_size(&self) -> usize {
|
231
|
+
getter!(self, WordLevelTrainer, vocab_size)
|
232
|
+
}
|
233
|
+
|
234
|
+
fn word_level_trainer_set_vocab_size(&self, vocab_size: usize) {
|
235
|
+
setter!(self, WordLevelTrainer, vocab_size, vocab_size);
|
236
|
+
}
|
237
|
+
|
238
|
+
fn word_level_trainer_min_frequency(&self) -> u32 {
|
239
|
+
getter!(self, WordLevelTrainer, min_frequency)
|
240
|
+
}
|
241
|
+
|
242
|
+
fn word_level_trainer_set_min_frequency(&self, freq: u32) {
|
243
|
+
setter!(self, WordLevelTrainer, min_frequency, freq);
|
244
|
+
}
|
245
|
+
|
246
|
+
fn word_level_trainer_show_progress(&self) -> bool {
|
247
|
+
getter!(self, WordLevelTrainer, show_progress)
|
248
|
+
}
|
249
|
+
|
250
|
+
fn word_level_trainer_set_show_progress(&self, show_progress: bool) {
|
251
|
+
setter!(self, WordLevelTrainer, show_progress, show_progress);
|
252
|
+
}
|
253
|
+
|
254
|
+
fn word_level_trainer_special_tokens(&self) -> Vec<String> {
|
255
|
+
getter!(
|
256
|
+
self,
|
257
|
+
WordLevelTrainer,
|
258
|
+
special_tokens
|
259
|
+
.iter()
|
260
|
+
.map(|tok| tok.content.clone())
|
261
|
+
.collect()
|
262
|
+
)
|
263
|
+
}
|
264
|
+
|
265
|
+
fn word_level_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
|
266
|
+
setter!(
|
267
|
+
self,
|
268
|
+
WordLevelTrainer,
|
269
|
+
special_tokens,
|
270
|
+
special_tokens
|
271
|
+
.each()
|
272
|
+
.map(|token| {
|
273
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
274
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
275
|
+
} else {
|
276
|
+
todo!()
|
277
|
+
}
|
278
|
+
})
|
279
|
+
.collect::<RbResult<Vec<_>>>()?
|
280
|
+
);
|
281
|
+
Ok(())
|
282
|
+
}
|
283
|
+
|
284
|
+
fn word_piece_trainer_vocab_size(&self) -> usize {
|
285
|
+
getter!(self, WordPieceTrainer, vocab_size())
|
286
|
+
}
|
287
|
+
|
288
|
+
fn word_piece_trainer_set_vocab_size(&self, vocab_size: usize) {
|
289
|
+
setter!(self, WordPieceTrainer, @set_vocab_size, vocab_size);
|
290
|
+
}
|
291
|
+
|
292
|
+
fn word_piece_trainer_min_frequency(&self) -> u32 {
|
293
|
+
getter!(self, WordPieceTrainer, min_frequency())
|
294
|
+
}
|
295
|
+
|
296
|
+
fn word_piece_trainer_set_min_frequency(&self, freq: u32) {
|
297
|
+
setter!(self, WordPieceTrainer, @set_min_frequency, freq);
|
298
|
+
}
|
299
|
+
|
300
|
+
fn word_piece_trainer_show_progress(&self) -> bool {
|
301
|
+
getter!(self, WordPieceTrainer, show_progress())
|
302
|
+
}
|
303
|
+
|
304
|
+
fn word_piece_trainer_set_show_progress(&self, show_progress: bool) {
|
305
|
+
setter!(self, WordPieceTrainer, @set_show_progress, show_progress);
|
306
|
+
}
|
307
|
+
|
308
|
+
fn word_piece_trainer_special_tokens(&self) -> Vec<String> {
|
309
|
+
getter!(
|
310
|
+
self,
|
311
|
+
WordPieceTrainer,
|
312
|
+
special_tokens()
|
313
|
+
.iter()
|
314
|
+
.map(|tok| tok.content.clone())
|
315
|
+
.collect()
|
316
|
+
)
|
317
|
+
}
|
318
|
+
|
319
|
+
fn word_piece_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
|
320
|
+
setter!(
|
321
|
+
self,
|
322
|
+
WordPieceTrainer,
|
323
|
+
@set_special_tokens,
|
324
|
+
special_tokens
|
325
|
+
.each()
|
326
|
+
.map(|token| {
|
327
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
328
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
329
|
+
} else {
|
330
|
+
todo!()
|
331
|
+
}
|
332
|
+
})
|
333
|
+
.collect::<RbResult<Vec<_>>>()?
|
334
|
+
);
|
335
|
+
Ok(())
|
336
|
+
}
|
337
|
+
|
338
|
+
fn word_piece_trainer_limit_alphabet(&self) -> Option<usize> {
|
339
|
+
getter!(self, WordPieceTrainer, limit_alphabet())
|
340
|
+
}
|
341
|
+
|
342
|
+
fn word_piece_trainer_set_limit_alphabet(&self, limit: Option<usize>) {
|
343
|
+
setter!(self, WordPieceTrainer, @set_limit_alphabet, limit);
|
344
|
+
}
|
345
|
+
|
346
|
+
fn word_piece_trainer_initial_alphabet(&self) -> Vec<String> {
|
347
|
+
getter!(
|
348
|
+
self,
|
349
|
+
WordPieceTrainer,
|
350
|
+
initial_alphabet().iter().map(|c| c.to_string()).collect()
|
351
|
+
)
|
352
|
+
}
|
353
|
+
|
354
|
+
fn word_piece_trainer_set_initial_alphabet(&self, alphabet: Vec<char>) {
|
355
|
+
setter!(
|
356
|
+
self,
|
357
|
+
WordPieceTrainer,
|
358
|
+
@set_initial_alphabet,
|
359
|
+
alphabet.into_iter().map(|c| c).collect()
|
360
|
+
);
|
361
|
+
}
|
362
|
+
|
363
|
+
fn word_piece_trainer_continuing_subword_prefix(&self) -> Option<String> {
|
364
|
+
getter!(self, WordPieceTrainer, continuing_subword_prefix().clone())
|
365
|
+
}
|
366
|
+
|
367
|
+
fn word_piece_trainer_set_continuing_subword_prefix(&self, prefix: Option<String>) {
|
368
|
+
setter!(self, WordPieceTrainer, @set_continuing_subword_prefix, prefix);
|
369
|
+
}
|
370
|
+
|
371
|
+
fn word_piece_trainer_end_of_word_suffix(&self) -> Option<String> {
|
372
|
+
getter!(self, WordPieceTrainer, end_of_word_suffix().clone())
|
373
|
+
}
|
374
|
+
|
375
|
+
fn word_piece_trainer_set_end_of_word_suffix(&self, suffix: Option<String>) {
|
376
|
+
setter!(self, WordPieceTrainer, @set_end_of_word_suffix, suffix);
|
377
|
+
}
|
378
|
+
}
|
379
|
+
|
380
|
+
impl<I> From<I> for RbTrainer
|
381
|
+
where
|
382
|
+
I: Into<TrainerWrapper>,
|
383
|
+
{
|
384
|
+
fn from(trainer: I) -> Self {
|
385
|
+
RbTrainer {
|
386
|
+
trainer: Arc::new(RwLock::new(trainer.into())),
|
387
|
+
}
|
388
|
+
}
|
389
|
+
}
|
390
|
+
|
391
|
+
pub struct RbBpeTrainer {}
|
392
|
+
|
393
|
+
impl RbBpeTrainer {
|
394
|
+
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
395
|
+
let mut builder = tk::models::bpe::BpeTrainer::builder();
|
396
|
+
|
397
|
+
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
398
|
+
if !value.is_nil() {
|
399
|
+
builder = builder.special_tokens(
|
400
|
+
value
|
401
|
+
.try_convert::<RArray>()?
|
402
|
+
.each()
|
403
|
+
.map(|token| {
|
404
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
405
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
406
|
+
} else {
|
407
|
+
todo!()
|
408
|
+
}
|
409
|
+
})
|
410
|
+
.collect::<RbResult<Vec<_>>>()?,
|
411
|
+
);
|
412
|
+
}
|
413
|
+
|
414
|
+
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
415
|
+
if !value.is_nil() {
|
416
|
+
let arr = value.try_convert::<Vec<char>>()?;
|
417
|
+
let set: HashSet<char> = HashSet::from_iter(arr);
|
418
|
+
builder = builder.initial_alphabet(set);
|
419
|
+
}
|
420
|
+
|
421
|
+
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
422
|
+
if !value.is_nil() {
|
423
|
+
builder = builder.vocab_size(value.try_convert()?);
|
424
|
+
}
|
425
|
+
|
426
|
+
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
427
|
+
if !value.is_nil() {
|
428
|
+
builder = builder.min_frequency(value.try_convert()?);
|
429
|
+
}
|
430
|
+
|
431
|
+
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
432
|
+
if !value.is_nil() {
|
433
|
+
builder = builder.show_progress(value.try_convert()?);
|
434
|
+
}
|
435
|
+
|
436
|
+
let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
|
437
|
+
if !value.is_nil() {
|
438
|
+
builder = builder.limit_alphabet(value.try_convert()?);
|
439
|
+
}
|
440
|
+
|
441
|
+
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
442
|
+
if !value.is_nil() {
|
443
|
+
builder = builder.continuing_subword_prefix(value.try_convert()?);
|
444
|
+
}
|
445
|
+
|
446
|
+
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
447
|
+
if !value.is_nil() {
|
448
|
+
builder = builder.end_of_word_suffix(value.try_convert()?);
|
449
|
+
}
|
450
|
+
|
451
|
+
if !kwargs.is_empty() {
|
452
|
+
// TODO improve message
|
453
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
454
|
+
}
|
455
|
+
|
456
|
+
Ok(builder.build().into())
|
457
|
+
}
|
458
|
+
}
|
459
|
+
|
460
|
+
pub struct RbUnigramTrainer {}
|
461
|
+
|
462
|
+
impl RbUnigramTrainer {
|
463
|
+
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
464
|
+
let mut builder = tk::models::unigram::UnigramTrainer::builder();
|
465
|
+
|
466
|
+
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
467
|
+
if !value.is_nil() {
|
468
|
+
builder.special_tokens(
|
469
|
+
value
|
470
|
+
.try_convert::<RArray>()?
|
471
|
+
.each()
|
472
|
+
.map(|token| {
|
473
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
474
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
475
|
+
} else {
|
476
|
+
todo!()
|
477
|
+
}
|
478
|
+
})
|
479
|
+
.collect::<RbResult<Vec<_>>>()?,
|
480
|
+
);
|
481
|
+
}
|
482
|
+
|
483
|
+
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
484
|
+
if !value.is_nil() {
|
485
|
+
let arr = value.try_convert::<Vec<char>>()?;
|
486
|
+
let set: HashSet<char> = HashSet::from_iter(arr);
|
487
|
+
builder.initial_alphabet(set);
|
488
|
+
}
|
489
|
+
|
490
|
+
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
491
|
+
if !value.is_nil() {
|
492
|
+
builder.vocab_size(value.try_convert()?);
|
493
|
+
}
|
494
|
+
|
495
|
+
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
496
|
+
if !value.is_nil() {
|
497
|
+
builder.show_progress(value.try_convert()?);
|
498
|
+
}
|
499
|
+
|
500
|
+
let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
|
501
|
+
if !value.is_nil() {
|
502
|
+
builder.n_sub_iterations(value.try_convert()?);
|
503
|
+
}
|
504
|
+
|
505
|
+
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
506
|
+
if !value.is_nil() {
|
507
|
+
builder.unk_token(Some(value.try_convert()?));
|
508
|
+
}
|
509
|
+
|
510
|
+
let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
|
511
|
+
if !value.is_nil() {
|
512
|
+
builder.max_piece_length(value.try_convert()?);
|
513
|
+
}
|
514
|
+
|
515
|
+
let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
|
516
|
+
if !value.is_nil() {
|
517
|
+
builder.seed_size(value.try_convert()?);
|
518
|
+
}
|
519
|
+
|
520
|
+
let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
|
521
|
+
if !value.is_nil() {
|
522
|
+
builder.shrinking_factor(value.try_convert()?);
|
523
|
+
}
|
524
|
+
|
525
|
+
if !kwargs.is_empty() {
|
526
|
+
// TODO improve message
|
527
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
528
|
+
}
|
529
|
+
|
530
|
+
let trainer = builder.build().map_err(|_| { Error::new(exception::arg_error(), "Cannot build UnigramTrainer") })?;
|
531
|
+
Ok(trainer.into())
|
532
|
+
}
|
533
|
+
}
|
534
|
+
|
535
|
+
pub struct RbWordLevelTrainer {}
|
536
|
+
|
537
|
+
impl RbWordLevelTrainer {
|
538
|
+
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
539
|
+
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
|
540
|
+
|
541
|
+
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
542
|
+
if !value.is_nil() {
|
543
|
+
builder.special_tokens(
|
544
|
+
value
|
545
|
+
.try_convert::<RArray>()?
|
546
|
+
.each()
|
547
|
+
.map(|token| {
|
548
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
549
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
550
|
+
} else {
|
551
|
+
todo!()
|
552
|
+
}
|
553
|
+
})
|
554
|
+
.collect::<RbResult<Vec<_>>>()?,
|
555
|
+
);
|
556
|
+
}
|
557
|
+
|
558
|
+
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
559
|
+
if !value.is_nil() {
|
560
|
+
builder.vocab_size(value.try_convert()?);
|
561
|
+
}
|
562
|
+
|
563
|
+
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
564
|
+
if !value.is_nil() {
|
565
|
+
builder.min_frequency(value.try_convert()?);
|
566
|
+
}
|
567
|
+
|
568
|
+
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
569
|
+
if !value.is_nil() {
|
570
|
+
builder.show_progress(value.try_convert()?);
|
571
|
+
}
|
572
|
+
|
573
|
+
Ok(builder.build().expect("WordLevelTrainerBuilder cannot fail").into())
|
574
|
+
}
|
575
|
+
}
|
576
|
+
|
577
|
+
pub struct RbWordPieceTrainer {}
|
578
|
+
|
579
|
+
impl RbWordPieceTrainer {
|
580
|
+
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
581
|
+
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
|
582
|
+
|
583
|
+
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
584
|
+
if !value.is_nil() {
|
585
|
+
builder = builder.special_tokens(
|
586
|
+
value
|
587
|
+
.try_convert::<RArray>()?
|
588
|
+
.each()
|
589
|
+
.map(|token| {
|
590
|
+
if let Ok(content) = token?.try_convert::<String>() {
|
591
|
+
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
592
|
+
} else {
|
593
|
+
todo!()
|
594
|
+
}
|
595
|
+
})
|
596
|
+
.collect::<RbResult<Vec<_>>>()?,
|
597
|
+
);
|
598
|
+
}
|
599
|
+
|
600
|
+
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
601
|
+
if !value.is_nil() {
|
602
|
+
let arr = value.try_convert::<Vec<char>>()?;
|
603
|
+
let set: HashSet<char> = HashSet::from_iter(arr);
|
604
|
+
builder = builder.initial_alphabet(set);
|
605
|
+
}
|
606
|
+
|
607
|
+
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
608
|
+
if !value.is_nil() {
|
609
|
+
builder = builder.vocab_size(value.try_convert()?);
|
610
|
+
}
|
611
|
+
|
612
|
+
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
613
|
+
if !value.is_nil() {
|
614
|
+
builder = builder.min_frequency(value.try_convert()?);
|
615
|
+
}
|
616
|
+
|
617
|
+
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
618
|
+
if !value.is_nil() {
|
619
|
+
builder = builder.show_progress(value.try_convert()?);
|
620
|
+
}
|
621
|
+
|
622
|
+
let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
|
623
|
+
if !value.is_nil() {
|
624
|
+
builder = builder.limit_alphabet(value.try_convert()?);
|
625
|
+
}
|
626
|
+
|
627
|
+
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
628
|
+
if !value.is_nil() {
|
629
|
+
builder = builder.continuing_subword_prefix(value.try_convert()?);
|
630
|
+
}
|
631
|
+
|
632
|
+
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
633
|
+
if !value.is_nil() {
|
634
|
+
builder = builder.end_of_word_suffix(value.try_convert()?);
|
635
|
+
}
|
636
|
+
|
637
|
+
if !kwargs.is_empty() {
|
638
|
+
// TODO improve message
|
639
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
640
|
+
}
|
641
|
+
|
642
|
+
Ok(builder.build().into())
|
643
|
+
}
|
644
|
+
}
|
645
|
+
|
646
|
+
unsafe impl TypedData for RbTrainer {
|
647
|
+
fn class() -> RClass {
|
648
|
+
*memoize!(RClass: {
|
649
|
+
let class: RClass = crate::trainers().const_get("Trainer").unwrap();
|
650
|
+
class.undef_alloc_func();
|
651
|
+
class
|
652
|
+
})
|
653
|
+
}
|
654
|
+
|
655
|
+
fn data_type() -> &'static DataType {
|
656
|
+
memoize!(DataType: DataTypeBuilder::<RbTrainer>::new("Tokenizers::Trainers::Trainer").build())
|
657
|
+
}
|
658
|
+
|
659
|
+
fn class_for(value: &Self) -> RClass {
|
660
|
+
match *value.trainer.read().unwrap() {
|
661
|
+
TrainerWrapper::BpeTrainer(_) => *memoize!(RClass: {
|
662
|
+
let class: RClass = crate::trainers().const_get("BpeTrainer").unwrap();
|
663
|
+
class.undef_alloc_func();
|
664
|
+
class
|
665
|
+
}),
|
666
|
+
TrainerWrapper::UnigramTrainer(_) => *memoize!(RClass: {
|
667
|
+
let class: RClass = crate::trainers().const_get("UnigramTrainer").unwrap();
|
668
|
+
class.undef_alloc_func();
|
669
|
+
class
|
670
|
+
}),
|
671
|
+
TrainerWrapper::WordLevelTrainer(_) => *memoize!(RClass: {
|
672
|
+
let class: RClass = crate::trainers().const_get("WordLevelTrainer").unwrap();
|
673
|
+
class.undef_alloc_func();
|
674
|
+
class
|
675
|
+
}),
|
676
|
+
TrainerWrapper::WordPieceTrainer(_) => *memoize!(RClass: {
|
677
|
+
let class: RClass = crate::trainers().const_get("WordPieceTrainer").unwrap();
|
678
|
+
class.undef_alloc_func();
|
679
|
+
class
|
680
|
+
}),
|
681
|
+
}
|
682
|
+
}
|
683
|
+
}
|
684
|
+
|
685
|
+
pub fn trainers(module: &RModule) -> RbResult<()> {
|
686
|
+
let trainer = module.define_class("Trainer", Default::default())?;
|
687
|
+
|
688
|
+
let class = module.define_class("BpeTrainer", trainer)?;
|
689
|
+
class.define_singleton_method("_new", function!(RbBpeTrainer::new, 1))?;
|
690
|
+
class.define_method("vocab_size", method!(RbTrainer::bpe_trainer_vocab_size, 0))?;
|
691
|
+
class.define_method("vocab_size=", method!(RbTrainer::bpe_trainer_set_vocab_size, 1))?;
|
692
|
+
class.define_method("min_frequency", method!(RbTrainer::bpe_trainer_min_frequency, 0))?;
|
693
|
+
class.define_method("min_frequency=", method!(RbTrainer::bpe_trainer_set_min_frequency, 1))?;
|
694
|
+
class.define_method("show_progress", method!(RbTrainer::bpe_trainer_show_progress, 0))?;
|
695
|
+
class.define_method("show_progress=", method!(RbTrainer::bpe_trainer_set_show_progress, 1))?;
|
696
|
+
class.define_method("special_tokens", method!(RbTrainer::bpe_trainer_special_tokens, 0))?;
|
697
|
+
class.define_method("special_tokens=", method!(RbTrainer::bpe_trainer_set_special_tokens, 1))?;
|
698
|
+
class.define_method("limit_alphabet", method!(RbTrainer::bpe_trainer_limit_alphabet, 0))?;
|
699
|
+
class.define_method("limit_alphabet=", method!(RbTrainer::bpe_trainer_set_limit_alphabet, 1))?;
|
700
|
+
class.define_method("initial_alphabet", method!(RbTrainer::bpe_trainer_initial_alphabet, 0))?;
|
701
|
+
class.define_method("initial_alphabet=", method!(RbTrainer::bpe_trainer_set_initial_alphabet, 1))?;
|
702
|
+
class.define_method("continuing_subword_prefix", method!(RbTrainer::bpe_trainer_continuing_subword_prefix, 0))?;
|
703
|
+
class.define_method("continuing_subword_prefix=", method!(RbTrainer::bpe_trainer_set_continuing_subword_prefix, 1))?;
|
704
|
+
class.define_method("end_of_word_suffix", method!(RbTrainer::bpe_trainer_end_of_word_suffix, 0))?;
|
705
|
+
class.define_method("end_of_word_suffix=", method!(RbTrainer::bpe_trainer_set_end_of_word_suffix, 1))?;
|
706
|
+
|
707
|
+
let class = module.define_class("UnigramTrainer", trainer)?;
|
708
|
+
class.define_singleton_method("_new", function!(RbUnigramTrainer::new, 1))?;
|
709
|
+
class.define_method("vocab_size", method!(RbTrainer::unigram_trainer_vocab_size, 0))?;
|
710
|
+
class.define_method("vocab_size=", method!(RbTrainer::unigram_trainer_set_vocab_size, 1))?;
|
711
|
+
class.define_method("show_progress", method!(RbTrainer::unigram_trainer_show_progress, 0))?;
|
712
|
+
class.define_method("show_progress=", method!(RbTrainer::unigram_trainer_set_show_progress, 1))?;
|
713
|
+
class.define_method("special_tokens", method!(RbTrainer::unigram_trainer_special_tokens, 0))?;
|
714
|
+
class.define_method("special_tokens=", method!(RbTrainer::unigram_trainer_set_special_tokens, 1))?;
|
715
|
+
class.define_method("initial_alphabet", method!(RbTrainer::unigram_trainer_initial_alphabet, 0))?;
|
716
|
+
class.define_method("initial_alphabet=", method!(RbTrainer::unigram_trainer_set_initial_alphabet, 1))?;
|
717
|
+
|
718
|
+
let class = module.define_class("WordLevelTrainer", trainer)?;
|
719
|
+
class.define_singleton_method("_new", function!(RbWordLevelTrainer::new, 1))?;
|
720
|
+
class.define_method("vocab_size", method!(RbTrainer::word_level_trainer_vocab_size, 0))?;
|
721
|
+
class.define_method("vocab_size=", method!(RbTrainer::word_level_trainer_set_vocab_size, 1))?;
|
722
|
+
class.define_method("min_frequency", method!(RbTrainer::word_level_trainer_min_frequency, 0))?;
|
723
|
+
class.define_method("min_frequency=", method!(RbTrainer::word_level_trainer_set_min_frequency, 1))?;
|
724
|
+
class.define_method("show_progress", method!(RbTrainer::word_level_trainer_show_progress, 0))?;
|
725
|
+
class.define_method("show_progress=", method!(RbTrainer::word_level_trainer_set_show_progress, 1))?;
|
726
|
+
class.define_method("special_tokens", method!(RbTrainer::word_level_trainer_special_tokens, 0))?;
|
727
|
+
class.define_method("special_tokens=", method!(RbTrainer::word_level_trainer_set_special_tokens, 1))?;
|
728
|
+
|
729
|
+
let class = module.define_class("WordPieceTrainer", trainer)?;
|
730
|
+
class.define_singleton_method("_new", function!(RbWordPieceTrainer::new, 1))?;
|
731
|
+
class.define_method("vocab_size", method!(RbTrainer::word_piece_trainer_vocab_size, 0))?;
|
732
|
+
class.define_method("vocab_size=", method!(RbTrainer::word_piece_trainer_set_vocab_size, 1))?;
|
733
|
+
class.define_method("min_frequency", method!(RbTrainer::word_piece_trainer_min_frequency, 0))?;
|
734
|
+
class.define_method("min_frequency=", method!(RbTrainer::word_piece_trainer_set_min_frequency, 1))?;
|
735
|
+
class.define_method("show_progress", method!(RbTrainer::word_piece_trainer_show_progress, 0))?;
|
736
|
+
class.define_method("show_progress=", method!(RbTrainer::word_piece_trainer_set_show_progress, 1))?;
|
737
|
+
class.define_method("special_tokens", method!(RbTrainer::word_piece_trainer_special_tokens, 0))?;
|
738
|
+
class.define_method("special_tokens=", method!(RbTrainer::word_piece_trainer_set_special_tokens, 1))?;
|
739
|
+
class.define_method("limit_alphabet", method!(RbTrainer::word_piece_trainer_limit_alphabet, 0))?;
|
740
|
+
class.define_method("limit_alphabet=", method!(RbTrainer::word_piece_trainer_set_limit_alphabet, 1))?;
|
741
|
+
class.define_method("initial_alphabet", method!(RbTrainer::word_piece_trainer_initial_alphabet, 0))?;
|
742
|
+
class.define_method("initial_alphabet=", method!(RbTrainer::word_piece_trainer_set_initial_alphabet, 1))?;
|
743
|
+
class.define_method("continuing_subword_prefix", method!(RbTrainer::word_piece_trainer_continuing_subword_prefix, 0))?;
|
744
|
+
class.define_method("continuing_subword_prefix=", method!(RbTrainer::word_piece_trainer_set_continuing_subword_prefix, 1))?;
|
745
|
+
class.define_method("end_of_word_suffix", method!(RbTrainer::word_piece_trainer_end_of_word_suffix, 0))?;
|
746
|
+
class.define_method("end_of_word_suffix=", method!(RbTrainer::word_piece_trainer_set_end_of_word_suffix, 1))?;
|
747
|
+
|
748
|
+
Ok(())
|
749
|
+
}
|