tokenizers 0.3.3 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,8 @@
1
1
  use std::sync::Arc;
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
- TryConvert, TypedData, Value,
4
+ data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
5
+ Ruby, TryConvert, TypedData, Value,
7
6
  };
8
7
  use serde::{Deserialize, Serialize};
9
8
  use tk::processors::bert::BertProcessing;
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
13
12
  use tk::processors::PostProcessorWrapper;
14
13
  use tk::{Encoding, PostProcessor};
15
14
 
16
- use super::RbResult;
15
+ use super::{PROCESSORS, RbResult};
17
16
 
18
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
19
18
  pub struct RbPostProcessor {
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
53
52
 
54
53
  impl TryConvert for RbSpecialToken {
55
54
  fn try_convert(ob: Value) -> RbResult<Self> {
56
- if let Ok(v) = ob.try_convert::<(String, u32)>() {
55
+ if let Ok(v) = <(String, u32)>::try_convert(ob) {
57
56
  Ok(Self(v.into()))
58
- } else if let Ok(v) = ob.try_convert::<(u32, String)>() {
57
+ } else if let Ok(v) = <(u32, String)>::try_convert(ob) {
59
58
  Ok(Self(v.into()))
60
59
  } else {
61
60
  todo!()
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
74
73
 
75
74
  impl TryConvert for RbTemplate {
76
75
  fn try_convert(ob: Value) -> RbResult<Self> {
77
- if let Ok(s) = ob.try_convert::<String>() {
76
+ if let Ok(s) = String::try_convert(ob) {
78
77
  Ok(Self(
79
78
  s.try_into().unwrap(), //.map_err(RbError::from)?,
80
79
  ))
81
- } else if let Ok(s) = ob.try_convert::<Vec<String>>() {
80
+ } else if let Ok(s) = <Vec<String>>::try_convert(ob) {
82
81
  Ok(Self(
83
82
  s.try_into().unwrap(), //.map_err(RbError::from)?,
84
83
  ))
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
152
151
  }
153
152
 
154
153
  unsafe impl TypedData for RbPostProcessor {
155
- fn class() -> RClass {
156
- *memoize!(RClass: {
157
- let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
158
- class.undef_alloc_func();
159
- class
160
- })
154
+ fn class(ruby: &Ruby) -> RClass {
155
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
156
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
157
+ class.undef_default_alloc_func();
158
+ class
159
+ });
160
+ ruby.get_inner(&CLASS)
161
161
  }
162
162
 
163
163
  fn data_type() -> &'static DataType {
164
- memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
164
+ static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
165
+ &DATA_TYPE
165
166
  }
166
167
 
167
- fn class_for(value: &Self) -> RClass {
168
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
169
+ static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
170
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
171
+ class.undef_default_alloc_func();
172
+ class
173
+ });
174
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
175
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
176
+ class.undef_default_alloc_func();
177
+ class
178
+ });
179
+ static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
180
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
181
+ class.undef_default_alloc_func();
182
+ class
183
+ });
184
+ static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
185
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
186
+ class.undef_default_alloc_func();
187
+ class
188
+ });
168
189
  match *value.processor {
169
- PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
170
- let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
171
- class.undef_alloc_func();
172
- class
173
- }),
174
- PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
175
- let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
176
- class.undef_alloc_func();
177
- class
178
- }),
179
- PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
180
- let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
181
- class.undef_alloc_func();
182
- class
183
- }),
184
- PostProcessorWrapper::Template(_) => *memoize!(RClass: {
185
- let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
186
- class.undef_alloc_func();
187
- class
188
- }),
190
+ PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
191
+ PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
192
+ PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
193
+ PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
189
194
  _ => todo!(),
190
195
  }
191
196
  }
192
197
  }
193
198
 
