tokenizers 0.6.3 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,10 @@
1
- use std::cell::RefCell;
2
1
  use std::collections::HashMap;
3
2
  use std::path::PathBuf;
4
3
  use std::str::FromStr;
4
+ use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
5
5
 
6
6
  use magnus::prelude::*;
7
- use magnus::{Error, RArray, RHash, RString, Ruby, TryConvert, Value};
7
+ use magnus::{function, method, Error, RArray, RHash, RModule, RString, Ruby, TryConvert, Value};
8
8
  use tk::tokenizer::{
9
9
  Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
10
10
  TruncationParams, TruncationStrategy,
@@ -19,6 +19,7 @@ use super::models::RbModel;
19
19
  use super::normalizers::RbNormalizer;
20
20
  use super::pre_tokenizers::RbPreTokenizer;
21
21
  use super::processors::RbPostProcessor;
22
+ use super::ruby::GvlExt;
22
23
  use super::trainers::RbTrainer;
23
24
  use super::{RbError, RbResult};
24
25
 
@@ -72,7 +73,7 @@ impl From<tk::AddedToken> for RbAddedToken {
72
73
  lstrip: Some(token.lstrip),
73
74
  rstrip: Some(token.rstrip),
74
75
  normalized: Some(token.normalized),
75
- special: !token.normalized,
76
+ special: token.special,
76
77
  }
77
78
  }
78
79
  }
@@ -200,8 +201,8 @@ impl TryConvert for TextEncodeInput<'_> {
200
201
  // TODO check if this branch is needed
201
202
  if let Ok(arr) = RArray::try_convert(ob) {
202
203
  if arr.len() == 2 {
203
- let first = arr.entry::<TextInputSequence>(0).unwrap();
204
- let second = arr.entry::<TextInputSequence>(1).unwrap();
204
+ let first = arr.entry::<TextInputSequence>(0)?;
205
+ let second = arr.entry::<TextInputSequence>(1)?;
205
206
  return Ok(Self((first, second).into()));
206
207
  }
207
208
  }
@@ -235,8 +236,8 @@ impl TryConvert for PreTokenizedEncodeInput<'_> {
235
236
  // TODO check if this branch is needed
236
237
  if let Ok(arr) = RArray::try_convert(ob) {
237
238
  if arr.len() == 2 {
238
- let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
239
- let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
239
+ let first = arr.entry::<PreTokenizedInputSequence>(0)?;
240
+ let second = arr.entry::<PreTokenizedInputSequence>(1)?;
240
241
  return Ok(Self((first, second).into()));
241
242
  }
242
243
  }
@@ -257,202 +258,170 @@ type Tokenizer = TokenizerImpl<RbModel, RbNormalizer, RbPreTokenizer, RbPostProc
257
258
 
258
259
  #[magnus::wrap(class = "Tokenizers::Tokenizer")]
259
260
  pub struct RbTokenizer {
260
- tokenizer: RefCell<Tokenizer>,
261
+ tokenizer: Arc<RwLock<Tokenizer>>,
262
+ }
263
+
264
+ impl Clone for RbTokenizer {
265
+ fn clone(&self) -> Self {
266
+ RbTokenizer {
267
+ tokenizer: Arc::clone(&self.tokenizer),
268
+ }
269
+ }
261
270
  }
262
271
 
263
272
  impl RbTokenizer {
264
273
  pub fn new(tokenizer: Tokenizer) -> Self {
265
274
  Self {
266
- tokenizer: RefCell::new(tokenizer),
275
+ tokenizer: Arc::new(RwLock::new(tokenizer)),
267
276
  }
268
277
  }
269
278
 
279
+ /// Acquire the inner tokenizer for reading; surfaces lock poisoning as a
280
+ /// `PyException` instead of panicking.
281
+ pub(crate) fn read_inner(&self) -> RbResult<RwLockReadGuard<'_, Tokenizer>> {
282
+ self.tokenizer
283
+ .read()
284
+ .map_err(|_| RbError::new_err("Tokenizer RwLock is poisoned"))
285
+ }
286
+
287
+ /// Acquire the inner tokenizer for writing; surfaces lock poisoning as a
288
+ /// `PyException` instead of panicking.
289
+ pub(crate) fn write_inner(&self) -> RbResult<RwLockWriteGuard<'_, Tokenizer>> {
290
+ self.tokenizer
291
+ .write()
292
+ .map_err(|_| RbError::new_err("Tokenizer RwLock is poisoned"))
293
+ }
294
+
270
295
  pub fn from_model(model: &RbModel) -> Self {
271
296
  RbTokenizer::new(TokenizerImpl::new(model.clone()))
272
297
  }
