tokenizers 0.5.5 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
3
  use magnus::{
4
- data_type_builder, exception, function, method, value::Lazy, Class, DataType,
5
- DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
5
+ Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
6
6
  };
7
7
 
8
8
  use serde::ser::SerializeStruct;
@@ -278,16 +278,16 @@ impl RbSequence {
278
278
  }
279
279
 
280
280
  pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
281
+ let ruby = Ruby::get().unwrap();
281
282
  let scheme = match string.as_str() {
282
283
  "first" => PrependScheme::First,
283
284
  "never" => PrependScheme::Never,
284
285
  "always" => PrependScheme::Always,
285
286
  _ => {
286
287
  return Err(Error::new(
287
- exception::arg_error(),
288
+ ruby.exception_arg_error(),
288
289
  format!(
289
- "{} is an unknown variant, should be one of ['first', 'never', 'always']",
290
- string
290
+ "{string} is an unknown variant, should be one of ['first', 'never', 'always']"
291
291
  ),
292
292
  ));
293
293
  }
@@ -4,7 +4,7 @@ use std::path::PathBuf;
4
4
  use std::str::FromStr;
5
5
 
6
6
  use magnus::prelude::*;
7
- use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
7
+ use magnus::{Error, RArray, RHash, RString, Ruby, TryConvert, Value};
8
8
  use tk::tokenizer::{
9
9
  Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
10
10
  TruncationParams, TruncationStrategy,
@@ -78,37 +78,37 @@ impl From<tk::AddedToken> for RbAddedToken {
78
78
  }
79
79
 
80
80
  impl RbAddedToken {
81
- pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
81
+ pub fn new(ruby: &Ruby, content: Option<String>, kwargs: RHash) -> RbResult<Self> {
82
82
  let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
83
83
 
84
- let value: Value = kwargs.delete(Symbol::new("single_word"))?;
84
+ let value: Value = kwargs.delete(ruby.to_symbol("single_word"))?;
85
85
  if !value.is_nil() {
86
86
  token.single_word = TryConvert::try_convert(value)?;
87
87
  }
88
88
 
89
- let value: Value = kwargs.delete(Symbol::new("lstrip"))?;
89
+ let value: Value = kwargs.delete(ruby.to_symbol("lstrip"))?;
90
90
  if !value.is_nil() {
91
91
  token.lstrip = TryConvert::try_convert(value)?;
92
92
  }
93
93
 
94
- let value: Value = kwargs.delete(Symbol::new("rstrip"))?;
94
+ let value: Value = kwargs.delete(ruby.to_symbol("rstrip"))?;
95
95
  if !value.is_nil() {
96
96
  token.rstrip = TryConvert::try_convert(value)?;
97
97
  }
98
98
 
99
- let value: Value = kwargs.delete(Symbol::new("normalized"))?;
99
+ let value: Value = kwargs.delete(ruby.to_symbol("normalized"))?;
100
100
  if !value.is_nil() {
101
101
  token.normalized = TryConvert::try_convert(value)?;
102
102
  }
103
103
 
104
- let value: Value = kwargs.delete(Symbol::new("special"))?;
104
+ let value: Value = kwargs.delete(ruby.to_symbol("special"))?;
105
105
  if !value.is_nil() {
106
106
  token.special = TryConvert::try_convert(value)?;
107
107
  }
108
108
 
109
109
  if !kwargs.is_empty() {
110
110
  // TODO improve message
111
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
111
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
112
112
  }
113
113
 
114
114
  Ok(token)
@@ -189,6 +189,8 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
189
189
 
190
190
  impl TryConvert for TextEncodeInput<'_> {
191
191
  fn try_convert(ob: Value) -> RbResult<Self> {
192
+ let ruby = Ruby::get_with(ob);
193
+
192
194
  if let Ok(i) = TextInputSequence::try_convert(ob) {
193
195
  return Ok(Self(i.into()));
194
196
  }
@@ -204,7 +206,7 @@ impl TryConvert for TextEncodeInput<'_> {
204
206
  }
205
207
  }
206
208
  Err(Error::new(
207
- exception::type_error(),
209
+ ruby.exception_type_error(),
208
210
  "TextEncodeInput must be a string or pair of strings",
209
211
  ))
210
212
  }
@@ -220,6 +222,8 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
220
222
 
221
223
  impl TryConvert for PreTokenizedEncodeInput<'_> {
222
224
  fn try_convert(ob: Value) -> RbResult<Self> {
225
+ let ruby = Ruby::get_with(ob);
226
+
223
227
  if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
224
228
  return Ok(Self(i.into()));
225
229
  }
@@ -237,7 +241,7 @@ impl TryConvert for PreTokenizedEncodeInput<'_> {
237
241
  }
238
242
  }
