tokenizers 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ use std::cell::RefCell;
2
2
  use std::collections::HashMap;
3
3
  use std::path::PathBuf;
4
4
 
5
+ use magnus::prelude::*;
5
6
  use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
6
7
  use tk::tokenizer::{
7
8
  Model, PaddingDirection, PaddingParams, PaddingStrategy,
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
78
79
 
79
80
  impl<'s> TryConvert for TextInputSequence<'s> {
80
81
  fn try_convert(ob: Value) -> RbResult<Self> {
81
- Ok(Self(ob.try_convert::<String>()?.into()))
82
+ Ok(Self(String::try_convert(ob)?.into()))
82
83
  }
83
84
  }
84
85
 
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
92
93
 
93
94
  impl TryConvert for RbArrayStr {
94
95
  fn try_convert(ob: Value) -> RbResult<Self> {
95
- let seq = ob.try_convert::<Vec<String>>()?;
96
+ let seq = <Vec<String>>::try_convert(ob)?;
96
97
  Ok(Self(seq))
97
98
  }
98
99
  }
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
107
108
 
108
109
  impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
109
110
  fn try_convert(ob: Value) -> RbResult<Self> {
110
- if let Ok(seq) = ob.try_convert::<RbArrayStr>() {
111
+ if let Ok(seq) = RbArrayStr::try_convert(ob) {
111
112
  return Ok(Self(seq.into()));
112
113
  }
113
114
  todo!()
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
124
125
 
125
126
  impl<'s> TryConvert for TextEncodeInput<'s> {
126
127
  fn try_convert(ob: Value) -> RbResult<Self> {
127
- if let Ok(i) = ob.try_convert::<TextInputSequence>() {
128
+ if let Ok(i) = TextInputSequence::try_convert(ob) {
128
129
  return Ok(Self(i.into()));
129
130
  }
130
- if let Ok((i1, i2)) = ob.try_convert::<(TextInputSequence, TextInputSequence)>() {
131
+ if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
131
132
  return Ok(Self((i1, i2).into()));
132
133
  }
133
134
  // TODO check if this branch is needed
134
- if let Ok(arr) = ob.try_convert::<RArray>() {
135
+ if let Ok(arr) = RArray::try_convert(ob) {
135
136
  if arr.len() == 2 {
136
137
  let first = arr.entry::<TextInputSequence>(0).unwrap();
137
138
  let second = arr.entry::<TextInputSequence>(1).unwrap();
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
155
156
 
156
157
  impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
157
158
  fn try_convert(ob: Value) -> RbResult<Self> {
158
- if let Ok(i) = ob.try_convert::<PreTokenizedInputSequence>() {
159
+ if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
159
160
  return Ok(Self(i.into()));
160
161
  }
161
162
  if let Ok((i1, i2)) =
162
- ob.try_convert::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
163
+ <(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
163
164
  {
164
165
  return Ok(Self((i1, i2).into()));
165
166
  }
166
167
  // TODO check if this branch is needed
167
- if let Ok(arr) = ob.try_convert::<RArray>() {
168
+ if let Ok(arr) = RArray::try_convert(ob) {
168
169
  if arr.len() == 2 {
169
170
  let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
170
171
  let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
@@ -251,16 +252,16 @@ impl RbTokenizer {
251
252
  add_special_tokens: bool,
252
253
  ) -> RbResult<RbEncoding> {
253
254
  let sequence: tk::InputSequence = if is_pretokenized {
254
- sequence.try_convert::<PreTokenizedInputSequence>()?.into()
255
+ PreTokenizedInputSequence::try_convert(sequence)?.into()
255
256
  } else {
256
- sequence.try_convert::<TextInputSequence>()?.into()
257
+ TextInputSequence::try_convert(sequence)?.into()
257
258
  };
258
259
  let input = match pair {
259
260
  Some(pair) => {
260
261
  let pair: tk::InputSequence = if is_pretokenized {
261
- pair.try_convert::<PreTokenizedInputSequence>()?.into()
262
+ PreTokenizedInputSequence::try_convert(pair)?.into()
262
263
  } else {
263
- pair.try_convert::<TextInputSequence>()?.into()
264
+ TextInputSequence::try_convert(pair)?.into()
264
265
  };
265
266
  tk::EncodeInput::Dual(sequence, pair)
266
267
  }
@@ -284,9 +285,9 @@ impl RbTokenizer {
284
285
  .each()
285
286
  .map(|o| {
286
287
  let input: tk::EncodeInput = if is_pretokenized {
287
- o?.try_convert::<PreTokenizedEncodeInput>()?.into()
288
+ PreTokenizedEncodeInput::try_convert(o?)?.into()
288
289
  } else {
289
- o?.try_convert::<TextEncodeInput>()?.into()
290
+ TextEncodeInput::try_convert(o?)?.into()
290
291
  };
291
292
  Ok(input)
292
293
  })
@@ -306,14 +307,15 @@ impl RbTokenizer {
306
307
  pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
307
308
  self.tokenizer
308
309
  .borrow()
309
- .decode(ids, skip_special_tokens)
310
+ .decode(&ids, skip_special_tokens)
310
311
  .map_err(RbError::from)
311
312
  }
312
313
 
313
314
  pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
315
+ let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
314
316
  self.tokenizer
315
317
  .borrow()
316
- .decode_batch(sequences, skip_special_tokens)
318
+ .decode_batch(&slices, skip_special_tokens)
317
319
  .map_err(RbError::from)
318
320
  }
319
321
 
@@ -353,7 +355,7 @@ impl RbTokenizer {
353
355
 
354
356
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
355
357
  if !value.is_nil() {
356
- let dir_str: String = value.try_convert()?;
358
+ let dir_str = String::try_convert(value)?;
357
359
  params.direction = match dir_str.as_str() {
358
360
  "left" => PaddingDirection::Left,
359
361
  "right" => PaddingDirection::Right,
@@ -363,29 +365,29 @@ impl RbTokenizer {
363
365
 
364
366
  let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
365
367
  if !value.is_nil() {
366
- params.pad_to_multiple_of = value.try_convert()?;
368
+ params.pad_to_multiple_of = TryConvert::try_convert(value)?;
367
369
  }
368
370
 
369
371
  let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
370
372
  if !value.is_nil() {
371
- params.pad_id = value.try_convert()?;
373
+ params.pad_id = TryConvert::try_convert(value)?;
372
374
  }
373
375
 
374
376
  let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
375
377
  if !value.is_nil() {
376
- params.pad_type_id = value.try_convert()?;
378
+ params.pad_type_id = TryConvert::try_convert(value)?;
377
379
  }
378
380
 
379
381
  let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
380
382
  if !value.is_nil() {
381
- params.pad_token = value.try_convert()?;
383
+ params.pad_token = TryConvert::try_convert(value)?;
382
384
  }
383
385
 
384
386
  let value: Value = kwargs.delete(Symbol::new("length"))?;
385
387
  if value.is_nil() {
386
388
  params.strategy = PaddingStrategy::BatchLongest;
387
389
  } else {
388
- params.strategy = PaddingStrategy::Fixed(value.try_convert()?);
390
+ params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
389
391
  }
390
392
 
391
393
  if !kwargs.is_empty() {
@@ -431,12 +433,12 @@ impl RbTokenizer {
431
433
 
432
434
  let value: Value = kwargs.delete(Symbol::new("stride"))?;
433
435
  if !value.is_nil() {
434
- params.stride = value.try_convert()?;
436
+ params.stride = TryConvert::try_convert(value)?;
435
437
  }
436
438
 
437
439
  let value: Value = kwargs.delete(Symbol::new("strategy"))?;
438
440
  if !value.is_nil() {
439
- let strategy_str: String = value.try_convert()?;
441
+ let strategy_str = String::try_convert(value)?;
440
442
  params.strategy = match strategy_str.as_str() {
441
443
  "longest_first" => TruncationStrategy::LongestFirst,
442
444
  "only_first" => TruncationStrategy::OnlyFirst,
@@ -447,7 +449,7 @@ impl RbTokenizer {
447
449
 
448
450
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
449
451
  if !value.is_nil() {
450
- let dir_str: String = value.try_convert()?;
452
+ let dir_str = String::try_convert(value)?;
451
453
  params.direction = match dir_str.as_str() {
452
454
  "left" => TruncationDirection::Left,
453
455
  "right" => TruncationDirection::Right,
@@ -460,13 +462,18 @@ impl RbTokenizer {
460
462
  return Err(Error::new(exception::arg_error(), "unknown keyword"));
461
463
  }
462
464
 
463
- self.tokenizer.borrow_mut().with_truncation(Some(params));
465
+ if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
466
+ return Err(Error::new(exception::arg_error(), error_message.to_string()));
467
+ }
464
468
 
465
469
  Ok(())
466
470
  }
467
471
 
468
472
  pub fn no_truncation(&self) {
469
- self.tokenizer.borrow_mut().with_truncation(None);
473
+ self.tokenizer
474
+ .borrow_mut()
475
+ .with_truncation(None)
476
+ .expect("Failed to set truncation to `None`! This should never happen");
470
477
  }
471
478
 
472
479
  pub fn truncation(&self) -> RbResult<Option<RHash>> {
@@ -3,16 +3,16 @@ use std::sync::{Arc, RwLock};
3
3
 
4
4
  use crate::models::RbModel;
5
5
  use crate::tokenizer::RbAddedToken;
6
- use magnus::typed_data::DataTypeBuilder;
6
+ use magnus::prelude::*;
7
7
  use magnus::{
8
- exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
- RArray, RClass, RHash, RModule, Symbol, TypedData, Value,
8
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RArray, RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
10
10
  };
11
11
  use serde::{Deserialize, Serialize};
12
12
  use tk::models::TrainerWrapper;
13
13
  use tk::Trainer;
14
14
 
15
- use super::RbResult;
15
+ use super::{RbResult, TRAINERS};
16
16
 
17
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
18
18
  pub struct RbTrainer {
@@ -112,7 +112,7 @@ impl RbTrainer {
112
112
  special_tokens
113
113
  .each()
114
114
  .map(|token| {
115
- if let Ok(content) = token?.try_convert::<String>() {
115
+ if let Ok(content) = String::try_convert(token?) {
116
116
  Ok(RbAddedToken::from(content, Some(true)).get_token())
117
117
  } else {
118
118
  todo!()
@@ -144,7 +144,7 @@ impl RbTrainer {
144
144
  self,
145
145
  BpeTrainer,
146
146
  initial_alphabet,
147
- alphabet.into_iter().map(|c| c).collect()
147
+ alphabet.into_iter().collect()
148
148
  );
149
149
  }
150
150
 
@@ -199,7 +199,7 @@ impl RbTrainer {
199
199
  special_tokens
200
200
  .each()
201
201
  .map(|token| {
202
- if let Ok(content) = token?.try_convert::<String>() {
202
+ if let Ok(content) = String::try_convert(token?) {
203
203
  Ok(RbAddedToken::from(content, Some(true)).get_token())
204
204
  } else {
205
205
  todo!()
@@ -223,7 +223,7 @@ impl RbTrainer {
223
223
  self,
224
224
  UnigramTrainer,
225
225
  initial_alphabet,
226
- alphabet.into_iter().map(|c| c).collect()
226
+ alphabet.into_iter().collect()
227
227
  );
228
228
  }
229
229
 
@@ -270,7 +270,7 @@ impl RbTrainer {
270
270
  special_tokens
271
271
  .each()
272
272
  .map(|token| {
273
- if let Ok(content) = token?.try_convert::<String>() {
273
+ if let Ok(content) = String::try_convert(token?) {
274
274
  Ok(RbAddedToken::from(content, Some(true)).get_token())
275
275
  } else {
276
276
  todo!()
@@ -324,7 +324,7 @@ impl RbTrainer {
324
324
  special_tokens
325
325
  .each()
326
326
  .map(|token| {
327
- if let Ok(content) = token?.try_convert::<String>() {
327
+ if let Ok(content) = String::try_convert(token?) {
328
328
  Ok(RbAddedToken::from(content, Some(true)).get_token())
329
329
  } else {
330
330
  todo!()
@@ -356,7 +356,7 @@ impl RbTrainer {
356
356
  self,
357
357
  WordPieceTrainer,
358
358
  @set_initial_alphabet,
359
- alphabet.into_iter().map(|c| c).collect()
359
+ alphabet.into_iter().collect()
360
360
  );
361
361
  }
362
362
 
@@ -397,11 +397,10 @@ impl RbBpeTrainer {
397
397
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
398
398
  if !value.is_nil() {
399
399
  builder = builder.special_tokens(
400
- value
401
- .try_convert::<RArray>()?
400
+ RArray::try_convert(value)?
402
401
  .each()
403
402
  .map(|token| {
404
- if let Ok(content) = token?.try_convert::<String>() {
403
+ if let Ok(content) = String::try_convert(token?) {
405
404
  Ok(RbAddedToken::from(content, Some(true)).get_token())
406
405
  } else {
407
406
  todo!()
@@ -413,39 +412,39 @@ impl RbBpeTrainer {
413
412
 
414
413
  let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
415
414
  if !value.is_nil() {
416
- let arr = value.try_convert::<Vec<char>>()?;
415
+ let arr = <Vec<char>>::try_convert(value)?;
417
416
  let set: HashSet<char> = HashSet::from_iter(arr);
418
417
  builder = builder.initial_alphabet(set);
419
418
  }
420
419
 
421
420
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
422
421
  if !value.is_nil() {
423
- builder = builder.vocab_size(value.try_convert()?);
422
+ builder = builder.vocab_size(TryConvert::try_convert(value)?);
424
423
  }
425
424
 
426
425
  let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
427
426
  if !value.is_nil() {
428
- builder = builder.min_frequency(value.try_convert()?);
427
+ builder = builder.min_frequency(TryConvert::try_convert(value)?);
429
428
  }
430
429
 
431
430
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
432
431
  if !value.is_nil() {
433
- builder = builder.show_progress(value.try_convert()?);
432
+ builder = builder.show_progress(TryConvert::try_convert(value)?);
434
433
  }
435
434
 
436
435
  let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
437
436
  if !value.is_nil() {
438
- builder = builder.limit_alphabet(value.try_convert()?);
437
+ builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
439
438
  }
440
439
 
441
440
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
442
441
  if !value.is_nil() {
443
- builder = builder.continuing_subword_prefix(value.try_convert()?);
442
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
444
443
  }
445
444
 
446
445
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
447
446
  if !value.is_nil() {
448
- builder = builder.end_of_word_suffix(value.try_convert()?);
447
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
449
448
  }
450
449
 
451
450
  if !kwargs.is_empty() {
@@ -466,11 +465,10 @@ impl RbUnigramTrainer {
466
465
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
467
466
  if !value.is_nil() {
468
467
  builder.special_tokens(
469
- value
470
- .try_convert::<RArray>()?
468
+ RArray::try_convert(value)?
471
469
  .each()
472
470
  .map(|token| {
473
- if let Ok(content) = token?.try_convert::<String>() {
471
+ if let Ok(content) = String::try_convert(token?) {
474
472
  Ok(RbAddedToken::from(content, Some(true)).get_token())
475
473
  } else {
476
474
  todo!()
@@ -482,44 +480,44 @@ impl RbUnigramTrainer {
482
480
 
483
481
  let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
484
482
  if !value.is_nil() {
485
- let arr = value.try_convert::<Vec<char>>()?;
483
+ let arr = <Vec<char>>::try_convert(value)?;
486
484
  let set: HashSet<char> = HashSet::from_iter(arr);
487
485
  builder.initial_alphabet(set);
488
486
  }
489
487
 
490
488
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
491
489
  if !value.is_nil() {
492
- builder.vocab_size(value.try_convert()?);
490
+ builder.vocab_size(TryConvert::try_convert(value)?);
493
491
  }
494
492
 
495
493
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
496
494
  if !value.is_nil() {
497
- builder.show_progress(value.try_convert()?);
495
+ builder.show_progress(TryConvert::try_convert(value)?);
498
496
  }
499
497
 
500
498
  let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
501
499
  if !value.is_nil() {
502
- builder.n_sub_iterations(value.try_convert()?);
500
+ builder.n_sub_iterations(TryConvert::try_convert(value)?);
503
501
  }
504
502
 
505
503
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
506
504
  if !value.is_nil() {
507
- builder.unk_token(Some(value.try_convert()?));
505
+ builder.unk_token(Some(TryConvert::try_convert(value)?));
508
506
  }
509
507
 
510
508
  let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
511
509
  if !value.is_nil() {
512
- builder.max_piece_length(value.try_convert()?);
510
+ builder.max_piece_length(TryConvert::try_convert(value)?);
513
511
  }
514
512
 
515
513
  let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
516
514
  if !value.is_nil() {
517
- builder.seed_size(value.try_convert()?);
515
+ builder.seed_size(TryConvert::try_convert(value)?);
518
516
  }
519
517
 
520
518
  let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
521
519
  if !value.is_nil() {
522
- builder.shrinking_factor(value.try_convert()?);
520
+ builder.shrinking_factor(TryConvert::try_convert(value)?);
523
521
  }
524
522
 
525
523
  if !kwargs.is_empty() {
@@ -541,11 +539,10 @@ impl RbWordLevelTrainer {
541
539
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
542
540
  if !value.is_nil() {
543
541
  builder.special_tokens(
544
- value
545
- .try_convert::<RArray>()?
542
+ RArray::try_convert(value)?
546
543
  .each()
547
544
  .map(|token| {
548
- if let Ok(content) = token?.try_convert::<String>() {
545
+ if let Ok(content) = String::try_convert(token?) {
549
546
  Ok(RbAddedToken::from(content, Some(true)).get_token())
550
547
  } else {
551
548
  todo!()
@@ -557,17 +554,17 @@ impl RbWordLevelTrainer {
557
554
 
558
555
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
559
556
  if !value.is_nil() {
560
- builder.vocab_size(value.try_convert()?);
557
+ builder.vocab_size(TryConvert::try_convert(value)?);
561
558
  }
562
559
 
563
560
  let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
564
561
  if !value.is_nil() {
565
- builder.min_frequency(value.try_convert()?);
562
+ builder.min_frequency(TryConvert::try_convert(value)?);
566
563
  }
567
564
 
568
565
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
569
566
  if !value.is_nil() {
570
- builder.show_progress(value.try_convert()?);
567
+ builder.show_progress(TryConvert::try_convert(value)?);
571
568
  }
572
569
 
573
570
  Ok(builder.build().expect("WordLevelTrainerBuilder cannot fail").into())
@@ -583,11 +580,10 @@ impl RbWordPieceTrainer {
583
580
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
584
581
  if !value.is_nil() {
585
582
  builder = builder.special_tokens(
586
- value
587
- .try_convert::<RArray>()?
583
+ RArray::try_convert(value)?
588
584
  .each()
589
585
  .map(|token| {
590
- if let Ok(content) = token?.try_convert::<String>() {
586
+ if let Ok(content) = String::try_convert(token?) {
591
587
  Ok(RbAddedToken::from(content, Some(true)).get_token())
592
588
  } else {
593
589
  todo!()
@@ -599,39 +595,39 @@ impl RbWordPieceTrainer {
599
595
 
600
596
  let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
601
597
  if !value.is_nil() {
602
- let arr = value.try_convert::<Vec<char>>()?;
598
+ let arr = <Vec<char>>::try_convert(value)?;
603
599
  let set: HashSet<char> = HashSet::from_iter(arr);
604
600
  builder = builder.initial_alphabet(set);
605
601
  }
606
602
 
607
603
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
608
604
  if !value.is_nil() {
609
- builder = builder.vocab_size(value.try_convert()?);
605
+ builder = builder.vocab_size(TryConvert::try_convert(value)?);
610
606
  }
611
607
 
612
608
  let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
613
609
  if !value.is_nil() {
614
- builder = builder.min_frequency(value.try_convert()?);
610
+ builder = builder.min_frequency(TryConvert::try_convert(value)?);
615
611
  }
616
612
 
617
613
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
618
614
  if !value.is_nil() {
619
- builder = builder.show_progress(value.try_convert()?);
615
+ builder = builder.show_progress(TryConvert::try_convert(value)?);
620
616
  }
621
617
 
622
618
  let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
623
619
  if !value.is_nil() {
624
- builder = builder.limit_alphabet(value.try_convert()?);
620
+ builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
625
621
  }
626
622
 
627
623
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
628
624
  if !value.is_nil() {
629
- builder = builder.continuing_subword_prefix(value.try_convert()?);
625
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
630
626
  }
631
627
 
632
628
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
633
629
  if !value.is_nil() {
634
- builder = builder.end_of_word_suffix(value.try_convert()?);
630
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
635
631
  }
636
632
 
637
633
  if !kwargs.is_empty() {
@@ -644,46 +640,52 @@ impl RbWordPieceTrainer {
644
640
  }
645
641
 
646
642
  unsafe impl TypedData for RbTrainer {
647
- fn class() -> RClass {
648
- *memoize!(RClass: {
649
- let class: RClass = crate::trainers().const_get("Trainer").unwrap();
650
- class.undef_alloc_func();
651
- class
652
- })
643
+ fn class(ruby: &Ruby) -> RClass {
644
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
645
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("Trainer").unwrap();
646
+ class.undef_default_alloc_func();
647
+ class
648
+ });
649
+ ruby.get_inner(&CLASS)
653
650
  }
654
651
 
655
652
  fn data_type() -> &'static DataType {
656
- memoize!(DataType: DataTypeBuilder::<RbTrainer>::new("Tokenizers::Trainers::Trainer").build())
657
- }
658
-
659
- fn class_for(value: &Self) -> RClass {
653
+ static DATA_TYPE: DataType = data_type_builder!(RbTrainer, "Tokenizers::Trainers::Trainer").build();
654
+ &DATA_TYPE
655
+ }
656
+
657
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
658
+ static BPE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
659
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("BpeTrainer").unwrap();
660
+ class.undef_default_alloc_func();
661
+ class
662
+ });
663
+ static UNIGRAM_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
664
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("UnigramTrainer").unwrap();
665
+ class.undef_default_alloc_func();
666
+ class
667
+ });
668
+ static WORD_LEVEL_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
669
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordLevelTrainer").unwrap();
670
+ class.undef_default_alloc_func();
671
+ class
672
+ });
673
+ static WORD_PIECE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
674
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordPieceTrainer").unwrap();
675
+ class.undef_default_alloc_func();
676
+ class
677
+ });
660
678
  match *value.trainer.read().unwrap() {
661
- TrainerWrapper::BpeTrainer(_) => *memoize!(RClass: {
662
- let class: RClass = crate::trainers().const_get("BpeTrainer").unwrap();
663
- class.undef_alloc_func();
664
- class
665
- }),
666
- TrainerWrapper::UnigramTrainer(_) => *memoize!(RClass: {
667
- let class: RClass = crate::trainers().const_get("UnigramTrainer").unwrap();
668
- class.undef_alloc_func();
669
- class
670
- }),
671
- TrainerWrapper::WordLevelTrainer(_) => *memoize!(RClass: {
672
- let class: RClass = crate::trainers().const_get("WordLevelTrainer").unwrap();
673
- class.undef_alloc_func();
674
- class
675
- }),
676
- TrainerWrapper::WordPieceTrainer(_) => *memoize!(RClass: {
677
- let class: RClass = crate::trainers().const_get("WordPieceTrainer").unwrap();
678
- class.undef_alloc_func();
679
- class
680
- }),
679
+ TrainerWrapper::BpeTrainer(_) => ruby.get_inner(&BPE_TRAINER),
680
+ TrainerWrapper::UnigramTrainer(_) => ruby.get_inner(&UNIGRAM_TRAINER),
681
+ TrainerWrapper::WordLevelTrainer(_) => ruby.get_inner(&WORD_LEVEL_TRAINER),
682
+ TrainerWrapper::WordPieceTrainer(_) => ruby.get_inner(&WORD_PIECE_TRAINER),
681
683
  }
682
684
  }
683
685
  }
684
686
 
685
- pub fn trainers(module: &RModule) -> RbResult<()> {
686
- let trainer = module.define_class("Trainer", Default::default())?;
687
+ pub fn init_trainers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
688
+ let trainer = module.define_class("Trainer", ruby.class_object())?;
687
689
 
688
690
  let class = module.define_class("BpeTrainer", trainer)?;
689
691
  class.define_singleton_method("_new", function!(RbBpeTrainer::new, 1))?;
@@ -1,5 +1,6 @@
1
1
  use super::regex::{regex, RbRegex};
2
2
  use crate::RbResult;
3
+ use magnus::prelude::*;
3
4
  use magnus::{exception, Error, TryConvert, Value};
4
5
  use tk::normalizer::SplitDelimiterBehavior;
5
6
  use tk::pattern::Pattern;
@@ -13,9 +14,9 @@ pub enum RbPattern<'p> {
13
14
  impl TryConvert for RbPattern<'_> {
14
15
  fn try_convert(obj: Value) -> RbResult<Self> {
15
16
  if obj.is_kind_of(regex()) {
16
- Ok(RbPattern::Regex(obj.try_convert()?))
17
+ Ok(RbPattern::Regex(TryConvert::try_convert(obj)?))
17
18
  } else {
18
- Ok(RbPattern::Str(obj.try_convert()?))
19
+ Ok(RbPattern::Str(TryConvert::try_convert(obj)?))
19
20
  }
20
21
  }
21
22
  }
@@ -61,7 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
61
62
 
62
63
  impl TryConvert for RbSplitDelimiterBehavior {
63
64
  fn try_convert(obj: Value) -> RbResult<Self> {
64
- let s = obj.try_convert::<String>()?;
65
+ let s = String::try_convert(obj)?;
65
66
 
66
67
  Ok(Self(match s.as_str() {
67
68
  "removed" => Ok(SplitDelimiterBehavior::Removed),
@@ -1,6 +1,6 @@
1
1
  use onig::Regex;
2
- use magnus::{exception, memoize, Error, Module, RClass};
3
- use crate::{module, RbResult};
2
+ use magnus::{exception, prelude::*, value::Lazy, Error, RClass, Ruby};
3
+ use crate::{RbResult, TOKENIZERS};
4
4
 
5
5
  #[magnus::wrap(class = "Tokenizers::Regex")]
6
6
  pub struct RbRegex {
@@ -17,6 +17,8 @@ impl RbRegex {
17
17
  }
18
18
  }
19
19
 
20
+ static REGEX: Lazy<RClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Regex").unwrap());
21
+
20
22
  pub fn regex() -> RClass {
21
- *memoize!(RClass: module().const_get("Regex").unwrap())
23
+ Ruby::get().unwrap().get_inner(&REGEX)
22
24
  }
@@ -0,0 +1,9 @@
1
+ module Tokenizers
2
+ module Decoders
3
+ class Strip
4
+ def self.new(content: " ", start: 0, stop: 0)
5
+ _new(content, start, stop)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,7 +1,7 @@
1
1
  module Tokenizers
2
2
  module FromPretrained
3
3
  # for user agent
4
- TOKENIZERS_VERSION = "0.13.2"
4
+ TOKENIZERS_VERSION = "0.14.0"
5
5
 
6
6
  # use Ruby for downloads
7
7
  # this avoids the need to vendor OpenSSL on Linux