tokenizers 0.2.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/Cargo.lock +33 -74
- data/README.md +4 -0
- data/ext/tokenizers/Cargo.toml +4 -2
- data/ext/tokenizers/src/decoders.rs +275 -6
- data/ext/tokenizers/src/encoding.rs +78 -3
- data/ext/tokenizers/src/error.rs +2 -2
- data/ext/tokenizers/src/lib.rs +88 -17
- data/ext/tokenizers/src/models.rs +372 -11
- data/ext/tokenizers/src/normalizers.rs +435 -7
- data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
- data/ext/tokenizers/src/processors.rs +210 -0
- data/ext/tokenizers/src/tokenizer.rs +448 -20
- data/ext/tokenizers/src/trainers.rs +749 -0
- data/ext/tokenizers/src/utils/mod.rs +5 -0
- data/ext/tokenizers/src/utils/normalization.rs +85 -0
- data/ext/tokenizers/src/utils/regex.rs +22 -0
- data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
- data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
- data/lib/tokenizers/decoders/ctc.rb +9 -0
- data/lib/tokenizers/decoders/metaspace.rb +9 -0
- data/lib/tokenizers/decoders/word_piece.rb +9 -0
- data/lib/tokenizers/encoding.rb +19 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/bpe.rb +9 -0
- data/lib/tokenizers/models/unigram.rb +9 -0
- data/lib/tokenizers/models/word_level.rb +13 -0
- data/lib/tokenizers/models/word_piece.rb +9 -0
- data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
- data/lib/tokenizers/normalizers/strip.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
- data/lib/tokenizers/processors/byte_level.rb +9 -0
- data/lib/tokenizers/processors/roberta_processing.rb +9 -0
- data/lib/tokenizers/processors/template_processing.rb +9 -0
- data/lib/tokenizers/tokenizer.rb +45 -0
- data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
- data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
- data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
- data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +49 -7
- metadata +32 -3
@@ -1,27 +1,207 @@
|
|
1
1
|
use std::cell::RefCell;
|
2
|
+
use std::collections::HashMap;
|
2
3
|
use std::path::PathBuf;
|
3
|
-
|
4
|
+
|
5
|
+
use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
|
6
|
+
use tk::tokenizer::{
|
7
|
+
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
8
|
+
TruncationDirection, TruncationParams, TruncationStrategy, TokenizerImpl
|
9
|
+
};
|
4
10
|
use tk::AddedToken;
|
5
11
|
|
6
|
-
use
|
12
|
+
use crate::tk::PostProcessor;
|
13
|
+
|
14
|
+
use super::decoders::RbDecoder;
|
7
15
|
use super::encoding::RbEncoding;
|
8
|
-
use super::models::
|
9
|
-
use super::normalizers::
|
10
|
-
use super::pre_tokenizers::
|
16
|
+
use super::models::RbModel;
|
17
|
+
use super::normalizers::RbNormalizer;
|
18
|
+
use super::pre_tokenizers::RbPreTokenizer;
|
19
|
+
use super::processors::RbPostProcessor;
|
20
|
+
use super::trainers::RbTrainer;
|
11
21
|
use super::{RbError, RbResult};
|
12
22
|
|
23
|
+
pub struct RbAddedToken {
|
24
|
+
pub content: String,
|
25
|
+
pub is_special_token: bool,
|
26
|
+
pub single_word: Option<bool>,
|
27
|
+
pub lstrip: Option<bool>,
|
28
|
+
pub rstrip: Option<bool>,
|
29
|
+
pub normalized: Option<bool>,
|
30
|
+
}
|
31
|
+
|
32
|
+
impl RbAddedToken {
|
33
|
+
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
|
34
|
+
Self {
|
35
|
+
content: content.into(),
|
36
|
+
is_special_token: is_special_token.unwrap_or(false),
|
37
|
+
single_word: None,
|
38
|
+
lstrip: None,
|
39
|
+
rstrip: None,
|
40
|
+
normalized: None,
|
41
|
+
}
|
42
|
+
}
|
43
|
+
|
44
|
+
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
45
|
+
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
|
46
|
+
|
47
|
+
if let Some(sw) = self.single_word {
|
48
|
+
token = token.single_word(sw);
|
49
|
+
}
|
50
|
+
if let Some(ls) = self.lstrip {
|
51
|
+
token = token.lstrip(ls);
|
52
|
+
}
|
53
|
+
if let Some(rs) = self.rstrip {
|
54
|
+
token = token.rstrip(rs);
|
55
|
+
}
|
56
|
+
if let Some(n) = self.normalized {
|
57
|
+
token = token.normalized(n);
|
58
|
+
}
|
59
|
+
|
60
|
+
token
|
61
|
+
}
|
62
|
+
}
|
63
|
+
|
64
|
+
impl From<tk::AddedToken> for RbAddedToken {
|
65
|
+
fn from(token: tk::AddedToken) -> Self {
|
66
|
+
Self {
|
67
|
+
content: token.content,
|
68
|
+
single_word: Some(token.single_word),
|
69
|
+
lstrip: Some(token.lstrip),
|
70
|
+
rstrip: Some(token.rstrip),
|
71
|
+
normalized: Some(token.normalized),
|
72
|
+
is_special_token: !token.normalized,
|
73
|
+
}
|
74
|
+
}
|
75
|
+
}
|
76
|
+
|
77
|
+
struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
78
|
+
|
79
|
+
impl<'s> TryConvert for TextInputSequence<'s> {
|
80
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
81
|
+
Ok(Self(ob.try_convert::<String>()?.into()))
|
82
|
+
}
|
83
|
+
}
|
84
|
+
|
85
|
+
impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
|
86
|
+
fn from(s: TextInputSequence<'s>) -> Self {
|
87
|
+
s.0
|
88
|
+
}
|
89
|
+
}
|
90
|
+
|
91
|
+
struct RbArrayStr(Vec<String>);
|
92
|
+
|
93
|
+
impl TryConvert for RbArrayStr {
|
94
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
95
|
+
let seq = ob.try_convert::<Vec<String>>()?;
|
96
|
+
Ok(Self(seq))
|
97
|
+
}
|
98
|
+
}
|
99
|
+
|
100
|
+
impl From<RbArrayStr> for tk::InputSequence<'_> {
|
101
|
+
fn from(s: RbArrayStr) -> Self {
|
102
|
+
s.0.into()
|
103
|
+
}
|
104
|
+
}
|
105
|
+
|
106
|
+
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
107
|
+
|
108
|
+
impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
|
109
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
110
|
+
if let Ok(seq) = ob.try_convert::<RbArrayStr>() {
|
111
|
+
return Ok(Self(seq.into()));
|
112
|
+
}
|
113
|
+
todo!()
|
114
|
+
}
|
115
|
+
}
|
116
|
+
|
117
|
+
impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
|
118
|
+
fn from(s: PreTokenizedInputSequence<'s>) -> Self {
|
119
|
+
s.0
|
120
|
+
}
|
121
|
+
}
|
122
|
+
|
123
|
+
struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
124
|
+
|
125
|
+
impl<'s> TryConvert for TextEncodeInput<'s> {
|
126
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
127
|
+
if let Ok(i) = ob.try_convert::<TextInputSequence>() {
|
128
|
+
return Ok(Self(i.into()));
|
129
|
+
}
|
130
|
+
if let Ok((i1, i2)) = ob.try_convert::<(TextInputSequence, TextInputSequence)>() {
|
131
|
+
return Ok(Self((i1, i2).into()));
|
132
|
+
}
|
133
|
+
// TODO check if this branch is needed
|
134
|
+
if let Ok(arr) = ob.try_convert::<RArray>() {
|
135
|
+
if arr.len() == 2 {
|
136
|
+
let first = arr.entry::<TextInputSequence>(0).unwrap();
|
137
|
+
let second = arr.entry::<TextInputSequence>(1).unwrap();
|
138
|
+
return Ok(Self((first, second).into()));
|
139
|
+
}
|
140
|
+
}
|
141
|
+
Err(Error::new(
|
142
|
+
exception::type_error(),
|
143
|
+
"TextEncodeInput must be a string or pair of strings",
|
144
|
+
))
|
145
|
+
}
|
146
|
+
}
|
147
|
+
|
148
|
+
impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
149
|
+
fn from(i: TextEncodeInput<'s>) -> Self {
|
150
|
+
i.0
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
155
|
+
|
156
|
+
impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
|
157
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
158
|
+
if let Ok(i) = ob.try_convert::<PreTokenizedInputSequence>() {
|
159
|
+
return Ok(Self(i.into()));
|
160
|
+
}
|
161
|
+
if let Ok((i1, i2)) =
|
162
|
+
ob.try_convert::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
|
163
|
+
{
|
164
|
+
return Ok(Self((i1, i2).into()));
|
165
|
+
}
|
166
|
+
// TODO check if this branch is needed
|
167
|
+
if let Ok(arr) = ob.try_convert::<RArray>() {
|
168
|
+
if arr.len() == 2 {
|
169
|
+
let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
|
170
|
+
let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
|
171
|
+
return Ok(Self((first, second).into()));
|
172
|
+
}
|
173
|
+
}
|
174
|
+
Err(Error::new(
|
175
|
+
exception::type_error(),
|
176
|
+
"PreTokenizedEncodeInput must be an array of strings or pair of arrays",
|
177
|
+
))
|
178
|
+
}
|
179
|
+
}
|
180
|
+
|
181
|
+
impl<'s> From<PreTokenizedEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
182
|
+
fn from(i: PreTokenizedEncodeInput<'s>) -> Self {
|
183
|
+
i.0
|
184
|
+
}
|
185
|
+
}
|
186
|
+
|
187
|
+
type Tokenizer = TokenizerImpl<RbModel, RbNormalizer, RbPreTokenizer, RbPostProcessor, RbDecoder>;
|
188
|
+
|
13
189
|
#[magnus::wrap(class = "Tokenizers::Tokenizer")]
|
14
190
|
pub struct RbTokenizer {
|
15
191
|
tokenizer: RefCell<Tokenizer>,
|
16
192
|
}
|
17
193
|
|
18
194
|
impl RbTokenizer {
|
19
|
-
pub fn new(
|
195
|
+
pub fn new(tokenizer: Tokenizer) -> Self {
|
20
196
|
Self {
|
21
|
-
tokenizer: RefCell::new(
|
197
|
+
tokenizer: RefCell::new(tokenizer),
|
22
198
|
}
|
23
199
|
}
|
24
200
|
|
201
|
+
pub fn from_model(model: &RbModel) -> Self {
|
202
|
+
RbTokenizer::new(TokenizerImpl::new(model.clone()))
|
203
|
+
}
|
204
|
+
|
25
205
|
pub fn from_file(path: PathBuf) -> RbResult<Self> {
|
26
206
|
Tokenizer::from_file(path)
|
27
207
|
.map(|v| RbTokenizer {
|
@@ -30,42 +210,290 @@ impl RbTokenizer {
|
|
30
210
|
.map_err(RbError::from)
|
31
211
|
}
|
32
212
|
|
33
|
-
pub fn
|
213
|
+
pub fn to_str(&self, pretty: bool) -> RbResult<String> {
|
214
|
+
self.tokenizer.borrow().to_string(pretty).map_err(RbError::from)
|
215
|
+
}
|
216
|
+
|
217
|
+
pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
|
34
218
|
let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
|
35
|
-
self.tokenizer.borrow_mut().add_special_tokens(&tokens)
|
36
|
-
// TODO return self
|
219
|
+
self.tokenizer.borrow_mut().add_special_tokens(&tokens)
|
37
220
|
}
|
38
221
|
|
39
|
-
pub fn
|
222
|
+
pub fn train(&self, files: Vec<String>, trainer: Option<&RbTrainer>) -> RbResult<()> {
|
223
|
+
let mut trainer = trainer.map_or_else(
|
224
|
+
|| self.tokenizer.borrow().get_model().get_trainer(),
|
225
|
+
|t| t.clone(),
|
226
|
+
);
|
227
|
+
self.tokenizer
|
228
|
+
.borrow_mut()
|
229
|
+
.train_from_files(&mut trainer, files)
|
230
|
+
.map(|_| {})
|
231
|
+
.map_err(RbError::from)
|
232
|
+
}
|
233
|
+
|
234
|
+
pub fn save(&self, path: String, pretty: bool) -> RbResult<()> {
|
40
235
|
self.tokenizer
|
41
236
|
.borrow()
|
42
|
-
.
|
237
|
+
.save(&path, pretty)
|
238
|
+
.map_err(RbError::from)
|
239
|
+
}
|
240
|
+
|
241
|
+
pub fn add_tokens(&self, tokens: Vec<String>) -> usize {
|
242
|
+
let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
|
243
|
+
self.tokenizer.borrow_mut().add_tokens(&tokens)
|
244
|
+
}
|
245
|
+
|
246
|
+
pub fn encode(
|
247
|
+
&self,
|
248
|
+
sequence: Value,
|
249
|
+
pair: Option<Value>,
|
250
|
+
is_pretokenized: bool,
|
251
|
+
add_special_tokens: bool,
|
252
|
+
) -> RbResult<RbEncoding> {
|
253
|
+
let sequence: tk::InputSequence = if is_pretokenized {
|
254
|
+
sequence.try_convert::<PreTokenizedInputSequence>()?.into()
|
255
|
+
} else {
|
256
|
+
sequence.try_convert::<TextInputSequence>()?.into()
|
257
|
+
};
|
258
|
+
let input = match pair {
|
259
|
+
Some(pair) => {
|
260
|
+
let pair: tk::InputSequence = if is_pretokenized {
|
261
|
+
pair.try_convert::<PreTokenizedInputSequence>()?.into()
|
262
|
+
} else {
|
263
|
+
pair.try_convert::<TextInputSequence>()?.into()
|
264
|
+
};
|
265
|
+
tk::EncodeInput::Dual(sequence, pair)
|
266
|
+
}
|
267
|
+
None => tk::EncodeInput::Single(sequence),
|
268
|
+
};
|
269
|
+
|
270
|
+
self.tokenizer
|
271
|
+
.borrow()
|
272
|
+
.encode_char_offsets(input, add_special_tokens)
|
43
273
|
.map(|v| RbEncoding { encoding: v })
|
44
274
|
.map_err(RbError::from)
|
45
275
|
}
|
46
276
|
|
47
|
-
pub fn
|
277
|
+
pub fn encode_batch(
|
278
|
+
&self,
|
279
|
+
input: RArray,
|
280
|
+
is_pretokenized: bool,
|
281
|
+
add_special_tokens: bool,
|
282
|
+
) -> RbResult<RArray> {
|
283
|
+
let input: Vec<tk::EncodeInput> = input
|
284
|
+
.each()
|
285
|
+
.map(|o| {
|
286
|
+
let input: tk::EncodeInput = if is_pretokenized {
|
287
|
+
o?.try_convert::<PreTokenizedEncodeInput>()?.into()
|
288
|
+
} else {
|
289
|
+
o?.try_convert::<TextEncodeInput>()?.into()
|
290
|
+
};
|
291
|
+
Ok(input)
|
292
|
+
})
|
293
|
+
.collect::<RbResult<Vec<tk::EncodeInput>>>()?;
|
294
|
+
self.tokenizer
|
295
|
+
.borrow()
|
296
|
+
.encode_batch_char_offsets(input, add_special_tokens)
|
297
|
+
.map(|encodings| {
|
298
|
+
encodings
|
299
|
+
.into_iter()
|
300
|
+
.map(Into::<RbEncoding>::into)
|
301
|
+
.collect()
|
302
|
+
})
|
303
|
+
.map_err(RbError::from)
|
304
|
+
}
|
305
|
+
|
306
|
+
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
307
|
+
self.tokenizer
|
308
|
+
.borrow()
|
309
|
+
.decode(ids, skip_special_tokens)
|
310
|
+
.map_err(RbError::from)
|
311
|
+
}
|
312
|
+
|
313
|
+
pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
|
48
314
|
self.tokenizer
|
49
315
|
.borrow()
|
50
|
-
.
|
316
|
+
.decode_batch(sequences, skip_special_tokens)
|
51
317
|
.map_err(RbError::from)
|
52
318
|
}
|
53
319
|
|
54
|
-
pub fn set_decoder(&self, decoder: &
|
320
|
+
pub fn set_decoder(&self, decoder: &RbDecoder) {
|
321
|
+
self.tokenizer.borrow_mut().with_decoder(decoder.clone());
|
322
|
+
}
|
323
|
+
|
324
|
+
pub fn set_pre_tokenizer(&self, pretok: &RbPreTokenizer) {
|
55
325
|
self.tokenizer
|
56
326
|
.borrow_mut()
|
57
|
-
.
|
327
|
+
.with_pre_tokenizer(pretok.clone());
|
58
328
|
}
|
59
329
|
|
60
|
-
pub fn
|
330
|
+
pub fn set_post_processor(&self, processor: &RbPostProcessor) {
|
61
331
|
self.tokenizer
|
62
332
|
.borrow_mut()
|
63
|
-
.
|
333
|
+
.with_post_processor(processor.clone());
|
64
334
|
}
|
65
335
|
|
66
|
-
pub fn set_normalizer(&self, normalizer: &
|
336
|
+
pub fn set_normalizer(&self, normalizer: &RbNormalizer) {
|
67
337
|
self.tokenizer
|
68
338
|
.borrow_mut()
|
69
|
-
.with_normalizer(normalizer.
|
339
|
+
.with_normalizer(normalizer.clone());
|
340
|
+
}
|
341
|
+
|
342
|
+
pub fn token_to_id(&self, token: String) -> Option<u32> {
|
343
|
+
self.tokenizer.borrow().token_to_id(&token)
|
344
|
+
}
|
345
|
+
|
346
|
+
pub fn id_to_token(&self, id: u32) -> Option<String> {
|
347
|
+
self.tokenizer.borrow().id_to_token(id)
|
348
|
+
}
|
349
|
+
|
350
|
+
// TODO support more kwargs
|
351
|
+
pub fn enable_padding(&self, kwargs: RHash) -> RbResult<()> {
|
352
|
+
let mut params = PaddingParams::default();
|
353
|
+
|
354
|
+
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
355
|
+
if !value.is_nil() {
|
356
|
+
let dir_str: String = value.try_convert()?;
|
357
|
+
params.direction = match dir_str.as_str() {
|
358
|
+
"left" => PaddingDirection::Left,
|
359
|
+
"right" => PaddingDirection::Right,
|
360
|
+
_ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
|
361
|
+
}
|
362
|
+
}
|
363
|
+
|
364
|
+
let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
|
365
|
+
if !value.is_nil() {
|
366
|
+
params.pad_to_multiple_of = value.try_convert()?;
|
367
|
+
}
|
368
|
+
|
369
|
+
let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
|
370
|
+
if !value.is_nil() {
|
371
|
+
params.pad_id = value.try_convert()?;
|
372
|
+
}
|
373
|
+
|
374
|
+
let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
|
375
|
+
if !value.is_nil() {
|
376
|
+
params.pad_type_id = value.try_convert()?;
|
377
|
+
}
|
378
|
+
|
379
|
+
let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
|
380
|
+
if !value.is_nil() {
|
381
|
+
params.pad_token = value.try_convert()?;
|
382
|
+
}
|
383
|
+
|
384
|
+
let value: Value = kwargs.delete(Symbol::new("length"))?;
|
385
|
+
if value.is_nil() {
|
386
|
+
params.strategy = PaddingStrategy::BatchLongest;
|
387
|
+
} else {
|
388
|
+
params.strategy = PaddingStrategy::Fixed(value.try_convert()?);
|
389
|
+
}
|
390
|
+
|
391
|
+
if !kwargs.is_empty() {
|
392
|
+
// TODO improve message
|
393
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
394
|
+
}
|
395
|
+
|
396
|
+
self.tokenizer.borrow_mut().with_padding(Some(params));
|
397
|
+
|
398
|
+
Ok(())
|
399
|
+
}
|
400
|
+
|
401
|
+
pub fn no_padding(&self) {
|
402
|
+
self.tokenizer.borrow_mut().with_padding(None);
|
403
|
+
}
|
404
|
+
|
405
|
+
pub fn padding(&self) -> RbResult<Option<RHash>> {
|
406
|
+
self.tokenizer.borrow().get_padding().map_or(Ok(None), |params| {
|
407
|
+
let ret_hash = RHash::new();
|
408
|
+
|
409
|
+
ret_hash.aset(
|
410
|
+
"length",
|
411
|
+
match params.strategy {
|
412
|
+
tk::PaddingStrategy::BatchLongest => None,
|
413
|
+
tk::PaddingStrategy::Fixed(size) => Some(size),
|
414
|
+
},
|
415
|
+
)?;
|
416
|
+
ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
|
417
|
+
ret_hash.aset("pad_id", params.pad_id)?;
|
418
|
+
ret_hash.aset("pad_token", &*params.pad_token)?;
|
419
|
+
ret_hash.aset("pad_type_id", params.pad_type_id)?;
|
420
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
421
|
+
|
422
|
+
Ok(Some(ret_hash))
|
423
|
+
})
|
424
|
+
}
|
425
|
+
|
426
|
+
pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
|
427
|
+
let mut params = TruncationParams {
|
428
|
+
max_length,
|
429
|
+
..Default::default()
|
430
|
+
};
|
431
|
+
|
432
|
+
let value: Value = kwargs.delete(Symbol::new("stride"))?;
|
433
|
+
if !value.is_nil() {
|
434
|
+
params.stride = value.try_convert()?;
|
435
|
+
}
|
436
|
+
|
437
|
+
let value: Value = kwargs.delete(Symbol::new("strategy"))?;
|
438
|
+
if !value.is_nil() {
|
439
|
+
let strategy_str: String = value.try_convert()?;
|
440
|
+
params.strategy = match strategy_str.as_str() {
|
441
|
+
"longest_first" => TruncationStrategy::LongestFirst,
|
442
|
+
"only_first" => TruncationStrategy::OnlyFirst,
|
443
|
+
"only_second" => TruncationStrategy::OnlySecond,
|
444
|
+
_ => return Err(Error::new(exception::arg_error(), "The strategy value must be 'longest_first', 'only_first', or 'only_second'")),
|
445
|
+
}
|
446
|
+
}
|
447
|
+
|
448
|
+
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
449
|
+
if !value.is_nil() {
|
450
|
+
let dir_str: String = value.try_convert()?;
|
451
|
+
params.direction = match dir_str.as_str() {
|
452
|
+
"left" => TruncationDirection::Left,
|
453
|
+
"right" => TruncationDirection::Right,
|
454
|
+
_ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
|
455
|
+
}
|
456
|
+
}
|
457
|
+
|
458
|
+
if !kwargs.is_empty() {
|
459
|
+
// TODO improve message
|
460
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
461
|
+
}
|
462
|
+
|
463
|
+
self.tokenizer.borrow_mut().with_truncation(Some(params));
|
464
|
+
|
465
|
+
Ok(())
|
466
|
+
}
|
467
|
+
|
468
|
+
pub fn no_truncation(&self) {
|
469
|
+
self.tokenizer.borrow_mut().with_truncation(None);
|
470
|
+
}
|
471
|
+
|
472
|
+
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|
473
|
+
self.tokenizer.borrow().get_truncation().map_or(Ok(None), |params| {
|
474
|
+
let ret_hash = RHash::new();
|
475
|
+
|
476
|
+
ret_hash.aset("max_length", params.max_length)?;
|
477
|
+
ret_hash.aset("stride", params.stride)?;
|
478
|
+
ret_hash.aset("strategy", params.strategy.as_ref())?;
|
479
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
480
|
+
|
481
|
+
Ok(Some(ret_hash))
|
482
|
+
})
|
483
|
+
}
|
484
|
+
|
485
|
+
pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
486
|
+
self.tokenizer
|
487
|
+
.borrow()
|
488
|
+
.get_post_processor()
|
489
|
+
.map_or(0, |p| p.added_tokens(is_pair))
|
490
|
+
}
|
491
|
+
|
492
|
+
pub fn vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
|
493
|
+
self.tokenizer.borrow().get_vocab(with_added_tokens)
|
494
|
+
}
|
495
|
+
|
496
|
+
pub fn vocab_size(&self, with_added_tokens: bool) -> usize {
|
497
|
+
self.tokenizer.borrow().get_vocab_size(with_added_tokens)
|
70
498
|
}
|
71
499
|
}
|