239
243
  Err(Error::new(
240
- exception::type_error(),
244
+ ruby.exception_type_error(),
241
245
  "PreTokenizedEncodeInput must be an array of strings or pair of arrays",
242
246
  ))
243
247
  }
@@ -351,7 +355,8 @@ impl RbTokenizer {
351
355
  }
352
356
 
353
357
  pub fn encode_batch(
354
- &self,
358
+ ruby: &Ruby,
359
+ rb_self: &Self,
355
360
  input: RArray,
356
361
  is_pretokenized: bool,
357
362
  add_special_tokens: bool,
@@ -367,14 +372,12 @@ impl RbTokenizer {
367
372
  Ok(input)
368
373
  })
369
374
  .collect::<RbResult<Vec<tk::EncodeInput>>>()?;
370
- self.tokenizer
375
+ rb_self
376
+ .tokenizer
371
377
  .borrow()
372
378
  .encode_batch_char_offsets(input, add_special_tokens)
373
379
  .map(|encodings| {
374
- encodings
375
- .into_iter()
376
- .map(Into::<RbEncoding>::into)
377
- .collect()
380
+ ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into))
378
381
  })
379
382
  .map_err(RbError::from)
380
383
  }
@@ -453,10 +456,10 @@ impl RbTokenizer {
453
456
  }
454
457
 
455
458
  // TODO support more kwargs
