tokenizers 0.2.3 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/Cargo.lock +32 -73
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +3 -1
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +3 -2
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +64 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +437 -23
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +9 -6
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/from_pretrained.rb +1 -1
  25. data/lib/tokenizers/models/bpe.rb +9 -0
  26. data/lib/tokenizers/models/unigram.rb +9 -0
  27. data/lib/tokenizers/models/word_level.rb +13 -0
  28. data/lib/tokenizers/models/word_piece.rb +9 -0
  29. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  30. data/lib/tokenizers/normalizers/strip.rb +9 -0
  31. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  36. data/lib/tokenizers/processors/byte_level.rb +9 -0
  37. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  38. data/lib/tokenizers/processors/template_processing.rb +9 -0
  39. data/lib/tokenizers/tokenizer.rb +40 -7
  40. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  41. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  42. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  43. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  44. data/lib/tokenizers/version.rb +1 -1
  45. data/lib/tokenizers.rb +42 -2
  46. metadata +30 -3
@@ -0,0 +1,749 @@
1
+ use std::collections::HashSet;
2
+ use std::sync::{Arc, RwLock};
3
+
4
+ use crate::models::RbModel;
5
+ use crate::tokenizer::RbAddedToken;
6
+ use magnus::typed_data::DataTypeBuilder;
7
+ use magnus::{
8
+ exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RArray, RClass, RHash, RModule, Symbol, TypedData, Value,
10
+ };
11
+ use serde::{Deserialize, Serialize};
12
+ use tk::models::TrainerWrapper;
13
+ use tk::Trainer;
14
+
15
+ use super::RbResult;
16
+
17
+ #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
18
+ pub struct RbTrainer {
19
+ #[serde(flatten)]
20
+ pub trainer: Arc<RwLock<TrainerWrapper>>,
21
+ }
22
+
23
+ impl Trainer for RbTrainer {
24
+ type Model = RbModel;
25
+
26
+ fn should_show_progress(&self) -> bool {
27
+ self.trainer.read().unwrap().should_show_progress()
28
+ }
29
+
30
+ fn train(&self, model: &mut RbModel) -> tk::Result<Vec<tk::AddedToken>> {
31
+ self.trainer
32
+ .read()
33
+ .unwrap()
34
+ .train(&mut model.model.write().unwrap())
35
+ }
36
+
37
+ fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
38
+ where
39
+ I: Iterator<Item = S> + Send,
40
+ S: AsRef<str> + Send,
41
+ F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
42
+ {
43
+ self.trainer.write().unwrap().feed(iterator, process)
44
+ }
45
+ }
46
+
47
+ macro_rules! getter {
48
+ ($self: ident, $variant: ident, $($name: tt)+) => {{
49
+ if let TrainerWrapper::$variant(ref trainer) = *$self.trainer.read().unwrap() {
50
+ trainer.$($name)+
51
+ } else {
52
+ unreachable!()
53
+ }
54
+ }};
55
+ }
56
+
57
+ macro_rules! setter {
58
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
59
+ if let TrainerWrapper::$variant(ref mut trainer) = *$self.trainer.write().unwrap() {
60
+ trainer.$name = $value;
61
+ }
62
+ }};
63
+ ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
64
+ if let TrainerWrapper::$variant(ref mut trainer) = *$self.trainer.write().unwrap() {
65
+ trainer.$name($value);
66
+ }
67
+ }};
68
+ }
69
+
70
+ impl RbTrainer {
71
+
72
+ fn bpe_trainer_vocab_size(&self) -> usize {
73
+ getter!(self, BpeTrainer, vocab_size)
74
+ }
75
+
76
+ fn bpe_trainer_set_vocab_size(&self, vocab_size: usize) {
77
+ setter!(self, BpeTrainer, vocab_size, vocab_size);
78
+ }
79
+
80
+ fn bpe_trainer_min_frequency(&self) -> u32 {
81
+ getter!(self, BpeTrainer, min_frequency)
82
+ }
83
+
84
+ fn bpe_trainer_set_min_frequency(&self, freq: u32) {
85
+ setter!(self, BpeTrainer, min_frequency, freq);
86
+ }
87
+
88
+ fn bpe_trainer_show_progress(&self) -> bool {
89
+ getter!(self, BpeTrainer, show_progress)
90
+ }
91
+
92
+ fn bpe_trainer_set_show_progress(&self, show_progress: bool) {
93
+ setter!(self, BpeTrainer, show_progress, show_progress);
94
+ }
95
+
96
+ fn bpe_trainer_special_tokens(&self) -> Vec<String> {
97
+ getter!(
98
+ self,
99
+ BpeTrainer,
100
+ special_tokens
101
+ .iter()
102
+ .map(|tok| tok.content.clone())
103
+ .collect()
104
+ )
105
+ }
106
+
107
+ fn bpe_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
108
+ setter!(
109
+ self,
110
+ BpeTrainer,
111
+ special_tokens,
112
+ special_tokens
113
+ .each()
114
+ .map(|token| {
115
+ if let Ok(content) = token?.try_convert::<String>() {
116
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
117
+ } else {
118
+ todo!()
119
+ }
120
+ })
121
+ .collect::<RbResult<Vec<_>>>()?
122
+ );
123
+ Ok(())
124
+ }
125
+
126
+ fn bpe_trainer_limit_alphabet(&self) -> Option<usize> {
127
+ getter!(self, BpeTrainer, limit_alphabet)
128
+ }
129
+
130
+ fn bpe_trainer_set_limit_alphabet(&self, limit: Option<usize>) {
131
+ setter!(self, BpeTrainer, limit_alphabet, limit);
132
+ }
133
+
134
+ fn bpe_trainer_initial_alphabet(&self) -> Vec<String> {
135
+ getter!(
136
+ self,
137
+ BpeTrainer,
138
+ initial_alphabet.iter().map(|c| c.to_string()).collect()
139
+ )
140
+ }
141
+
142
+ fn bpe_trainer_set_initial_alphabet(&self, alphabet: Vec<char>) {
143
+ setter!(
144
+ self,
145
+ BpeTrainer,
146
+ initial_alphabet,
147
+ alphabet.into_iter().map(|c| c).collect()
148
+ );
149
+ }
150
+
151
+ fn bpe_trainer_continuing_subword_prefix(&self) -> Option<String> {
152
+ getter!(self, BpeTrainer, continuing_subword_prefix.clone())
153
+ }
154
+
155
+ fn bpe_trainer_set_continuing_subword_prefix(&self, prefix: Option<String>) {
156
+ setter!(self, BpeTrainer, continuing_subword_prefix, prefix);
157
+ }
158
+
159
+ fn bpe_trainer_end_of_word_suffix(&self) -> Option<String> {
160
+ getter!(self, BpeTrainer, end_of_word_suffix.clone())
161
+ }
162
+
163
+ fn bpe_trainer_set_end_of_word_suffix(&self, suffix: Option<String>) {
164
+ setter!(self, BpeTrainer, end_of_word_suffix, suffix);
165
+ }
166
+
167
+ fn unigram_trainer_vocab_size(&self) -> u32 {
168
+ getter!(self, UnigramTrainer, vocab_size)
169
+ }
170
+
171
+ fn unigram_trainer_set_vocab_size(&self, vocab_size: u32) {
172
+ setter!(self, UnigramTrainer, vocab_size, vocab_size);
173
+ }
174
+
175
+ fn unigram_trainer_show_progress(&self) -> bool {
176
+ getter!(self, UnigramTrainer, show_progress)
177
+ }
178
+
179
+ fn unigram_trainer_set_show_progress(&self, show_progress: bool) {
180
+ setter!(self, UnigramTrainer, show_progress, show_progress);
181
+ }
182
+
183
+ fn unigram_trainer_special_tokens(&self) -> Vec<String> {
184
+ getter!(
185
+ self,
186
+ UnigramTrainer,
187
+ special_tokens
188
+ .iter()
189
+ .map(|tok| tok.content.clone())
190
+ .collect()
191
+ )
192
+ }
193
+
194
+ fn unigram_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
195
+ setter!(
196
+ self,
197
+ UnigramTrainer,
198
+ special_tokens,
199
+ special_tokens
200
+ .each()
201
+ .map(|token| {
202
+ if let Ok(content) = token?.try_convert::<String>() {
203
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
204
+ } else {
205
+ todo!()
206
+ }
207
+ })
208
+ .collect::<RbResult<Vec<_>>>()?
209
+ );
210
+ Ok(())
211
+ }
212
+
213
+ fn unigram_trainer_initial_alphabet(&self) -> Vec<String> {
214
+ getter!(
215
+ self,
216
+ UnigramTrainer,
217
+ initial_alphabet.iter().map(|c| c.to_string()).collect()
218
+ )
219
+ }
220
+
221
+ fn unigram_trainer_set_initial_alphabet(&self, alphabet: Vec<char>) {
222
+ setter!(
223
+ self,
224
+ UnigramTrainer,
225
+ initial_alphabet,
226
+ alphabet.into_iter().map(|c| c).collect()
227
+ );
228
+ }
229
+
230
+ fn word_level_trainer_vocab_size(&self) -> usize {
231
+ getter!(self, WordLevelTrainer, vocab_size)
232
+ }
233
+
234
+ fn word_level_trainer_set_vocab_size(&self, vocab_size: usize) {
235
+ setter!(self, WordLevelTrainer, vocab_size, vocab_size);
236
+ }
237
+
238
+ fn word_level_trainer_min_frequency(&self) -> u32 {
239
+ getter!(self, WordLevelTrainer, min_frequency)
240
+ }
241
+
242
+ fn word_level_trainer_set_min_frequency(&self, freq: u32) {
243
+ setter!(self, WordLevelTrainer, min_frequency, freq);
244
+ }
245
+
246
+ fn word_level_trainer_show_progress(&self) -> bool {
247
+ getter!(self, WordLevelTrainer, show_progress)
248
+ }
249
+
250
+ fn word_level_trainer_set_show_progress(&self, show_progress: bool) {
251
+ setter!(self, WordLevelTrainer, show_progress, show_progress);
252
+ }
253
+
254
+ fn word_level_trainer_special_tokens(&self) -> Vec<String> {
255
+ getter!(
256
+ self,
257
+ WordLevelTrainer,
258
+ special_tokens
259
+ .iter()
260
+ .map(|tok| tok.content.clone())
261
+ .collect()
262
+ )
263
+ }
264
+
265
+ fn word_level_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
266
+ setter!(
267
+ self,
268
+ WordLevelTrainer,
269
+ special_tokens,
270
+ special_tokens
271
+ .each()
272
+ .map(|token| {
273
+ if let Ok(content) = token?.try_convert::<String>() {
274
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
275
+ } else {
276
+ todo!()
277
+ }
278
+ })
279
+ .collect::<RbResult<Vec<_>>>()?
280
+ );
281
+ Ok(())
282
+ }
283
+
284
+ fn word_piece_trainer_vocab_size(&self) -> usize {
285
+ getter!(self, WordPieceTrainer, vocab_size())
286
+ }
287
+
288
+ fn word_piece_trainer_set_vocab_size(&self, vocab_size: usize) {
289
+ setter!(self, WordPieceTrainer, @set_vocab_size, vocab_size);
290
+ }
291
+
292
+ fn word_piece_trainer_min_frequency(&self) -> u32 {
293
+ getter!(self, WordPieceTrainer, min_frequency())
294
+ }
295
+
296
+ fn word_piece_trainer_set_min_frequency(&self, freq: u32) {
297
+ setter!(self, WordPieceTrainer, @set_min_frequency, freq);
298
+ }
299
+
300
+ fn word_piece_trainer_show_progress(&self) -> bool {
301
+ getter!(self, WordPieceTrainer, show_progress())
302
+ }
303
+
304
+ fn word_piece_trainer_set_show_progress(&self, show_progress: bool) {
305
+ setter!(self, WordPieceTrainer, @set_show_progress, show_progress);
306
+ }
307
+
308
+ fn word_piece_trainer_special_tokens(&self) -> Vec<String> {
309
+ getter!(
310
+ self,
311
+ WordPieceTrainer,
312
+ special_tokens()
313
+ .iter()
314
+ .map(|tok| tok.content.clone())
315
+ .collect()
316
+ )
317
+ }
318
+
319
+ fn word_piece_trainer_set_special_tokens(&self, special_tokens: RArray) -> RbResult<()> {
320
+ setter!(
321
+ self,
322
+ WordPieceTrainer,
323
+ @set_special_tokens,
324
+ special_tokens
325
+ .each()
326
+ .map(|token| {
327
+ if let Ok(content) = token?.try_convert::<String>() {
328
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
329
+ } else {
330
+ todo!()
331
+ }
332
+ })
333
+ .collect::<RbResult<Vec<_>>>()?
334
+ );
335
+ Ok(())
336
+ }
337
+
338
+ fn word_piece_trainer_limit_alphabet(&self) -> Option<usize> {
339
+ getter!(self, WordPieceTrainer, limit_alphabet())
340
+ }
341
+
342
+ fn word_piece_trainer_set_limit_alphabet(&self, limit: Option<usize>) {
343
+ setter!(self, WordPieceTrainer, @set_limit_alphabet, limit);
344
+ }
345
+
346
+ fn word_piece_trainer_initial_alphabet(&self) -> Vec<String> {
347
+ getter!(
348
+ self,
349
+ WordPieceTrainer,
350
+ initial_alphabet().iter().map(|c| c.to_string()).collect()
351
+ )
352
+ }
353
+
354
+ fn word_piece_trainer_set_initial_alphabet(&self, alphabet: Vec<char>) {
355
+ setter!(
356
+ self,
357
+ WordPieceTrainer,
358
+ @set_initial_alphabet,
359
+ alphabet.into_iter().map(|c| c).collect()
360
+ );
361
+ }
362
+
363
+ fn word_piece_trainer_continuing_subword_prefix(&self) -> Option<String> {
364
+ getter!(self, WordPieceTrainer, continuing_subword_prefix().clone())
365
+ }
366
+
367
+ fn word_piece_trainer_set_continuing_subword_prefix(&self, prefix: Option<String>) {
368
+ setter!(self, WordPieceTrainer, @set_continuing_subword_prefix, prefix);
369
+ }
370
+
371
+ fn word_piece_trainer_end_of_word_suffix(&self) -> Option<String> {
372
+ getter!(self, WordPieceTrainer, end_of_word_suffix().clone())
373
+ }
374
+
375
+ fn word_piece_trainer_set_end_of_word_suffix(&self, suffix: Option<String>) {
376
+ setter!(self, WordPieceTrainer, @set_end_of_word_suffix, suffix);
377
+ }
378
+ }
379
+
380
+ impl<I> From<I> for RbTrainer
381
+ where
382
+ I: Into<TrainerWrapper>,
383
+ {
384
+ fn from(trainer: I) -> Self {
385
+ RbTrainer {
386
+ trainer: Arc::new(RwLock::new(trainer.into())),
387
+ }
388
+ }
389
+ }
390
+
391
+ pub struct RbBpeTrainer {}
392
+
393
+ impl RbBpeTrainer {
394
+ pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
395
+ let mut builder = tk::models::bpe::BpeTrainer::builder();
396
+
397
+ let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
398
+ if !value.is_nil() {
399
+ builder = builder.special_tokens(
400
+ value
401
+ .try_convert::<RArray>()?
402
+ .each()
403
+ .map(|token| {
404
+ if let Ok(content) = token?.try_convert::<String>() {
405
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
406
+ } else {
407
+ todo!()
408
+ }
409
+ })
410
+ .collect::<RbResult<Vec<_>>>()?,
411
+ );
412
+ }
413
+
414
+ let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
415
+ if !value.is_nil() {
416
+ let arr = value.try_convert::<Vec<char>>()?;
417
+ let set: HashSet<char> = HashSet::from_iter(arr);
418
+ builder = builder.initial_alphabet(set);
419
+ }
420
+
421
+ let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
422
+ if !value.is_nil() {
423
+ builder = builder.vocab_size(value.try_convert()?);
424
+ }
425
+
426
+ let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
427
+ if !value.is_nil() {
428
+ builder = builder.min_frequency(value.try_convert()?);
429
+ }
430
+
431
+ let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
432
+ if !value.is_nil() {
433
+ builder = builder.show_progress(value.try_convert()?);
434
+ }
435
+
436
+ let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
437
+ if !value.is_nil() {
438
+ builder = builder.limit_alphabet(value.try_convert()?);
439
+ }
440
+
441
+ let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
442
+ if !value.is_nil() {
443
+ builder = builder.continuing_subword_prefix(value.try_convert()?);
444
+ }
445
+
446
+ let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
447
+ if !value.is_nil() {
448
+ builder = builder.end_of_word_suffix(value.try_convert()?);
449
+ }
450
+
451
+ if !kwargs.is_empty() {
452
+ // TODO improve message
453
+ return Err(Error::new(exception::arg_error(), "unknown keyword"));
454
+ }
455
+
456
+ Ok(builder.build().into())
457
+ }
458
+ }
459
+
460
+ pub struct RbUnigramTrainer {}
461
+
462
+ impl RbUnigramTrainer {
463
+ pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
464
+ let mut builder = tk::models::unigram::UnigramTrainer::builder();
465
+
466
+ let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
467
+ if !value.is_nil() {
468
+ builder.special_tokens(
469
+ value
470
+ .try_convert::<RArray>()?
471
+ .each()
472
+ .map(|token| {
473
+ if let Ok(content) = token?.try_convert::<String>() {
474
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
475
+ } else {
476
+ todo!()
477
+ }
478
+ })
479
+ .collect::<RbResult<Vec<_>>>()?,
480
+ );
481
+ }
482
+
483
+ let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
484
+ if !value.is_nil() {
485
+ let arr = value.try_convert::<Vec<char>>()?;
486
+ let set: HashSet<char> = HashSet::from_iter(arr);
487
+ builder.initial_alphabet(set);
488
+ }
489
+
490
+ let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
491
+ if !value.is_nil() {
492
+ builder.vocab_size(value.try_convert()?);
493
+ }
494
+
495
+ let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
496
+ if !value.is_nil() {
497
+ builder.show_progress(value.try_convert()?);
498
+ }
499
+
500
+ let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
501
+ if !value.is_nil() {
502
+ builder.n_sub_iterations(value.try_convert()?);
503
+ }
504
+
505
+ let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
506
+ if !value.is_nil() {
507
+ builder.unk_token(Some(value.try_convert()?));
508
+ }
509
+
510
+ let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
511
+ if !value.is_nil() {
512
+ builder.max_piece_length(value.try_convert()?);
513
+ }
514
+
515
+ let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
516
+ if !value.is_nil() {
517
+ builder.seed_size(value.try_convert()?);
518
+ }
519
+
520
+ let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
521
+ if !value.is_nil() {
522
+ builder.shrinking_factor(value.try_convert()?);
523
+ }
524
+
525
+ if !kwargs.is_empty() {
526
+ // TODO improve message
527
+ return Err(Error::new(exception::arg_error(), "unknown keyword"));
528
+ }
529
+
530
+ let trainer = builder.build().map_err(|_| { Error::new(exception::arg_error(), "Cannot build UnigramTrainer") })?;
531
+ Ok(trainer.into())
532
+ }
533
+ }
534
+
535
+ pub struct RbWordLevelTrainer {}
536
+
537
+ impl RbWordLevelTrainer {
538
+ pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
539
+ let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
540
+
541
+ let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
542
+ if !value.is_nil() {
543
+ builder.special_tokens(
544
+ value
545
+ .try_convert::<RArray>()?
546
+ .each()
547
+ .map(|token| {
548
+ if let Ok(content) = token?.try_convert::<String>() {
549
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
550
+ } else {
551
+ todo!()
552
+ }
553
+ })
554
+ .collect::<RbResult<Vec<_>>>()?,
555
+ );
556
+ }
557
+
558
+ let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
559
+ if !value.is_nil() {
560
+ builder.vocab_size(value.try_convert()?);
561
+ }
562
+
563
+ let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
564
+ if !value.is_nil() {
565
+ builder.min_frequency(value.try_convert()?);
566
+ }
567
+
568
+ let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
569
+ if !value.is_nil() {
570
+ builder.show_progress(value.try_convert()?);
571
+ }
572
+
573
+ Ok(builder.build().expect("WordLevelTrainerBuilder cannot fail").into())
574
+ }
575
+ }
576
+
577
+ pub struct RbWordPieceTrainer {}
578
+
579
+ impl RbWordPieceTrainer {
580
+ pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
581
+ let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
582
+
583
+ let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
584
+ if !value.is_nil() {
585
+ builder = builder.special_tokens(
586
+ value
587
+ .try_convert::<RArray>()?
588
+ .each()
589
+ .map(|token| {
590
+ if let Ok(content) = token?.try_convert::<String>() {
591
+ Ok(RbAddedToken::from(content, Some(true)).get_token())
592
+ } else {
593
+ todo!()
594
+ }
595
+ })
596
+ .collect::<RbResult<Vec<_>>>()?,
597
+ );
598
+ }
599
+
600
+ let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
601
+ if !value.is_nil() {
602
+ let arr = value.try_convert::<Vec<char>>()?;
603
+ let set: HashSet<char> = HashSet::from_iter(arr);
604
+ builder = builder.initial_alphabet(set);
605
+ }
606
+
607
+ let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
608
+ if !value.is_nil() {
609
+ builder = builder.vocab_size(value.try_convert()?);
610
+ }
611
+
612
+ let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
613
+ if !value.is_nil() {
614
+ builder = builder.min_frequency(value.try_convert()?);
615
+ }
616
+
617
+ let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
618
+ if !value.is_nil() {
619
+ builder = builder.show_progress(value.try_convert()?);
620
+ }
621
+
622
+ let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
623
+ if !value.is_nil() {
624
+ builder = builder.limit_alphabet(value.try_convert()?);
625
+ }
626
+
627
+ let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
628
+ if !value.is_nil() {
629
+ builder = builder.continuing_subword_prefix(value.try_convert()?);
630
+ }
631
+
632
+ let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
633
+ if !value.is_nil() {
634
+ builder = builder.end_of_word_suffix(value.try_convert()?);
635
+ }
636
+
637
+ if !kwargs.is_empty() {
638
+ // TODO improve message
639
+ return Err(Error::new(exception::arg_error(), "unknown keyword"));
640
+ }
641
+
642
+ Ok(builder.build().into())
643
+ }
644
+ }
645
+
646
+ unsafe impl TypedData for RbTrainer {
647
+ fn class() -> RClass {
648
+ *memoize!(RClass: {
649
+ let class: RClass = crate::trainers().const_get("Trainer").unwrap();
650
+ class.undef_alloc_func();
651
+ class
652
+ })
653
+ }
654
+
655
+ fn data_type() -> &'static DataType {
656
+ memoize!(DataType: DataTypeBuilder::<RbTrainer>::new("Tokenizers::Trainers::Trainer").build())
657
+ }
658
+
659
+ fn class_for(value: &Self) -> RClass {
660
+ match *value.trainer.read().unwrap() {
661
+ TrainerWrapper::BpeTrainer(_) => *memoize!(RClass: {
662
+ let class: RClass = crate::trainers().const_get("BpeTrainer").unwrap();
663
+ class.undef_alloc_func();
664
+ class
665
+ }),
666
+ TrainerWrapper::UnigramTrainer(_) => *memoize!(RClass: {
667
+ let class: RClass = crate::trainers().const_get("UnigramTrainer").unwrap();
668
+ class.undef_alloc_func();
669
+ class
670
+ }),
671
+ TrainerWrapper::WordLevelTrainer(_) => *memoize!(RClass: {
672
+ let class: RClass = crate::trainers().const_get("WordLevelTrainer").unwrap();
673
+ class.undef_alloc_func();
674
+ class
675
+ }),
676
+ TrainerWrapper::WordPieceTrainer(_) => *memoize!(RClass: {
677
+ let class: RClass = crate::trainers().const_get("WordPieceTrainer").unwrap();
678
+ class.undef_alloc_func();
679
+ class
680
+ }),
681
+ }
682
+ }
683
+ }
684
+
685
+ pub fn trainers(module: &RModule) -> RbResult<()> {
686
+ let trainer = module.define_class("Trainer", Default::default())?;
687
+
688
+ let class = module.define_class("BpeTrainer", trainer)?;
689
+ class.define_singleton_method("_new", function!(RbBpeTrainer::new, 1))?;
690
+ class.define_method("vocab_size", method!(RbTrainer::bpe_trainer_vocab_size, 0))?;
691
+ class.define_method("vocab_size=", method!(RbTrainer::bpe_trainer_set_vocab_size, 1))?;
692
+ class.define_method("min_frequency", method!(RbTrainer::bpe_trainer_min_frequency, 0))?;
693
+ class.define_method("min_frequency=", method!(RbTrainer::bpe_trainer_set_min_frequency, 1))?;
694
+ class.define_method("show_progress", method!(RbTrainer::bpe_trainer_show_progress, 0))?;
695
+ class.define_method("show_progress=", method!(RbTrainer::bpe_trainer_set_show_progress, 1))?;
696
+ class.define_method("special_tokens", method!(RbTrainer::bpe_trainer_special_tokens, 0))?;
697
+ class.define_method("special_tokens=", method!(RbTrainer::bpe_trainer_set_special_tokens, 1))?;
698
+ class.define_method("limit_alphabet", method!(RbTrainer::bpe_trainer_limit_alphabet, 0))?;
699
+ class.define_method("limit_alphabet=", method!(RbTrainer::bpe_trainer_set_limit_alphabet, 1))?;
700
+ class.define_method("initial_alphabet", method!(RbTrainer::bpe_trainer_initial_alphabet, 0))?;
701
+ class.define_method("initial_alphabet=", method!(RbTrainer::bpe_trainer_set_initial_alphabet, 1))?;
702
+ class.define_method("continuing_subword_prefix", method!(RbTrainer::bpe_trainer_continuing_subword_prefix, 0))?;
703
+ class.define_method("continuing_subword_prefix=", method!(RbTrainer::bpe_trainer_set_continuing_subword_prefix, 1))?;
704
+ class.define_method("end_of_word_suffix", method!(RbTrainer::bpe_trainer_end_of_word_suffix, 0))?;
705
+ class.define_method("end_of_word_suffix=", method!(RbTrainer::bpe_trainer_set_end_of_word_suffix, 1))?;
706
+
707
+ let class = module.define_class("UnigramTrainer", trainer)?;
708
+ class.define_singleton_method("_new", function!(RbUnigramTrainer::new, 1))?;
709
+ class.define_method("vocab_size", method!(RbTrainer::unigram_trainer_vocab_size, 0))?;
710
+ class.define_method("vocab_size=", method!(RbTrainer::unigram_trainer_set_vocab_size, 1))?;
711
+ class.define_method("show_progress", method!(RbTrainer::unigram_trainer_show_progress, 0))?;
712
+ class.define_method("show_progress=", method!(RbTrainer::unigram_trainer_set_show_progress, 1))?;
713
+ class.define_method("special_tokens", method!(RbTrainer::unigram_trainer_special_tokens, 0))?;
714
+ class.define_method("special_tokens=", method!(RbTrainer::unigram_trainer_set_special_tokens, 1))?;
715
+ class.define_method("initial_alphabet", method!(RbTrainer::unigram_trainer_initial_alphabet, 0))?;
716
+ class.define_method("initial_alphabet=", method!(RbTrainer::unigram_trainer_set_initial_alphabet, 1))?;
717
+
718
+ let class = module.define_class("WordLevelTrainer", trainer)?;
719
+ class.define_singleton_method("_new", function!(RbWordLevelTrainer::new, 1))?;
720
+ class.define_method("vocab_size", method!(RbTrainer::word_level_trainer_vocab_size, 0))?;
721
+ class.define_method("vocab_size=", method!(RbTrainer::word_level_trainer_set_vocab_size, 1))?;
722
+ class.define_method("min_frequency", method!(RbTrainer::word_level_trainer_min_frequency, 0))?;
723
+ class.define_method("min_frequency=", method!(RbTrainer::word_level_trainer_set_min_frequency, 1))?;
724
+ class.define_method("show_progress", method!(RbTrainer::word_level_trainer_show_progress, 0))?;
725
+ class.define_method("show_progress=", method!(RbTrainer::word_level_trainer_set_show_progress, 1))?;
726
+ class.define_method("special_tokens", method!(RbTrainer::word_level_trainer_special_tokens, 0))?;
727
+ class.define_method("special_tokens=", method!(RbTrainer::word_level_trainer_set_special_tokens, 1))?;
728
+
729
+ let class = module.define_class("WordPieceTrainer", trainer)?;
730
+ class.define_singleton_method("_new", function!(RbWordPieceTrainer::new, 1))?;
731
+ class.define_method("vocab_size", method!(RbTrainer::word_piece_trainer_vocab_size, 0))?;
732
+ class.define_method("vocab_size=", method!(RbTrainer::word_piece_trainer_set_vocab_size, 1))?;
733
+ class.define_method("min_frequency", method!(RbTrainer::word_piece_trainer_min_frequency, 0))?;
734
+ class.define_method("min_frequency=", method!(RbTrainer::word_piece_trainer_set_min_frequency, 1))?;
735
+ class.define_method("show_progress", method!(RbTrainer::word_piece_trainer_show_progress, 0))?;
736
+ class.define_method("show_progress=", method!(RbTrainer::word_piece_trainer_set_show_progress, 1))?;
737
+ class.define_method("special_tokens", method!(RbTrainer::word_piece_trainer_special_tokens, 0))?;
738
+ class.define_method("special_tokens=", method!(RbTrainer::word_piece_trainer_set_special_tokens, 1))?;
739
+ class.define_method("limit_alphabet", method!(RbTrainer::word_piece_trainer_limit_alphabet, 0))?;
740
+ class.define_method("limit_alphabet=", method!(RbTrainer::word_piece_trainer_set_limit_alphabet, 1))?;
741
+ class.define_method("initial_alphabet", method!(RbTrainer::word_piece_trainer_initial_alphabet, 0))?;
742
+ class.define_method("initial_alphabet=", method!(RbTrainer::word_piece_trainer_set_initial_alphabet, 1))?;
743
+ class.define_method("continuing_subword_prefix", method!(RbTrainer::word_piece_trainer_continuing_subword_prefix, 0))?;
744
+ class.define_method("continuing_subword_prefix=", method!(RbTrainer::word_piece_trainer_set_continuing_subword_prefix, 1))?;
745
+ class.define_method("end_of_word_suffix", method!(RbTrainer::word_piece_trainer_end_of_word_suffix, 0))?;
746
+ class.define_method("end_of_word_suffix=", method!(RbTrainer::word_piece_trainer_set_end_of_word_suffix, 1))?;
747
+
748
+ Ok(())
749
+ }