194
- pub fn processors(module: &RModule) -> RbResult<()> {
195
- let post_processor = module.define_class("PostProcessor", Default::default())?;
199
+ pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
200
+ let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
196
201
 
197
202
  let class = module.define_class("BertProcessing", post_processor)?;
198
203
  class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
@@ -2,6 +2,7 @@ use std::cell::RefCell;
2
2
  use std::collections::HashMap;
3
3
  use std::path::PathBuf;
4
4
 
5
+ use magnus::prelude::*;
5
6
  use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
6
7
  use tk::tokenizer::{
7
8
  Model, PaddingDirection, PaddingParams, PaddingStrategy,
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
78
79
 
79
80
  impl<'s> TryConvert for TextInputSequence<'s> {
80
81
  fn try_convert(ob: Value) -> RbResult<Self> {
81
- Ok(Self(ob.try_convert::<String>()?.into()))
82
+ Ok(Self(String::try_convert(ob)?.into()))
82
83
  }
83
84
  }
84
85
 
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
92
93
 
93
94
  impl TryConvert for RbArrayStr {
94
95
  fn try_convert(ob: Value) -> RbResult<Self> {
95
- let seq = ob.try_convert::<Vec<String>>()?;
96
+ let seq = <Vec<String>>::try_convert(ob)?;
96
97
  Ok(Self(seq))
97
98
  }
98
99
  }
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
107
108
 
108
109
  impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
109
110
  fn try_convert(ob: Value) -> RbResult<Self> {
110
- if let Ok(seq) = ob.try_convert::<RbArrayStr>() {
111
+ if let Ok(seq) = RbArrayStr::try_convert(ob) {
111
112
  return Ok(Self(seq.into()));
112
113
  }
113
114
  todo!()
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
124
125
 
125
126
  impl<'s> TryConvert for TextEncodeInput<'s> {
126
127
  fn try_convert(ob: Value) -> RbResult<Self> {
127
- if let Ok(i) = ob.try_convert::<TextInputSequence>() {
128
+ if let Ok(i) = TextInputSequence::try_convert(ob) {
128
129
  return Ok(Self(i.into()));
129
130
  }
130
- if let Ok((i1, i2)) = ob.try_convert::<(TextInputSequence, TextInputSequence)>() {
131
+ if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
131
132
  return Ok(Self((i1, i2).into()));
132
133
  }
133
134
  // TODO check if this branch is needed
134
- if let Ok(arr) = ob.try_convert::<RArray>() {
135
+ if let Ok(arr) = RArray::try_convert(ob) {
135
136
  if arr.len() == 2 {
136
137
  let first = arr.entry::<TextInputSequence>(0).unwrap();
137
138
  let second = arr.entry::<TextInputSequence>(1).unwrap();
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
155
156
 
156
157
  impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
157
158
  fn try_convert(ob: Value) -> RbResult<Self> {
158
- if let Ok(i) = ob.try_convert::<PreTokenizedInputSequence>() {
159
+ if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
159
160
  return Ok(Self(i.into()));
160
161
  }
161
162
  if let Ok((i1, i2)) =
162
- ob.try_convert::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
163
+ <(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
163
164
  {
164
165
  return Ok(Self((i1, i2).into()));
165
166
  }
166
167
  // TODO check if this branch is needed
167
- if let Ok(arr) = ob.try_convert::<RArray>() {
168
+ if let Ok(arr) = RArray::try_convert(ob) {
168
169
  if arr.len() == 2 {
169
170
  let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
170
171
  let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
@@ -251,16 +252,16 @@ impl RbTokenizer {
251
252
  add_special_tokens: bool,
252
253
  ) -> RbResult<RbEncoding> {
253
254
  let sequence: tk::InputSequence = if is_pretokenized {
254
- sequence.try_convert::<PreTokenizedInputSequence>()?.into()
255
+ PreTokenizedInputSequence::try_convert(sequence)?.into()
255
256
  } else {
256
- sequence.try_convert::<TextInputSequence>()?.into()
257
+ TextInputSequence::try_convert(sequence)?.into()
257
258
  };
258
259
  let input = match pair {
259
260
  Some(pair) => {
260
261
  let pair: tk::InputSequence = if is_pretokenized {
261
- pair.try_convert::<PreTokenizedInputSequence>()?.into()
262
+ PreTokenizedInputSequence::try_convert(pair)?.into()
262
263
  } else {
263
- pair.try_convert::<TextInputSequence>()?.into()
264
+ TextInputSequence::try_convert(pair)?.into()
264
265
  };
265
266
  tk::EncodeInput::Dual(sequence, pair)
266
267
  }
@@ -284,9 +285,9 @@ impl RbTokenizer {
284
285
  .each()
285
286
  .map(|o| {
286
287
  let input: tk::EncodeInput = if is_pretokenized {
287
- o?.try_convert::<PreTokenizedEncodeInput>()?.into()
288
+ PreTokenizedEncodeInput::try_convert(o?)?.into()
288
289
  } else {
289
- o?.try_convert::<TextEncodeInput>()?.into()
290
+ TextEncodeInput::try_convert(o?)?.into()
290
291
  };
291
292
  Ok(input)
292
293
  })
@@ -306,14 +307,15 @@ impl RbTokenizer {
306
307
  pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
307
308
  self.tokenizer
308
309
  .borrow()
309
- .decode(ids, skip_special_tokens)
310
+ .decode(&ids, skip_special_tokens)
310
311
  .map_err(RbError::from)
311
312
  }
312
313
 
313
314
  pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
315
+ let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
314
316
  self.tokenizer
315
317
  .borrow()
316
- .decode_batch(sequences, skip_special_tokens)
318
+ .decode_batch(&slices, skip_special_tokens)
317
319
  .map_err(RbError::from)
318
320
  }
319
321
 
@@ -353,7 +355,7 @@ impl RbTokenizer {
353
355
 
354
356
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
355
357
  if !value.is_nil() {
356
- let dir_str: String = value.try_convert()?;
358
+ let dir_str = String::try_convert(value)?;
357
359
  params.direction = match dir_str.as_str() {
358
360
  "left" => PaddingDirection::Left,
359
361
  "right" => PaddingDirection::Right,
@@ -363,29 +365,29 @@ impl RbTokenizer {
363
365
 
364
366
  let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
365
367
  if !value.is_nil() {
366
- params.pad_to_multiple_of = value.try_convert()?;
368
+ params.pad_to_multiple_of = TryConvert::try_convert(value)?;
367
369
  }
368
370
 
369
371
  let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
370
372
  if !value.is_nil() {
371
- params.pad_id = value.try_convert()?;
373
+ params.pad_id = TryConvert::try_convert(value)?;
372
374
  }
373
375
 
374
376
  let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
375
377
  if !value.is_nil() {
376
- params.pad_type_id = value.try_convert()?;
378
+ params.pad_type_id = TryConvert::try_convert(value)?;
377
379
  }
378
380
 
379
381
  let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
380
382
  if !value.is_nil() {
381
- params.pad_token = value.try_convert()?;
383
+ params.pad_token = TryConvert::try_convert(value)?;
382
384
  }
383
385
 
384
386
  let value: Value = kwargs.delete(Symbol::new("length"))?;
385
387
  if value.is_nil() {
386
388
  params.strategy = PaddingStrategy::BatchLongest;
387
389
  } else {
388
- params.strategy = PaddingStrategy::Fixed(value.try_convert()?);
390
+ params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
389
391
  }
390
392
 
391
393
  if !kwargs.is_empty() {
@@ -431,12 +433,12 @@ impl RbTokenizer {
431
433
 
432
434
  let value: Value = kwargs.delete(Symbol::new("stride"))?;
433
435
  if !value.is_nil() {
434
- params.stride = value.try_convert()?;
436
+ params.stride = TryConvert::try_convert(value)?;
435
437
  }
436
438
 
437
439
  let value: Value = kwargs.delete(Symbol::new("strategy"))?;
438
440
  if !value.is_nil() {
439
- let strategy_str: String = value.try_convert()?;
441
+ let strategy_str = String::try_convert(value)?;
440
442
  params.strategy = match strategy_str.as_str() {
441
443
  "longest_first" => TruncationStrategy::LongestFirst,
442
444
  "only_first" => TruncationStrategy::OnlyFirst,
@@ -447,7 +449,7 @@ impl RbTokenizer {
447
449
 
448
450
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
449
451
  if !value.is_nil() {
450
- let dir_str: String = value.try_convert()?;
452
+ let dir_str = String::try_convert(value)?;
451
453
  params.direction = match dir_str.as_str() {
452
454
  "left" => TruncationDirection::Left,
453
455
  "right" => TruncationDirection::Right,
@@ -460,13 +462,18 @@ impl RbTokenizer {
460
462
  return Err(Error::new(exception::arg_error(), "unknown keyword"));
461
463
  }
462
464
 
463
- self.tokenizer.borrow_mut().with_truncation(Some(params));
465
+ if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
466
+ return Err(Error::new(exception::arg_error(), error_message.to_string()));
467
+ }
464
468
 
465
469
  Ok(())
466
470
  }
467
471
 
468
472
  pub fn no_truncation(&self) {
469
- self.tokenizer.borrow_mut().with_truncation(None);
473
+ self.tokenizer
474
+ .borrow_mut()
475
+ .with_truncation(None)
476
+ .expect("Failed to set truncation to `None`! This should never happen");
470
477
  }
471
478
 
472
479
  pub fn truncation(&self) -> RbResult<Option<RHash>> {