tokenizers 0.5.5 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
1
  use std::sync::{Arc, RwLock};
2
2
 
3
3
  use magnus::{
4
- data_type_builder, exception, function, method, value::Lazy, Class, DataType,
5
- DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
4
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
5
+ Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
6
6
  };
7
7
 
8
8
  use serde::ser::SerializeStruct;
@@ -278,16 +278,16 @@ impl RbSequence {
278
278
  }
279
279
 
280
280
  pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
281
+ let ruby = Ruby::get().unwrap();
281
282
  let scheme = match string.as_str() {
282
283
  "first" => PrependScheme::First,
283
284
  "never" => PrependScheme::Never,
284
285
  "always" => PrependScheme::Always,
285
286
  _ => {
286
287
  return Err(Error::new(
287
- exception::arg_error(),
288
+ ruby.exception_arg_error(),
288
289
  format!(
289
- "{} is an unknown variant, should be one of ['first', 'never', 'always']",
290
- string
290
+ "{string} is an unknown variant, should be one of ['first', 'never', 'always']"
291
291
  ),
292
292
  ));
293
293
  }
@@ -4,7 +4,7 @@ use std::path::PathBuf;
4
4
  use std::str::FromStr;
5
5
 
6
6
  use magnus::prelude::*;
7
- use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
7
+ use magnus::{Error, RArray, RHash, RString, Ruby, TryConvert, Value};
8
8
  use tk::tokenizer::{
9
9
  Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
10
10
  TruncationParams, TruncationStrategy,
@@ -78,37 +78,37 @@ impl From<tk::AddedToken> for RbAddedToken {
78
78
  }
79
79
 
80
80
  impl RbAddedToken {
81
- pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
81
+ pub fn new(ruby: &Ruby, content: Option<String>, kwargs: RHash) -> RbResult<Self> {
82
82
  let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
83
83
 
84
- let value: Value = kwargs.delete(Symbol::new("single_word"))?;
84
+ let value: Value = kwargs.delete(ruby.to_symbol("single_word"))?;
85
85
  if !value.is_nil() {
86
86
  token.single_word = TryConvert::try_convert(value)?;
87
87
  }
88
88
 
89
- let value: Value = kwargs.delete(Symbol::new("lstrip"))?;
89
+ let value: Value = kwargs.delete(ruby.to_symbol("lstrip"))?;
90
90
  if !value.is_nil() {
91
91
  token.lstrip = TryConvert::try_convert(value)?;
92
92
  }
93
93
 
94
- let value: Value = kwargs.delete(Symbol::new("rstrip"))?;
94
+ let value: Value = kwargs.delete(ruby.to_symbol("rstrip"))?;
95
95
  if !value.is_nil() {
96
96
  token.rstrip = TryConvert::try_convert(value)?;
97
97
  }
98
98
 
99
- let value: Value = kwargs.delete(Symbol::new("normalized"))?;
99
+ let value: Value = kwargs.delete(ruby.to_symbol("normalized"))?;
100
100
  if !value.is_nil() {
101
101
  token.normalized = TryConvert::try_convert(value)?;
102
102
  }
103
103
 
104
- let value: Value = kwargs.delete(Symbol::new("special"))?;
104
+ let value: Value = kwargs.delete(ruby.to_symbol("special"))?;
105
105
  if !value.is_nil() {
106
106
  token.special = TryConvert::try_convert(value)?;
107
107
  }
108
108
 
109
109
  if !kwargs.is_empty() {
110
110
  // TODO improve message
111
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
111
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
112
112
  }
113
113
 
114
114
  Ok(token)
@@ -189,6 +189,8 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
189
189
 
190
190
  impl TryConvert for TextEncodeInput<'_> {
191
191
  fn try_convert(ob: Value) -> RbResult<Self> {
192
+ let ruby = Ruby::get_with(ob);
193
+
192
194
  if let Ok(i) = TextInputSequence::try_convert(ob) {
193
195
  return Ok(Self(i.into()));
194
196
  }
@@ -204,7 +206,7 @@ impl TryConvert for TextEncodeInput<'_> {
204
206
  }
205
207
  }
206
208
  Err(Error::new(
207
- exception::type_error(),
209
+ ruby.exception_type_error(),
208
210
  "TextEncodeInput must be a string or pair of strings",
209
211
  ))
210
212
  }
@@ -220,6 +222,8 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
220
222
 
221
223
  impl TryConvert for PreTokenizedEncodeInput<'_> {
222
224
  fn try_convert(ob: Value) -> RbResult<Self> {
225
+ let ruby = Ruby::get_with(ob);
226
+
223
227
  if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
224
228
  return Ok(Self(i.into()));
225
229
  }
@@ -237,7 +241,7 @@ impl TryConvert for PreTokenizedEncodeInput<'_> {
237
241
  }
238
242
  }
