tokenizers 0.2.3 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/Cargo.lock +33 -74
- data/README.md +4 -0
- data/ext/tokenizers/Cargo.toml +4 -2
- data/ext/tokenizers/src/decoders.rs +275 -6
- data/ext/tokenizers/src/encoding.rs +3 -2
- data/ext/tokenizers/src/error.rs +2 -2
- data/ext/tokenizers/src/lib.rs +64 -17
- data/ext/tokenizers/src/models.rs +372 -11
- data/ext/tokenizers/src/normalizers.rs +435 -7
- data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
- data/ext/tokenizers/src/processors.rs +210 -0
- data/ext/tokenizers/src/tokenizer.rs +437 -23
- data/ext/tokenizers/src/trainers.rs +749 -0
- data/ext/tokenizers/src/utils/mod.rs +5 -0
- data/ext/tokenizers/src/utils/normalization.rs +85 -0
- data/ext/tokenizers/src/utils/regex.rs +22 -0
- data/lib/tokenizers/char_bpe_tokenizer.rb +9 -6
- data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
- data/lib/tokenizers/decoders/ctc.rb +9 -0
- data/lib/tokenizers/decoders/metaspace.rb +9 -0
- data/lib/tokenizers/decoders/word_piece.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +2 -2
- data/lib/tokenizers/models/bpe.rb +9 -0
- data/lib/tokenizers/models/unigram.rb +9 -0
- data/lib/tokenizers/models/word_level.rb +13 -0
- data/lib/tokenizers/models/word_piece.rb +9 -0
- data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
- data/lib/tokenizers/normalizers/strip.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
- data/lib/tokenizers/processors/byte_level.rb +9 -0
- data/lib/tokenizers/processors/roberta_processing.rb +9 -0
- data/lib/tokenizers/processors/template_processing.rb +9 -0
- data/lib/tokenizers/tokenizer.rb +40 -7
- data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
- data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
- data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
- data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +42 -2
- metadata +30 -3
@@ -1,27 +1,207 @@
|
|
1
1
|
use std::cell::RefCell;
|
2
|
+
use std::collections::HashMap;
|
2
3
|
use std::path::PathBuf;
|
3
|
-
|
4
|
+
|
5
|
+
use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
|
6
|
+
use tk::tokenizer::{
|
7
|
+
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
8
|
+
TruncationDirection, TruncationParams, TruncationStrategy, TokenizerImpl
|
9
|
+
};
|
4
10
|
use tk::AddedToken;
|
5
11
|
|
6
|
-
use
|
12
|
+
use crate::tk::PostProcessor;
|
13
|
+
|
14
|
+
use super::decoders::RbDecoder;
|
7
15
|
use super::encoding::RbEncoding;
|
8
|
-
use super::models::
|
9
|
-
use super::normalizers::
|
10
|
-
use super::pre_tokenizers::
|
16
|
+
use super::models::RbModel;
|
17
|
+
use super::normalizers::RbNormalizer;
|
18
|
+
use super::pre_tokenizers::RbPreTokenizer;
|
19
|
+
use super::processors::RbPostProcessor;
|
20
|
+
use super::trainers::RbTrainer;
|
11
21
|
use super::{RbError, RbResult};
|
12
22
|
|
23
|
+
pub struct RbAddedToken {
|
24
|
+
pub content: String,
|
25
|
+
pub is_special_token: bool,
|
26
|
+
pub single_word: Option<bool>,
|
27
|
+
pub lstrip: Option<bool>,
|
28
|
+
pub rstrip: Option<bool>,
|
29
|
+
pub normalized: Option<bool>,
|
30
|
+
}
|
31
|
+
|
32
|
+
impl RbAddedToken {
|
33
|
+
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
|
34
|
+
Self {
|
35
|
+
content: content.into(),
|
36
|
+
is_special_token: is_special_token.unwrap_or(false),
|
37
|
+
single_word: None,
|
38
|
+
lstrip: None,
|
39
|
+
rstrip: None,
|
40
|
+
normalized: None,
|
41
|
+
}
|
42
|
+
}
|
43
|
+
|
44
|
+
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
45
|
+
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
|
46
|
+
|
47
|
+
if let Some(sw) = self.single_word {
|
48
|
+
token = token.single_word(sw);
|
49
|
+
}
|
50
|
+
if let Some(ls) = self.lstrip {
|
51
|
+
token = token.lstrip(ls);
|
52
|
+
}
|
53
|
+
if let Some(rs) = self.rstrip {
|
54
|
+
token = token.rstrip(rs);
|
55
|
+
}
|
56
|
+
if let Some(n) = self.normalized {
|
57
|
+
token = token.normalized(n);
|
58
|
+
}
|
59
|
+
|
60
|
+
token
|
61
|
+
}
|
62
|
+
}
|
63
|
+
|
64
|
+
impl From<tk::AddedToken> for RbAddedToken {
|
65
|
+
fn from(token: tk::AddedToken) -> Self {
|
66
|
+
Self {
|
67
|
+
content: token.content,
|
68
|
+
single_word: Some(token.single_word),
|
69
|
+
lstrip: Some(token.lstrip),
|
70
|
+
rstrip: Some(token.rstrip),
|
71
|
+
normalized: Some(token.normalized),
|
72
|
+
is_special_token: !token.normalized,
|
73
|
+
}
|
74
|
+
}
|
75
|
+
}
|
76
|
+
|
77
|
+
struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
78
|
+
|
79
|
+
impl<'s> TryConvert for TextInputSequence<'s> {
|
80
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
81
|
+
Ok(Self(ob.try_convert::<String>()?.into()))
|
82
|
+
}
|
83
|
+
}
|
84
|
+
|
85
|
+
impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
|
86
|
+
fn from(s: TextInputSequence<'s>) -> Self {
|
87
|
+
s.0
|
88
|
+
}
|
89
|
+
}
|
90
|
+
|
91
|
+
struct RbArrayStr(Vec<String>);
|
92
|
+
|
93
|
+
impl TryConvert for RbArrayStr {
|
94
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
95
|
+
let seq = ob.try_convert::<Vec<String>>()?;
|
96
|
+
Ok(Self(seq))
|
97
|
+
}
|
98
|
+
}
|
99
|
+
|
100
|
+
impl From<RbArrayStr> for tk::InputSequence<'_> {
|
101
|
+
fn from(s: RbArrayStr) -> Self {
|
102
|
+
s.0.into()
|
103
|
+
}
|
104
|
+
}
|
105
|
+
|
106
|
+
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
107
|
+
|
108
|
+
impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
|
109
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
110
|
+
if let Ok(seq) = ob.try_convert::<RbArrayStr>() {
|
111
|
+
return Ok(Self(seq.into()));
|
112
|
+
}
|
113
|
+
todo!()
|
114
|
+
}
|
115
|
+
}
|
116
|
+
|
117
|
+
impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
|
118
|
+
fn from(s: PreTokenizedInputSequence<'s>) -> Self {
|
119
|
+
s.0
|
120
|
+
}
|
121
|
+
}
|
122
|
+
|
123
|
+
struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
124
|
+
|
125
|
+
impl<'s> TryConvert for TextEncodeInput<'s> {
|
126
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
127
|
+
if let Ok(i) = ob.try_convert::<TextInputSequence>() {
|
128
|
+
return Ok(Self(i.into()));
|
129
|
+
}
|
130
|
+
if let Ok((i1, i2)) = ob.try_convert::<(TextInputSequence, TextInputSequence)>() {
|
131
|
+
return Ok(Self((i1, i2).into()));
|
132
|
+
}
|
133
|
+
// TODO check if this branch is needed
|
134
|
+
if let Ok(arr) = ob.try_convert::<RArray>() {
|
135
|
+
if arr.len() == 2 {
|
136
|
+
let first = arr.entry::<TextInputSequence>(0).unwrap();
|
137
|
+
let second = arr.entry::<TextInputSequence>(1).unwrap();
|
138
|
+
return Ok(Self((first, second).into()));
|
139
|
+
}
|
140
|
+
}
|
141
|
+
Err(Error::new(
|
142
|
+
exception::type_error(),
|
143
|
+
"TextEncodeInput must be a string or pair of strings",
|
144
|
+
))
|
145
|
+
}
|
146
|
+
}
|
147
|
+
|
148
|
+
impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
149
|
+
fn from(i: TextEncodeInput<'s>) -> Self {
|
150
|
+
i.0
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
155
|
+
|
156
|
+
impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
|
157
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
158
|
+
if let Ok(i) = ob.try_convert::<PreTokenizedInputSequence>() {
|
159
|
+
return Ok(Self(i.into()));
|
160
|
+
}
|
161
|
+
if let Ok((i1, i2)) =
|
162
|
+
ob.try_convert::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
|
163
|
+
{
|
164
|
+
return Ok(Self((i1, i2).into()));
|
165
|
+
}
|
166
|
+
// TODO check if this branch is needed
|
167
|
+
if let Ok(arr) = ob.try_convert::<RArray>() {
|
168
|
+
if arr.len() == 2 {
|
169
|
+
let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
|
170
|
+
let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
|
171
|
+
return Ok(Self((first, second).into()));
|
172
|
+
}
|
173
|
+
}
|
174
|
+
Err(Error::new(
|
175
|
+
exception::type_error(),
|
176
|
+
"PreTokenizedEncodeInput must be an array of strings or pair of arrays",
|
177
|
+
))
|
178
|
+
}
|
179
|
+
}
|
180
|
+
|
181
|
+
impl<'s> From<PreTokenizedEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
182
|
+
fn from(i: PreTokenizedEncodeInput<'s>) -> Self {
|
183
|
+
i.0
|
184
|
+
}
|
185
|
+
}
|
186
|
+
|
187
|
+
type Tokenizer = TokenizerImpl<RbModel, RbNormalizer, RbPreTokenizer, RbPostProcessor, RbDecoder>;
|
188
|
+
|
13
189
|
#[magnus::wrap(class = "Tokenizers::Tokenizer")]
|
14
190
|
pub struct RbTokenizer {
|
15
191
|
tokenizer: RefCell<Tokenizer>,
|
16
192
|
}
|
17
193
|
|
18
194
|
impl RbTokenizer {
|
19
|
-
pub fn new(
|
195
|
+
pub fn new(tokenizer: Tokenizer) -> Self {
|
20
196
|
Self {
|
21
|
-
tokenizer: RefCell::new(
|
197
|
+
tokenizer: RefCell::new(tokenizer),
|
22
198
|
}
|
23
199
|
}
|
24
200
|
|
201
|
+
pub fn from_model(model: &RbModel) -> Self {
|
202
|
+
RbTokenizer::new(TokenizerImpl::new(model.clone()))
|
203
|
+
}
|
204
|
+
|
25
205
|
pub fn from_file(path: PathBuf) -> RbResult<Self> {
|
26
206
|
Tokenizer::from_file(path)
|
27
207
|
.map(|v| RbTokenizer {
|
@@ -30,49 +210,133 @@ impl RbTokenizer {
|
|
30
210
|
.map_err(RbError::from)
|
31
211
|
}
|
32
212
|
|
33
|
-
pub fn
|
213
|
+
pub fn to_str(&self, pretty: bool) -> RbResult<String> {
|
214
|
+
self.tokenizer.borrow().to_string(pretty).map_err(RbError::from)
|
215
|
+
}
|
216
|
+
|
217
|
+
pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
|
34
218
|
let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
|
35
|
-
self.tokenizer.borrow_mut().add_special_tokens(&tokens)
|
36
|
-
|
219
|
+
self.tokenizer.borrow_mut().add_special_tokens(&tokens)
|
220
|
+
}
|
221
|
+
|
222
|
+
pub fn train(&self, files: Vec<String>, trainer: Option<&RbTrainer>) -> RbResult<()> {
|
223
|
+
let mut trainer = trainer.map_or_else(
|
224
|
+
|| self.tokenizer.borrow().get_model().get_trainer(),
|
225
|
+
|t| t.clone(),
|
226
|
+
);
|
227
|
+
self.tokenizer
|
228
|
+
.borrow_mut()
|
229
|
+
.train_from_files(&mut trainer, files)
|
230
|
+
.map(|_| {})
|
231
|
+
.map_err(RbError::from)
|
37
232
|
}
|
38
233
|
|
39
|
-
pub fn
|
234
|
+
pub fn save(&self, path: String, pretty: bool) -> RbResult<()> {
|
235
|
+
self.tokenizer
|
236
|
+
.borrow()
|
237
|
+
.save(&path, pretty)
|
238
|
+
.map_err(RbError::from)
|
239
|
+
}
|
240
|
+
|
241
|
+
pub fn add_tokens(&self, tokens: Vec<String>) -> usize {
|
40
242
|
let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
|
41
|
-
self.tokenizer.borrow_mut().add_tokens(&tokens)
|
42
|
-
// TODO return self
|
243
|
+
self.tokenizer.borrow_mut().add_tokens(&tokens)
|
43
244
|
}
|
44
245
|
|
45
|
-
pub fn encode(
|
246
|
+
pub fn encode(
|
247
|
+
&self,
|
248
|
+
sequence: Value,
|
249
|
+
pair: Option<Value>,
|
250
|
+
is_pretokenized: bool,
|
251
|
+
add_special_tokens: bool,
|
252
|
+
) -> RbResult<RbEncoding> {
|
253
|
+
let sequence: tk::InputSequence = if is_pretokenized {
|
254
|
+
sequence.try_convert::<PreTokenizedInputSequence>()?.into()
|
255
|
+
} else {
|
256
|
+
sequence.try_convert::<TextInputSequence>()?.into()
|
257
|
+
};
|
258
|
+
let input = match pair {
|
259
|
+
Some(pair) => {
|
260
|
+
let pair: tk::InputSequence = if is_pretokenized {
|
261
|
+
pair.try_convert::<PreTokenizedInputSequence>()?.into()
|
262
|
+
} else {
|
263
|
+
pair.try_convert::<TextInputSequence>()?.into()
|
264
|
+
};
|
265
|
+
tk::EncodeInput::Dual(sequence, pair)
|
266
|
+
}
|
267
|
+
None => tk::EncodeInput::Single(sequence),
|
268
|
+
};
|
269
|
+
|
46
270
|
self.tokenizer
|
47
271
|
.borrow()
|
48
|
-
.
|
272
|
+
.encode_char_offsets(input, add_special_tokens)
|
49
273
|
.map(|v| RbEncoding { encoding: v })
|
50
274
|
.map_err(RbError::from)
|
51
275
|
}
|
52
276
|
|
53
|
-
pub fn
|
277
|
+
pub fn encode_batch(
|
278
|
+
&self,
|
279
|
+
input: RArray,
|
280
|
+
is_pretokenized: bool,
|
281
|
+
add_special_tokens: bool,
|
282
|
+
) -> RbResult<RArray> {
|
283
|
+
let input: Vec<tk::EncodeInput> = input
|
284
|
+
.each()
|
285
|
+
.map(|o| {
|
286
|
+
let input: tk::EncodeInput = if is_pretokenized {
|
287
|
+
o?.try_convert::<PreTokenizedEncodeInput>()?.into()
|
288
|
+
} else {
|
289
|
+
o?.try_convert::<TextEncodeInput>()?.into()
|
290
|
+
};
|
291
|
+
Ok(input)
|
292
|
+
})
|
293
|
+
.collect::<RbResult<Vec<tk::EncodeInput>>>()?;
|
294
|
+
self.tokenizer
|
295
|
+
.borrow()
|
296
|
+
.encode_batch_char_offsets(input, add_special_tokens)
|
297
|
+
.map(|encodings| {
|
298
|
+
encodings
|
299
|
+
.into_iter()
|
300
|
+
.map(Into::<RbEncoding>::into)
|
301
|
+
.collect()
|
302
|
+
})
|
303
|
+
.map_err(RbError::from)
|
304
|
+
}
|
305
|
+
|
306
|
+
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
307
|
+
self.tokenizer
|
308
|
+
.borrow()
|
309
|
+
.decode(ids, skip_special_tokens)
|
310
|
+
.map_err(RbError::from)
|
311
|
+
}
|
312
|
+
|
313
|
+
pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
|
54
314
|
self.tokenizer
|
55
315
|
.borrow()
|
56
|
-
.
|
316
|
+
.decode_batch(sequences, skip_special_tokens)
|
57
317
|
.map_err(RbError::from)
|
58
318
|
}
|
59
319
|
|
60
|
-
pub fn set_decoder(&self, decoder: &
|
320
|
+
pub fn set_decoder(&self, decoder: &RbDecoder) {
|
321
|
+
self.tokenizer.borrow_mut().with_decoder(decoder.clone());
|
322
|
+
}
|
323
|
+
|
324
|
+
pub fn set_pre_tokenizer(&self, pretok: &RbPreTokenizer) {
|
61
325
|
self.tokenizer
|
62
326
|
.borrow_mut()
|
63
|
-
.
|
327
|
+
.with_pre_tokenizer(pretok.clone());
|
64
328
|
}
|
65
329
|
|
66
|
-
pub fn
|
330
|
+
pub fn set_post_processor(&self, processor: &RbPostProcessor) {
|
67
331
|
self.tokenizer
|
68
332
|
.borrow_mut()
|
69
|
-
.
|
333
|
+
.with_post_processor(processor.clone());
|
70
334
|
}
|
71
335
|
|
72
|
-
pub fn set_normalizer(&self, normalizer: &
|
336
|
+
pub fn set_normalizer(&self, normalizer: &RbNormalizer) {
|
73
337
|
self.tokenizer
|
74
338
|
.borrow_mut()
|
75
|
-
.with_normalizer(normalizer.
|
339
|
+
.with_normalizer(normalizer.clone());
|
76
340
|
}
|
77
341
|
|
78
342
|
pub fn token_to_id(&self, token: String) -> Option<u32> {
|
@@ -82,4 +346,154 @@ impl RbTokenizer {
|
|
82
346
|
pub fn id_to_token(&self, id: u32) -> Option<String> {
|
83
347
|
self.tokenizer.borrow().id_to_token(id)
|
84
348
|
}
|
349
|
+
|
350
|
+
// TODO support more kwargs
|
351
|
+
pub fn enable_padding(&self, kwargs: RHash) -> RbResult<()> {
|
352
|
+
let mut params = PaddingParams::default();
|
353
|
+
|
354
|
+
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
355
|
+
if !value.is_nil() {
|
356
|
+
let dir_str: String = value.try_convert()?;
|
357
|
+
params.direction = match dir_str.as_str() {
|
358
|
+
"left" => PaddingDirection::Left,
|
359
|
+
"right" => PaddingDirection::Right,
|
360
|
+
_ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
|
361
|
+
}
|
362
|
+
}
|
363
|
+
|
364
|
+
let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
|
365
|
+
if !value.is_nil() {
|
366
|
+
params.pad_to_multiple_of = value.try_convert()?;
|
367
|
+
}
|
368
|
+
|
369
|
+
let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
|
370
|
+
if !value.is_nil() {
|
371
|
+
params.pad_id = value.try_convert()?;
|
372
|
+
}
|
373
|
+
|
374
|
+
let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
|
375
|
+
if !value.is_nil() {
|
376
|
+
params.pad_type_id = value.try_convert()?;
|
377
|
+
}
|
378
|
+
|
379
|
+
let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
|
380
|
+
if !value.is_nil() {
|
381
|
+
params.pad_token = value.try_convert()?;
|
382
|
+
}
|
383
|
+
|
384
|
+
let value: Value = kwargs.delete(Symbol::new("length"))?;
|
385
|
+
if value.is_nil() {
|
386
|
+
params.strategy = PaddingStrategy::BatchLongest;
|
387
|
+
} else {
|
388
|
+
params.strategy = PaddingStrategy::Fixed(value.try_convert()?);
|
389
|
+
}
|
390
|
+
|
391
|
+
if !kwargs.is_empty() {
|
392
|
+
// TODO improve message
|
393
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
394
|
+
}
|
395
|
+
|
396
|
+
self.tokenizer.borrow_mut().with_padding(Some(params));
|
397
|
+
|
398
|
+
Ok(())
|
399
|
+
}
|
400
|
+
|
401
|
+
pub fn no_padding(&self) {
|
402
|
+
self.tokenizer.borrow_mut().with_padding(None);
|
403
|
+
}
|
404
|
+
|
405
|
+
pub fn padding(&self) -> RbResult<Option<RHash>> {
|
406
|
+
self.tokenizer.borrow().get_padding().map_or(Ok(None), |params| {
|
407
|
+
let ret_hash = RHash::new();
|
408
|
+
|
409
|
+
ret_hash.aset(
|
410
|
+
"length",
|
411
|
+
match params.strategy {
|
412
|
+
tk::PaddingStrategy::BatchLongest => None,
|
413
|
+
tk::PaddingStrategy::Fixed(size) => Some(size),
|
414
|
+
},
|
415
|
+
)?;
|
416
|
+
ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
|
417
|
+
ret_hash.aset("pad_id", params.pad_id)?;
|
418
|
+
ret_hash.aset("pad_token", &*params.pad_token)?;
|
419
|
+
ret_hash.aset("pad_type_id", params.pad_type_id)?;
|
420
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
421
|
+
|
422
|
+
Ok(Some(ret_hash))
|
423
|
+
})
|
424
|
+
}
|
425
|
+
|
426
|
+
pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
|
427
|
+
let mut params = TruncationParams {
|
428
|
+
max_length,
|
429
|
+
..Default::default()
|
430
|
+
};
|
431
|
+
|
432
|
+
let value: Value = kwargs.delete(Symbol::new("stride"))?;
|
433
|
+
if !value.is_nil() {
|
434
|
+
params.stride = value.try_convert()?;
|
435
|
+
}
|
436
|
+
|
437
|
+
let value: Value = kwargs.delete(Symbol::new("strategy"))?;
|
438
|
+
if !value.is_nil() {
|
439
|
+
let strategy_str: String = value.try_convert()?;
|
440
|
+
params.strategy = match strategy_str.as_str() {
|
441
|
+
"longest_first" => TruncationStrategy::LongestFirst,
|
442
|
+
"only_first" => TruncationStrategy::OnlyFirst,
|
443
|
+
"only_second" => TruncationStrategy::OnlySecond,
|
444
|
+
_ => return Err(Error::new(exception::arg_error(), "The strategy value must be 'longest_first', 'only_first', or 'only_second'")),
|
445
|
+
}
|
446
|
+
}
|
447
|
+
|
448
|
+
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
449
|
+
if !value.is_nil() {
|
450
|
+
let dir_str: String = value.try_convert()?;
|
451
|
+
params.direction = match dir_str.as_str() {
|
452
|
+
"left" => TruncationDirection::Left,
|
453
|
+
"right" => TruncationDirection::Right,
|
454
|
+
_ => return Err(Error::new(exception::arg_error(), "The direction value must be 'left' or 'right'")),
|
455
|
+
}
|
456
|
+
}
|
457
|
+
|
458
|
+
if !kwargs.is_empty() {
|
459
|
+
// TODO improve message
|
460
|
+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
461
|
+
}
|
462
|
+
|
463
|
+
self.tokenizer.borrow_mut().with_truncation(Some(params));
|
464
|
+
|
465
|
+
Ok(())
|
466
|
+
}
|
467
|
+
|
468
|
+
pub fn no_truncation(&self) {
|
469
|
+
self.tokenizer.borrow_mut().with_truncation(None);
|
470
|
+
}
|
471
|
+
|
472
|
+
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|
473
|
+
self.tokenizer.borrow().get_truncation().map_or(Ok(None), |params| {
|
474
|
+
let ret_hash = RHash::new();
|
475
|
+
|
476
|
+
ret_hash.aset("max_length", params.max_length)?;
|
477
|
+
ret_hash.aset("stride", params.stride)?;
|
478
|
+
ret_hash.aset("strategy", params.strategy.as_ref())?;
|
479
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
480
|
+
|
481
|
+
Ok(Some(ret_hash))
|
482
|
+
})
|
483
|
+
}
|
484
|
+
|
485
|
+
pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
486
|
+
self.tokenizer
|
487
|
+
.borrow()
|
488
|
+
.get_post_processor()
|
489
|
+
.map_or(0, |p| p.added_tokens(is_pair))
|
490
|
+
}
|
491
|
+
|
492
|
+
pub fn vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
|
493
|
+
self.tokenizer.borrow().get_vocab(with_added_tokens)
|
494
|
+
}
|
495
|
+
|
496
|
+
pub fn vocab_size(&self, with_added_tokens: bool) -> usize {
|
497
|
+
self.tokenizer.borrow().get_vocab_size(with_added_tokens)
|
498
|
+
}
|
85
499
|
}
|