273
298
 
274
299
  pub fn from_str(json: RString) -> RbResult<Self> {
275
- Tokenizer::from_str(unsafe { json.as_str()? })
276
- .map(|v| RbTokenizer {
277
- tokenizer: RefCell::new(v),
278
- })
279
- .map_err(RbError::from)
300
+ let tokenizer = Tokenizer::from_str(unsafe { json.as_str()? }).map_err(RbError::from);
301
+ Ok(Self::new(tokenizer?))
280
302
  }
281
303
 
282
304
  pub fn from_file(path: PathBuf) -> RbResult<Self> {
283
- Tokenizer::from_file(path)
284
- .map(|v| RbTokenizer {
285
- tokenizer: RefCell::new(v),
286
- })
287
- .map_err(RbError::from)
305
+ let tokenizer = Tokenizer::from_file(path).map_err(RbError::from);
306
+ Ok(Self::new(tokenizer?))
288
307
  }
289
308
 
290
309
  pub fn to_str(&self, pretty: bool) -> RbResult<String> {
291
- self.tokenizer
292
- .borrow()
293
- .to_string(pretty)
294
- .map_err(RbError::from)
295
- }
296
-
297
- pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
298
- let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
299
- self.tokenizer.borrow_mut().add_special_tokens(&tokens)
300
- }
301
-
302
- pub fn train(&self, files: Vec<String>, trainer: Option<&RbTrainer>) -> RbResult<()> {
303
- let mut trainer = trainer.map_or_else(
304
- || self.tokenizer.borrow().get_model().get_trainer(),
305
- |t| t.clone(),
306
- );
307
- self.tokenizer
308
- .borrow_mut()
309
- .train_from_files(&mut trainer, files)
310
- .map(|_| {})
311
- .map_err(RbError::from)
310
+ self.read_inner()?.to_string(pretty).map_err(RbError::from)
312
311
  }
313
312
 
314
313
  pub fn save(&self, path: String, pretty: bool) -> RbResult<()> {
315
- self.tokenizer
316
- .borrow()
314
+ self.read_inner()?
317
315
  .save(&path, pretty)
318
316
  .map_err(RbError::from)
319
317
  }
320
318
 