239
243
  Err(Error::new(
240
- exception::type_error(),
244
+ ruby.exception_type_error(),
241
245
  "PreTokenizedEncodeInput must be an array of strings or pair of arrays",
242
246
  ))
243
247
  }
@@ -351,7 +355,8 @@ impl RbTokenizer {
351
355
  }
352
356
 
353
357
  pub fn encode_batch(
354
- &self,
358
+ ruby: &Ruby,
359
+ rb_self: &Self,
355
360
  input: RArray,
356
361
  is_pretokenized: bool,
357
362
  add_special_tokens: bool,
@@ -367,16 +372,16 @@ impl RbTokenizer {
367
372
  Ok(input)
368
373
  })
369
374
  .collect::<RbResult<Vec<tk::EncodeInput>>>()?;
370
- self.tokenizer
371
- .borrow()
372
- .encode_batch_char_offsets(input, add_special_tokens)
373
- .map(|encodings| {
374
- encodings
375
- .into_iter()
376
- .map(Into::<RbEncoding>::into)
377
- .collect()
378
- })
379
- .map_err(RbError::from)
375
+ Ok(ruby.ary_from_iter(
376
+ rb_self
377
+ .tokenizer
378
+ .borrow()
379
+ .encode_batch_char_offsets(input, add_special_tokens)
380
+ .map(|encodings| {
381
+ ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into))
382
+ })
383
+ .map_err(RbError::from),
384
+ ))
380
385
  }
381
386
 
382
387
  pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
@@ -453,10 +458,10 @@ impl RbTokenizer {
453
458
  }
454
459
 
455
460
  // TODO support more kwargs
