tokenizers 0.3.3 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,9 +1,8 @@
1
1
  use std::sync::Arc;
2
2
 
3
- use magnus::typed_data::DataTypeBuilder;
4
3
  use magnus::{
5
- function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
- TryConvert, TypedData, Value,
4
+ data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
5
+ Ruby, TryConvert, TypedData, Value,
7
6
  };
8
7
  use serde::{Deserialize, Serialize};
9
8
  use tk::processors::bert::BertProcessing;
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
13
12
  use tk::processors::PostProcessorWrapper;
14
13
  use tk::{Encoding, PostProcessor};
15
14
 
16
- use super::RbResult;
15
+ use super::{PROCESSORS, RbResult};
17
16
 
18
17
  #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
19
18
  pub struct RbPostProcessor {
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
53
52
 
54
53
  impl TryConvert for RbSpecialToken {
55
54
  fn try_convert(ob: Value) -> RbResult<Self> {
56
- if let Ok(v) = ob.try_convert::<(String, u32)>() {
55
+ if let Ok(v) = <(String, u32)>::try_convert(ob) {
57
56
  Ok(Self(v.into()))
58
- } else if let Ok(v) = ob.try_convert::<(u32, String)>() {
57
+ } else if let Ok(v) = <(u32, String)>::try_convert(ob) {
59
58
  Ok(Self(v.into()))
60
59
  } else {
61
60
  todo!()
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
74
73
 
75
74
  impl TryConvert for RbTemplate {
76
75
  fn try_convert(ob: Value) -> RbResult<Self> {
77
- if let Ok(s) = ob.try_convert::<String>() {
76
+ if let Ok(s) = String::try_convert(ob) {
78
77
  Ok(Self(
79
78
  s.try_into().unwrap(), //.map_err(RbError::from)?,
80
79
  ))
81
- } else if let Ok(s) = ob.try_convert::<Vec<String>>() {
80
+ } else if let Ok(s) = <Vec<String>>::try_convert(ob) {
82
81
  Ok(Self(
83
82
  s.try_into().unwrap(), //.map_err(RbError::from)?,
84
83
  ))
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
152
151
  }
153
152
 
154
153
  unsafe impl TypedData for RbPostProcessor {
155
- fn class() -> RClass {
156
- *memoize!(RClass: {
157
- let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
158
- class.undef_alloc_func();
159
- class
160
- })
154
+ fn class(ruby: &Ruby) -> RClass {
155
+ static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
156
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
157
+ class.undef_default_alloc_func();
158
+ class
159
+ });
160
+ ruby.get_inner(&CLASS)
161
161
  }
162
162
 
163
163
  fn data_type() -> &'static DataType {
164
- memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
164
+ static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
165
+ &DATA_TYPE
165
166
  }
166
167
 
167
- fn class_for(value: &Self) -> RClass {
168
+ fn class_for(ruby: &Ruby, value: &Self) -> RClass {
169
+ static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
170
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
171
+ class.undef_default_alloc_func();
172
+ class
173
+ });
174
+ static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
175
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
176
+ class.undef_default_alloc_func();
177
+ class
178
+ });
179
+ static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
180
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
181
+ class.undef_default_alloc_func();
182
+ class
183
+ });
184
+ static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
185
+ let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
186
+ class.undef_default_alloc_func();
187
+ class
188
+ });
168
189
  match *value.processor {
169
- PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
170
- let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
171
- class.undef_alloc_func();
172
- class
173
- }),
174
- PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
175
- let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
176
- class.undef_alloc_func();
177
- class
178
- }),
179
- PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
180
- let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
181
- class.undef_alloc_func();
182
- class
183
- }),
184
- PostProcessorWrapper::Template(_) => *memoize!(RClass: {
185
- let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
186
- class.undef_alloc_func();
187
- class
188
- }),
190
+ PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
191
+ PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
192
+ PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
193
+ PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
189
194
  _ => todo!(),
190
195
  }
191
196
  }
192
197
  }
193
198
 