321
- pub fn add_tokens(&self, tokens: Vec<String>) -> usize {
322
- let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
323
- self.tokenizer.borrow_mut().add_tokens(&tokens)
319
+ pub fn num_special_tokens_to_add(&self, is_pair: bool) -> RbResult<usize> {
320
+ Ok(self
321
+ .read_inner()?
322
+ .get_post_processor()
323
+ .map_or(0, |p| p.added_tokens(is_pair)))
324
324
  }
325
325
 
326
- pub fn encode(
327
- &self,
328
- sequence: Value,
329
- pair: Option<Value>,
330
- is_pretokenized: bool,
331
- add_special_tokens: bool,
332
- ) -> RbResult<RbEncoding> {
333
- let sequence: tk::InputSequence = if is_pretokenized {
334
- PreTokenizedInputSequence::try_convert(sequence)?.into()
335
- } else {
336
- TextInputSequence::try_convert(sequence)?.into()
337
- };
338
- let input = match pair {
339
- Some(pair) => {
340
- let pair: tk::InputSequence = if is_pretokenized {
341
- PreTokenizedInputSequence::try_convert(pair)?.into()
342
- } else {
343
- TextInputSequence::try_convert(pair)?.into()
344
- };
345
- tk::EncodeInput::Dual(sequence, pair)
346
- }
347
- None => tk::EncodeInput::Single(sequence),
348
- };
349
-
350
- self.tokenizer
351
- .borrow()
352
- .encode_char_offsets(input, add_special_tokens)
353
- .map(|v| RbEncoding { encoding: v })
354
- .map_err(RbError::from)
326
+ pub fn get_vocab(&self, with_added_tokens: bool) -> RbResult<HashMap<String, u32>> {
327
+ Ok(self.read_inner()?.get_vocab(with_added_tokens))
355
328
  }
356
329
 
357
- pub fn encode_batch(
358
- ruby: &Ruby,
359
- rb_self: &Self,
360
- input: RArray,
361
- is_pretokenized: bool,
362
- add_special_tokens: bool,
363
- ) -> RbResult<RArray> {
364
- let input: Vec<tk::EncodeInput> = input
365
- .into_iter()
366
- .map(|o| {
367
- let input: tk::EncodeInput = if is_pretokenized {
368
- PreTokenizedEncodeInput::try_convert(o)?.into()
369
- } else {
370
- TextEncodeInput::try_convert(o)?.into()
371
- };
372
- Ok(input)
373
- })
374
- .collect::<RbResult<Vec<tk::EncodeInput>>>()?;
375
- rb_self
376
- .tokenizer
377
- .borrow()
378
- .encode_batch_char_offsets(input, add_special_tokens)
379
- .map(|encodings| {
380
- ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into))
381
- })
382
- .map_err(RbError::from)
383
- }
330
+ pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
331
+ let sorted_map = ruby.hash_new();
384
332
 
385
- pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
386
- self.tokenizer
387
- .borrow()
388
- .decode(&ids, skip_special_tokens)
389
- .map_err(RbError::from)
390
- }
333
+ for (key, value) in rb_self.read_inner()?.get_added_tokens_decoder() {
334
+ sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
335
+ }
391
336
 
392
- pub fn decode_batch(
393
- &self,
394
- sequences: Vec<Vec<u32>>,
395
- skip_special_tokens: bool,
396
- ) -> RbResult<Vec<String>> {
397
- let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
398
- self.tokenizer
399
- .borrow()
400
- .decode_batch(&slices, skip_special_tokens)
401
- .map_err(RbError::from)
337
+ Ok(sorted_map)
402
338
  }
403
339
 
