tokenizers 0.5.5 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +124 -53
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/encoding.rs +10 -8
- data/ext/tokenizers/src/models.rs +37 -24
- data/ext/tokenizers/src/normalizers.rs +1 -2
- data/ext/tokenizers/src/pre_tokenizers.rs +5 -5
- data/ext/tokenizers/src/tokenizer.rs +61 -49
- data/ext/tokenizers/src/trainers.rs +60 -50
- data/ext/tokenizers/src/utils/normalization.rs +3 -2
- data/ext/tokenizers/src/utils/regex.rs +5 -4
- data/lib/tokenizers/from_pretrained.rb +2 -2
- data/lib/tokenizers/trainers/unigram_trainer.rb +10 -9
- data/lib/tokenizers/trainers/word_piece_trainer.rb +10 -9
- data/lib/tokenizers/version.rb +1 -1
- metadata +3 -3
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder,
|
5
|
-
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
|
5
|
+
Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
6
6
|
};
|
7
7
|
|
8
8
|
use serde::ser::SerializeStruct;
|
@@ -278,16 +278,16 @@ impl RbSequence {
|
|
278
278
|
}
|
279
279
|
|
280
280
|
pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
|
281
|
+
let ruby = Ruby::get().unwrap();
|
281
282
|
let scheme = match string.as_str() {
|
282
283
|
"first" => PrependScheme::First,
|
283
284
|
"never" => PrependScheme::Never,
|
284
285
|
"always" => PrependScheme::Always,
|
285
286
|
_ => {
|
286
287
|
return Err(Error::new(
|
287
|
-
|
288
|
+
ruby.exception_arg_error(),
|
288
289
|
format!(
|
289
|
-
"{} is an unknown variant, should be one of ['first', 'never', 'always']"
|
290
|
-
string
|
290
|
+
"{string} is an unknown variant, should be one of ['first', 'never', 'always']"
|
291
291
|
),
|
292
292
|
));
|
293
293
|
}
|
@@ -4,7 +4,7 @@ use std::path::PathBuf;
|
|
4
4
|
use std::str::FromStr;
|
5
5
|
|
6
6
|
use magnus::prelude::*;
|
7
|
-
use magnus::{
|
7
|
+
use magnus::{Error, RArray, RHash, RString, Ruby, TryConvert, Value};
|
8
8
|
use tk::tokenizer::{
|
9
9
|
Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
|
10
10
|
TruncationParams, TruncationStrategy,
|
@@ -78,37 +78,37 @@ impl From<tk::AddedToken> for RbAddedToken {
|
|
78
78
|
}
|
79
79
|
|
80
80
|
impl RbAddedToken {
|
81
|
-
pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
81
|
+
pub fn new(ruby: &Ruby, content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
82
82
|
let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
|
83
83
|
|
84
|
-
let value: Value = kwargs.delete(
|
84
|
+
let value: Value = kwargs.delete(ruby.to_symbol("single_word"))?;
|
85
85
|
if !value.is_nil() {
|
86
86
|
token.single_word = TryConvert::try_convert(value)?;
|
87
87
|
}
|
88
88
|
|
89
|
-
let value: Value = kwargs.delete(
|
89
|
+
let value: Value = kwargs.delete(ruby.to_symbol("lstrip"))?;
|
90
90
|
if !value.is_nil() {
|
91
91
|
token.lstrip = TryConvert::try_convert(value)?;
|
92
92
|
}
|
93
93
|
|
94
|
-
let value: Value = kwargs.delete(
|
94
|
+
let value: Value = kwargs.delete(ruby.to_symbol("rstrip"))?;
|
95
95
|
if !value.is_nil() {
|
96
96
|
token.rstrip = TryConvert::try_convert(value)?;
|
97
97
|
}
|
98
98
|
|
99
|
-
let value: Value = kwargs.delete(
|
99
|
+
let value: Value = kwargs.delete(ruby.to_symbol("normalized"))?;
|
100
100
|
if !value.is_nil() {
|
101
101
|
token.normalized = TryConvert::try_convert(value)?;
|
102
102
|
}
|
103
103
|
|
104
|
-
let value: Value = kwargs.delete(
|
104
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special"))?;
|
105
105
|
if !value.is_nil() {
|
106
106
|
token.special = TryConvert::try_convert(value)?;
|
107
107
|
}
|
108
108
|
|
109
109
|
if !kwargs.is_empty() {
|
110
110
|
// TODO improve message
|
111
|
-
return Err(Error::new(
|
111
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
112
112
|
}
|
113
113
|
|
114
114
|
Ok(token)
|
@@ -189,6 +189,8 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
189
189
|
|
190
190
|
impl TryConvert for TextEncodeInput<'_> {
|
191
191
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
192
|
+
let ruby = Ruby::get_with(ob);
|
193
|
+
|
192
194
|
if let Ok(i) = TextInputSequence::try_convert(ob) {
|
193
195
|
return Ok(Self(i.into()));
|
194
196
|
}
|
@@ -204,7 +206,7 @@ impl TryConvert for TextEncodeInput<'_> {
|
|
204
206
|
}
|
205
207
|
}
|
206
208
|
Err(Error::new(
|
207
|
-
|
209
|
+
ruby.exception_type_error(),
|
208
210
|
"TextEncodeInput must be a string or pair of strings",
|
209
211
|
))
|
210
212
|
}
|
@@ -220,6 +222,8 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
220
222
|
|
221
223
|
impl TryConvert for PreTokenizedEncodeInput<'_> {
|
222
224
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
225
|
+
let ruby = Ruby::get_with(ob);
|
226
|
+
|
223
227
|
if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
|
224
228
|
return Ok(Self(i.into()));
|
225
229
|
}
|
@@ -237,7 +241,7 @@ impl TryConvert for PreTokenizedEncodeInput<'_> {
|
|
237
241
|
}
|
238
242
|
}
|
239
243
|
Err(Error::new(
|
240
|
-
|
244
|
+
ruby.exception_type_error(),
|
241
245
|
"PreTokenizedEncodeInput must be an array of strings or pair of arrays",
|
242
246
|
))
|
243
247
|
}
|
@@ -351,7 +355,8 @@ impl RbTokenizer {
|
|
351
355
|
}
|
352
356
|
|
353
357
|
pub fn encode_batch(
|
354
|
-
&
|
358
|
+
ruby: &Ruby,
|
359
|
+
rb_self: &Self,
|
355
360
|
input: RArray,
|
356
361
|
is_pretokenized: bool,
|
357
362
|
add_special_tokens: bool,
|
@@ -367,16 +372,16 @@ impl RbTokenizer {
|
|
367
372
|
Ok(input)
|
368
373
|
})
|
369
374
|
.collect::<RbResult<Vec<tk::EncodeInput>>>()?;
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
.map(Into::<RbEncoding>::into)
|
377
|
-
|
378
|
-
|
379
|
-
|
375
|
+
Ok(ruby.ary_from_iter(
|
376
|
+
rb_self
|
377
|
+
.tokenizer
|
378
|
+
.borrow()
|
379
|
+
.encode_batch_char_offsets(input, add_special_tokens)
|
380
|
+
.map(|encodings| {
|
381
|
+
ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into))
|
382
|
+
})
|
383
|
+
.map_err(RbError::from),
|
384
|
+
))
|
380
385
|
}
|
381
386
|
|
382
387
|
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
@@ -453,10 +458,10 @@ impl RbTokenizer {
|
|
453
458
|
}
|
454
459
|
|
455
460
|
// TODO support more kwargs
|
456
|
-
pub fn enable_padding(&
|
461
|
+
pub fn enable_padding(ruby: &Ruby, rb_self: &Self, kwargs: RHash) -> RbResult<()> {
|
457
462
|
let mut params = PaddingParams::default();
|
458
463
|
|
459
|
-
let value: Value = kwargs.delete(
|
464
|
+
let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
|
460
465
|
if !value.is_nil() {
|
461
466
|
let dir_str = String::try_convert(value)?;
|
462
467
|
params.direction = match dir_str.as_str() {
|
@@ -464,34 +469,34 @@ impl RbTokenizer {
|
|
464
469
|
"right" => PaddingDirection::Right,
|
465
470
|
_ => {
|
466
471
|
return Err(Error::new(
|
467
|
-
|
472
|
+
ruby.exception_arg_error(),
|
468
473
|
"The direction value must be 'left' or 'right'",
|
469
474
|
))
|
470
475
|
}
|
471
476
|
}
|
472
477
|
}
|
473
478
|
|
474
|
-
let value: Value = kwargs.delete(
|
479
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_to_multiple_of"))?;
|
475
480
|
if !value.is_nil() {
|
476
481
|
params.pad_to_multiple_of = TryConvert::try_convert(value)?;
|
477
482
|
}
|
478
483
|
|
479
|
-
let value: Value = kwargs.delete(
|
484
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_id"))?;
|
480
485
|
if !value.is_nil() {
|
481
486
|
params.pad_id = TryConvert::try_convert(value)?;
|
482
487
|
}
|
483
488
|
|
484
|
-
let value: Value = kwargs.delete(
|
489
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_type_id"))?;
|
485
490
|
if !value.is_nil() {
|
486
491
|
params.pad_type_id = TryConvert::try_convert(value)?;
|
487
492
|
}
|
488
493
|
|
489
|
-
let value: Value = kwargs.delete(
|
494
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_token"))?;
|
490
495
|
if !value.is_nil() {
|
491
496
|
params.pad_token = TryConvert::try_convert(value)?;
|
492
497
|
}
|
493
498
|
|
494
|
-
let value: Value = kwargs.delete(
|
499
|
+
let value: Value = kwargs.delete(ruby.to_symbol("length"))?;
|
495
500
|
if value.is_nil() {
|
496
501
|
params.strategy = PaddingStrategy::BatchLongest;
|
497
502
|
} else {
|
@@ -500,10 +505,10 @@ impl RbTokenizer {
|
|
500
505
|
|
501
506
|
if !kwargs.is_empty() {
|
502
507
|
// TODO improve message
|
503
|
-
return Err(Error::new(
|
508
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
504
509
|
}
|
505
510
|
|
506
|
-
|
511
|
+
rb_self.tokenizer.borrow_mut().with_padding(Some(params));
|
507
512
|
|
508
513
|
Ok(())
|
509
514
|
}
|
@@ -512,12 +517,13 @@ impl RbTokenizer {
|
|
512
517
|
self.tokenizer.borrow_mut().with_padding(None);
|
513
518
|
}
|
514
519
|
|
515
|
-
pub fn padding(&
|
516
|
-
|
520
|
+
pub fn padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
521
|
+
rb_self
|
522
|
+
.tokenizer
|
517
523
|
.borrow()
|
518
524
|
.get_padding()
|
519
525
|
.map_or(Ok(None), |params| {
|
520
|
-
let ret_hash =
|
526
|
+
let ret_hash = ruby.hash_new();
|
521
527
|
|
522
528
|
ret_hash.aset(
|
523
529
|
"length",
|
@@ -536,18 +542,23 @@ impl RbTokenizer {
|
|
536
542
|
})
|
537
543
|
}
|
538
544
|
|
539
|
-
pub fn enable_truncation(
|
545
|
+
pub fn enable_truncation(
|
546
|
+
ruby: &Ruby,
|
547
|
+
rb_self: &Self,
|
548
|
+
max_length: usize,
|
549
|
+
kwargs: RHash,
|
550
|
+
) -> RbResult<()> {
|
540
551
|
let mut params = TruncationParams {
|
541
552
|
max_length,
|
542
553
|
..Default::default()
|
543
554
|
};
|
544
555
|
|
545
|
-
let value: Value = kwargs.delete(
|
556
|
+
let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
|
546
557
|
if !value.is_nil() {
|
547
558
|
params.stride = TryConvert::try_convert(value)?;
|
548
559
|
}
|
549
560
|
|
550
|
-
let value: Value = kwargs.delete(
|
561
|
+
let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
|
551
562
|
if !value.is_nil() {
|
552
563
|
let strategy_str = String::try_convert(value)?;
|
553
564
|
params.strategy = match strategy_str.as_str() {
|
@@ -555,13 +566,13 @@ impl RbTokenizer {
|
|
555
566
|
"only_first" => TruncationStrategy::OnlyFirst,
|
556
567
|
"only_second" => TruncationStrategy::OnlySecond,
|
557
568
|
_ => return Err(Error::new(
|
558
|
-
|
569
|
+
ruby.exception_arg_error(),
|
559
570
|
"The strategy value must be 'longest_first', 'only_first', or 'only_second'",
|
560
571
|
)),
|
561
572
|
}
|
562
573
|
}
|
563
574
|
|
564
|
-
let value: Value = kwargs.delete(
|
575
|
+
let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
|
565
576
|
if !value.is_nil() {
|
566
577
|
let dir_str = String::try_convert(value)?;
|
567
578
|
params.direction = match dir_str.as_str() {
|
@@ -569,7 +580,7 @@ impl RbTokenizer {
|
|
569
580
|
"right" => TruncationDirection::Right,
|
570
581
|
_ => {
|
571
582
|
return Err(Error::new(
|
572
|
-
|
583
|
+
ruby.exception_arg_error(),
|
573
584
|
"The direction value must be 'left' or 'right'",
|
574
585
|
))
|
575
586
|
}
|
@@ -578,12 +589,12 @@ impl RbTokenizer {
|
|
578
589
|
|
579
590
|
if !kwargs.is_empty() {
|
580
591
|
// TODO improve message
|
581
|
-
return Err(Error::new(
|
592
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
582
593
|
}
|
583
594
|
|
584
|
-
if let Err(error_message) =
|
595
|
+
if let Err(error_message) = rb_self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
585
596
|
return Err(Error::new(
|
586
|
-
|
597
|
+
ruby.exception_arg_error(),
|
587
598
|
error_message.to_string(),
|
588
599
|
));
|
589
600
|
}
|
@@ -598,12 +609,13 @@ impl RbTokenizer {
|
|
598
609
|
.expect("Failed to set truncation to `None`! This should never happen");
|
599
610
|
}
|
600
611
|
|
601
|
-
pub fn truncation(&
|
602
|
-
|
612
|
+
pub fn truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
613
|
+
rb_self
|
614
|
+
.tokenizer
|
603
615
|
.borrow()
|
604
616
|
.get_truncation()
|
605
617
|
.map_or(Ok(None), |params| {
|
606
|
-
let ret_hash =
|
618
|
+
let ret_hash = ruby.hash_new();
|
607
619
|
|
608
620
|
ret_hash.aset("max_length", params.max_length)?;
|
609
621
|
ret_hash.aset("stride", params.stride)?;
|
@@ -629,10 +641,10 @@ impl RbTokenizer {
|
|
629
641
|
self.tokenizer.borrow().get_vocab_size(with_added_tokens)
|
630
642
|
}
|
631
643
|
|
632
|
-
pub fn get_added_tokens_decoder(&
|
633
|
-
let sorted_map =
|
644
|
+
pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
|
645
|
+
let sorted_map = ruby.hash_new();
|
634
646
|
|
635
|
-
for (key, value) in
|
647
|
+
for (key, value) in rb_self.tokenizer.borrow().get_added_tokens_decoder() {
|
636
648
|
sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
|
637
649
|
}
|
638
650
|
|
@@ -1,13 +1,11 @@
|
|
1
|
-
use std::collections::HashSet;
|
2
1
|
use std::sync::{Arc, RwLock};
|
3
2
|
|
4
3
|
use crate::models::RbModel;
|
5
4
|
use crate::tokenizer::RbAddedToken;
|
6
5
|
use magnus::prelude::*;
|
7
6
|
use magnus::{
|
8
|
-
data_type_builder,
|
9
|
-
|
10
|
-
TryConvert, TypedData, Value,
|
7
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
|
8
|
+
Module, Object, RArray, RClass, RHash, RModule, Ruby, TryConvert, TypedData, Value,
|
11
9
|
};
|
12
10
|
use serde::{Deserialize, Serialize};
|
13
11
|
use tk::models::TrainerWrapper;
|
@@ -391,10 +389,10 @@ where
|
|
391
389
|
pub struct RbBpeTrainer {}
|
392
390
|
|
393
391
|
impl RbBpeTrainer {
|
394
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
392
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
395
393
|
let mut builder = tk::models::bpe::BpeTrainer::builder();
|
396
394
|
|
397
|
-
let value: Value = kwargs.delete(
|
395
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
398
396
|
if !value.is_nil() {
|
399
397
|
builder = builder.special_tokens(
|
400
398
|
RArray::try_convert(value)?
|
@@ -410,46 +408,50 @@ impl RbBpeTrainer {
|
|
410
408
|
);
|
411
409
|
}
|
412
410
|
|
413
|
-
let value: Value = kwargs.delete(
|
411
|
+
let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
|
414
412
|
if !value.is_nil() {
|
415
|
-
let
|
416
|
-
|
417
|
-
|
413
|
+
let alphabet = Vec::<String>::try_convert(value)?;
|
414
|
+
builder = builder.initial_alphabet(
|
415
|
+
alphabet
|
416
|
+
.into_iter()
|
417
|
+
.filter_map(|s| s.chars().next())
|
418
|
+
.collect(),
|
419
|
+
);
|
418
420
|
}
|
419
421
|
|
420
|
-
let value: Value = kwargs.delete(
|
422
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
421
423
|
if !value.is_nil() {
|
422
424
|
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
423
425
|
}
|
424
426
|
|
425
|
-
let value: Value = kwargs.delete(
|
427
|
+
let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
|
426
428
|
if !value.is_nil() {
|
427
429
|
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
428
430
|
}
|
429
431
|
|
430
|
-
let value: Value = kwargs.delete(
|
432
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
431
433
|
if !value.is_nil() {
|
432
434
|
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
433
435
|
}
|
434
436
|
|
435
|
-
let value: Value = kwargs.delete(
|
437
|
+
let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
|
436
438
|
if !value.is_nil() {
|
437
439
|
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
438
440
|
}
|
439
441
|
|
440
|
-
let value: Value = kwargs.delete(
|
442
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
441
443
|
if !value.is_nil() {
|
442
444
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
443
445
|
}
|
444
446
|
|
445
|
-
let value: Value = kwargs.delete(
|
447
|
+
let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
|
446
448
|
if !value.is_nil() {
|
447
449
|
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
448
450
|
}
|
449
451
|
|
450
452
|
if !kwargs.is_empty() {
|
451
453
|
// TODO improve message
|
452
|
-
return Err(Error::new(
|
454
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
453
455
|
}
|
454
456
|
|
455
457
|
Ok(builder.build().into())
|
@@ -459,10 +461,10 @@ impl RbBpeTrainer {
|
|
459
461
|
pub struct RbUnigramTrainer {}
|
460
462
|
|
461
463
|
impl RbUnigramTrainer {
|
462
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
464
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
463
465
|
let mut builder = tk::models::unigram::UnigramTrainer::builder();
|
464
466
|
|
465
|
-
let value: Value = kwargs.delete(
|
467
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
466
468
|
if !value.is_nil() {
|
467
469
|
builder.special_tokens(
|
468
470
|
RArray::try_convert(value)?
|
@@ -478,56 +480,60 @@ impl RbUnigramTrainer {
|
|
478
480
|
);
|
479
481
|
}
|
480
482
|
|
481
|
-
let value: Value = kwargs.delete(
|
483
|
+
let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
|
482
484
|
if !value.is_nil() {
|
483
|
-
let
|
484
|
-
|
485
|
-
|
485
|
+
let alphabet = Vec::<String>::try_convert(value)?;
|
486
|
+
builder.initial_alphabet(
|
487
|
+
alphabet
|
488
|
+
.into_iter()
|
489
|
+
.filter_map(|s| s.chars().next())
|
490
|
+
.collect(),
|
491
|
+
);
|
486
492
|
}
|
487
493
|
|
488
|
-
let value: Value = kwargs.delete(
|
494
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
489
495
|
if !value.is_nil() {
|
490
496
|
builder.vocab_size(TryConvert::try_convert(value)?);
|
491
497
|
}
|
492
498
|
|
493
|
-
let value: Value = kwargs.delete(
|
499
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
494
500
|
if !value.is_nil() {
|
495
501
|
builder.show_progress(TryConvert::try_convert(value)?);
|
496
502
|
}
|
497
503
|
|
498
|
-
let value: Value = kwargs.delete(
|
504
|
+
let value: Value = kwargs.delete(ruby.to_symbol("n_sub_iterations"))?;
|
499
505
|
if !value.is_nil() {
|
500
506
|
builder.n_sub_iterations(TryConvert::try_convert(value)?);
|
501
507
|
}
|
502
508
|
|
503
|
-
let value: Value = kwargs.delete(
|
509
|
+
let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
|
504
510
|
if !value.is_nil() {
|
505
511
|
builder.unk_token(Some(TryConvert::try_convert(value)?));
|
506
512
|
}
|
507
513
|
|
508
|
-
let value: Value = kwargs.delete(
|
514
|
+
let value: Value = kwargs.delete(ruby.to_symbol("max_piece_length"))?;
|
509
515
|
if !value.is_nil() {
|
510
516
|
builder.max_piece_length(TryConvert::try_convert(value)?);
|
511
517
|
}
|
512
518
|
|
513
|
-
let value: Value = kwargs.delete(
|
519
|
+
let value: Value = kwargs.delete(ruby.to_symbol("seed_size"))?;
|
514
520
|
if !value.is_nil() {
|
515
521
|
builder.seed_size(TryConvert::try_convert(value)?);
|
516
522
|
}
|
517
523
|
|
518
|
-
let value: Value = kwargs.delete(
|
524
|
+
let value: Value = kwargs.delete(ruby.to_symbol("shrinking_factor"))?;
|
519
525
|
if !value.is_nil() {
|
520
526
|
builder.shrinking_factor(TryConvert::try_convert(value)?);
|
521
527
|
}
|
522
528
|
|
523
529
|
if !kwargs.is_empty() {
|
524
530
|
// TODO improve message
|
525
|
-
return Err(Error::new(
|
531
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
526
532
|
}
|
527
533
|
|
528
534
|
let trainer = builder
|
529
535
|
.build()
|
530
|
-
.map_err(|_| Error::new(
|
536
|
+
.map_err(|_| Error::new(ruby.exception_arg_error(), "Cannot build UnigramTrainer"))?;
|
531
537
|
Ok(trainer.into())
|
532
538
|
}
|
533
539
|
}
|
@@ -535,10 +541,10 @@ impl RbUnigramTrainer {
|
|
535
541
|
pub struct RbWordLevelTrainer {}
|
536
542
|
|
537
543
|
impl RbWordLevelTrainer {
|
538
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
544
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
539
545
|
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
|
540
546
|
|
541
|
-
let value: Value = kwargs.delete(
|
547
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
542
548
|
if !value.is_nil() {
|
543
549
|
builder.special_tokens(
|
544
550
|
RArray::try_convert(value)?
|
@@ -554,17 +560,17 @@ impl RbWordLevelTrainer {
|
|
554
560
|
);
|
555
561
|
}
|
556
562
|
|
557
|
-
let value: Value = kwargs.delete(
|
563
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
558
564
|
if !value.is_nil() {
|
559
565
|
builder.vocab_size(TryConvert::try_convert(value)?);
|
560
566
|
}
|
561
567
|
|
562
|
-
let value: Value = kwargs.delete(
|
568
|
+
let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
|
563
569
|
if !value.is_nil() {
|
564
570
|
builder.min_frequency(TryConvert::try_convert(value)?);
|
565
571
|
}
|
566
572
|
|
567
|
-
let value: Value = kwargs.delete(
|
573
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
568
574
|
if !value.is_nil() {
|
569
575
|
builder.show_progress(TryConvert::try_convert(value)?);
|
570
576
|
}
|
@@ -579,10 +585,10 @@ impl RbWordLevelTrainer {
|
|
579
585
|
pub struct RbWordPieceTrainer {}
|
580
586
|
|
581
587
|
impl RbWordPieceTrainer {
|
582
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
588
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
583
589
|
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
|
584
590
|
|
585
|
-
let value: Value = kwargs.delete(
|
591
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
586
592
|
if !value.is_nil() {
|
587
593
|
builder = builder.special_tokens(
|
588
594
|
RArray::try_convert(value)?
|
@@ -598,46 +604,50 @@ impl RbWordPieceTrainer {
|
|
598
604
|
);
|
599
605
|
}
|
600
606
|
|
601
|
-
let value: Value = kwargs.delete(
|
607
|
+
let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
|
602
608
|
if !value.is_nil() {
|
603
|
-
let
|
604
|
-
|
605
|
-
|
609
|
+
let alphabet = Vec::<String>::try_convert(value)?;
|
610
|
+
builder = builder.initial_alphabet(
|
611
|
+
alphabet
|
612
|
+
.into_iter()
|
613
|
+
.filter_map(|s| s.chars().next())
|
614
|
+
.collect(),
|
615
|
+
);
|
606
616
|
}
|
607
617
|
|
608
|
-
let value: Value = kwargs.delete(
|
618
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
609
619
|
if !value.is_nil() {
|
610
620
|
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
611
621
|
}
|
612
622
|
|
613
|
-
let value: Value = kwargs.delete(
|
623
|
+
let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
|
614
624
|
if !value.is_nil() {
|
615
625
|
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
616
626
|
}
|
617
627
|
|
618
|
-
let value: Value = kwargs.delete(
|
628
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
619
629
|
if !value.is_nil() {
|
620
630
|
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
621
631
|
}
|
622
632
|
|
623
|
-
let value: Value = kwargs.delete(
|
633
|
+
let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
|
624
634
|
if !value.is_nil() {
|
625
635
|
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
626
636
|
}
|
627
637
|
|
628
|
-
let value: Value = kwargs.delete(
|
638
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
629
639
|
if !value.is_nil() {
|
630
640
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
631
641
|
}
|
632
642
|
|
633
|
-
let value: Value = kwargs.delete(
|
643
|
+
let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
|
634
644
|
if !value.is_nil() {
|
635
645
|
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
636
646
|
}
|
637
647
|
|
638
648
|
if !kwargs.is_empty() {
|
639
649
|
// TODO improve message
|
640
|
-
return Err(Error::new(
|
650
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
641
651
|
}
|
642
652
|
|
643
653
|
Ok(builder.build().into())
|
@@ -1,7 +1,7 @@
|
|
1
1
|
use super::regex::{regex, RbRegex};
|
2
2
|
use crate::RbResult;
|
3
3
|
use magnus::prelude::*;
|
4
|
-
use magnus::{
|
4
|
+
use magnus::{Error, Ruby, TryConvert, Value};
|
5
5
|
use tk::normalizer::SplitDelimiterBehavior;
|
6
6
|
use tk::pattern::Pattern;
|
7
7
|
|
@@ -62,6 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
|
|
62
62
|
|
63
63
|
impl TryConvert for RbSplitDelimiterBehavior {
|
64
64
|
fn try_convert(obj: Value) -> RbResult<Self> {
|
65
|
+
let ruby = Ruby::get_with(obj);
|
65
66
|
let s = String::try_convert(obj)?;
|
66
67
|
|
67
68
|
Ok(Self(match s.as_str() {
|
@@ -71,7 +72,7 @@ impl TryConvert for RbSplitDelimiterBehavior {
|
|
71
72
|
"merged_with_next" => Ok(SplitDelimiterBehavior::MergedWithNext),
|
72
73
|
"contiguous" => Ok(SplitDelimiterBehavior::Contiguous),
|
73
74
|
_ => Err(Error::new(
|
74
|
-
|
75
|
+
ruby.exception_arg_error(),
|
75
76
|
"Wrong value for SplitDelimiterBehavior, expected one of: \
|
76
77
|
`removed, isolated, merged_with_previous, merged_with_next, contiguous`",
|
77
78
|
)),
|
@@ -1,5 +1,5 @@
|
|
1
1
|
use crate::{RbResult, TOKENIZERS};
|
2
|
-
use magnus::{
|
2
|
+
use magnus::{prelude::*, value::Lazy, Error, RClass, Ruby};
|
3
3
|
use onig::Regex;
|
4
4
|
|
5
5
|
#[magnus::wrap(class = "Tokenizers::Regex")]
|
@@ -9,10 +9,11 @@ pub struct RbRegex {
|
|
9
9
|
}
|
10
10
|
|
11
11
|
impl RbRegex {
|
12
|
-
pub fn new(s: String) -> RbResult<Self> {
|
12
|
+
pub fn new(ruby: &Ruby, s: String) -> RbResult<Self> {
|
13
13
|
Ok(Self {
|
14
|
-
inner: Regex::new(&s)
|
15
|
-
|
14
|
+
inner: Regex::new(&s).map_err(|e| {
|
15
|
+
Error::new(ruby.exception_runtime_error(), e.description().to_owned())
|
16
|
+
})?,
|
16
17
|
pattern: s,
|
17
18
|
})
|
18
19
|
}
|