194
- pub fn processors(module: &RModule) -> RbResult<()> {
195
- let post_processor = module.define_class("PostProcessor", Default::default())?;
199
+ pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
200
+ let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
196
201
 
197
202
  let class = module.define_class("BertProcessing", post_processor)?;
198
203
  class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
@@ -2,6 +2,7 @@ use std::cell::RefCell;
2
2
  use std::collections::HashMap;
3
3
  use std::path::PathBuf;
4
4
 
5
+ use magnus::prelude::*;
5
6
  use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
6
7
  use tk::tokenizer::{
7
8
  Model, PaddingDirection, PaddingParams, PaddingStrategy,
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
78
79
 
79
80
  impl<'s> TryConvert for TextInputSequence<'s> {
80
81
  fn try_convert(ob: Value) -> RbResult<Self> {
81
- Ok(Self(ob.try_convert::<String>()?.into()))
82
+ Ok(Self(String::try_convert(ob)?.into()))
82
83
  }
83
84
  }
84
85
 
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
92
93
 
93
94
  impl TryConvert for RbArrayStr {
94
95
  fn try_convert(ob: Value) -> RbResult<Self> {
95
- let seq = ob.try_convert::<Vec<String>>()?;
96
+ let seq = <Vec<String>>::try_convert(ob)?;
96
97
  Ok(Self(seq))
97
98
  }
98
99
  }
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
107
108
 
108
109
  impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
109
110
  fn try_convert(ob: Value) -> RbResult<Self> {
110
- if let Ok(seq) = ob.try_convert::<RbArrayStr>() {
111
+ if let Ok(seq) = RbArrayStr::try_convert(ob) {
111
112
  return Ok(Self(seq.into()));
112
113
  }
113
114
  todo!()
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
124
125
 
125
126
  impl<'s> TryConvert for TextEncodeInput<'s> {
126
127
  fn try_convert(ob: Value) -> RbResult<Self> {
127
- if let Ok(i) = ob.try_convert::<TextInputSequence>() {
128
+ if let Ok(i) = TextInputSequence::try_convert(ob) {
128
129
  return Ok(Self(i.into()));
129
130
  }
130
- if let Ok((i1, i2)) = ob.try_convert::<(TextInputSequence, TextInputSequence)>() {
131
+ if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
131
132
  return Ok(Self((i1, i2).into()));
132
133
  }
133
134
  // TODO check if this branch is needed
134
- if let Ok(arr) = ob.try_convert::<RArray>() {
135
+ if let Ok(arr) = RArray::try_convert(ob) {
135
136
  if arr.len() == 2 {
136
137
  let first = arr.entry::<TextInputSequence>(0).unwrap();
137
138
  let second = arr.entry::<TextInputSequence>(1).unwrap();
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
155
156
 
156
157
  impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
157
158
  fn try_convert(ob: Value) -> RbResult<Self> {
158
- if let Ok(i) = ob.try_convert::<PreTokenizedInputSequence>() {
159
+ if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
159
160
  return Ok(Self(i.into()));
160
161
  }
161
162
  if let Ok((i1, i2)) =
162
- ob.try_convert::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
163
+ <(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
163
164
  {
164
165
  return Ok(Self((i1, i2).into()));
165
166
  }
166
167
  // TODO check if this branch is needed
167
- if let Ok(arr) = ob.try_convert::<RArray>() {
168
+ if let Ok(arr) = RArray::try_convert(ob) {
168
169
  if arr.len() == 2 {
169
170
  let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
170
171
  let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
@@ -251,16 +252,16 @@ impl RbTokenizer {
251
252
  add_special_tokens: bool,
252
253
  ) -> RbResult<RbEncoding> {
253
254
  let sequence: tk::InputSequence = if is_pretokenized {
254
- sequence.try_convert::<PreTokenizedInputSequence>()?.into()
255
+ PreTokenizedInputSequence::try_convert(sequence)?.into()
255
256
  } else {
256
- sequence.try_convert::<TextInputSequence>()?.into()
257
+ TextInputSequence::try_convert(sequence)?.into()
257
258
  };
258
259
  let input = match pair {
259
260
  Some(pair) => {
260
261
  let pair: tk::InputSequence = if is_pretokenized {
261
- pair.try_convert::<PreTokenizedInputSequence>()?.into()
262
+ PreTokenizedInputSequence::try_convert(pair)?.into()
262
263
  } else {
263
- pair.try_convert::<TextInputSequence>()?.into()
264
+ TextInputSequence::try_convert(pair)?.into()
264
265
  };
265
266
  tk::EncodeInput::Dual(sequence, pair)
266
267
  }
@@ -284,9 +285,9 @@ impl RbTokenizer {
284
285
  .each()
285
286
  .map(|o| {
286
287
  let input: tk::EncodeInput = if is_pretokenized {
287
- o?.try_convert::<PreTokenizedEncodeInput>()?.into()
288
+ PreTokenizedEncodeInput::try_convert(o?)?.into()
288
289
  } else {
289
- o?.try_convert::<TextEncodeInput>()?.into()
290
+ TextEncodeInput::try_convert(o?)?.into()
290
291
  };
291
292
  Ok(input)
292
293
  })
@@ -306,14 +307,15 @@ impl RbTokenizer {
306
307
  pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
307
308
  self.tokenizer
308
309
  .borrow()
309
- .decode(ids, skip_special_tokens)
310
+ .decode(&ids, skip_special_tokens)
310
311
  .map_err(RbError::from)
311
312
  }
312
313
 
313
314
  pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
315
+ let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
314
316
  self.tokenizer
315
317
  .borrow()
316
- .decode_batch(sequences, skip_special_tokens)
318
+ .decode_batch(&slices, skip_special_tokens)
317
319
  .map_err(RbError::from)
318
320
  }
319
321
 
@@ -353,7 +355,7 @@ impl RbTokenizer {
353
355
 
354
356
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
355
357
  if !value.is_nil() {
356
- let dir_str: String = value.try_convert()?;
358
+ let dir_str = String::try_convert(value)?;
357
359
  params.direction = match dir_str.as_str() {
358
360
  "left" => PaddingDirection::Left,
359
361
  "right" => PaddingDirection::Right,
@@ -363,29 +365,29 @@ impl RbTokenizer {
363
365
 
364
366
  let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
365
367
  if !value.is_nil() {
366
- params.pad_to_multiple_of = value.try_convert()?;
368
+ params.pad_to_multiple_of = TryConvert::try_convert(value)?;
367
369
  }
368
370
 
369
371
  let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
370
372
  if !value.is_nil() {
371
- params.pad_id = value.try_convert()?;
373
+ params.pad_id = TryConvert::try_convert(value)?;
372
374
  }
373
375
 
374
376
  let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
375
377
  if !value.is_nil() {
376
- params.pad_type_id = value.try_convert()?;
378
+ params.pad_type_id = TryConvert::try_convert(value)?;
377
379
  }
378
380
 
379
381
  let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
380
382
  if !value.is_nil() {
381
- params.pad_token = value.try_convert()?;
383
+ params.pad_token = TryConvert::try_convert(value)?;
382
384
  }
383
385
 
384
386
  let value: Value = kwargs.delete(Symbol::new("length"))?;
385
387
  if value.is_nil() {
386
388
  params.strategy = PaddingStrategy::BatchLongest;
387
389
  } else {
388
- params.strategy = PaddingStrategy::Fixed(value.try_convert()?);
390
+ params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
389
391
  }
390
392
 
391
393
  if !kwargs.is_empty() {
@@ -431,12 +433,12 @@ impl RbTokenizer {
431
433
 
432
434
  let value: Value = kwargs.delete(Symbol::new("stride"))?;
433
435
  if !value.is_nil() {
434
- params.stride = value.try_convert()?;
436
+ params.stride = TryConvert::try_convert(value)?;
435
437
  }
436
438
 
437
439
  let value: Value = kwargs.delete(Symbol::new("strategy"))?;
438
440
  if !value.is_nil() {
439
- let strategy_str: String = value.try_convert()?;
441
+ let strategy_str = String::try_convert(value)?;
440
442
  params.strategy = match strategy_str.as_str() {
441
443
  "longest_first" => TruncationStrategy::LongestFirst,
442
444
  "only_first" => TruncationStrategy::OnlyFirst,
@@ -447,7 +449,7 @@ impl RbTokenizer {
447
449
 
448
450
  let value: Value = kwargs.delete(Symbol::new("direction"))?;
449
451
  if !value.is_nil() {
450
- let dir_str: String = value.try_convert()?;
452
+ let dir_str = String::try_convert(value)?;
451
453
  params.direction = match dir_str.as_str() {
452
454
  "left" => TruncationDirection::Left,
453
455
  "right" => TruncationDirection::Right,
@@ -460,13 +462,18 @@ impl RbTokenizer {
460
462
  return Err(Error::new(exception::arg_error(), "unknown keyword"));
461
463
  }
462
464
 
463
- self.tokenizer.borrow_mut().with_truncation(Some(params));
465
+ if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
466
+ return Err(Error::new(exception::arg_error(), error_message.to_string()));
467
+ }
464
468
 
465
469
  Ok(())
466
470
  }
467
471
 
468
472
  pub fn no_truncation(&self) {
469
- self.tokenizer.borrow_mut().with_truncation(None);
473
+ self.tokenizer
474
+ .borrow_mut()
475
+ .with_truncation(None)
476
+ .expect("Failed to set truncation to `None`! This should never happen");
470
477
  }
471
478
 
472
479
  pub fn truncation(&self) -> RbResult<Option<RHash>> {