404
- pub fn get_model(&self) -> RbModel {
405
- self.tokenizer.borrow().get_model().clone()
340
+ pub fn get_vocab_size(&self, with_added_tokens: bool) -> RbResult<usize> {
341
+ Ok(self.read_inner()?.get_vocab_size(with_added_tokens))
406
342
  }
407
343
 
408
- pub fn set_model(&self, model: &RbModel) {
409
- self.tokenizer.borrow_mut().with_model(model.clone());
410
- }
344
+ pub fn enable_truncation(
345
+ ruby: &Ruby,
346
+ rb_self: &Self,
347
+ max_length: usize,
348
+ kwargs: RHash,
349
+ ) -> RbResult<()> {
350
+ let mut params = TruncationParams {
351
+ max_length,
352
+ ..Default::default()
353
+ };
411
354
 
412
- pub fn get_decoder(&self) -> Option<RbDecoder> {
413
- self.tokenizer.borrow().get_decoder().cloned()
414
- }
355
+ let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
356
+ if !value.is_nil() {
357
+ params.stride = TryConvert::try_convert(value)?;
358
+ }
415
359
 
416
- pub fn set_decoder(&self, decoder: Option<&RbDecoder>) {
417
- self.tokenizer.borrow_mut().with_decoder(decoder.cloned());
418
- }
360
+ let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
361
+ if !value.is_nil() {
362
+ let strategy_str = String::try_convert(value)?;
363
+ params.strategy = match strategy_str.as_str() {
364
+ "longest_first" => TruncationStrategy::LongestFirst,
365
+ "only_first" => TruncationStrategy::OnlyFirst,
366
+ "only_second" => TruncationStrategy::OnlySecond,
367
+ _ => return Err(Error::new(
368
+ ruby.exception_arg_error(),
369
+ "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
370
+ )),
371
+ }
372
+ }
419
373
 
420
- pub fn get_pre_tokenizer(&self) -> Option<RbPreTokenizer> {
421
- self.tokenizer.borrow().get_pre_tokenizer().cloned()
422
- }
374
+ let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
375
+ if !value.is_nil() {
376
+ let dir_str = String::try_convert(value)?;
377
+ params.direction = match dir_str.as_str() {
378
+ "left" => TruncationDirection::Left,
379
+ "right" => TruncationDirection::Right,
380
+ _ => {
381
+ return Err(Error::new(
382
+ ruby.exception_arg_error(),
383
+ "The direction value must be 'left' or 'right'",
384
+ ))
385
+ }
386
+ }
387
+ }
423
388
 
424
- pub fn set_pre_tokenizer(&self, pretok: Option<&RbPreTokenizer>) {
425
- self.tokenizer
426
- .borrow_mut()
427
- .with_pre_tokenizer(pretok.cloned());
428
- }
389
+ if !kwargs.is_empty() {
390
+ // TODO improve message
391
+ return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
392
+ }
429
393
 
430
- pub fn get_post_processor(&self) -> Option<RbPostProcessor> {
431
- self.tokenizer.borrow().get_post_processor().cloned()
432
- }
394
+ if let Err(error_message) = rb_self.write_inner()?.with_truncation(Some(params)) {
395
+ return Err(Error::new(
396
+ ruby.exception_arg_error(),
397
+ error_message.to_string(),
398
+ ));
399
+ }
433
400
 
434
- pub fn set_post_processor(&self, processor: Option<&RbPostProcessor>) {
435
- self.tokenizer
436
- .borrow_mut()
437
- .with_post_processor(processor.cloned());
401
+ Ok(())
438
402
  }
439
403
 
440
- pub fn get_normalizer(&self) -> Option<RbNormalizer> {
441
- self.tokenizer.borrow().get_normalizer().cloned()
404
+ pub fn no_truncation(&self) -> RbResult<()> {
405
+ self.write_inner()?
406
+ .with_truncation(None)
407
+ .expect("Failed to set truncation to `None`! This should never happen");
408
+ Ok(())
442
409
  }
443
410
 
444
- pub fn set_normalizer(&self, normalizer: Option<&RbNormalizer>) {
445
- self.tokenizer
446
- .borrow_mut()
447
- .with_normalizer(normalizer.cloned());
448
- }
411
+ pub fn get_truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
412
+ rb_self
413
+ .read_inner()?
414
+ .get_truncation()
415
+ .map_or(Ok(None), |params| {
416
+ let ret_hash = ruby.hash_new();
449
417
 
450
- pub fn token_to_id(&self, token: String) -> Option<u32> {
451
- self.tokenizer.borrow().token_to_id(&token)
452
- }
418
+ ret_hash.aset("max_length", params.max_length)?;
419
+ ret_hash.aset("stride", params.stride)?;
420
+ ret_hash.aset("strategy", params.strategy.as_ref())?;
421
+ ret_hash.aset("direction", params.direction.as_ref())?;
453
422
 
454
- pub fn id_to_token(&self, id: u32) -> Option<String> {
455
- self.tokenizer.borrow().id_to_token(id)
423
+ Ok(Some(ret_hash))
424
+ })
456
425
  }
457
426
 
458
427
  // TODO support more kwargs
@@ -506,19 +475,19 @@ impl RbTokenizer {
506
475
  return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
507
476
  }
508
477
 
509
- rb_self.tokenizer.borrow_mut().with_padding(Some(params));
478
+ rb_self.write_inner()?.with_padding(Some(params));
510
479
 
511
480
  Ok(())
512
481
  }
513
482
 
514
- pub fn no_padding(&self) {
515
- self.tokenizer.borrow_mut().with_padding(None);
483
+ pub fn no_padding(&self) -> RbResult<()> {
484
+ self.write_inner()?.with_padding(None);
485
+ Ok(())
516
486
  }
517
487
 
518
- pub fn padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
488
+ pub fn get_padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
519
489
  rb_self
520
- .tokenizer
521
- .borrow()
490
+ .read_inner()?
522
491
  .get_padding()
523
492
  .map_or(Ok(None), |params| {
524
493
  let ret_hash = ruby.hash_new();
@@ -540,112 +509,280 @@ impl RbTokenizer {
540
509
  })
541
510
  }
542
511
 
543
- pub fn enable_truncation(
544
- ruby: &Ruby,
545
- rb_self: &Self,
546
- max_length: usize,
547
- kwargs: RHash,
548
- ) -> RbResult<()> {
549
- let mut params = TruncationParams {
550
- max_length,
551
- ..Default::default()
512
+ pub fn encode(
513
+ &self,
514
+ sequence: Value,
515
+ pair: Option<Value>,
516
+ is_pretokenized: bool,
517
+ add_special_tokens: bool,
518
+ ) -> RbResult<RbEncoding> {
519
+ let sequence: tk::InputSequence = if is_pretokenized {
520
+ PreTokenizedInputSequence::try_convert(sequence)?.into()
521
+ } else {
522
+ TextInputSequence::try_convert(sequence)?.into()
523
+ };
524
+ let input = match pair {
525
+ Some(pair) => {
526
+ let pair: tk::InputSequence = if is_pretokenized {
527
+ PreTokenizedInputSequence::try_convert(pair)?.into()
528
+ } else {
529
+ TextInputSequence::try_convert(pair)?.into()
530
+ };
531
+ tk::EncodeInput::Dual(sequence, pair)
532
+ }
533
+ None => tk::EncodeInput::Single(sequence),
552
534
  };
553
535
 
554
- let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
555
- if !value.is_nil() {
556
- params.stride = TryConvert::try_convert(value)?;
557
- }
536
+ self.read_inner()?
537
+ .encode_char_offsets(input, add_special_tokens)
538
+ .map(|v| RbEncoding { encoding: v })
539
+ .map_err(RbError::from)
540
+ }
558
541
 
559
- let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
560
- if !value.is_nil() {
561
- let strategy_str = String::try_convert(value)?;
562
- params.strategy = match strategy_str.as_str() {
563
- "longest_first" => TruncationStrategy::LongestFirst,
564
- "only_first" => TruncationStrategy::OnlyFirst,
565
- "only_second" => TruncationStrategy::OnlySecond,
566
- _ => return Err(Error::new(
567
- ruby.exception_arg_error(),
568
- "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
569
- )),
570
- }
571
- }
542
+ pub fn encode_batch(
543
+ ruby: &Ruby,
544
+ rb_self: &Self,
545
+ input: RArray,
546
+ is_pretokenized: bool,
547
+ add_special_tokens: bool,
548
+ ) -> RbResult<RArray> {
549
+ let input: Vec<tk::EncodeInput> = input
550
+ .into_iter()
551
+ .map(|o| {
552
+ let input: tk::EncodeInput = if is_pretokenized {
553
+ PreTokenizedEncodeInput::try_convert(o)?.into()
554
+ } else {
555
+ TextEncodeInput::try_convert(o)?.into()
556
+ };
557
+ Ok(input)
558
+ })
559
+ .collect::<RbResult<Vec<tk::EncodeInput>>>()?;
560
+ ruby.detach(|| {
561
+ rb_self
562
+ .tokenizer
563
+ .read()
564
+ .unwrap()
565
+ .encode_batch_char_offsets(input, add_special_tokens)
566
+ })
567
+ .map(|encodings| ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into)))
568
+ .map_err(RbError::from)
569
+ }
570
+
571
+ pub fn encode_batch_fast(
572
+ ruby: &Ruby,
573
+ rb_self: &Self,
574
+ input: RArray,
575
+ is_pretokenized: bool,
576
+ add_special_tokens: bool,
577
+ ) -> RbResult<RArray> {
578
+ let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
579
+ for item in input {
580
+ let item: tk::EncodeInput = if is_pretokenized {
581
+ PreTokenizedEncodeInput::try_convert(item)?.into()
582
+ } else {
583
+ TextEncodeInput::try_convert(item)?.into()
584
+ };
585
+ items.push(item);
586
+ }
587
+ ruby.detach(|| {
588
+ rb_self
589
+ .tokenizer
590
+ .read()
591
+ .unwrap()
592
+ .encode_batch_fast(items, add_special_tokens)
593
+ })
594
+ .map(|encodings| ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into)))
595
+ .map_err(RbError::from)
596
+ }
572
597
 
573
- let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
574
- if !value.is_nil() {
575
- let dir_str = String::try_convert(value)?;
576
- params.direction = match dir_str.as_str() {
577
- "left" => TruncationDirection::Left,
578
- "right" => TruncationDirection::Right,
579
- _ => {
580
- return Err(Error::new(
581
- ruby.exception_arg_error(),
582
- "The direction value must be 'left' or 'right'",
583
- ))
584
- }
585
- }
586
- }
598
+ pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
599
+ self.read_inner()?
600
+ .decode(&ids, skip_special_tokens)
601
+ .map_err(RbError::from)
602
+ }
587
603
 
