tokenizers 0.5.5 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/Cargo.lock +124 -53
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/encoding.rs +10 -8
- data/ext/tokenizers/src/models.rs +37 -24
- data/ext/tokenizers/src/normalizers.rs +1 -2
- data/ext/tokenizers/src/pre_tokenizers.rs +5 -5
- data/ext/tokenizers/src/tokenizer.rs +54 -44
- data/ext/tokenizers/src/trainers.rs +60 -50
- data/ext/tokenizers/src/utils/normalization.rs +3 -2
- data/ext/tokenizers/src/utils/regex.rs +5 -4
- data/lib/tokenizers/from_pretrained.rb +2 -2
- data/lib/tokenizers/trainers/unigram_trainer.rb +10 -9
- data/lib/tokenizers/trainers/word_piece_trainer.rb +10 -9
- data/lib/tokenizers/version.rb +1 -1
- metadata +3 -3
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder,
|
5
|
-
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
|
5
|
+
Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
6
6
|
};
|
7
7
|
|
8
8
|
use serde::ser::SerializeStruct;
|
@@ -278,16 +278,16 @@ impl RbSequence {
|
|
278
278
|
}
|
279
279
|
|
280
280
|
pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
|
281
|
+
let ruby = Ruby::get().unwrap();
|
281
282
|
let scheme = match string.as_str() {
|
282
283
|
"first" => PrependScheme::First,
|
283
284
|
"never" => PrependScheme::Never,
|
284
285
|
"always" => PrependScheme::Always,
|
285
286
|
_ => {
|
286
287
|
return Err(Error::new(
|
287
|
-
|
288
|
+
ruby.exception_arg_error(),
|
288
289
|
format!(
|
289
|
-
"{} is an unknown variant, should be one of ['first', 'never', 'always']"
|
290
|
-
string
|
290
|
+
"{string} is an unknown variant, should be one of ['first', 'never', 'always']"
|
291
291
|
),
|
292
292
|
));
|
293
293
|
}
|
@@ -4,7 +4,7 @@ use std::path::PathBuf;
|
|
4
4
|
use std::str::FromStr;
|
5
5
|
|
6
6
|
use magnus::prelude::*;
|
7
|
-
use magnus::{
|
7
|
+
use magnus::{Error, RArray, RHash, RString, Ruby, TryConvert, Value};
|
8
8
|
use tk::tokenizer::{
|
9
9
|
Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
|
10
10
|
TruncationParams, TruncationStrategy,
|
@@ -78,37 +78,37 @@ impl From<tk::AddedToken> for RbAddedToken {
|
|
78
78
|
}
|
79
79
|
|
80
80
|
impl RbAddedToken {
|
81
|
-
pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
81
|
+
pub fn new(ruby: &Ruby, content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
82
82
|
let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
|
83
83
|
|
84
|
-
let value: Value = kwargs.delete(
|
84
|
+
let value: Value = kwargs.delete(ruby.to_symbol("single_word"))?;
|
85
85
|
if !value.is_nil() {
|
86
86
|
token.single_word = TryConvert::try_convert(value)?;
|
87
87
|
}
|
88
88
|
|
89
|
-
let value: Value = kwargs.delete(
|
89
|
+
let value: Value = kwargs.delete(ruby.to_symbol("lstrip"))?;
|
90
90
|
if !value.is_nil() {
|
91
91
|
token.lstrip = TryConvert::try_convert(value)?;
|
92
92
|
}
|
93
93
|
|
94
|
-
let value: Value = kwargs.delete(
|
94
|
+
let value: Value = kwargs.delete(ruby.to_symbol("rstrip"))?;
|
95
95
|
if !value.is_nil() {
|
96
96
|
token.rstrip = TryConvert::try_convert(value)?;
|
97
97
|
}
|
98
98
|
|
99
|
-
let value: Value = kwargs.delete(
|
99
|
+
let value: Value = kwargs.delete(ruby.to_symbol("normalized"))?;
|
100
100
|
if !value.is_nil() {
|
101
101
|
token.normalized = TryConvert::try_convert(value)?;
|
102
102
|
}
|
103
103
|
|
104
|
-
let value: Value = kwargs.delete(
|
104
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special"))?;
|
105
105
|
if !value.is_nil() {
|
106
106
|
token.special = TryConvert::try_convert(value)?;
|
107
107
|
}
|
108
108
|
|
109
109
|
if !kwargs.is_empty() {
|
110
110
|
// TODO improve message
|
111
|
-
return Err(Error::new(
|
111
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
112
112
|
}
|
113
113
|
|
114
114
|
Ok(token)
|
@@ -189,6 +189,8 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
189
189
|
|
190
190
|
impl TryConvert for TextEncodeInput<'_> {
|
191
191
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
192
|
+
let ruby = Ruby::get_with(ob);
|
193
|
+
|
192
194
|
if let Ok(i) = TextInputSequence::try_convert(ob) {
|
193
195
|
return Ok(Self(i.into()));
|
194
196
|
}
|
@@ -204,7 +206,7 @@ impl TryConvert for TextEncodeInput<'_> {
|
|
204
206
|
}
|
205
207
|
}
|
206
208
|
Err(Error::new(
|
207
|
-
|
209
|
+
ruby.exception_type_error(),
|
208
210
|
"TextEncodeInput must be a string or pair of strings",
|
209
211
|
))
|
210
212
|
}
|
@@ -220,6 +222,8 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
220
222
|
|
221
223
|
impl TryConvert for PreTokenizedEncodeInput<'_> {
|
222
224
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
225
|
+
let ruby = Ruby::get_with(ob);
|
226
|
+
|
223
227
|
if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
|
224
228
|
return Ok(Self(i.into()));
|
225
229
|
}
|
@@ -237,7 +241,7 @@ impl TryConvert for PreTokenizedEncodeInput<'_> {
|
|
237
241
|
}
|
238
242
|
}
|
239
243
|
Err(Error::new(
|
240
|
-
|
244
|
+
ruby.exception_type_error(),
|
241
245
|
"PreTokenizedEncodeInput must be an array of strings or pair of arrays",
|
242
246
|
))
|
243
247
|
}
|
@@ -351,7 +355,8 @@ impl RbTokenizer {
|
|
351
355
|
}
|
352
356
|
|
353
357
|
pub fn encode_batch(
|
354
|
-
&
|
358
|
+
ruby: &Ruby,
|
359
|
+
rb_self: &Self,
|
355
360
|
input: RArray,
|
356
361
|
is_pretokenized: bool,
|
357
362
|
add_special_tokens: bool,
|
@@ -367,14 +372,12 @@ impl RbTokenizer {
|
|
367
372
|
Ok(input)
|
368
373
|
})
|
369
374
|
.collect::<RbResult<Vec<tk::EncodeInput>>>()?;
|
370
|
-
|
375
|
+
rb_self
|
376
|
+
.tokenizer
|
371
377
|
.borrow()
|
372
378
|
.encode_batch_char_offsets(input, add_special_tokens)
|
373
379
|
.map(|encodings| {
|
374
|
-
encodings
|
375
|
-
.into_iter()
|
376
|
-
.map(Into::<RbEncoding>::into)
|
377
|
-
.collect()
|
380
|
+
ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into))
|
378
381
|
})
|
379
382
|
.map_err(RbError::from)
|
380
383
|
}
|
@@ -453,10 +456,10 @@ impl RbTokenizer {
|
|
453
456
|
}
|
454
457
|
|
455
458
|
// TODO support more kwargs
|
456
|
-
pub fn enable_padding(&
|
459
|
+
pub fn enable_padding(ruby: &Ruby, rb_self: &Self, kwargs: RHash) -> RbResult<()> {
|
457
460
|
let mut params = PaddingParams::default();
|
458
461
|
|
459
|
-
let value: Value = kwargs.delete(
|
462
|
+
let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
|
460
463
|
if !value.is_nil() {
|
461
464
|
let dir_str = String::try_convert(value)?;
|
462
465
|
params.direction = match dir_str.as_str() {
|
@@ -464,34 +467,34 @@ impl RbTokenizer {
|
|
464
467
|
"right" => PaddingDirection::Right,
|
465
468
|
_ => {
|
466
469
|
return Err(Error::new(
|
467
|
-
|
470
|
+
ruby.exception_arg_error(),
|
468
471
|
"The direction value must be 'left' or 'right'",
|
469
472
|
))
|
470
473
|
}
|
471
474
|
}
|
472
475
|
}
|
473
476
|
|
474
|
-
let value: Value = kwargs.delete(
|
477
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_to_multiple_of"))?;
|
475
478
|
if !value.is_nil() {
|
476
479
|
params.pad_to_multiple_of = TryConvert::try_convert(value)?;
|
477
480
|
}
|
478
481
|
|
479
|
-
let value: Value = kwargs.delete(
|
482
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_id"))?;
|
480
483
|
if !value.is_nil() {
|
481
484
|
params.pad_id = TryConvert::try_convert(value)?;
|
482
485
|
}
|
483
486
|
|
484
|
-
let value: Value = kwargs.delete(
|
487
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_type_id"))?;
|
485
488
|
if !value.is_nil() {
|
486
489
|
params.pad_type_id = TryConvert::try_convert(value)?;
|
487
490
|
}
|
488
491
|
|
489
|
-
let value: Value = kwargs.delete(
|
492
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_token"))?;
|
490
493
|
if !value.is_nil() {
|
491
494
|
params.pad_token = TryConvert::try_convert(value)?;
|
492
495
|
}
|
493
496
|
|
494
|
-
let value: Value = kwargs.delete(
|
497
|
+
let value: Value = kwargs.delete(ruby.to_symbol("length"))?;
|
495
498
|
if value.is_nil() {
|
496
499
|
params.strategy = PaddingStrategy::BatchLongest;
|
497
500
|
} else {
|
@@ -500,10 +503,10 @@ impl RbTokenizer {
|
|
500
503
|
|
501
504
|
if !kwargs.is_empty() {
|
502
505
|
// TODO improve message
|
503
|
-
return Err(Error::new(
|
506
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
504
507
|
}
|
505
508
|
|
506
|
-
|
509
|
+
rb_self.tokenizer.borrow_mut().with_padding(Some(params));
|
507
510
|
|
508
511
|
Ok(())
|
509
512
|
}
|
@@ -512,12 +515,13 @@ impl RbTokenizer {
|
|
512
515
|
self.tokenizer.borrow_mut().with_padding(None);
|
513
516
|
}
|
514
517
|
|
515
|
-
pub fn padding(&
|
516
|
-
|
518
|
+
pub fn padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
519
|
+
rb_self
|
520
|
+
.tokenizer
|
517
521
|
.borrow()
|
518
522
|
.get_padding()
|
519
523
|
.map_or(Ok(None), |params| {
|
520
|
-
let ret_hash =
|
524
|
+
let ret_hash = ruby.hash_new();
|
521
525
|
|
522
526
|
ret_hash.aset(
|
523
527
|
"length",
|
@@ -536,18 +540,23 @@ impl RbTokenizer {
|
|
536
540
|
})
|
537
541
|
}
|
538
542
|
|
539
|
-
pub fn enable_truncation(
|
543
|
+
pub fn enable_truncation(
|
544
|
+
ruby: &Ruby,
|
545
|
+
rb_self: &Self,
|
546
|
+
max_length: usize,
|
547
|
+
kwargs: RHash,
|
548
|
+
) -> RbResult<()> {
|
540
549
|
let mut params = TruncationParams {
|
541
550
|
max_length,
|
542
551
|
..Default::default()
|
543
552
|
};
|
544
553
|
|
545
|
-
let value: Value = kwargs.delete(
|
554
|
+
let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
|
546
555
|
if !value.is_nil() {
|
547
556
|
params.stride = TryConvert::try_convert(value)?;
|
548
557
|
}
|
549
558
|
|
550
|
-
let value: Value = kwargs.delete(
|
559
|
+
let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
|
551
560
|
if !value.is_nil() {
|
552
561
|
let strategy_str = String::try_convert(value)?;
|
553
562
|
params.strategy = match strategy_str.as_str() {
|
@@ -555,13 +564,13 @@ impl RbTokenizer {
|
|
555
564
|
"only_first" => TruncationStrategy::OnlyFirst,
|
556
565
|
"only_second" => TruncationStrategy::OnlySecond,
|
557
566
|
_ => return Err(Error::new(
|
558
|
-
|
567
|
+
ruby.exception_arg_error(),
|
559
568
|
"The strategy value must be 'longest_first', 'only_first', or 'only_second'",
|
560
569
|
)),
|
561
570
|
}
|
562
571
|
}
|
563
572
|
|
564
|
-
let value: Value = kwargs.delete(
|
573
|
+
let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
|
565
574
|
if !value.is_nil() {
|
566
575
|
let dir_str = String::try_convert(value)?;
|
567
576
|
params.direction = match dir_str.as_str() {
|
@@ -569,7 +578,7 @@ impl RbTokenizer {
|
|
569
578
|
"right" => TruncationDirection::Right,
|
570
579
|
_ => {
|
571
580
|
return Err(Error::new(
|
572
|
-
|
581
|
+
ruby.exception_arg_error(),
|
573
582
|
"The direction value must be 'left' or 'right'",
|
574
583
|
))
|
575
584
|
}
|
@@ -578,12 +587,12 @@ impl RbTokenizer {
|
|
578
587
|
|
579
588
|
if !kwargs.is_empty() {
|
580
589
|
// TODO improve message
|
581
|
-
return Err(Error::new(
|
590
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
582
591
|
}
|
583
592
|
|
584
|
-
if let Err(error_message) =
|
593
|
+
if let Err(error_message) = rb_self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
585
594
|
return Err(Error::new(
|
586
|
-
|
595
|
+
ruby.exception_arg_error(),
|
587
596
|
error_message.to_string(),
|
588
597
|
));
|
589
598
|
}
|
@@ -598,12 +607,13 @@ impl RbTokenizer {
|
|
598
607
|
.expect("Failed to set truncation to `None`! This should never happen");
|
599
608
|
}
|
600
609
|
|
601
|
-
pub fn truncation(&
|
602
|
-
|
610
|
+
pub fn truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
611
|
+
rb_self
|
612
|
+
.tokenizer
|
603
613
|
.borrow()
|
604
614
|
.get_truncation()
|
605
615
|
.map_or(Ok(None), |params| {
|
606
|
-
let ret_hash =
|
616
|
+
let ret_hash = ruby.hash_new();
|
607
617
|
|
608
618
|
ret_hash.aset("max_length", params.max_length)?;
|
609
619
|
ret_hash.aset("stride", params.stride)?;
|
@@ -629,10 +639,10 @@ impl RbTokenizer {
|
|
629
639
|
self.tokenizer.borrow().get_vocab_size(with_added_tokens)
|
630
640
|
}
|
631
641
|
|
632
|
-
pub fn get_added_tokens_decoder(&
|
633
|
-
let sorted_map =
|
642
|
+
pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
|
643
|
+
let sorted_map = ruby.hash_new();
|
634
644
|
|
635
|
-
for (key, value) in
|
645
|
+
for (key, value) in rb_self.tokenizer.borrow().get_added_tokens_decoder() {
|
636
646
|
sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
|
637
647
|
}
|
638
648
|
|
@@ -1,13 +1,11 @@
|
|
1
|
-
use std::collections::HashSet;
|
2
1
|
use std::sync::{Arc, RwLock};
|
3
2
|
|
4
3
|
use crate::models::RbModel;
|
5
4
|
use crate::tokenizer::RbAddedToken;
|
6
5
|
use magnus::prelude::*;
|
7
6
|
use magnus::{
|
8
|
-
data_type_builder,
|
9
|
-
|
10
|
-
TryConvert, TypedData, Value,
|
7
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
|
8
|
+
Module, Object, RArray, RClass, RHash, RModule, Ruby, TryConvert, TypedData, Value,
|
11
9
|
};
|
12
10
|
use serde::{Deserialize, Serialize};
|
13
11
|
use tk::models::TrainerWrapper;
|
@@ -391,10 +389,10 @@ where
|
|
391
389
|
pub struct RbBpeTrainer {}
|
392
390
|
|
393
391
|
impl RbBpeTrainer {
|
394
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
392
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
395
393
|
let mut builder = tk::models::bpe::BpeTrainer::builder();
|
396
394
|
|
397
|
-
let value: Value = kwargs.delete(
|
395
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
398
396
|
if !value.is_nil() {
|
399
397
|
builder = builder.special_tokens(
|
400
398
|
RArray::try_convert(value)?
|
@@ -410,46 +408,50 @@ impl RbBpeTrainer {
|
|
410
408
|
);
|
411
409
|
}
|
412
410
|
|
413
|
-
let value: Value = kwargs.delete(
|
411
|
+
let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
|
414
412
|
if !value.is_nil() {
|
415
|
-
let
|
416
|
-
|
417
|
-
|
413
|
+
let alphabet = Vec::<String>::try_convert(value)?;
|
414
|
+
builder = builder.initial_alphabet(
|
415
|
+
alphabet
|
416
|
+
.into_iter()
|
417
|
+
.filter_map(|s| s.chars().next())
|
418
|
+
.collect(),
|
419
|
+
);
|
418
420
|
}
|
419
421
|
|
420
|
-
let value: Value = kwargs.delete(
|
422
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
421
423
|
if !value.is_nil() {
|
422
424
|
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
423
425
|
}
|
424
426
|
|
425
|
-
let value: Value = kwargs.delete(
|
427
|
+
let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
|
426
428
|
if !value.is_nil() {
|
427
429
|
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
428
430
|
}
|
429
431
|
|
430
|
-
let value: Value = kwargs.delete(
|
432
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
431
433
|
if !value.is_nil() {
|
432
434
|
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
433
435
|
}
|
434
436
|
|
435
|
-
let value: Value = kwargs.delete(
|
437
|
+
let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
|
436
438
|
if !value.is_nil() {
|
437
439
|
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
438
440
|
}
|
439
441
|
|
440
|
-
let value: Value = kwargs.delete(
|
442
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
441
443
|
if !value.is_nil() {
|
442
444
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
443
445
|
}
|
444
446
|
|
445
|
-
let value: Value = kwargs.delete(
|
447
|
+
let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
|
446
448
|
if !value.is_nil() {
|
447
449
|
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
448
450
|
}
|
449
451
|
|
450
452
|
if !kwargs.is_empty() {
|
451
453
|
// TODO improve message
|
452
|
-
return Err(Error::new(
|
454
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
453
455
|
}
|
454
456
|
|
455
457
|
Ok(builder.build().into())
|
@@ -459,10 +461,10 @@ impl RbBpeTrainer {
|
|
459
461
|
pub struct RbUnigramTrainer {}
|
460
462
|
|
461
463
|
impl RbUnigramTrainer {
|
462
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
464
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
463
465
|
let mut builder = tk::models::unigram::UnigramTrainer::builder();
|
464
466
|
|
465
|
-
let value: Value = kwargs.delete(
|
467
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
466
468
|
if !value.is_nil() {
|
467
469
|
builder.special_tokens(
|
468
470
|
RArray::try_convert(value)?
|
@@ -478,56 +480,60 @@ impl RbUnigramTrainer {
|
|
478
480
|
);
|
479
481
|
}
|
480
482
|
|
481
|
-
let value: Value = kwargs.delete(
|
483
|
+
let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
|
482
484
|
if !value.is_nil() {
|
483
|
-
let
|
484
|
-
|
485
|
-
|
485
|
+
let alphabet = Vec::<String>::try_convert(value)?;
|
486
|
+
builder.initial_alphabet(
|
487
|
+
alphabet
|
488
|
+
.into_iter()
|
489
|
+
.filter_map(|s| s.chars().next())
|
490
|
+
.collect(),
|
491
|
+
);
|
486
492
|
}
|
487
493
|
|
488
|
-
let value: Value = kwargs.delete(
|
494
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
489
495
|
if !value.is_nil() {
|
490
496
|
builder.vocab_size(TryConvert::try_convert(value)?);
|
491
497
|
}
|
492
498
|
|
493
|
-
let value: Value = kwargs.delete(
|
499
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
494
500
|
if !value.is_nil() {
|
495
501
|
builder.show_progress(TryConvert::try_convert(value)?);
|
496
502
|
}
|
497
503
|
|
498
|
-
let value: Value = kwargs.delete(
|
504
|
+
let value: Value = kwargs.delete(ruby.to_symbol("n_sub_iterations"))?;
|
499
505
|
if !value.is_nil() {
|
500
506
|
builder.n_sub_iterations(TryConvert::try_convert(value)?);
|
501
507
|
}
|
502
508
|
|
503
|
-
let value: Value = kwargs.delete(
|
509
|
+
let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
|
504
510
|
if !value.is_nil() {
|
505
511
|
builder.unk_token(Some(TryConvert::try_convert(value)?));
|
506
512
|
}
|
507
513
|
|
508
|
-
let value: Value = kwargs.delete(
|
514
|
+
let value: Value = kwargs.delete(ruby.to_symbol("max_piece_length"))?;
|
509
515
|
if !value.is_nil() {
|
510
516
|
builder.max_piece_length(TryConvert::try_convert(value)?);
|
511
517
|
}
|
512
518
|
|
513
|
-
let value: Value = kwargs.delete(
|
519
|
+
let value: Value = kwargs.delete(ruby.to_symbol("seed_size"))?;
|
514
520
|
if !value.is_nil() {
|
515
521
|
builder.seed_size(TryConvert::try_convert(value)?);
|
516
522
|
}
|
517
523
|
|
518
|
-
let value: Value = kwargs.delete(
|
524
|
+
let value: Value = kwargs.delete(ruby.to_symbol("shrinking_factor"))?;
|
519
525
|
if !value.is_nil() {
|
520
526
|
builder.shrinking_factor(TryConvert::try_convert(value)?);
|
521
527
|
}
|
522
528
|
|
523
529
|
if !kwargs.is_empty() {
|
524
530
|
// TODO improve message
|
525
|
-
return Err(Error::new(
|
531
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
526
532
|
}
|
527
533
|
|
528
534
|
let trainer = builder
|
529
535
|
.build()
|
530
|
-
.map_err(|_| Error::new(
|
536
|
+
.map_err(|_| Error::new(ruby.exception_arg_error(), "Cannot build UnigramTrainer"))?;
|
531
537
|
Ok(trainer.into())
|
532
538
|
}
|
533
539
|
}
|
@@ -535,10 +541,10 @@ impl RbUnigramTrainer {
|
|
535
541
|
pub struct RbWordLevelTrainer {}
|
536
542
|
|
537
543
|
impl RbWordLevelTrainer {
|
538
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
544
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
539
545
|
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
|
540
546
|
|
541
|
-
let value: Value = kwargs.delete(
|
547
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
542
548
|
if !value.is_nil() {
|
543
549
|
builder.special_tokens(
|
544
550
|
RArray::try_convert(value)?
|
@@ -554,17 +560,17 @@ impl RbWordLevelTrainer {
|
|
554
560
|
);
|
555
561
|
}
|
556
562
|
|
557
|
-
let value: Value = kwargs.delete(
|
563
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
558
564
|
if !value.is_nil() {
|
559
565
|
builder.vocab_size(TryConvert::try_convert(value)?);
|
560
566
|
}
|
561
567
|
|
562
|
-
let value: Value = kwargs.delete(
|
568
|
+
let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
|
563
569
|
if !value.is_nil() {
|
564
570
|
builder.min_frequency(TryConvert::try_convert(value)?);
|
565
571
|
}
|
566
572
|
|
567
|
-
let value: Value = kwargs.delete(
|
573
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
568
574
|
if !value.is_nil() {
|
569
575
|
builder.show_progress(TryConvert::try_convert(value)?);
|
570
576
|
}
|
@@ -579,10 +585,10 @@ impl RbWordLevelTrainer {
|
|
579
585
|
pub struct RbWordPieceTrainer {}
|
580
586
|
|
581
587
|
impl RbWordPieceTrainer {
|
582
|
-
pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
|
588
|
+
pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
|
583
589
|
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
|
584
590
|
|
585
|
-
let value: Value = kwargs.delete(
|
591
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
|
586
592
|
if !value.is_nil() {
|
587
593
|
builder = builder.special_tokens(
|
588
594
|
RArray::try_convert(value)?
|
@@ -598,46 +604,50 @@ impl RbWordPieceTrainer {
|
|
598
604
|
);
|
599
605
|
}
|
600
606
|
|
601
|
-
let value: Value = kwargs.delete(
|
607
|
+
let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
|
602
608
|
if !value.is_nil() {
|
603
|
-
let
|
604
|
-
|
605
|
-
|
609
|
+
let alphabet = Vec::<String>::try_convert(value)?;
|
610
|
+
builder = builder.initial_alphabet(
|
611
|
+
alphabet
|
612
|
+
.into_iter()
|
613
|
+
.filter_map(|s| s.chars().next())
|
614
|
+
.collect(),
|
615
|
+
);
|
606
616
|
}
|
607
617
|
|
608
|
-
let value: Value = kwargs.delete(
|
618
|
+
let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
|
609
619
|
if !value.is_nil() {
|
610
620
|
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
611
621
|
}
|
612
622
|
|
613
|
-
let value: Value = kwargs.delete(
|
623
|
+
let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
|
614
624
|
if !value.is_nil() {
|
615
625
|
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
616
626
|
}
|
617
627
|
|
618
|
-
let value: Value = kwargs.delete(
|
628
|
+
let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
|
619
629
|
if !value.is_nil() {
|
620
630
|
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
621
631
|
}
|
622
632
|
|
623
|
-
let value: Value = kwargs.delete(
|
633
|
+
let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
|
624
634
|
if !value.is_nil() {
|
625
635
|
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
626
636
|
}
|
627
637
|
|
628
|
-
let value: Value = kwargs.delete(
|
638
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
629
639
|
if !value.is_nil() {
|
630
640
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
631
641
|
}
|
632
642
|
|
633
|
-
let value: Value = kwargs.delete(
|
643
|
+
let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
|
634
644
|
if !value.is_nil() {
|
635
645
|
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
636
646
|
}
|
637
647
|
|
638
648
|
if !kwargs.is_empty() {
|
639
649
|
// TODO improve message
|
640
|
-
return Err(Error::new(
|
650
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
641
651
|
}
|
642
652
|
|
643
653
|
Ok(builder.build().into())
|
@@ -1,7 +1,7 @@
|
|
1
1
|
use super::regex::{regex, RbRegex};
|
2
2
|
use crate::RbResult;
|
3
3
|
use magnus::prelude::*;
|
4
|
-
use magnus::{
|
4
|
+
use magnus::{Error, Ruby, TryConvert, Value};
|
5
5
|
use tk::normalizer::SplitDelimiterBehavior;
|
6
6
|
use tk::pattern::Pattern;
|
7
7
|
|
@@ -62,6 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
|
|
62
62
|
|
63
63
|
impl TryConvert for RbSplitDelimiterBehavior {
|
64
64
|
fn try_convert(obj: Value) -> RbResult<Self> {
|
65
|
+
let ruby = Ruby::get_with(obj);
|
65
66
|
let s = String::try_convert(obj)?;
|
66
67
|
|
67
68
|
Ok(Self(match s.as_str() {
|
@@ -71,7 +72,7 @@ impl TryConvert for RbSplitDelimiterBehavior {
|
|
71
72
|
"merged_with_next" => Ok(SplitDelimiterBehavior::MergedWithNext),
|
72
73
|
"contiguous" => Ok(SplitDelimiterBehavior::Contiguous),
|
73
74
|
_ => Err(Error::new(
|
74
|
-
|
75
|
+
ruby.exception_arg_error(),
|
75
76
|
"Wrong value for SplitDelimiterBehavior, expected one of: \
|
76
77
|
`removed, isolated, merged_with_previous, merged_with_next, contiguous`",
|
77
78
|
)),
|
@@ -1,5 +1,5 @@
|
|
1
1
|
use crate::{RbResult, TOKENIZERS};
|
2
|
-
use magnus::{
|
2
|
+
use magnus::{prelude::*, value::Lazy, Error, RClass, Ruby};
|
3
3
|
use onig::Regex;
|
4
4
|
|
5
5
|
#[magnus::wrap(class = "Tokenizers::Regex")]
|
@@ -9,10 +9,11 @@ pub struct RbRegex {
|
|
9
9
|
}
|
10
10
|
|
11
11
|
impl RbRegex {
|
12
|
-
pub fn new(s: String) -> RbResult<Self> {
|
12
|
+
pub fn new(ruby: &Ruby, s: String) -> RbResult<Self> {
|
13
13
|
Ok(Self {
|
14
|
-
inner: Regex::new(&s)
|
15
|
-
|
14
|
+
inner: Regex::new(&s).map_err(|e| {
|
15
|
+
Error::new(ruby.exception_runtime_error(), e.description().to_owned())
|
16
|
+
})?,
|
16
17
|
pattern: s,
|
17
18
|
})
|
18
19
|
}
|