456
- pub fn enable_padding(&self, kwargs: RHash) -> RbResult<()> {
461
+ pub fn enable_padding(ruby: &Ruby, rb_self: &Self, kwargs: RHash) -> RbResult<()> {
457
462
  let mut params = PaddingParams::default();
458
463
 
459
- let value: Value = kwargs.delete(Symbol::new("direction"))?;
464
+ let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
460
465
  if !value.is_nil() {
461
466
  let dir_str = String::try_convert(value)?;
462
467
  params.direction = match dir_str.as_str() {
@@ -464,34 +469,34 @@ impl RbTokenizer {
464
469
  "right" => PaddingDirection::Right,
465
470
  _ => {
466
471
  return Err(Error::new(
467
- exception::arg_error(),
472
+ ruby.exception_arg_error(),
468
473
  "The direction value must be 'left' or 'right'",
469
474
  ))
470
475
  }
471
476
  }
472
477
  }
473
478
 
474
- let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
479
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_to_multiple_of"))?;
475
480
  if !value.is_nil() {
476
481
  params.pad_to_multiple_of = TryConvert::try_convert(value)?;
477
482
  }
478
483
 
479
- let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
484
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_id"))?;
480
485
  if !value.is_nil() {
481
486
  params.pad_id = TryConvert::try_convert(value)?;
482
487
  }
483
488
 
484
- let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
489
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_type_id"))?;
485
490
  if !value.is_nil() {
486
491
  params.pad_type_id = TryConvert::try_convert(value)?;
487
492
  }
488
493
 
489
- let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
494
+ let value: Value = kwargs.delete(ruby.to_symbol("pad_token"))?;
490
495
  if !value.is_nil() {
491
496
  params.pad_token = TryConvert::try_convert(value)?;
492
497
  }
493
498
 
494
- let value: Value = kwargs.delete(Symbol::new("length"))?;
499
+ let value: Value = kwargs.delete(ruby.to_symbol("length"))?;
495
500
  if value.is_nil() {
496
501
  params.strategy = PaddingStrategy::BatchLongest;
497
502
  } else {
@@ -500,10 +505,10 @@ impl RbTokenizer {
500
505
 
501
506
  if !kwargs.is_empty() {
502
507
  // TODO improve message
503
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
508
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
504
509
  }
505
510
 
506
- self.tokenizer.borrow_mut().with_padding(Some(params));
511
+ rb_self.tokenizer.borrow_mut().with_padding(Some(params));
507
512
 
508
513
  Ok(())
509
514
  }
@@ -512,12 +517,13 @@ impl RbTokenizer {
512
517
  self.tokenizer.borrow_mut().with_padding(None);
513
518
  }
514
519
 
515
- pub fn padding(&self) -> RbResult<Option<RHash>> {
516
- self.tokenizer
520
+ pub fn padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
521
+ rb_self
522
+ .tokenizer
517
523
  .borrow()
518
524
  .get_padding()
519
525
  .map_or(Ok(None), |params| {
520
- let ret_hash = RHash::new();
526
+ let ret_hash = ruby.hash_new();
521
527
 
522
528
  ret_hash.aset(
523
529
  "length",
@@ -536,18 +542,23 @@ impl RbTokenizer {
536
542
  })
537
543
  }
538
544
 
539
- pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
545
+ pub fn enable_truncation(
546
+ ruby: &Ruby,
547
+ rb_self: &Self,
548
+ max_length: usize,
549
+ kwargs: RHash,
550
+ ) -> RbResult<()> {
540
551
  let mut params = TruncationParams {
541
552
  max_length,
542
553
  ..Default::default()
543
554
  };
544
555
 
545
- let value: Value = kwargs.delete(Symbol::new("stride"))?;
556
+ let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
546
557
  if !value.is_nil() {
547
558
  params.stride = TryConvert::try_convert(value)?;
548
559
  }
549
560
 
550
- let value: Value = kwargs.delete(Symbol::new("strategy"))?;
561
+ let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
551
562
  if !value.is_nil() {
552
563
  let strategy_str = String::try_convert(value)?;
553
564
  params.strategy = match strategy_str.as_str() {
@@ -555,13 +566,13 @@ impl RbTokenizer {
555
566
  "only_first" => TruncationStrategy::OnlyFirst,
556
567
  "only_second" => TruncationStrategy::OnlySecond,
557
568
  _ => return Err(Error::new(
558
- exception::arg_error(),
569
+ ruby.exception_arg_error(),
559
570
  "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
560
571
  )),
561
572
  }
562
573
  }
563
574
 
564
- let value: Value = kwargs.delete(Symbol::new("direction"))?;
575
+ let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
565
576
  if !value.is_nil() {
566
577
  let dir_str = String::try_convert(value)?;
567
578
  params.direction = match dir_str.as_str() {
@@ -569,7 +580,7 @@ impl RbTokenizer {
569
580
  "right" => TruncationDirection::Right,
570
581
  _ => {
571
582
  return Err(Error::new(
572
- exception::arg_error(),
583
+ ruby.exception_arg_error(),
573
584
  "The direction value must be 'left' or 'right'",
574
585
  ))
575
586
  }
@@ -578,12 +589,12 @@ impl RbTokenizer {
578
589
 
579
590
  if !kwargs.is_empty() {
580
591
  // TODO improve message
581
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
592
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
582
593
  }
583
594
 
584
- if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
595
+ if let Err(error_message) = rb_self.tokenizer.borrow_mut().with_truncation(Some(params)) {
585
596
  return Err(Error::new(
586
- exception::arg_error(),
597
+ ruby.exception_arg_error(),
587
598
  error_message.to_string(),
588
599
  ));
589
600
  }
@@ -598,12 +609,13 @@ impl RbTokenizer {
598
609
  .expect("Failed to set truncation to `None`! This should never happen");
599
610
  }
600
611
 
601
- pub fn truncation(&self) -> RbResult<Option<RHash>> {
602
- self.tokenizer
612
+ pub fn truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
613
+ rb_self
614
+ .tokenizer
603
615
  .borrow()
604
616
  .get_truncation()
605
617
  .map_or(Ok(None), |params| {
606
- let ret_hash = RHash::new();
618
+ let ret_hash = ruby.hash_new();
607
619
 
608
620
  ret_hash.aset("max_length", params.max_length)?;
609
621
  ret_hash.aset("stride", params.stride)?;
@@ -629,10 +641,10 @@ impl RbTokenizer {
629
641
  self.tokenizer.borrow().get_vocab_size(with_added_tokens)
630
642
  }
631
643
 
632
- pub fn get_added_tokens_decoder(&self) -> RbResult<RHash> {
633
- let sorted_map = RHash::new();
644
+ pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
645
+ let sorted_map = ruby.hash_new();
634
646
 
635
- for (key, value) in self.tokenizer.borrow().get_added_tokens_decoder() {
647
+ for (key, value) in rb_self.tokenizer.borrow().get_added_tokens_decoder() {
636
648
  sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
637
649
  }
638
650
 
@@ -1,13 +1,11 @@
1
- use std::collections::HashSet;
2
1
  use std::sync::{Arc, RwLock};
3
2
 
4
3
  use crate::models::RbModel;
5
4
  use crate::tokenizer::RbAddedToken;
6
5
  use magnus::prelude::*;
7
6
  use magnus::{
8
- data_type_builder, exception, function, method, value::Lazy, Class, DataType,
9
- DataTypeFunctions, Error, Module, Object, RArray, RClass, RHash, RModule, Ruby, Symbol,
10
- TryConvert, TypedData, Value,
7
+ data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
8
+ Module, Object, RArray, RClass, RHash, RModule, Ruby, TryConvert, TypedData, Value,
11
9
  };
12
10
  use serde::{Deserialize, Serialize};
13
11
  use tk::models::TrainerWrapper;
@@ -391,10 +389,10 @@ where
391
389
  pub struct RbBpeTrainer {}
392
390
 
393
391
  impl RbBpeTrainer {
394
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
392
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
395
393
  let mut builder = tk::models::bpe::BpeTrainer::builder();
396
394
 
397
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
395
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
398
396
  if !value.is_nil() {
399
397
  builder = builder.special_tokens(
400
398
  RArray::try_convert(value)?
@@ -410,46 +408,50 @@ impl RbBpeTrainer {
410
408
  );
411
409
  }
412
410
 
413
- let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
411
+ let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
414
412
  if !value.is_nil() {
415
- let arr = <Vec<char>>::try_convert(value)?;
416
- let set: HashSet<char> = HashSet::from_iter(arr);
417
- builder = builder.initial_alphabet(set);
413
+ let alphabet = Vec::<String>::try_convert(value)?;
414
+ builder = builder.initial_alphabet(
415
+ alphabet
416
+ .into_iter()
417
+ .filter_map(|s| s.chars().next())
418
+ .collect(),
419
+ );
418
420
  }
419
421
 
420
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
422
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
421
423
  if !value.is_nil() {
422
424
  builder = builder.vocab_size(TryConvert::try_convert(value)?);
423
425
  }
424
426
 
425
- let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
427
+ let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
426
428
  if !value.is_nil() {
427
429
  builder = builder.min_frequency(TryConvert::try_convert(value)?);
428
430
  }
429
431
 
430
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
432
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
431
433
  if !value.is_nil() {
432
434
  builder = builder.show_progress(TryConvert::try_convert(value)?);
433
435
  }
434
436
 
435
- let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
437
+ let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
436
438
  if !value.is_nil() {
437
439
  builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
438
440
  }
439
441
 
440
- let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
442
+ let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
441
443
  if !value.is_nil() {
442
444
  builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
443
445
  }
444
446
 
445
- let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
447
+ let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
446
448
  if !value.is_nil() {
447
449
  builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
448
450
  }
449
451
 
450
452
  if !kwargs.is_empty() {
451
453
  // TODO improve message
452
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
454
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
453
455
  }
454
456
 
455
457
  Ok(builder.build().into())
@@ -459,10 +461,10 @@ impl RbBpeTrainer {
459
461
  pub struct RbUnigramTrainer {}
460
462
 
461
463
  impl RbUnigramTrainer {
462
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
464
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
463
465
  let mut builder = tk::models::unigram::UnigramTrainer::builder();
464
466
 
465
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
467
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
466
468
  if !value.is_nil() {
467
469
  builder.special_tokens(
468
470
  RArray::try_convert(value)?
@@ -478,56 +480,60 @@ impl RbUnigramTrainer {
478
480
  );
479
481
  }
480
482
 
481
- let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
483
+ let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
482
484
  if !value.is_nil() {
483
- let arr = <Vec<char>>::try_convert(value)?;
484
- let set: HashSet<char> = HashSet::from_iter(arr);
485
- builder.initial_alphabet(set);
485
+ let alphabet = Vec::<String>::try_convert(value)?;
486
+ builder.initial_alphabet(
487
+ alphabet
488
+ .into_iter()
489
+ .filter_map(|s| s.chars().next())
490
+ .collect(),
491
+ );
486
492
  }
487
493
 
488
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
494
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
489
495
  if !value.is_nil() {
490
496
  builder.vocab_size(TryConvert::try_convert(value)?);
491
497
  }
492
498
 
493
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
499
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
494
500
  if !value.is_nil() {
495
501
  builder.show_progress(TryConvert::try_convert(value)?);
496
502
  }
497
503
 
498
- let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
504
+ let value: Value = kwargs.delete(ruby.to_symbol("n_sub_iterations"))?;
499
505
  if !value.is_nil() {
500
506
  builder.n_sub_iterations(TryConvert::try_convert(value)?);
501
507
  }
502
508
 
503
- let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
509
+ let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
504
510
  if !value.is_nil() {
505
511
  builder.unk_token(Some(TryConvert::try_convert(value)?));
506
512
  }
507
513
 
508
- let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
514
+ let value: Value = kwargs.delete(ruby.to_symbol("max_piece_length"))?;
509
515
  if !value.is_nil() {
510
516
  builder.max_piece_length(TryConvert::try_convert(value)?);
511
517
  }
512
518
 
513
- let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
519
+ let value: Value = kwargs.delete(ruby.to_symbol("seed_size"))?;
514
520
  if !value.is_nil() {
515
521
  builder.seed_size(TryConvert::try_convert(value)?);
516
522
  }
517
523
 
518
- let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
524
+ let value: Value = kwargs.delete(ruby.to_symbol("shrinking_factor"))?;
519
525
  if !value.is_nil() {
520
526
  builder.shrinking_factor(TryConvert::try_convert(value)?);
521
527
  }
522
528
 
523
529
  if !kwargs.is_empty() {
524
530
  // TODO improve message
525
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
531
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
526
532
  }
527
533
 
528
534
  let trainer = builder
529
535
  .build()
530
- .map_err(|_| Error::new(exception::arg_error(), "Cannot build UnigramTrainer"))?;
536
+ .map_err(|_| Error::new(ruby.exception_arg_error(), "Cannot build UnigramTrainer"))?;
531
537
  Ok(trainer.into())
532
538
  }
533
539
  }
@@ -535,10 +541,10 @@ impl RbUnigramTrainer {
535
541
  pub struct RbWordLevelTrainer {}
536
542
 
537
543
  impl RbWordLevelTrainer {
538
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
544
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
539
545
  let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
540
546
 
541
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
547
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
542
548
  if !value.is_nil() {
543
549
  builder.special_tokens(
544
550
  RArray::try_convert(value)?
@@ -554,17 +560,17 @@ impl RbWordLevelTrainer {
554
560
  );
555
561
  }
556
562
 
557
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
563
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
558
564
  if !value.is_nil() {
559
565
  builder.vocab_size(TryConvert::try_convert(value)?);
560
566
  }
561
567
 
562
- let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
568
+ let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
563
569
  if !value.is_nil() {
564
570
  builder.min_frequency(TryConvert::try_convert(value)?);
565
571
  }
566
572
 
567
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
573
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
568
574
  if !value.is_nil() {
569
575
  builder.show_progress(TryConvert::try_convert(value)?);
570
576
  }
@@ -579,10 +585,10 @@ impl RbWordLevelTrainer {
579
585
  pub struct RbWordPieceTrainer {}
580
586
 
581
587
  impl RbWordPieceTrainer {
582
- pub fn new(kwargs: RHash) -> RbResult<RbTrainer> {
588
+ pub fn new(ruby: &Ruby, kwargs: RHash) -> RbResult<RbTrainer> {
583
589
  let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
584
590
 
585
- let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
591
+ let value: Value = kwargs.delete(ruby.to_symbol("special_tokens"))?;
586
592
  if !value.is_nil() {
587
593
  builder = builder.special_tokens(
588
594
  RArray::try_convert(value)?
@@ -598,46 +604,50 @@ impl RbWordPieceTrainer {
598
604
  );
599
605
  }
600
606
 
601
- let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
607
+ let value: Value = kwargs.delete(ruby.to_symbol("initial_alphabet"))?;
602
608
  if !value.is_nil() {
603
- let arr = <Vec<char>>::try_convert(value)?;
604
- let set: HashSet<char> = HashSet::from_iter(arr);
605
- builder = builder.initial_alphabet(set);
609
+ let alphabet = Vec::<String>::try_convert(value)?;
610
+ builder = builder.initial_alphabet(
611
+ alphabet
612
+ .into_iter()
613
+ .filter_map(|s| s.chars().next())
614
+ .collect(),
615
+ );
606
616
  }
607
617
 
608
- let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
618
+ let value: Value = kwargs.delete(ruby.to_symbol("vocab_size"))?;
609
619
  if !value.is_nil() {
610
620
  builder = builder.vocab_size(TryConvert::try_convert(value)?);
611
621
  }
612
622
 
613
- let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
623
+ let value: Value = kwargs.delete(ruby.to_symbol("min_frequency"))?;
614
624
  if !value.is_nil() {
615
625
  builder = builder.min_frequency(TryConvert::try_convert(value)?);
616
626
  }
617
627
 
618
- let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
628
+ let value: Value = kwargs.delete(ruby.to_symbol("show_progress"))?;
619
629
  if !value.is_nil() {
620
630
  builder = builder.show_progress(TryConvert::try_convert(value)?);
621
631
  }
622
632
 
623
- let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
633
+ let value: Value = kwargs.delete(ruby.to_symbol("limit_alphabet"))?;
624
634
  if !value.is_nil() {
625
635
  builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
626
636
  }
627
637
 
628
- let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
638
+ let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
629
639
  if !value.is_nil() {
630
640
  builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
631
641
  }
632
642
 
633
- let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
643
+ let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
634
644
  if !value.is_nil() {
635
645
  builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
636
646
  }
637
647
 
638
648
  if !kwargs.is_empty() {
639
649
  // TODO improve message
640
- return Err(Error::new(exception::arg_error(), "unknown keyword"));
650
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
641
651
  }
642
652
 
643
653
  Ok(builder.build().into())
@@ -1,7 +1,7 @@
1
1
  use super::regex::{regex, RbRegex};
2
2
  use crate::RbResult;
3
3
  use magnus::prelude::*;
4
- use magnus::{exception, Error, TryConvert, Value};
4
+ use magnus::{Error, Ruby, TryConvert, Value};
5
5
  use tk::normalizer::SplitDelimiterBehavior;
6
6
  use tk::pattern::Pattern;
7
7
 
@@ -62,6 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
62
62
 
63
63
  impl TryConvert for RbSplitDelimiterBehavior {
64
64
  fn try_convert(obj: Value) -> RbResult<Self> {
65
+ let ruby = Ruby::get_with(obj);
65
66
  let s = String::try_convert(obj)?;
66
67
 
67
68
  Ok(Self(match s.as_str() {
@@ -71,7 +72,7 @@ impl TryConvert for RbSplitDelimiterBehavior {
71
72
  "merged_with_next" => Ok(SplitDelimiterBehavior::MergedWithNext),
72
73
  "contiguous" => Ok(SplitDelimiterBehavior::Contiguous),
73
74
  _ => Err(Error::new(
74
- exception::arg_error(),
75
+ ruby.exception_arg_error(),
75
76
  "Wrong value for SplitDelimiterBehavior, expected one of: \
76
77
  `removed, isolated, merged_with_previous, merged_with_next, contiguous`",
77
78
  )),
@@ -1,5 +1,5 @@
1
1
  use crate::{RbResult, TOKENIZERS};
2
- use magnus::{exception, prelude::*, value::Lazy, Error, RClass, Ruby};
2
+ use magnus::{prelude::*, value::Lazy, Error, RClass, Ruby};
3
3
  use onig::Regex;
4
4
 
5
5
  #[magnus::wrap(class = "Tokenizers::Regex")]
@@ -9,10 +9,11 @@ pub struct RbRegex {
9
9
  }
10
10
 
11
11
  impl RbRegex {
12
- pub fn new(s: String) -> RbResult<Self> {
12
+ pub fn new(ruby: &Ruby, s: String) -> RbResult<Self> {
13
13
  Ok(Self {
14
- inner: Regex::new(&s)
15
- .map_err(|e| Error::new(exception::runtime_error(), e.description().to_owned()))?,
14
+ inner: Regex::new(&s).map_err(|e| {
15
+ Error::new(ruby.exception_runtime_error(), e.description().to_owned())
16
+ })?,
16
17
  pattern: s,
17
18
  })
18
19
  }