tokenizers 0.3.3 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +52 -23
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/decoders.rs +72 -61
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +57 -51
- data/ext/tokenizers/src/normalizers.rs +90 -77
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +2 -2
- metadata +3 -3
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::Arc;
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
function,
|
6
|
-
TryConvert, TypedData, Value,
|
4
|
+
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData, Value,
|
7
6
|
};
|
8
7
|
use serde::{Deserialize, Serialize};
|
9
8
|
use tk::processors::bert::BertProcessing;
|
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
|
|
13
12
|
use tk::processors::PostProcessorWrapper;
|
14
13
|
use tk::{Encoding, PostProcessor};
|
15
14
|
|
16
|
-
use super::RbResult;
|
15
|
+
use super::{PROCESSORS, RbResult};
|
17
16
|
|
18
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
19
18
|
pub struct RbPostProcessor {
|
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
|
|
53
52
|
|
54
53
|
impl TryConvert for RbSpecialToken {
|
55
54
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
56
|
-
if let Ok(v) =
|
55
|
+
if let Ok(v) = <(String, u32)>::try_convert(ob) {
|
57
56
|
Ok(Self(v.into()))
|
58
|
-
} else if let Ok(v) =
|
57
|
+
} else if let Ok(v) = <(u32, String)>::try_convert(ob) {
|
59
58
|
Ok(Self(v.into()))
|
60
59
|
} else {
|
61
60
|
todo!()
|
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
|
|
74
73
|
|
75
74
|
impl TryConvert for RbTemplate {
|
76
75
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
77
|
-
if let Ok(s) =
|
76
|
+
if let Ok(s) = String::try_convert(ob) {
|
78
77
|
Ok(Self(
|
79
78
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
80
79
|
))
|
81
|
-
} else if let Ok(s) =
|
80
|
+
} else if let Ok(s) = <Vec<String>>::try_convert(ob) {
|
82
81
|
Ok(Self(
|
83
82
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
84
83
|
))
|
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
|
|
152
151
|
}
|
153
152
|
|
154
153
|
unsafe impl TypedData for RbPostProcessor {
|
155
|
-
fn class() -> RClass {
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
})
|
154
|
+
fn class(ruby: &Ruby) -> RClass {
|
155
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
156
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
|
157
|
+
class.undef_default_alloc_func();
|
158
|
+
class
|
159
|
+
});
|
160
|
+
ruby.get_inner(&CLASS)
|
161
161
|
}
|
162
162
|
|
163
163
|
fn data_type() -> &'static DataType {
|
164
|
-
|
164
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
|
165
|
+
&DATA_TYPE
|
165
166
|
}
|
166
167
|
|
167
|
-
fn class_for(value: &Self) -> RClass {
|
168
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
169
|
+
static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
170
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
|
171
|
+
class.undef_default_alloc_func();
|
172
|
+
class
|
173
|
+
});
|
174
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
175
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
|
176
|
+
class.undef_default_alloc_func();
|
177
|
+
class
|
178
|
+
});
|
179
|
+
static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
180
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
|
181
|
+
class.undef_default_alloc_func();
|
182
|
+
class
|
183
|
+
});
|
184
|
+
static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
185
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
|
186
|
+
class.undef_default_alloc_func();
|
187
|
+
class
|
188
|
+
});
|
168
189
|
match *value.processor {
|
169
|
-
PostProcessorWrapper::Bert(_) =>
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
}),
|
174
|
-
PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
|
175
|
-
let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
|
176
|
-
class.undef_alloc_func();
|
177
|
-
class
|
178
|
-
}),
|
179
|
-
PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
|
180
|
-
let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
|
181
|
-
class.undef_alloc_func();
|
182
|
-
class
|
183
|
-
}),
|
184
|
-
PostProcessorWrapper::Template(_) => *memoize!(RClass: {
|
185
|
-
let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
|
186
|
-
class.undef_alloc_func();
|
187
|
-
class
|
188
|
-
}),
|
190
|
+
PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
|
191
|
+
PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
192
|
+
PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
|
193
|
+
PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
|
189
194
|
_ => todo!(),
|
190
195
|
}
|
191
196
|
}
|
192
197
|
}
|
193
198
|
|
194
|
-
pub fn
|
195
|
-
let post_processor = module.define_class("PostProcessor",
|
199
|
+
pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
200
|
+
let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
|
196
201
|
|
197
202
|
let class = module.define_class("BertProcessing", post_processor)?;
|
198
203
|
class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
|
@@ -2,6 +2,7 @@ use std::cell::RefCell;
|
|
2
2
|
use std::collections::HashMap;
|
3
3
|
use std::path::PathBuf;
|
4
4
|
|
5
|
+
use magnus::prelude::*;
|
5
6
|
use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
|
6
7
|
use tk::tokenizer::{
|
7
8
|
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
|
78
79
|
|
79
80
|
impl<'s> TryConvert for TextInputSequence<'s> {
|
80
81
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
81
|
-
Ok(Self(
|
82
|
+
Ok(Self(String::try_convert(ob)?.into()))
|
82
83
|
}
|
83
84
|
}
|
84
85
|
|
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
|
|
92
93
|
|
93
94
|
impl TryConvert for RbArrayStr {
|
94
95
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
95
|
-
let seq =
|
96
|
+
let seq = <Vec<String>>::try_convert(ob)?;
|
96
97
|
Ok(Self(seq))
|
97
98
|
}
|
98
99
|
}
|
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
|
107
108
|
|
108
109
|
impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
|
109
110
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
110
|
-
if let Ok(seq) =
|
111
|
+
if let Ok(seq) = RbArrayStr::try_convert(ob) {
|
111
112
|
return Ok(Self(seq.into()));
|
112
113
|
}
|
113
114
|
todo!()
|
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
124
125
|
|
125
126
|
impl<'s> TryConvert for TextEncodeInput<'s> {
|
126
127
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
127
|
-
if let Ok(i) =
|
128
|
+
if let Ok(i) = TextInputSequence::try_convert(ob) {
|
128
129
|
return Ok(Self(i.into()));
|
129
130
|
}
|
130
|
-
if let Ok((i1, i2)) =
|
131
|
+
if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
|
131
132
|
return Ok(Self((i1, i2).into()));
|
132
133
|
}
|
133
134
|
// TODO check if this branch is needed
|
134
|
-
if let Ok(arr) =
|
135
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
135
136
|
if arr.len() == 2 {
|
136
137
|
let first = arr.entry::<TextInputSequence>(0).unwrap();
|
137
138
|
let second = arr.entry::<TextInputSequence>(1).unwrap();
|
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
155
156
|
|
156
157
|
impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
|
157
158
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
158
|
-
if let Ok(i) =
|
159
|
+
if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
|
159
160
|
return Ok(Self(i.into()));
|
160
161
|
}
|
161
162
|
if let Ok((i1, i2)) =
|
162
|
-
|
163
|
+
<(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
|
163
164
|
{
|
164
165
|
return Ok(Self((i1, i2).into()));
|
165
166
|
}
|
166
167
|
// TODO check if this branch is needed
|
167
|
-
if let Ok(arr) =
|
168
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
168
169
|
if arr.len() == 2 {
|
169
170
|
let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
|
170
171
|
let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
|
@@ -251,16 +252,16 @@ impl RbTokenizer {
|
|
251
252
|
add_special_tokens: bool,
|
252
253
|
) -> RbResult<RbEncoding> {
|
253
254
|
let sequence: tk::InputSequence = if is_pretokenized {
|
254
|
-
|
255
|
+
PreTokenizedInputSequence::try_convert(sequence)?.into()
|
255
256
|
} else {
|
256
|
-
|
257
|
+
TextInputSequence::try_convert(sequence)?.into()
|
257
258
|
};
|
258
259
|
let input = match pair {
|
259
260
|
Some(pair) => {
|
260
261
|
let pair: tk::InputSequence = if is_pretokenized {
|
261
|
-
|
262
|
+
PreTokenizedInputSequence::try_convert(pair)?.into()
|
262
263
|
} else {
|
263
|
-
|
264
|
+
TextInputSequence::try_convert(pair)?.into()
|
264
265
|
};
|
265
266
|
tk::EncodeInput::Dual(sequence, pair)
|
266
267
|
}
|
@@ -284,9 +285,9 @@ impl RbTokenizer {
|
|
284
285
|
.each()
|
285
286
|
.map(|o| {
|
286
287
|
let input: tk::EncodeInput = if is_pretokenized {
|
287
|
-
|
288
|
+
PreTokenizedEncodeInput::try_convert(o?)?.into()
|
288
289
|
} else {
|
289
|
-
|
290
|
+
TextEncodeInput::try_convert(o?)?.into()
|
290
291
|
};
|
291
292
|
Ok(input)
|
292
293
|
})
|
@@ -306,14 +307,15 @@ impl RbTokenizer {
|
|
306
307
|
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
307
308
|
self.tokenizer
|
308
309
|
.borrow()
|
309
|
-
.decode(ids, skip_special_tokens)
|
310
|
+
.decode(&ids, skip_special_tokens)
|
310
311
|
.map_err(RbError::from)
|
311
312
|
}
|
312
313
|
|
313
314
|
pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
|
315
|
+
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
314
316
|
self.tokenizer
|
315
317
|
.borrow()
|
316
|
-
.decode_batch(
|
318
|
+
.decode_batch(&slices, skip_special_tokens)
|
317
319
|
.map_err(RbError::from)
|
318
320
|
}
|
319
321
|
|
@@ -353,7 +355,7 @@ impl RbTokenizer {
|
|
353
355
|
|
354
356
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
355
357
|
if !value.is_nil() {
|
356
|
-
let dir_str
|
358
|
+
let dir_str = String::try_convert(value)?;
|
357
359
|
params.direction = match dir_str.as_str() {
|
358
360
|
"left" => PaddingDirection::Left,
|
359
361
|
"right" => PaddingDirection::Right,
|
@@ -363,29 +365,29 @@ impl RbTokenizer {
|
|
363
365
|
|
364
366
|
let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
|
365
367
|
if !value.is_nil() {
|
366
|
-
params.pad_to_multiple_of =
|
368
|
+
params.pad_to_multiple_of = TryConvert::try_convert(value)?;
|
367
369
|
}
|
368
370
|
|
369
371
|
let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
|
370
372
|
if !value.is_nil() {
|
371
|
-
params.pad_id =
|
373
|
+
params.pad_id = TryConvert::try_convert(value)?;
|
372
374
|
}
|
373
375
|
|
374
376
|
let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
|
375
377
|
if !value.is_nil() {
|
376
|
-
params.pad_type_id =
|
378
|
+
params.pad_type_id = TryConvert::try_convert(value)?;
|
377
379
|
}
|
378
380
|
|
379
381
|
let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
|
380
382
|
if !value.is_nil() {
|
381
|
-
params.pad_token =
|
383
|
+
params.pad_token = TryConvert::try_convert(value)?;
|
382
384
|
}
|
383
385
|
|
384
386
|
let value: Value = kwargs.delete(Symbol::new("length"))?;
|
385
387
|
if value.is_nil() {
|
386
388
|
params.strategy = PaddingStrategy::BatchLongest;
|
387
389
|
} else {
|
388
|
-
params.strategy = PaddingStrategy::Fixed(
|
390
|
+
params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
|
389
391
|
}
|
390
392
|
|
391
393
|
if !kwargs.is_empty() {
|
@@ -431,12 +433,12 @@ impl RbTokenizer {
|
|
431
433
|
|
432
434
|
let value: Value = kwargs.delete(Symbol::new("stride"))?;
|
433
435
|
if !value.is_nil() {
|
434
|
-
params.stride =
|
436
|
+
params.stride = TryConvert::try_convert(value)?;
|
435
437
|
}
|
436
438
|
|
437
439
|
let value: Value = kwargs.delete(Symbol::new("strategy"))?;
|
438
440
|
if !value.is_nil() {
|
439
|
-
let strategy_str
|
441
|
+
let strategy_str = String::try_convert(value)?;
|
440
442
|
params.strategy = match strategy_str.as_str() {
|
441
443
|
"longest_first" => TruncationStrategy::LongestFirst,
|
442
444
|
"only_first" => TruncationStrategy::OnlyFirst,
|
@@ -447,7 +449,7 @@ impl RbTokenizer {
|
|
447
449
|
|
448
450
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
449
451
|
if !value.is_nil() {
|
450
|
-
let dir_str
|
452
|
+
let dir_str = String::try_convert(value)?;
|
451
453
|
params.direction = match dir_str.as_str() {
|
452
454
|
"left" => TruncationDirection::Left,
|
453
455
|
"right" => TruncationDirection::Right,
|
@@ -460,13 +462,18 @@ impl RbTokenizer {
|
|
460
462
|
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
461
463
|
}
|
462
464
|
|
463
|
-
self.tokenizer.borrow_mut().with_truncation(Some(params))
|
465
|
+
if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
466
|
+
return Err(Error::new(exception::arg_error(), error_message.to_string()));
|
467
|
+
}
|
464
468
|
|
465
469
|
Ok(())
|
466
470
|
}
|
467
471
|
|
468
472
|
pub fn no_truncation(&self) {
|
469
|
-
self.tokenizer
|
473
|
+
self.tokenizer
|
474
|
+
.borrow_mut()
|
475
|
+
.with_truncation(None)
|
476
|
+
.expect("Failed to set truncation to `None`! This should never happen");
|
470
477
|
}
|
471
478
|
|
472
479
|
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|