456
- pub fn enable_padding(&self, kwargs: RHash) -> RbResult<()> {
459
+ pub fn enable_padding(ruby: &Ruby, rb_self: &Self, kwargs: RHash) -> RbResult<()> {
457
460
  let mut params = PaddingParams::default();
458
461
 
459
- let value: Value = kwargs.delete(Symbol::new("direction"))?;
462
+ let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
460
463
  if !value.is_nil() {
461
464
  let dir_str = String::try_convert(value)?;
462
465
  params.direction = match dir_str.as_str() {
@@ -464,34 +467,34 @@ impl RbTokenizer {
464
467
  "right" => PaddingDirection::Right,
465
468
  _ => {
466
469
  return Err(Error::new(
467
- exception::arg_error(),
470
+ ruby.exception_arg_error(),
468
471
  "The direction value must be 'left' or 'right'",
469
472
  ))
470
473
  }
471
474
  }
472
475
  }
473
476
 
474
- let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
477
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_to_multiple_of"))?;
475
478
  if !value.is_nil() {
476
479
  params.pad_to_multiple_of = TryConvert::try_convert(value)?;
477
480
  }
478
481
 
479
- let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
482
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_id"))?;
480
483
  if !value.is_nil() {
481
484
  params.pad_id = TryConvert::try_convert(value)?;
482
485
  }
483
486
 
484
- let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
487
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_type_id"))?;
485
488
  if !value.is_nil() {
486
489
  params.pad_type_id = TryConvert::try_convert(value)?;
487
490
  }
488
491
 
489
- let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
492
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_token"))?;
490
493
  if !value.is_nil() {
491
494
  params.pad_token = TryConvert::try_convert(value)?;
492
495
  }
493
496
 
494
- let value: Value = kwargs.delete(Symbol::new("length"))?;
497
+ let value: Value = kwargs.delete(ruby.to_symbol("length"))?;
495
498
  if value.is_nil() {
496
499
  params.strategy = PaddingStrategy::BatchLongest;
497
500
  } else {
@@ -500,10 +503,10 @@ impl RbTokenizer {
500
503
 
501
504
  if !kwargs.is_empty() {
502
505
  // TODO improve message
503
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
506
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
504
507
  }
505
508
 
506
- self.tokenizer.borrow_mut().with_padding(Some(params));
509
+ rb_self.tokenizer.borrow_mut().with_padding(Some(params));
507
510
 
508
511
  Ok(())
509
512
  }
@@ -512,12 +515,13 @@ impl RbTokenizer {
512
515
  self.tokenizer.borrow_mut().with_padding(None);
513
516
  }
514
517
 
515
- pub fn padding(&self) -> RbResult<Option<RHash>> {
516
- self.tokenizer
518
+ pub fn padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
519
+ rb_self
520
+ .tokenizer
517
521
  .borrow()
518
522
  .get_padding()
519
523
  .map_or(Ok(None), |params| {
520
- let ret_hash = RHash::new();
524
+ let ret_hash = ruby.hash_new();
521
525
 
522
526
  ret_hash.aset(
523
527
  "length",
@@ -536,18 +540,23 @@ impl RbTokenizer {
536
540
  })
537
541
  }
538
542
 
539
- pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
543
+ pub fn enable_truncation(
544
+ ruby: &Ruby,
545
+ rb_self: &Self,
546
+ max_length: usize,
547
+ kwargs: RHash,
548
+ ) -> RbResult<()> {
540
549
  let mut params = TruncationParams {
541
550
  max_length,
542
551
  ..Default::default()
543
552
  };
544
553
 
545
- let value: Value = kwargs.delete(Symbol::new("stride"))?;
554
+ let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
546
555
  if !value.is_nil() {
547
556
  params.stride = TryConvert::try_convert(value)?;
548
557
  }
549
558
 
550
- let value: Value = kwargs.delete(Symbol::new("strategy"))?;
559
+ let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
551
560
  if !value.is_nil() {
552
561
  let strategy_str = String::try_convert(value)?;
553
562
  params.strategy = match strategy_str.as_str() {
@@ -555,13 +564,13 @@ impl RbTokenizer {
555
564
  "only_first" => TruncationStrategy::OnlyFirst,
556
565
  "only_second" => TruncationStrategy::OnlySecond,
557
566
  _ => return Err(Error::new(
558
- exception::arg_error(),
567
+ ruby.exception_arg_error(),
559
568
  "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
560
569
  )),
561
570
  }
562
571
  }
563
572
 
564
- let value: Value = kwargs.delete(Symbol::new("direction"))?;
573
+ let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
565
574
  if !value.is_nil() {
566
575
  let dir_str = String::try_convert(value)?;
567
576
  params.direction = match dir_str.as_str() {
@@ -569,7 +578,7 @@ impl RbTokenizer {
569
578
  "right" => TruncationDirection::Right,
570
579
  _ => {
571
580
  return Err(Error::new(
572
- exception::arg_error(),
581
+ ruby.exception_arg_error(),
573
582
  "The direction value must be 'left' or 'right'",
574
583
  ))
575
584
  }
@@ -578,12 +587,12 @@ impl RbTokenizer {
578
587
 
579
588
  if !kwargs.is_empty() {
580
589
  // TODO improve message
581
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
590
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
582
591
  }
583
592
 
584
- if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
593
+ if let Err(error_message) = rb_self.tokenizer.borrow_mut().with_truncation(Some(params)) {
585
594
  return Err(Error::new(
586
- exception::arg_error(),
595
+ ruby.exception_arg_error(),
587
596
  error_message.to_string(),
588
597
  ));
589
598
  }
@@ -598,12 +607,13 @@ impl RbTokenizer {
598
607
  .expect("Failed to set truncation to `None`! This should never happen");
599
608
  }
600
609
 
601
- pub fn truncation(&self) -> RbResult<Option<RHash>> {
602
- self.tokenizer
610
+ pub fn truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
611
+ rb_self
612
+ .tokenizer
603
613
  .borrow()
604
614
  .get_truncation()
605
615
  .map_or(Ok(None), |params| {
606
- let ret_hash = RHash::new();
616
+ let ret_hash = ruby.hash_new();
607
617
 
608
618
  ret_hash.aset("max_length", params.max_length)?;
609
619
  ret_hash.aset("stride", params.stride)?;
@@ -629,10 +639,10 @@ impl RbTokenizer {
629
639
  self.tokenizer.borrow().get_vocab_size(with_added_tokens)
630
640
  }
631
641
 
632
- pub fn get_added_tokens_decoder(&self) -> RbResult<RHash> {
633
- let sorted_map = RHash::new();
642
+ pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
643
+ let sorted_map = ruby.hash_new();
634
644
 
635
- for (key, value) in self.tokenizer.borrow().get_added_tokens_decoder() {
645
+ for (key, value) in rb_self.tokenizer.borrow().get_added_tokens_decoder() {
636
646
  sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
637
647
  }
638
648
 
@@ -1,13 +1,11 @@
1
- use std::collections::HashSet;
2
1
  use std::sync::{Arc, RwLock};
3
2
 
4
3
  use crate::models::RbModel;
5
4
  use crate::tokenizer::RbAddedToken;
6
5
  use magnus::prelude::*;
7
6
  use magnus::{
8
- data_type_builder, exception, function, method, value::Lazy, Class, DataType,
9
- DataTypeFunctions, Error, Module, Object, RArray, RClass, RHash, RModule, Ruby, Symbol,
10
- TryConvert, TypedData, Value,
7
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
8
+ Module, Object, RArray, RClass, RHash, RModule, Ruby, TryConvert, TypedData, Value,
11
9
  };
12
10
  use serde::{Deserialize, Serialize};
13
11
  use tk::models::TrainerWrapper;
@@ -391,10 +389,10 @@ where
391
389
  pub struct RbBpeTrainer {}
392
390
 
393
391
  impl RbBpeTrainer {
394
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
392
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
395
393
  let mut builder = tk::models::bpe::BpeTrainer::builder();
396
394
 
397
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
395
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
398
396
  if !value.is_nil() {
399
397
  builder = builder.special_tokens(
400
398
  RArray::try_convert(value)?
@@ -410,46 +408,50 @@ impl RbBpeTrainer {
410
408
  );
411
409
  }
412
410
 
413
- let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
411
+ let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
414
412
  if !value.is_nil() {
415
- let arr = <Vec<char>>::try_convert(value)?;
416
- let set: HashSet<char> = HashSet::from_iter(arr);
417
- builder = builder.initial_alphabet(set);
413
+ let alphabet = Vec::<String>::try_convert(value)?;
414
+ builder = builder.initial_alphabet(
415
+ alphabet
416
+ .into_iter()
417
+ .filter_map(|s| s.chars().next())
418
+ .collect(),
419
+ );
418
420
  }
419
421
 
420
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
422
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
421
423
  if !value.is_nil() {
422
424
  builder = builder.vocab_size(TryConvert::try_convert(value)?);
423
425
  }
424
426
 
425
- let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
427
+ let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
426
428
  if !value.is_nil() {
427
429
  builder = builder.min_frequency(TryConvert::try_convert(value)?);
428
430
  }
429
431
 
430
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
432
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
431
433
  if !value.is_nil() {
432
434
  builder = builder.show_progress(TryConvert::try_convert(value)?);
433
435
  }
434
436
 
435
- let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
437
+ let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
436
438
  if !value.is_nil() {
437
439
  builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
438
440
  }
439
441
 
440
- let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
442
+ let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
441
443
  if !value.is_nil() {
442
444
  builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
443
445
  }
444
446
 
445
- let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
447
+ let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
446
448
  if !value.is_nil() {
447
449
  builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
448
450
  }
449
451
 
450
452
  if !kwargs.is_empty() {
451
453
  // TODO improve message
452
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
454
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
453
455
  }
454
456
 
455
457
  Ok(builder.build().into())
@@ -459,10 +461,10 @@ impl RbBpeTrainer {
459
461
  pub struct RbUnigramTrainer {}
460
462
 
461
463
  impl RbUnigramTrainer {
462
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
464
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
463
465
  let mut builder = tk::models::unigram::UnigramTrainer::builder();
464
466
 
465
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
467
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
466
468
  if !value.is_nil() {
467
469
  builder.special_tokens(
468
470
  RArray::try_convert(value)?
@@ -478,56 +480,60 @@ impl RbUnigramTrainer {
478
480
  );
479
481
  }
480
482
 
481
- let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
483
+ let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
482
484
  if !value.is_nil() {
483
- let arr = <Vec<char>>::try_convert(value)?;
484
- let set: HashSet<char> = HashSet::from_iter(arr);
485
- builder.initial_alphabet(set);
485
+ let alphabet = Vec::<String>::try_convert(value)?;
486
+ builder.initial_alphabet(
487
+ alphabet
488
+ .into_iter()
489
+ .filter_map(|s| s.chars().next())
490
+ .collect(),
491
+ );
486
492
  }
487
493
 
488
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
494
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
489
495
  if !value.is_nil() {
490
496
  builder.vocab_size(TryConvert::try_convert(value)?);
491
497
  }
492
498
 
493
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
499
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
494
500
  if !value.is_nil() {
495
501
  builder.show_progress(TryConvert::try_convert(value)?);
496
502
  }
497
503
 
498
- let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
504
+ let value: Value = kwargs.delete(ruby.to_symbol("n_sub_iterations"))?;
499
505
  if !value.is_nil() {
500
506
  builder.n_sub_iterations(TryConvert::try_convert(value)?);
501
507
  }
502
508
 
503
- let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
509
+ let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
504
510
  if !value.is_nil() {
505
511
  builder.unk_token(Some(TryConvert::try_convert(value)?));
506
512
  }
507
513
 
508
- let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
514
+ let value: Value = kwargs.delete(ruby.to_symbol("max_piece_length"))?;
509
515
  if !value.is_nil() {
510
516
  builder.max_piece_length(TryConvert::try_convert(value)?);
511
517
  }
512
518
 
513
- let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
519
+ let value: Value = kwargs.delete(ruby.to_symbol("seed_size"))?;
514
520
  if !value.is_nil() {
515
521
  builder.seed_size(TryConvert::try_convert(value)?);
516
522
  }
517
523
 
518
- let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
524
+ let value: Value = kwargs.delete(ruby.to_symbol("shrinking_factor"))?;
519
525
  if !value.is_nil() {
520
526
  builder.shrinking_factor(TryConvert::try_convert(value)?);
521
527
  }
522
528
 
523
529
  if !kwargs.is_empty() {
524
530
  // TODO improve message
525
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
531
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
526
532
  }
527
533
 
528
534
  let trainer = builder
529
535
  .build()
530
- .map_err(|_| Error::new(exception::arg_error(), "Cannot build UnigramTrainer"))?;
536
+ .map_err(|_| Error::new(ruby.exception_arg_error(), "Cannot build UnigramTrainer"))?;
531
537
  Ok(trainer.into())
532
538
  }
533
539
  }
@@ -535,10 +541,10 @@ impl RbUnigramTrainer {
535
541
  pub struct RbWordLevelTrainer {}
536
542
 
537
543
  impl RbWordLevelTrainer {
538
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
544
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
539
545
  let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
540
546
 
541
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
547
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
542
548
  if !value.is_nil() {
543
549
  builder.special_tokens(
544
550
  RArray::try_convert(value)?
@@ -554,17 +560,17 @@ impl RbWordLevelTrainer {
554
560
  );
555
561
  }
556
562
 
557
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
563
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
558
564
  if !value.is_nil() {
559
565
  builder.vocab_size(TryConvert::try_convert(value)?);
560
566
  }
561
567
 
562
- let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
568
+ let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
563
569
  if !value.is_nil() {
564
570
  builder.min_frequency(TryConvert::try_convert(value)?);
565
571
  }
566
572
 
567
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
573
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
568
574
  if !value.is_nil() {
569
575
  builder.show_progress(TryConvert::try_convert(value)?);
570
576
  }
@@ -579,10 +585,10 @@ impl RbWordLevelTrainer {
579
585
  pub struct RbWordPieceTrainer {}
580
586
 
581
587
  impl RbWordPieceTrainer {
582
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
588
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
583
589
  let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
584
590
 
585
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
591
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
586
592
  if !value.is_nil() {
587
593
  builder = builder.special_tokens(
588
594
  RArray::try_convert(value)?
@@ -598,46 +604,50 @@ impl RbWordPieceTrainer {
598
604
  );
599
605
  }
600
606
 
601
- let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
607
+ let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
602
608
  if !value.is_nil() {
603
- let arr = <Vec<char>>::try_convert(value)?;
604
- let set: HashSet<char> = HashSet::from_iter(arr);
605
- builder = builder.initial_alphabet(set);
609
+ let alphabet = Vec::<String>::try_convert(value)?;
610
+ builder = builder.initial_alphabet(
611
+ alphabet
612
+ .into_iter()
613
+ .filter_map(|s| s.chars().next())
614
+ .collect(),
615
+ );
606
616
  }
607
617
 
608
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
618
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
609
619
  if !value.is_nil() {
610
620
  builder = builder.vocab_size(TryConvert::try_convert(value)?);
611
621
  }
612
622
 
613
- let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
623
+ let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
614
624
  if !value.is_nil() {
615
625
  builder = builder.min_frequency(TryConvert::try_convert(value)?);
616
626
  }
617
627
 
618
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
628
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
619
629
  if !value.is_nil() {
620
630
  builder = builder.show_progress(TryConvert::try_convert(value)?);
621
631
  }
622
632
 
623
- let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
633
+ let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
624
634
  if !value.is_nil() {
625
635
  builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
626
636
  }
627
637
 
628
- let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
638
+ let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
629
639
  if !value.is_nil() {
630
640
  builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
631
641
  }
632
642
 
633
- let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
643
+ let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
634
644
  if !value.is_nil() {
635
645
  builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
636
646
  }
637
647
 
638
648
  if !kwargs.is_empty() {
639
649
  // TODO improve message
640
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
650
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
641
651
  }
642
652
 
643
653
  Ok(builder.build().into())
@@ -1,7 +1,7 @@
1
1
  use super::regex::{regex, RbRegex};
2
2
  use crate::RbResult;
3
3
  use magnus::prelude::*;
4
- use magnus::{exception, Error, TryConvert, Value};
4
+ use magnus::{Error, Ruby, TryConvert, Value};
5
5
  use tk::normalizer::SplitDelimiterBehavior;
6
6
  use tk::pattern::Pattern;
7
7
 
@@ -62,6 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
62
62
 
63
63
  impl TryConvert for RbSplitDelimiterBehavior {
64
64
  fn try_convert(obj: Value) -> RbResult<Self> {
65
+ let ruby = Ruby::get_with(obj);
65
66
  let s = String::try_convert(obj)?;
66
67
 
67
68
  Ok(Self(match s.as_str() {
@@ -71,7 +72,7 @@ impl TryConvert for RbSplitDelimiterBehavior {
71
72
  "merged_with_next" => Ok(SplitDelimiterBehavior::MergedWithNext),
72
73
  "contiguous" => Ok(SplitDelimiterBehavior::Contiguous),
73
74
  _ => Err(Error::new(
74
- exception::arg_error(),
75
+ ruby.exception_arg_error(),
75
76
  "Wrong value for SplitDelimiterBehavior, expected one of: \
76
77
  `removed, isolated, merged_with_previous, merged_with_next, contiguous`",
77
78
  )),
@@ -1,5 +1,5 @@
1
1
  use crate::{RbResult, TOKENIZERS};
2
- use magnus::{exception, prelude::*, value::Lazy, Error, RClass, Ruby};
2
+ use magnus::{prelude::*, value::Lazy, Error, RClass, Ruby};
3
3
  use onig::Regex;
4
4
 
5
5
  #[magnus::wrap(class = "Tokenizers::Regex")]
@@ -9,10 +9,11 @@ pub struct RbRegex {
9
9
  }
10
10
 
11
11
  impl RbRegex {
12
- pub fn new(s: String) -> RbResult<Self> {
12
+ pub fn new(ruby: &Ruby, s: String) -> RbResult<Self> {
13
13
  Ok(Self {
14
- inner: Regex::new(&s)
15
- .map_err(|e| Error::new(exception::runtime_error(), e.description().to_owned()))?,
14
+ inner: Regex::new(&s).map_err(|e| {
15
+ Error::new(ruby.exception_runtime_error(), e.description().to_owned())
16
+ })?,
16
17
  pattern: s,
17
18
  })
18
19
  }