tokenizers 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,6 +2,7 @@ use std::cell::RefCell;
2
2
  use std::collections::HashMap;
3
3
  use std::path::PathBuf;
4
4
 
5
+ use magnus::prelude::*;
5
6
  use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
6
7
  use tk::tokenizer::{
7
8
  Model, PaddingDirection, PaddingParams, PaddingStrategy,
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
78
79
 
79
80
  impl<'s> TryConvert for TextInputSequence<'s> {
80
81
  fn try_convert(ob: Value) -> RbResult<Self> {
81
- Ok(Self(ob.try_convert::<String>()?.into()))
82
+ Ok(Self(String::try_convert(ob)?.into()))
82
83
  }
83
84
  }
84
85
 
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
92
93
 
93
94
  impl TryConvert for RbArrayStr {
94
95
  fn try_convert(ob: Value) -> RbResult<Self> {
95
- let seq = ob.try_convert::<Vec<String>>()?;
96
+ let seq = <Vec<String>>::try_convert(ob)?;
96
97
  Ok(Self(seq))
97
98
  }
98
99
  }
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
107
108
 
108
109
  impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
109
110
  fn try_convert(ob: Value) -> RbResult<Self> {
110
- if let Ok(seq) = ob.try_convert::<RbArrayStr>() {
111
+ if let Ok(seq) = RbArrayStr::try_convert(ob) {
111
112
  return Ok(Self(seq.into()));
112
113
  }
113
114
  todo!()
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
124
125
 
125
126
  impl<'s> TryConvert for TextEncodeInput<'s> {
126
127
  fn try_convert(ob: Value) -> RbResult<Self> {
127
- if let Ok(i) = ob.try_convert::<TextInputSequence>() {
128
+ if let Ok(i) = TextInputSequence::try_convert(ob) {
128
129
  return Ok(Self(i.into()));
129
130
  }
130
- if let Ok((i1, i2)) = ob.try_convert::<(TextInputSequence, TextInputSequence)>() {
131
+ if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
131
132
  return Ok(Self((i1, i2).into()));
132
133
  }
133
134
  // TODO check if this branch is needed
134
- if let Ok(arr) = ob.try_convert::<RArray>() {
135
+ if let Ok(arr) = RArray::try_convert(ob) {
135
136
  if arr.len() == 2 {
136
137
  let first = arr.entry::<TextInputSequence>(0).unwrap();
137
138
  let second = arr.entry::<TextInputSequence>(1).unwrap();
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
155
156
 
156
157
  impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
157
158
  fn try_convert(ob: Value) -> RbResult<Self> {
158
- if let Ok(i) = ob.try_convert::<PreTokenizedInputSequence>() {
159
+ if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
159
160
  return Ok(Self(i.into()));
160
161
  }
161
162
  if let Ok((i1, i2)) =
162
- ob.try_convert::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
163
+ <(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
163
164
  {
164
165
  return Ok(Self((i1, i2).into()));
165
166
  }
166
167
  // TODO check if this branch is needed
167
- if let Ok(arr) = ob.try_convert::<RArray>() {
168
+ if let Ok(arr) = RArray::try_convert(ob) {
168
169
  if arr.len() == 2 {
169
170
  let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
170
171
  let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
@@ -251,16 +252,16 @@ impl RbTokenizer {
251
252
  add_special_tokens: bool,
252
253
  ) -> RbResult<RbEncoding> {
253
254
  let sequence: tk::InputSequence = if is_pretokenized {
254
- sequence.try_convert::<PreTokenizedInputSequence>()?.into()
255
+ PreTokenizedInputSequence::try_convert(sequence)?.into()
255
256
  } else {
256
- sequence.try_convert::<TextInputSequence>()?.into()
257
+ TextInputSequence::try_convert(sequence)?.into()
257
258
  };
258
259
  let input = match pair {
259
260
  Some(pair) => {
260
261
  let pair: tk::InputSequence = if is_pretokenized {
261
- pair.try_convert::<PreTokenizedInputSequence>()?.into()
262
+ PreTokenizedInputSequence::try_convert(pair)?.into()
262
263
  } else {
263
- pair.try_convert::<TextInputSequence>()?.into()
264
+ TextInputSequence::try_convert(pair)?.into()
264
265
  };
265
266
  tk::EncodeInput::Dual(sequence, pair)
266
267
  }
@@ -284,9 +285,9 @@ impl RbTokenizer {
284
285
  .each()
285
286
  .map(|o| {
286
287
  let input: tk::EncodeInput = if is_pretokenized {
287
- o?.try_convert::<PreTokenizedEncodeInput>()?.into()
288
+ PreTokenizedEncodeInput::try_convert(o?)?.into()
288
289
  } else {
289
- o?.try_convert::<TextEncodeInput>()?.into()
290
+ TextEncodeInput::try_convert(o?)?.into()
290
291
  };
291
292
  Ok(input)
292
293
  })
@@ -306,14 +307,15 @@ impl RbTokenizer {
306
307
  pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
307
308
  self.tokenizer
308
309
  .borrow()
309
- .decode(ids, skip_special_tokens)
310
+ .decode(&ids, skip_special_tokens)
310
311
  .map_err(RbError::from)
311
312
  }
312
313
 
313
314
  pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
315
+ let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
314
316
  self.tokenizer
315
317
  .borrow()
316
- .decode_batch(sequences, skip_special_tokens)
318
+ .decode_batch(&slices, skip_special_tokens)
317
319
  .map_err(RbError::from)
318
320
  }
319
321
 
@@ -353,7 +355,7 @@ impl RbTokenizer {
353
355
 
354
356
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
355
357
  if !value.is_nil() {
356
- let dir_str: String = value.try_convert()?;
358
+ let dir_str = String::try_convert(value)?;
357
359
  params.direction = match dir_str.as_str() {
358
360
  "left" => PaddingDirection::Left,
359
361
  "right" => PaddingDirection::Right,
@@ -363,29 +365,29 @@ impl RbTokenizer {
363
365
 
364
366
  let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
365
367
  if !value.is_nil() {
366
- params.pad_to_multiple_of = value.try_convert()?;
368
+ params.pad_to_multiple_of = TryConvert::try_convert(value)?;
367
369
  }
368
370
 
369
371
  let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
370
372
  if !value.is_nil() {
371
- params.pad_id = value.try_convert()?;
373
+ params.pad_id = TryConvert::try_convert(value)?;
372
374
  }
373
375
 
374
376
  let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
375
377
  if !value.is_nil() {
376
- params.pad_type_id = value.try_convert()?;
378
+ params.pad_type_id = TryConvert::try_convert(value)?;
377
379
  }
378
380
 
379
381
  let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
380
382
  if !value.is_nil() {
381
- params.pad_token = value.try_convert()?;
383
+ params.pad_token = TryConvert::try_convert(value)?;
382
384
  }
383
385
 
384
386
  let value: Value = kwargs.delete(Symbol::new("length"))?;
385
387
  if value.is_nil() {
386
388
  params.strategy = PaddingStrategy::BatchLongest;
387
389
  } else {
388
- params.strategy = PaddingStrategy::Fixed(value.try_convert()?);
390
+ params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
389
391
  }
390
392
 
391
393
  if !kwargs.is_empty() {
@@ -431,12 +433,12 @@ impl RbTokenizer {
431
433
 
432
434
  let value: Value = kwargs.delete(Symbol::new("stride"))?;
433
435
  if !value.is_nil() {
434
- params.stride = value.try_convert()?;
436
+ params.stride = TryConvert::try_convert(value)?;
435
437
  }
436
438
 
437
439
  let value: Value = kwargs.delete(Symbol::new("strategy"))?;
438
440
  if !value.is_nil() {
439
- let strategy_str: String = value.try_convert()?;
441
+ let strategy_str = String::try_convert(value)?;
440
442
  params.strategy = match strategy_str.as_str() {
441
443
  "longest_first" => TruncationStrategy::LongestFirst,
442
444
  "only_first" => TruncationStrategy::OnlyFirst,
@@ -447,7 +449,7 @@ impl RbTokenizer {
447
449
 
448
450
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
449
451
  if !value.is_nil() {
450
- let dir_str: String = value.try_convert()?;
452
+ let dir_str = String::try_convert(value)?;
451
453
  params.direction = match dir_str.as_str() {
452
454
  "left" => TruncationDirection::Left,
453
455
  "right" => TruncationDirection::Right,
@@ -460,13 +462,18 @@ impl RbTokenizer {
460
462
  return Err(Error::new(exception::arg_error(), "unknown keyword"));
461
463
  }
462
464
 
463
- self.tokenizer.borrow_mut().with_truncation(Some(params));
465
+ if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
466
+ return Err(Error::new(exception::arg_error(), error_message.to_string()));
467
+ }
464
468
 
465
469
  Ok(())
466
470
  }
467
471
 
468
472
  pub fn no_truncation(&self) {
469
- self.tokenizer.borrow_mut().with_truncation(None);
473
+ self.tokenizer
474
+ .borrow_mut()
475
+ .with_truncation(None)
476
+ .expect("Failed to set truncation to `None`! This should never happen");
470
477
  }
471
478
 
472
479
  pub fn truncation(&self) -> RbResult<Option<RHash>> {
@@ -3,16 +3,16 @@ use std::sync::{Arc, RwLock};
3
3
 
4
4
  use crate::models::RbModel;
5
5
  use crate::tokenizer::RbAddedToken;
6
- use magnus::typed_data::DataTypeBuilder;
6
+ use magnus::prelude::*;
7
7
  use magnus::{
8
- exception, function, memoize, method, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
- RArray, RClass, RHash, RModule, Symbol, TypedData, Value,
8
+ data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
9
+ RArray, RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
10
10
  };
11
11
  use serde::{Deserialize, Serialize};
12
12
  use tk::models::TrainerWrapper;
13
13
  use tk::Trainer;
14
14
 
15
- use super::RbResult;
15
+ use super::{RbResult, TRAINERS};
16
16
 
17
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
18
18
  pub struct RbTrainer {
@@ -112,7 +112,7 @@ impl RbTrainer {
112
112
  special_tokens
113
113
  .each()
114
114
  .map(|token| {
115
- if let Ok(content) = token?.try_convert::<String>() {
115
+ if let Ok(content) = String::try_convert(token?) {
116
116
  Ok(RbAddedToken::from(content, Some(true)).get_token())
117
117
  } else {
118
118
  todo!()
@@ -144,7 +144,7 @@ impl RbTrainer {
144
144
  self,
145
145
  BpeTrainer,
146
146
  initial_alphabet,
147
- alphabet.into_iter().map(|c| c).collect()
147
+ alphabet.into_iter().collect()
148
148
  );
149
149
  }
150
150
 
@@ -199,7 +199,7 @@ impl RbTrainer {
199
199
  special_tokens
200
200
  .each()
201
201
  .map(|token| {
202
- if let Ok(content) = token?.try_convert::<String>() {
202
+ if let Ok(content) = String::try_convert(token?) {
203
203
  Ok(RbAddedToken::from(content, Some(true)).get_token())
204
204
  } else {
205
205
  todo!()
@@ -223,7 +223,7 @@ impl RbTrainer {
223
223
  self,
224
224
  UnigramTrainer,
225
225
  initial_alphabet,
226
- alphabet.into_iter().map(|c| c).collect()
226
+ alphabet.into_iter().collect()
227
227
  );
228
228
  }
229
229
 
@@ -270,7 +270,7 @@ impl RbTrainer {
270
270
  special_tokens
271
271
  .each()
272
272
  .map(|token| {
273
- if let Ok(content) = token?.try_convert::<String>() {
273
+ if let Ok(content) = String::try_convert(token?) {
274
274
  Ok(RbAddedToken::from(content, Some(true)).get_token())
275
275
  } else {
276
276
  todo!()
@@ -324,7 +324,7 @@ impl RbTrainer {
324
324
  special_tokens
325
325
  .each()
326
326
  .map(|token| {
327
- if let Ok(content) = token?.try_convert::<String>() {
327
+ if let Ok(content) = String::try_convert(token?) {
328
328
  Ok(RbAddedToken::from(content, Some(true)).get_token())
329
329
  } else {
330
330
  todo!()
@@ -356,7 +356,7 @@ impl RbTrainer {
356
356
  self,
357
357
  WordPieceTrainer,
358
358
  @set_initial_alphabet,
359
- alphabet.into_iter().map(|c| c).collect()
359
+ alphabet.into_iter().collect()
360
360
  );
361
361
  }
362
362
 
@@ -397,11 +397,10 @@ impl RbBpeTrainer {
397
397
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
398
398
  if !value.is_nil() {
399
399
  builder = builder.special_tokens(
400
- value
401
- .try_convert::<RArray>()?
400
+ RArray::try_convert(value)?
402
401
  .each()
403
402
  .map(|token| {
404
- if let Ok(content) = token?.try_convert::<String>() {
403
+ if let Ok(content) = String::try_convert(token?) {
405
404
  Ok(RbAddedToken::from(content, Some(true)).get_token())
406
405
  } else {
407
406
  todo!()
@@ -413,39 +412,39 @@ impl RbBpeTrainer {
413
412
 
414
413
  let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
415
414
  if !value.is_nil() {
416
- let arr = value.try_convert::<Vec<char>>()?;
415
+ let arr = <Vec<char>>::try_convert(value)?;
417
416
  let set: HashSet<char> = HashSet::from_iter(arr);
418
417
  builder = builder.initial_alphabet(set);
419
418
  }
420
419
 
421
420
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
422
421
  if !value.is_nil() {
423
- builder = builder.vocab_size(value.try_convert()?);
422
+ builder = builder.vocab_size(TryConvert::try_convert(value)?);
424
423
  }
425
424
 
426
425
  let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
427
426
  if !value.is_nil() {
428
- builder = builder.min_frequency(value.try_convert()?);
427
+ builder = builder.min_frequency(TryConvert::try_convert(value)?);
429
428
  }
430
429
 
431
430
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
432
431
  if !value.is_nil() {
433
- builder = builder.show_progress(value.try_convert()?);
432
+ builder = builder.show_progress(TryConvert::try_convert(value)?);
434
433
  }
435
434
 
436
435
  let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
437
436
  if !value.is_nil() {
438
- builder = builder.limit_alphabet(value.try_convert()?);
437
+ builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
439
438
  }
440
439
 
441
440
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
442
441
  if !value.is_nil() {
443
- builder = builder.continuing_subword_prefix(value.try_convert()?);
442
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
444
443
  }
445
444
 
446
445
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
447
446
  if !value.is_nil() {
448
- builder = builder.end_of_word_suffix(value.try_convert()?);
447
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
449
448
  }
450
449
 
451
450
  if !kwargs.is_empty() {
@@ -466,11 +465,10 @@ impl RbUnigramTrainer {
466
465
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
467
466
  if !value.is_nil() {
468
467
  builder.special_tokens(
469
- value
470
- .try_convert::<RArray>()?
468
+ RArray::try_convert(value)?
471
469
  .each()
472
470
  .map(|token| {
473
- if let Ok(content) = token?.try_convert::<String>() {
471
+ if let Ok(content) = String::try_convert(token?) {
474
472
  Ok(RbAddedToken::from(content, Some(true)).get_token())
475
473
  } else {
476
474
  todo!()
@@ -482,44 +480,44 @@ impl RbUnigramTrainer {
482
480
 
483
481
  let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
484
482
  if !value.is_nil() {
485
- let arr = value.try_convert::<Vec<char>>()?;
483
+ let arr = <Vec<char>>::try_convert(value)?;
486
484
  let set: HashSet<char> = HashSet::from_iter(arr);
487
485
  builder.initial_alphabet(set);
488
486
  }
489
487
 
490
488
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
491
489
  if !value.is_nil() {
492
- builder.vocab_size(value.try_convert()?);
490
+ builder.vocab_size(TryConvert::try_convert(value)?);
493
491
  }
494
492
 
495
493
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
496
494
  if !value.is_nil() {
497
- builder.show_progress(value.try_convert()?);
495
+ builder.show_progress(TryConvert::try_convert(value)?);
498
496
  }
499
497
 
500
498
  let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
501
499
  if !value.is_nil() {
502
- builder.n_sub_iterations(value.try_convert()?);
500
+ builder.n_sub_iterations(TryConvert::try_convert(value)?);
503
501
  }
504
502
 
505
503
  let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
506
504
  if !value.is_nil() {
507
- builder.unk_token(Some(value.try_convert()?));
505
+ builder.unk_token(Some(TryConvert::try_convert(value)?));
508
506
  }
509
507
 
510
508
  let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
511
509
  if !value.is_nil() {
512
- builder.max_piece_length(value.try_convert()?);
510
+ builder.max_piece_length(TryConvert::try_convert(value)?);
513
511
  }
514
512
 
515
513
  let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
516
514
  if !value.is_nil() {
517
- builder.seed_size(value.try_convert()?);
515
+ builder.seed_size(TryConvert::try_convert(value)?);
518
516
  }
519
517
 
520
518
  let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
521
519
  if !value.is_nil() {
522
- builder.shrinking_factor(value.try_convert()?);
520
+ builder.shrinking_factor(TryConvert::try_convert(value)?);
523
521
  }
524
522
 
525
523
  if !kwargs.is_empty() {
@@ -541,11 +539,10 @@ impl RbWordLevelTrainer {
541
539
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
542
540
  if !value.is_nil() {
543
541
  builder.special_tokens(
544
- value
545
- .try_convert::<RArray>()?
542
+ RArray::try_convert(value)?
546
543
  .each()
547
544
  .map(|token| {
548
- if let Ok(content) = token?.try_convert::<String>() {
545
+ if let Ok(content) = String::try_convert(token?) {
549
546
  Ok(RbAddedToken::from(content, Some(true)).get_token())
550
547
  } else {
551
548
  todo!()
@@ -557,17 +554,17 @@ impl RbWordLevelTrainer {
557
554
 
558
555
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
559
556
  if !value.is_nil() {
560
- builder.vocab_size(value.try_convert()?);
557
+ builder.vocab_size(TryConvert::try_convert(value)?);
561
558
  }
562
559
 
563
560
  let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
564
561
  if !value.is_nil() {
565
- builder.min_frequency(value.try_convert()?);
562
+ builder.min_frequency(TryConvert::try_convert(value)?);
566
563
  }
567
564
 
568
565
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
569
566
  if !value.is_nil() {
570
- builder.show_progress(value.try_convert()?);
567
+ builder.show_progress(TryConvert::try_convert(value)?);
571
568
  }
572
569
 
573
570
  Ok(builder.build().expect("WordLevelTrainerBuilder cannot fail").into())
@@ -583,11 +580,10 @@ impl RbWordPieceTrainer {
583
580
  let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
584
581
  if !value.is_nil() {
585
582
  builder = builder.special_tokens(
586
- value
587
- .try_convert::<RArray>()?
583
+ RArray::try_convert(value)?
588
584
  .each()
589
585
  .map(|token| {
590
- if let Ok(content) = token?.try_convert::<String>() {
586
+ if let Ok(content) = String::try_convert(token?) {
591
587
  Ok(RbAddedToken::from(content, Some(true)).get_token())
592
588
  } else {
593
589
  todo!()
@@ -599,39 +595,39 @@ impl RbWordPieceTrainer {
599
595
 
600
596
  let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
601
597
  if !value.is_nil() {
602
- let arr = value.try_convert::<Vec<char>>()?;
598
+ let arr = <Vec<char>>::try_convert(value)?;
603
599
  let set: HashSet<char> = HashSet::from_iter(arr);
604
600
  builder = builder.initial_alphabet(set);
605
601
  }
606
602
 
607
603
  let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
608
604
  if !value.is_nil() {
609
- builder = builder.vocab_size(value.try_convert()?);
605
+ builder = builder.vocab_size(TryConvert::try_convert(value)?);
610
606
  }
611
607
 
612
608
  let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
613
609
  if !value.is_nil() {
614
- builder = builder.min_frequency(value.try_convert()?);
610
+ builder = builder.min_frequency(TryConvert::try_convert(value)?);
615
611
  }
616
612
 
617
613
  let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
618
614
  if !value.is_nil() {
619
- builder = builder.show_progress(value.try_convert()?);
615
+ builder = builder.show_progress(TryConvert::try_convert(value)?);
620
616
  }
621
617
 
622
618
  let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
623
619
  if !value.is_nil() {
624
- builder = builder.limit_alphabet(value.try_convert()?);
620
+ builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
625
621
  }
626
622
 
627
623
  let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
628
624
  if !value.is_nil() {
629
- builder = builder.continuing_subword_prefix(value.try_convert()?);
625
+ builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
630
626
  }
631
627
 
632
628
  let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
633
629
  if !value.is_nil() {
634
- builder = builder.end_of_word_suffix(value.try_convert()?);
630
+ builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
635
631
  }
636
632
 
637
633
  if !kwargs.is_empty() {
@@ -644,46 +640,52 @@ impl RbWordPieceTrainer {
644
640
  }
645
641
 
646
642
  unsafe impl TypedData for RbTrainer {
647
- fn class() -> RClass {
648
- *memoize!(RClass: {
649
- let class: RClass = crate::trainers().const_get("Trainer").unwrap();
650
- class.undef_alloc_func();
651
- class
652
- })
643
+ fn class(ruby: &Ruby) -> RClass {
644
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
645
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("Trainer").unwrap();
646
+ class.undef_default_alloc_func();
647
+ class
648
+ });
649
+ ruby.get_inner(&CLASS)
653
650
  }
654
651
 
655
652
  fn data_type() -> &'static DataType {
656
- memoize!(DataType: DataTypeBuilder::<RbTrainer>::new("Tokenizers::Trainers::Trainer").build())
657
- }
658
-
659
- fn class_for(value: &Self) -> RClass {
653
+ static DATA_TYPE: DataType = data_type_builder!(RbTrainer, "Tokenizers::Trainers::Trainer").build();
654
+ &DATA_TYPE
655
+ }
656
+
657
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
658
+ static BPE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
659
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("BpeTrainer").unwrap();
660
+ class.undef_default_alloc_func();
661
+ class
662
+ });
663
+ static UNIGRAM_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
664
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("UnigramTrainer").unwrap();
665
+ class.undef_default_alloc_func();
666
+ class
667
+ });
668
+ static WORD_LEVEL_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
669
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordLevelTrainer").unwrap();
670
+ class.undef_default_alloc_func();
671
+ class
672
+ });
673
+ static WORD_PIECE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
674
+ let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordPieceTrainer").unwrap();
675
+ class.undef_default_alloc_func();
676
+ class
677
+ });
660
678
  match *value.trainer.read().unwrap() {
661
- TrainerWrapper::BpeTrainer(_) => *memoize!(RClass: {
662
- let class: RClass = crate::trainers().const_get("BpeTrainer").unwrap();
663
- class.undef_alloc_func();
664
- class
665
- }),
666
- TrainerWrapper::UnigramTrainer(_) => *memoize!(RClass: {
667
- let class: RClass = crate::trainers().const_get("UnigramTrainer").unwrap();
668
- class.undef_alloc_func();
669
- class
670
- }),
671
- TrainerWrapper::WordLevelTrainer(_) => *memoize!(RClass: {
672
- let class: RClass = crate::trainers().const_get("WordLevelTrainer").unwrap();
673
- class.undef_alloc_func();
674
- class
675
- }),
676
- TrainerWrapper::WordPieceTrainer(_) => *memoize!(RClass: {
677
- let class: RClass = crate::trainers().const_get("WordPieceTrainer").unwrap();
678
- class.undef_alloc_func();
679
- class
680
- }),
679
+ TrainerWrapper::BpeTrainer(_) => ruby.get_inner(&BPE_TRAINER),
680
+ TrainerWrapper::UnigramTrainer(_) => ruby.get_inner(&UNIGRAM_TRAINER),
681
+ TrainerWrapper::WordLevelTrainer(_) => ruby.get_inner(&WORD_LEVEL_TRAINER),
682
+ TrainerWrapper::WordPieceTrainer(_) => ruby.get_inner(&WORD_PIECE_TRAINER),
681
683
  }
682
684
  }
683
685
  }
684
686
 
685
- pub fn trainers(module: &RModule) -> RbResult<()> {
686
- let trainer = module.define_class("Trainer", Default::default())?;
687
+ pub fn init_trainers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
688
+ let trainer = module.define_class("Trainer", ruby.class_object())?;
687
689
 
688
690
  let class = module.define_class("BpeTrainer", trainer)?;
689
691
  class.define_singleton_method("_new", function!(RbBpeTrainer::new, 1))?;
@@ -1,5 +1,6 @@
1
1
  use super::regex::{regex, RbRegex};
2
2
  use crate::RbResult;
3
+ use magnus::prelude::*;
3
4
  use magnus::{exception, Error, TryConvert, Value};
4
5
  use tk::normalizer::SplitDelimiterBehavior;
5
6
  use tk::pattern::Pattern;
@@ -13,9 +14,9 @@ pub enum RbPattern<'p> {
13
14
  impl TryConvert for RbPattern<'_> {
14
15
  fn try_convert(obj: Value) -> RbResult<Self> {
15
16
  if obj.is_kind_of(regex()) {
16
- Ok(RbPattern::Regex(obj.try_convert()?))
17
+ Ok(RbPattern::Regex(TryConvert::try_convert(obj)?))
17
18
  } else {
18
- Ok(RbPattern::Str(obj.try_convert()?))
19
+ Ok(RbPattern::Str(TryConvert::try_convert(obj)?))
19
20
  }
20
21
  }
21
22
  }
@@ -61,7 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
61
62
 
62
63
  impl TryConvert for RbSplitDelimiterBehavior {
63
64
  fn try_convert(obj: Value) -> RbResult<Self> {
64
- let s = obj.try_convert::<String>()?;
65
+ let s = String::try_convert(obj)?;
65
66
 
66
67
  Ok(Self(match s.as_str() {
67
68
  "removed" => Ok(SplitDelimiterBehavior::Removed),
@@ -1,6 +1,6 @@
1
1
  use onig::Regex;
2
- use magnus::{exception, memoize, Error, Module, RClass};
3
- use crate::{module, RbResult};
2
+ use magnus::{exception, prelude::*, value::Lazy, Error, RClass, Ruby};
3
+ use crate::{RbResult, TOKENIZERS};
4
4
 
5
5
  #[magnus::wrap(class = "Tokenizers::Regex")]
6
6
  pub struct RbRegex {
@@ -17,6 +17,8 @@ impl RbRegex {
17
17
  }
18
18
  }
19
19
 
20
+ static REGEX: Lazy<RClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Regex").unwrap());
21
+
20
22
  pub fn regex() -> RClass {
21
- *memoize!(RClass: module().const_get("Regex").unwrap())
23
+ Ruby::get().unwrap().get_inner(&REGEX)
22
24
  }
@@ -0,0 +1,9 @@
1
+ module Tokenizers
2
+ module Decoders
3
+ class Strip
4
+ def self.new(content: " ", start: 0, stop: 0)
5
+ _new(content, start, stop)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,7 +1,7 @@
1
1
  module Tokenizers
2
2
  module FromPretrained
3
3
  # for user agent
4
- TOKENIZERS_VERSION = "0.13.2"
4
+ TOKENIZERS_VERSION = "0.14.0"
5
5
 
6
6
  # use Ruby for downloads
7
7
  # this avoids the need to vendor OpenSSL on Linux