tokenizers 0.3.3 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +52 -23
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/decoders.rs +72 -61
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +57 -51
- data/ext/tokenizers/src/normalizers.rs +90 -77
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +2 -2
- metadata +3 -3
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::Arc;
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
function,
|
6
|
-
TryConvert, TypedData, Value,
|
4
|
+
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData, Value,
|
7
6
|
};
|
8
7
|
use serde::{Deserialize, Serialize};
|
9
8
|
use tk::processors::bert::BertProcessing;
|
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
|
|
13
12
|
use tk::processors::PostProcessorWrapper;
|
14
13
|
use tk::{Encoding, PostProcessor};
|
15
14
|
|
16
|
-
use super::RbResult;
|
15
|
+
use super::{PROCESSORS, RbResult};
|
17
16
|
|
18
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
19
18
|
pub struct RbPostProcessor {
|
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
|
|
53
52
|
|
54
53
|
impl TryConvert for RbSpecialToken {
|
55
54
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
56
|
-
if let Ok(v) =
|
55
|
+
if let Ok(v) = <(String, u32)>::try_convert(ob) {
|
57
56
|
Ok(Self(v.into()))
|
58
|
-
} else if let Ok(v) =
|
57
|
+
} else if let Ok(v) = <(u32, String)>::try_convert(ob) {
|
59
58
|
Ok(Self(v.into()))
|
60
59
|
} else {
|
61
60
|
todo!()
|
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
|
|
74
73
|
|
75
74
|
impl TryConvert for RbTemplate {
|
76
75
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
77
|
-
if let Ok(s) =
|
76
|
+
if let Ok(s) = String::try_convert(ob) {
|
78
77
|
Ok(Self(
|
79
78
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
80
79
|
))
|
81
|
-
} else if let Ok(s) =
|
80
|
+
} else if let Ok(s) = <Vec<String>>::try_convert(ob) {
|
82
81
|
Ok(Self(
|
83
82
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
84
83
|
))
|
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
|
|
152
151
|
}
|
153
152
|
|
154
153
|
unsafe impl TypedData for RbPostProcessor {
|
155
|
-
fn class() -> RClass {
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
})
|
154
|
+
fn class(ruby: &Ruby) -> RClass {
|
155
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
156
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
|
157
|
+
class.undef_default_alloc_func();
|
158
|
+
class
|
159
|
+
});
|
160
|
+
ruby.get_inner(&CLASS)
|
161
161
|
}
|
162
162
|
|
163
163
|
fn data_type() -> &'static DataType {
|
164
|
-
|
164
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
|
165
|
+
&DATA_TYPE
|
165
166
|
}
|
166
167
|
|
167
|
-
fn class_for(value: &Self) -> RClass {
|
168
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
169
|
+
static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
170
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
|
171
|
+
class.undef_default_alloc_func();
|
172
|
+
class
|
173
|
+
});
|
174
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
175
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
|
176
|
+
class.undef_default_alloc_func();
|
177
|
+
class
|
178
|
+
});
|
179
|
+
static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
180
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
|
181
|
+
class.undef_default_alloc_func();
|
182
|
+
class
|
183
|
+
});
|
184
|
+
static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
185
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
|
186
|
+
class.undef_default_alloc_func();
|
187
|
+
class
|
188
|
+
});
|
168
189
|
match *value.processor {
|
169
|
-
PostProcessorWrapper::Bert(_) =>
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
}),
|
174
|
-
PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
|
175
|
-
let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
|
176
|
-
class.undef_alloc_func();
|
177
|
-
class
|
178
|
-
}),
|
179
|
-
PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
|
180
|
-
let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
|
181
|
-
class.undef_alloc_func();
|
182
|
-
class
|
183
|
-
}),
|
184
|
-
PostProcessorWrapper::Template(_) => *memoize!(RClass: {
|
185
|
-
let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
|
186
|
-
class.undef_alloc_func();
|
187
|
-
class
|
188
|
-
}),
|
190
|
+
PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
|
191
|
+
PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
192
|
+
PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
|
193
|
+
PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
|
189
194
|
_ => todo!(),
|
190
195
|
}
|
191
196
|
}
|
192
197
|
}
|
193
198
|
|
194
|
-
pub fn
|
195
|
-
let post_processor = module.define_class("PostProcessor",
|
199
|
+
pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
200
|
+
let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
|
196
201
|
|
197
202
|
let class = module.define_class("BertProcessing", post_processor)?;
|
198
203
|
class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
|
@@ -2,6 +2,7 @@ use std::cell::RefCell;
|
|
2
2
|
use std::collections::HashMap;
|
3
3
|
use std::path::PathBuf;
|
4
4
|
|
5
|
+
use magnus::prelude::*;
|
5
6
|
use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
|
6
7
|
use tk::tokenizer::{
|
7
8
|
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
|
78
79
|
|
79
80
|
impl<'s> TryConvert for TextInputSequence<'s> {
|
80
81
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
81
|
-
Ok(Self(
|
82
|
+
Ok(Self(String::try_convert(ob)?.into()))
|
82
83
|
}
|
83
84
|
}
|
84
85
|
|
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
|
|
92
93
|
|
93
94
|
impl TryConvert for RbArrayStr {
|
94
95
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
95
|
-
let seq =
|
96
|
+
let seq = <Vec<String>>::try_convert(ob)?;
|
96
97
|
Ok(Self(seq))
|
97
98
|
}
|
98
99
|
}
|
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
|
107
108
|
|
108
109
|
impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
|
109
110
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
110
|
-
if let Ok(seq) =
|
111
|
+
if let Ok(seq) = RbArrayStr::try_convert(ob) {
|
111
112
|
return Ok(Self(seq.into()));
|
112
113
|
}
|
113
114
|
todo!()
|
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
124
125
|
|
125
126
|
impl<'s> TryConvert for TextEncodeInput<'s> {
|
126
127
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
127
|
-
if let Ok(i) =
|
128
|
+
if let Ok(i) = TextInputSequence::try_convert(ob) {
|
128
129
|
return Ok(Self(i.into()));
|
129
130
|
}
|
130
|
-
if let Ok((i1, i2)) =
|
131
|
+
if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
|
131
132
|
return Ok(Self((i1, i2).into()));
|
132
133
|
}
|
133
134
|
// TODO check if this branch is needed
|
134
|
-
if let Ok(arr) =
|
135
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
135
136
|
if arr.len() == 2 {
|
136
137
|
let first = arr.entry::<TextInputSequence>(0).unwrap();
|
137
138
|
let second = arr.entry::<TextInputSequence>(1).unwrap();
|
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
155
156
|
|
156
157
|
impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
|
157
158
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
158
|
-
if let Ok(i) =
|
159
|
+
if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
|
159
160
|
return Ok(Self(i.into()));
|
160
161
|
}
|
161
162
|
if let Ok((i1, i2)) =
|
162
|
-
|
163
|
+
<(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
|
163
164
|
{
|
164
165
|
return Ok(Self((i1, i2).into()));
|
165
166
|
}
|
166
167
|
// TODO check if this branch is needed
|
167
|
-
if let Ok(arr) =
|
168
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
168
169
|
if arr.len() == 2 {
|
169
170
|
let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
|
170
171
|
let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
|
@@ -251,16 +252,16 @@ impl RbTokenizer {
|
|
251
252
|
add_special_tokens: bool,
|
252
253
|
) -> RbResult<RbEncoding> {
|
253
254
|
let sequence: tk::InputSequence = if is_pretokenized {
|
254
|
-
|
255
|
+
PreTokenizedInputSequence::try_convert(sequence)?.into()
|
255
256
|
} else {
|
256
|
-
|
257
|
+
TextInputSequence::try_convert(sequence)?.into()
|
257
258
|
};
|
258
259
|
let input = match pair {
|
259
260
|
Some(pair) => {
|
260
261
|
let pair: tk::InputSequence = if is_pretokenized {
|
261
|
-
|
262
|
+
PreTokenizedInputSequence::try_convert(pair)?.into()
|
262
263
|
} else {
|
263
|
-
|
264
|
+
TextInputSequence::try_convert(pair)?.into()
|
264
265
|
};
|
265
266
|
tk::EncodeInput::Dual(sequence, pair)
|
266
267
|
}
|
@@ -284,9 +285,9 @@ impl RbTokenizer {
|
|
284
285
|
.each()
|
285
286
|
.map(|o| {
|
286
287
|
let input: tk::EncodeInput = if is_pretokenized {
|
287
|
-
|
288
|
+
PreTokenizedEncodeInput::try_convert(o?)?.into()
|
288
289
|
} else {
|
289
|
-
|
290
|
+
TextEncodeInput::try_convert(o?)?.into()
|
290
291
|
};
|
291
292
|
Ok(input)
|
292
293
|
})
|
@@ -306,14 +307,15 @@ impl RbTokenizer {
|
|
306
307
|
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
307
308
|
self.tokenizer
|
308
309
|
.borrow()
|
309
|
-
.decode(ids, skip_special_tokens)
|
310
|
+
.decode(&ids, skip_special_tokens)
|
310
311
|
.map_err(RbError::from)
|
311
312
|
}
|
312
313
|
|
313
314
|
pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
|
315
|
+
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
314
316
|
self.tokenizer
|
315
317
|
.borrow()
|
316
|
-
.decode_batch(
|
318
|
+
.decode_batch(&slices, skip_special_tokens)
|
317
319
|
.map_err(RbError::from)
|
318
320
|
}
|
319
321
|
|
@@ -353,7 +355,7 @@ impl RbTokenizer {
|
|
353
355
|
|
354
356
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
355
357
|
if !value.is_nil() {
|
356
|
-
let dir_str
|
358
|
+
let dir_str = String::try_convert(value)?;
|
357
359
|
params.direction = match dir_str.as_str() {
|
358
360
|
"left" => PaddingDirection::Left,
|
359
361
|
"right" => PaddingDirection::Right,
|
@@ -363,29 +365,29 @@ impl RbTokenizer {
|
|
363
365
|
|
364
366
|
let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
|
365
367
|
if !value.is_nil() {
|
366
|
-
params.pad_to_multiple_of =
|
368
|
+
params.pad_to_multiple_of = TryConvert::try_convert(value)?;
|
367
369
|
}
|
368
370
|
|
369
371
|
let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
|
370
372
|
if !value.is_nil() {
|
371
|
-
params.pad_id =
|
373
|
+
params.pad_id = TryConvert::try_convert(value)?;
|
372
374
|
}
|
373
375
|
|
374
376
|
let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
|
375
377
|
if !value.is_nil() {
|
376
|
-
params.pad_type_id =
|
378
|
+
params.pad_type_id = TryConvert::try_convert(value)?;
|
377
379
|
}
|
378
380
|
|
379
381
|
let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
|
380
382
|
if !value.is_nil() {
|
381
|
-
params.pad_token =
|
383
|
+
params.pad_token = TryConvert::try_convert(value)?;
|
382
384
|
}
|
383
385
|
|
384
386
|
let value: Value = kwargs.delete(Symbol::new("length"))?;
|
385
387
|
if value.is_nil() {
|
386
388
|
params.strategy = PaddingStrategy::BatchLongest;
|
387
389
|
} else {
|
388
|
-
params.strategy = PaddingStrategy::Fixed(
|
390
|
+
params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
|
389
391
|
}
|
390
392
|
|
391
393
|
if !kwargs.is_empty() {
|
@@ -431,12 +433,12 @@ impl RbTokenizer {
|
|
431
433
|
|
432
434
|
let value: Value = kwargs.delete(Symbol::new("stride"))?;
|
433
435
|
if !value.is_nil() {
|
434
|
-
params.stride =
|
436
|
+
params.stride = TryConvert::try_convert(value)?;
|
435
437
|
}
|
436
438
|
|
437
439
|
let value: Value = kwargs.delete(Symbol::new("strategy"))?;
|
438
440
|
if !value.is_nil() {
|
439
|
-
let strategy_str
|
441
|
+
let strategy_str = String::try_convert(value)?;
|
440
442
|
params.strategy = match strategy_str.as_str() {
|
441
443
|
"longest_first" => TruncationStrategy::LongestFirst,
|
442
444
|
"only_first" => TruncationStrategy::OnlyFirst,
|
@@ -447,7 +449,7 @@ impl RbTokenizer {
|
|
447
449
|
|
448
450
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
449
451
|
if !value.is_nil() {
|
450
|
-
let dir_str
|
452
|
+
let dir_str = String::try_convert(value)?;
|
451
453
|
params.direction = match dir_str.as_str() {
|
452
454
|
"left" => TruncationDirection::Left,
|
453
455
|
"right" => TruncationDirection::Right,
|
@@ -460,13 +462,18 @@ impl RbTokenizer {
|
|
460
462
|
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
461
463
|
}
|
462
464
|
|
463
|
-
self.tokenizer.borrow_mut().with_truncation(Some(params))
|
465
|
+
if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
466
|
+
return Err(Error::new(exception::arg_error(), error_message.to_string()));
|
467
|
+
}
|
464
468
|
|
465
469
|
Ok(())
|
466
470
|
}
|
467
471
|
|
468
472
|
pub fn no_truncation(&self) {
|
469
|
-
self.tokenizer
|
473
|
+
self.tokenizer
|
474
|
+
.borrow_mut()
|
475
|
+
.with_truncation(None)
|
476
|
+
.expect("Failed to set truncation to `None`! This should never happen");
|
470
477
|
}
|
471
478
|
|
472
479
|
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|