588
- if !kwargs.is_empty() {
589
- // TODO improve message
590
- return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
591
- }
604
+ pub fn decode_batch(
605
+ ruby: &Ruby,
606
+ rb_self: &Self,
607
+ sequences: Vec<Vec<u32>>,
608
+ skip_special_tokens: bool,
609
+ ) -> RbResult<Vec<String>> {
610
+ ruby.detach(|| {
611
+ let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
612
+ rb_self
613
+ .tokenizer
614
+ .read()
615
+ .unwrap()
616
+ .decode_batch(&slices, skip_special_tokens)
617
+ })
618
+ .map_err(RbError::from)
619
+ }
592
620
 
593
- if let Err(error_message) = rb_self.tokenizer.borrow_mut().with_truncation(Some(params)) {
594
- return Err(Error::new(
595
- ruby.exception_arg_error(),
596
- error_message.to_string(),
597
- ));
598
- }
621
+ pub fn token_to_id(&self, token: String) -> RbResult<Option<u32>> {
622
+ Ok(self.read_inner()?.token_to_id(&token))
623
+ }
624
+
625
+ pub fn id_to_token(&self, id: u32) -> RbResult<Option<String>> {
626
+ Ok(self.read_inner()?.id_to_token(id))
627
+ }
599
628
 
629
+ pub fn set_encode_special_tokens(&self, value: bool) -> RbResult<()> {
630
+ self.write_inner()?.set_encode_special_tokens(value);
600
631
  Ok(())
601
632
  }
602
633
 
603
- pub fn no_truncation(&self) {
604
- self.tokenizer
605
- .borrow_mut()
606
- .with_truncation(None)
607
- .expect("Failed to set truncation to `None`! This should never happen");
634
+ pub fn get_encode_special_tokens(&self) -> RbResult<bool> {
635
+ Ok(self.read_inner()?.get_encode_special_tokens())
608
636
  }
609
637
 
610
- pub fn truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
611
- rb_self
612
- .tokenizer
613
- .borrow()
614
- .get_truncation()
615
- .map_or(Ok(None), |params| {
616
- let ret_hash = ruby.hash_new();
638
+ pub fn add_tokens(&self, tokens: Vec<String>) -> RbResult<usize> {
639
+ let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
640
+ self.write_inner()?
641
+ .add_tokens(tokens)
642
+ .map_err(RbError::from)
643
+ }
617
644
 
618
- ret_hash.aset("max_length", params.max_length)?;
619
- ret_hash.aset("stride", params.stride)?;
620
- ret_hash.aset("strategy", params.strategy.as_ref())?;
621
- ret_hash.aset("direction", params.direction.as_ref())?;
645
+ pub fn add_special_tokens(&self, tokens: Vec<String>) -> RbResult<usize> {
646
+ let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
647
+ self.write_inner()?
648
+ .add_special_tokens(tokens)
649
+ .map_err(RbError::from)
650
+ }
622
651
 
623
- Ok(Some(ret_hash))
624
- })
652
+ pub fn train(&self, files: Vec<String>, trainer: Option<&RbTrainer>) -> RbResult<()> {
653
+ let mut trainer = match trainer {
654
+ Some(t) => t.clone(),
655
+ None => self.read_inner()?.get_model().get_trainer(),
656
+ };
657
+ self.write_inner()?
658
+ .train_from_files(&mut trainer, files)
659
+ .map(|_| {})
660
+ .map_err(RbError::from)
625
661
  }
626
662
 
627
- pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
628
- self.tokenizer
629
- .borrow()
630
- .get_post_processor()
631
- .map_or(0, |p| p.added_tokens(is_pair))
663
+ pub fn get_model(&self) -> RbResult<RbModel> {
664
+ Ok(self.read_inner()?.get_model().clone())
632
665
  }
633
666
 
634
- pub fn vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
635
- self.tokenizer.borrow().get_vocab(with_added_tokens)
667
+ pub fn set_model(&self, model: &RbModel) -> RbResult<()> {
668
+ self.write_inner()?.with_model(model.clone());
669
+ Ok(())
636
670
  }
637
671
 
638
- pub fn vocab_size(&self, with_added_tokens: bool) -> usize {
639
- self.tokenizer.borrow().get_vocab_size(with_added_tokens)
672
+ pub fn get_normalizer(&self) -> RbResult<Option<RbNormalizer>> {
673
+ Ok(self.read_inner()?.get_normalizer().cloned())
640
674
  }
641
675
 
642
- pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
643
- let sorted_map = ruby.hash_new();
676
+ pub fn set_normalizer(&self, normalizer: Option<&RbNormalizer>) -> RbResult<()> {
677
+ self.write_inner()?
678
+ .with_normalizer(normalizer.cloned())
679
+ .map(|_| ())
680
+ .map_err(RbError::from)
681
+ }
644
682
 
645
- for (key, value) in rb_self.tokenizer.borrow().get_added_tokens_decoder() {
646
- sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
647
- }
683
+ pub fn get_pre_tokenizer(&self) -> RbResult<Option<RbPreTokenizer>> {
684
+ Ok(self.read_inner()?.get_pre_tokenizer().cloned())
685
+ }
648
686
 
649
- Ok(sorted_map)
687
+ pub fn set_pre_tokenizer(&self, pretok: Option<&RbPreTokenizer>) -> RbResult<()> {
688
+ self.write_inner()?.with_pre_tokenizer(pretok.cloned());
689
+ Ok(())
690
+ }
691
+
692
+ pub fn get_post_processor(&self) -> RbResult<Option<RbPostProcessor>> {
693
+ Ok(self.read_inner()?.get_post_processor().cloned())
650
694
  }
695
+
696
+ pub fn set_post_processor(&self, processor: Option<&RbPostProcessor>) -> RbResult<()> {
697
+ self.write_inner()?.with_post_processor(processor.cloned());
698
+ Ok(())
699
+ }
700
+
701
+ pub fn get_decoder(&self) -> RbResult<Option<RbDecoder>> {
702
+ Ok(self.read_inner()?.get_decoder().cloned())
703
+ }
704
+
705
+ pub fn set_decoder(&self, decoder: Option<&RbDecoder>) -> RbResult<()> {
706
+ self.write_inner()?.with_decoder(decoder.cloned());
707
+ Ok(())
708
+ }
709
+ }
710
+
711
+ pub fn init_tokenizer(ruby: &Ruby, module: &RModule) -> RbResult<()> {
712
+ let class = module.define_class("Tokenizer", ruby.class_object())?;
713
+ class.define_singleton_method("new", function!(RbTokenizer::from_model, 1))?;
714
+ class.define_singleton_method("from_str", function!(RbTokenizer::from_str, 1))?;
715
+ class.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
716
+ class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
717
+ class.define_method("_save", method!(RbTokenizer::save, 2))?;
718
+ class.define_method(
719
+ "num_special_tokens_to_add",
720
+ method!(RbTokenizer::num_special_tokens_to_add, 1),
721
+ )?;
722
+ class.define_method("_vocab", method!(RbTokenizer::get_vocab, 1))?;
723
+ class.define_method(
724
+ "added_tokens_decoder",
725
+ method!(RbTokenizer::get_added_tokens_decoder, 0),
726
+ )?;
727
+ class.define_method("_vocab_size", method!(RbTokenizer::get_vocab_size, 1))?;
728
+ class.define_method(
729
+ "_enable_truncation",
730
+ method!(RbTokenizer::enable_truncation, 2),
731
+ )?;
732
+ class.define_method("no_truncation", method!(RbTokenizer::no_truncation, 0))?;
733
+ class.define_method("truncation", method!(RbTokenizer::get_truncation, 0))?;
734
+ class.define_method("_enable_padding", method!(RbTokenizer::enable_padding, 1))?;
735
+ class.define_method("no_padding", method!(RbTokenizer::no_padding, 0))?;
736
+ class.define_method("padding", method!(RbTokenizer::get_padding, 0))?;
737
+ class.define_method("_encode", method!(RbTokenizer::encode, 4))?;
738
+ class.define_method("_encode_batch", method!(RbTokenizer::encode_batch, 3))?;
739
+ class.define_method(
740
+ "_encode_batch_fast",
741
+ method!(RbTokenizer::encode_batch_fast, 3),
742
+ )?;
743
+ class.define_method("_decode", method!(RbTokenizer::decode, 2))?;
744
+ class.define_method("_decode_batch", method!(RbTokenizer::decode_batch, 2))?;
745
+ class.define_method("token_to_id", method!(RbTokenizer::token_to_id, 1))?;
746
+ class.define_method("id_to_token", method!(RbTokenizer::id_to_token, 1))?;
747
+ class.define_method(
748
+ "encode_special_tokens=",
749
+ method!(RbTokenizer::set_encode_special_tokens, 1),
750
+ )?;
751
+ class.define_method(
752
+ "encode_special_tokens",
753
+ method!(RbTokenizer::get_encode_special_tokens, 0),
754
+ )?;
755
+ class.define_method("add_tokens", method!(RbTokenizer::add_tokens, 1))?;
756
+ class.define_method(
757
+ "add_special_tokens",
758
+ method!(RbTokenizer::add_special_tokens, 1),
759
+ )?;
760
+ class.define_method("train", method!(RbTokenizer::train, 2))?;
761
+ class.define_method("model", method!(RbTokenizer::get_model, 0))?;
762
+ class.define_method("model=", method!(RbTokenizer::set_model, 1))?;
763
+ class.define_method("normalizer", method!(RbTokenizer::get_normalizer, 0))?;
764
+ class.define_method("normalizer=", method!(RbTokenizer::set_normalizer, 1))?;
765
+ class.define_method("pre_tokenizer", method!(RbTokenizer::get_pre_tokenizer, 0))?;
766
+ class.define_method("pre_tokenizer=", method!(RbTokenizer::set_pre_tokenizer, 1))?;
767
+ class.define_method(
768
+ "post_processor",
769
+ method!(RbTokenizer::get_post_processor, 0),
770
+ )?;
771
+ class.define_method(
772
+ "post_processor=",
773
+ method!(RbTokenizer::set_post_processor, 1),
774
+ )?;
775
+ class.define_method("decoder", method!(RbTokenizer::get_decoder, 0))?;
776
+ class.define_method("decoder=", method!(RbTokenizer::set_decoder, 1))?;
777
+
778
+ let class = module.define_class("AddedToken", ruby.class_object())?;
779
+ class.define_singleton_method("_new", function!(RbAddedToken::new, 2))?;
780
+ class.define_method("content", method!(RbAddedToken::get_content, 0))?;
781
+ class.define_method("rstrip", method!(RbAddedToken::get_rstrip, 0))?;
782
+ class.define_method("lstrip", method!(RbAddedToken::get_lstrip, 0))?;
783
+ class.define_method("single_word", method!(RbAddedToken::get_single_word, 0))?;
784
+ class.define_method("normalized", method!(RbAddedToken::get_normalized, 0))?;
785
+ class.define_method("special", method!(RbAddedToken::get_special, 0))?;
786
+
787
+ Ok(())
651
788
  }