tokenizers 0.5.4 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/Cargo.lock +202 -126
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/encoding.rs +10 -8
- data/ext/tokenizers/src/models.rs +37 -24
- data/ext/tokenizers/src/normalizers.rs +1 -2
- data/ext/tokenizers/src/pre_tokenizers.rs +5 -5
- data/ext/tokenizers/src/tokenizer.rs +65 -53
- data/ext/tokenizers/src/trainers.rs +60 -50
- data/ext/tokenizers/src/utils/normalization.rs +3 -2
- data/ext/tokenizers/src/utils/regex.rs +5 -4
- data/lib/tokenizers/from_pretrained.rb +2 -2
- data/lib/tokenizers/trainers/unigram_trainer.rb +10 -9
- data/lib/tokenizers/trainers/word_piece_trainer.rb +10 -9
- data/lib/tokenizers/version.rb +1 -1
- metadata +4 -4
data/ext/tokenizers/Cargo.toml
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
[package]
|
2
2
|
name = "tokenizers"
|
3
|
-
version = "0.
|
3
|
+
version = "0.6.0"
|
4
4
|
license = "Apache-2.0"
|
5
5
|
authors = ["Andrew Kane <andrew@ankane.org>"]
|
6
6
|
edition = "2021"
|
@@ -11,11 +11,12 @@ publish = false
|
|
11
11
|
crate-type = ["cdylib"]
|
12
12
|
|
13
13
|
[dependencies]
|
14
|
-
|
14
|
+
ahash = { version = "0.8.11", features = ["serde"] }
|
15
|
+
magnus = "0.8"
|
15
16
|
onig = { version = "6", default-features = false }
|
16
17
|
serde = { version = "1", features = ["rc", "derive"] }
|
17
18
|
|
18
19
|
[dependencies.tokenizers]
|
19
|
-
version = "=0.
|
20
|
+
version = "=0.22.0" # also update in from_pretrained.rb
|
20
21
|
default-features = false
|
21
22
|
features = ["progressbar", "onig", "esaxx_fast"]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
use magnus::RArray;
|
1
|
+
use magnus::{RArray, Ruby};
|
2
2
|
use tk::{Encoding, Offsets};
|
3
3
|
|
4
4
|
#[magnus::wrap(class = "Tokenizers::Encoding")]
|
@@ -50,13 +50,15 @@ impl RbEncoding {
|
|
50
50
|
self.encoding.get_attention_mask().to_vec()
|
51
51
|
}
|
52
52
|
|
53
|
-
pub fn overflowing(&
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
53
|
+
pub fn overflowing(ruby: &Ruby, rb_self: &Self) -> RArray {
|
54
|
+
ruby.ary_from_iter(
|
55
|
+
rb_self
|
56
|
+
.encoding
|
57
|
+
.get_overflowing()
|
58
|
+
.clone()
|
59
|
+
.into_iter()
|
60
|
+
.map(Into::<RbEncoding>::into),
|
61
|
+
)
|
60
62
|
}
|
61
63
|
|
62
64
|
pub fn word_to_tokens(&self, word_index: u32, sequence_index: usize) -> Option<(usize, usize)> {
|
@@ -3,14 +3,14 @@ use std::path::{Path, PathBuf};
|
|
3
3
|
use std::sync::{Arc, RwLock};
|
4
4
|
|
5
5
|
use crate::trainers::RbTrainer;
|
6
|
+
use ahash::AHashMap;
|
6
7
|
use magnus::prelude::*;
|
7
8
|
use magnus::{
|
8
|
-
data_type_builder,
|
9
|
-
|
10
|
-
TypedData, Value,
|
9
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
|
10
|
+
Module, Object, RClass, RHash, RModule, Ruby, TryConvert, TypedData, Value,
|
11
11
|
};
|
12
12
|
use serde::{Deserialize, Serialize};
|
13
|
-
use tk::models::bpe::{BpeBuilder, Merges,
|
13
|
+
use tk::models::bpe::{BpeBuilder, Merges, BPE};
|
14
14
|
use tk::models::unigram::Unigram;
|
15
15
|
use tk::models::wordlevel::WordLevel;
|
16
16
|
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
|
@@ -72,52 +72,59 @@ pub struct RbBPE {}
|
|
72
72
|
|
73
73
|
impl RbBPE {
|
74
74
|
fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
75
|
-
let
|
75
|
+
let ruby = Ruby::get().unwrap();
|
76
|
+
|
77
|
+
let value: Value = kwargs.delete(ruby.to_symbol("cache_capacity"))?;
|
76
78
|
if !value.is_nil() {
|
77
79
|
builder = builder.cache_capacity(TryConvert::try_convert(value)?);
|
78
80
|
}
|
79
81
|
|
80
|
-
let value: Value = kwargs.delete(
|
82
|
+
let value: Value = kwargs.delete(ruby.to_symbol("dropout"))?;
|
81
83
|
if !value.is_nil() {
|
82
84
|
builder = builder.dropout(TryConvert::try_convert(value)?);
|
83
85
|
}
|
84
86
|
|
85
|
-
let value: Value = kwargs.delete(
|
87
|
+
let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
|
86
88
|
if !value.is_nil() {
|
87
89
|
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
88
90
|
}
|
89
91
|
|
90
|
-
let value: Value = kwargs.delete(
|
92
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
91
93
|
if !value.is_nil() {
|
92
94
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
93
95
|
}
|
94
96
|
|
95
|
-
let value: Value = kwargs.delete(
|
97
|
+
let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
|
96
98
|
if !value.is_nil() {
|
97
99
|
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
98
100
|
}
|
99
101
|
|
100
|
-
let value: Value = kwargs.delete(
|
102
|
+
let value: Value = kwargs.delete(ruby.to_symbol("fuse_unk"))?;
|
101
103
|
if !value.is_nil() {
|
102
104
|
builder = builder.fuse_unk(TryConvert::try_convert(value)?);
|
103
105
|
}
|
104
106
|
|
105
|
-
let value: Value = kwargs.delete(
|
107
|
+
let value: Value = kwargs.delete(ruby.to_symbol("byte_fallback"))?;
|
106
108
|
if !value.is_nil() {
|
107
109
|
builder = builder.byte_fallback(TryConvert::try_convert(value)?);
|
108
110
|
}
|
109
111
|
|
110
112
|
if !kwargs.is_empty() {
|
111
113
|
// TODO improve message
|
112
|
-
return Err(Error::new(
|
114
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
113
115
|
}
|
114
116
|
|
115
117
|
builder.build().map(|v| v.into()).map_err(RbError::from)
|
116
118
|
}
|
117
119
|
|
118
|
-
pub fn new(
|
120
|
+
pub fn new(
|
121
|
+
vocab: Option<HashMap<String, u32>>,
|
122
|
+
merges: Option<Merges>,
|
123
|
+
kwargs: RHash,
|
124
|
+
) -> RbResult<RbModel> {
|
119
125
|
let mut builder = BPE::builder();
|
120
126
|
if let (Some(vocab), Some(merges)) = (vocab, merges) {
|
127
|
+
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
|
121
128
|
builder = builder.vocab_and_merges(vocab, merges);
|
122
129
|
}
|
123
130
|
RbBPE::with_builder(builder, kwargs)
|
@@ -125,7 +132,7 @@ impl RbBPE {
|
|
125
132
|
|
126
133
|
pub fn from_file(vocab: String, merges: String, kwargs: RHash) -> RbResult<RbModel> {
|
127
134
|
let (vocab, merges) = BPE::read_file(&vocab, &merges).map_err(RbError::from)?;
|
128
|
-
|
135
|
+
let vocab = vocab.into_iter().collect();
|
129
136
|
RbBPE::new(Some(vocab), Some(merges), kwargs)
|
130
137
|
}
|
131
138
|
}
|
@@ -251,6 +258,7 @@ pub struct RbUnigram {}
|
|
251
258
|
|
252
259
|
impl RbUnigram {
|
253
260
|
fn new(
|
261
|
+
ruby: &Ruby,
|
254
262
|
vocab: Option<Vec<(String, f64)>>,
|
255
263
|
unk_id: Option<usize>,
|
256
264
|
byte_fallback: Option<bool>,
|
@@ -263,7 +271,7 @@ impl RbUnigram {
|
|
263
271
|
}
|
264
272
|
(None, None, _) => Ok(Unigram::default().into()),
|
265
273
|
_ => Err(Error::new(
|
266
|
-
|
274
|
+
ruby.exception_arg_error(),
|
267
275
|
"`vocab` and `unk_id` must be both specified",
|
268
276
|
)),
|
269
277
|
}
|
@@ -279,7 +287,7 @@ impl RbWordLevel {
|
|
279
287
|
) -> RbResult<RbModel> {
|
280
288
|
let mut builder = WordLevel::builder();
|
281
289
|
if let Some(vocab) = vocab {
|
282
|
-
builder = builder.vocab(vocab);
|
290
|
+
builder = builder.vocab(vocab.into_iter().collect());
|
283
291
|
}
|
284
292
|
if let Some(unk_token) = unk_token {
|
285
293
|
builder = builder.unk_token(unk_token);
|
@@ -287,13 +295,15 @@ impl RbWordLevel {
|
|
287
295
|
builder.build().map(|v| v.into()).map_err(RbError::from)
|
288
296
|
}
|
289
297
|
|
290
|
-
pub fn read_file(vocab: String) -> RbResult<
|
291
|
-
WordLevel::read_file(&vocab).map_err(RbError::from)
|
298
|
+
pub fn read_file(vocab: String) -> RbResult<HashMap<String, u32>> {
|
299
|
+
let vocab = WordLevel::read_file(&vocab).map_err(RbError::from)?;
|
300
|
+
let vocab: HashMap<_, _> = vocab.into_iter().collect();
|
301
|
+
Ok(vocab)
|
292
302
|
}
|
293
303
|
|
294
304
|
pub fn from_file(vocab: String, unk_token: Option<String>) -> RbResult<RbModel> {
|
295
305
|
let vocab = WordLevel::read_file(&vocab).map_err(RbError::from)?;
|
296
|
-
|
306
|
+
let vocab = vocab.into_iter().collect();
|
297
307
|
RbWordLevel::new(Some(vocab), unk_token)
|
298
308
|
}
|
299
309
|
}
|
@@ -302,24 +312,26 @@ pub struct RbWordPiece {}
|
|
302
312
|
|
303
313
|
impl RbWordPiece {
|
304
314
|
fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
305
|
-
let
|
315
|
+
let ruby = Ruby::get().unwrap();
|
316
|
+
|
317
|
+
let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
|
306
318
|
if !value.is_nil() {
|
307
319
|
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
308
320
|
}
|
309
321
|
|
310
|
-
let value: Value = kwargs.delete(
|
322
|
+
let value: Value = kwargs.delete(ruby.to_symbol("max_input_chars_per_word"))?;
|
311
323
|
if !value.is_nil() {
|
312
324
|
builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
|
313
325
|
}
|
314
326
|
|
315
|
-
let value: Value = kwargs.delete(
|
327
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
316
328
|
if !value.is_nil() {
|
317
329
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
318
330
|
}
|
319
331
|
|
320
332
|
if !kwargs.is_empty() {
|
321
333
|
// TODO improve message
|
322
|
-
return Err(Error::new(
|
334
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
323
335
|
}
|
324
336
|
|
325
337
|
builder.build().map(|v| v.into()).map_err(RbError::from)
|
@@ -328,6 +340,7 @@ impl RbWordPiece {
|
|
328
340
|
pub fn new(vocab: Option<HashMap<String, u32>>, kwargs: RHash) -> RbResult<RbModel> {
|
329
341
|
let mut builder = WordPiece::builder();
|
330
342
|
if let Some(vocab) = vocab {
|
343
|
+
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
|
331
344
|
builder = builder.vocab(vocab);
|
332
345
|
}
|
333
346
|
RbWordPiece::with_builder(builder, kwargs)
|
@@ -336,7 +349,7 @@ impl RbWordPiece {
|
|
336
349
|
pub fn from_file(vocab: String, kwargs: RHash) -> RbResult<RbModel> {
|
337
350
|
let vocab = WordPiece::read_file(&vocab).map_err(RbError::from)?;
|
338
351
|
|
339
|
-
RbWordPiece::new(Some(vocab), kwargs)
|
352
|
+
RbWordPiece::new(Some(vocab.into_iter().collect()), kwargs)
|
340
353
|
}
|
341
354
|
}
|
342
355
|
|
@@ -199,8 +199,7 @@ impl RbPrecompiled {
|
|
199
199
|
Precompiled::from(&precompiled_charsmap)
|
200
200
|
.map_err(|e| {
|
201
201
|
RbError::new_err(format!(
|
202
|
-
"Error while attempting to build Precompiled normalizer: {}"
|
203
|
-
e
|
202
|
+
"Error while attempting to build Precompiled normalizer: {e}"
|
204
203
|
))
|
205
204
|
})
|
206
205
|
.map(|v| v.into())
|
@@ -1,8 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder,
|
5
|
-
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
|
5
|
+
Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
6
6
|
};
|
7
7
|
|
8
8
|
use serde::ser::SerializeStruct;
|
@@ -278,16 +278,16 @@ impl RbSequence {
|
|
278
278
|
}
|
279
279
|
|
280
280
|
pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> {
|
281
|
+
let ruby = Ruby::get().unwrap();
|
281
282
|
let scheme = match string.as_str() {
|
282
283
|
"first" => PrependScheme::First,
|
283
284
|
"never" => PrependScheme::Never,
|
284
285
|
"always" => PrependScheme::Always,
|
285
286
|
_ => {
|
286
287
|
return Err(Error::new(
|
287
|
-
|
288
|
+
ruby.exception_arg_error(),
|
288
289
|
format!(
|
289
|
-
"{} is an unknown variant, should be one of ['first', 'never', 'always']"
|
290
|
-
string
|
290
|
+
"{string} is an unknown variant, should be one of ['first', 'never', 'always']"
|
291
291
|
),
|
292
292
|
));
|
293
293
|
}
|
@@ -4,7 +4,7 @@ use std::path::PathBuf;
|
|
4
4
|
use std::str::FromStr;
|
5
5
|
|
6
6
|
use magnus::prelude::*;
|
7
|
-
use magnus::{
|
7
|
+
use magnus::{Error, RArray, RHash, RString, Ruby, TryConvert, Value};
|
8
8
|
use tk::tokenizer::{
|
9
9
|
Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
|
10
10
|
TruncationParams, TruncationStrategy,
|
@@ -78,37 +78,37 @@ impl From<tk::AddedToken> for RbAddedToken {
|
|
78
78
|
}
|
79
79
|
|
80
80
|
impl RbAddedToken {
|
81
|
-
pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
81
|
+
pub fn new(ruby: &Ruby, content: Option<String>, kwargs: RHash) -> RbResult<Self> {
|
82
82
|
let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
|
83
83
|
|
84
|
-
let value: Value = kwargs.delete(
|
84
|
+
let value: Value = kwargs.delete(ruby.to_symbol("single_word"))?;
|
85
85
|
if !value.is_nil() {
|
86
86
|
token.single_word = TryConvert::try_convert(value)?;
|
87
87
|
}
|
88
88
|
|
89
|
-
let value: Value = kwargs.delete(
|
89
|
+
let value: Value = kwargs.delete(ruby.to_symbol("lstrip"))?;
|
90
90
|
if !value.is_nil() {
|
91
91
|
token.lstrip = TryConvert::try_convert(value)?;
|
92
92
|
}
|
93
93
|
|
94
|
-
let value: Value = kwargs.delete(
|
94
|
+
let value: Value = kwargs.delete(ruby.to_symbol("rstrip"))?;
|
95
95
|
if !value.is_nil() {
|
96
96
|
token.rstrip = TryConvert::try_convert(value)?;
|
97
97
|
}
|
98
98
|
|
99
|
-
let value: Value = kwargs.delete(
|
99
|
+
let value: Value = kwargs.delete(ruby.to_symbol("normalized"))?;
|
100
100
|
if !value.is_nil() {
|
101
101
|
token.normalized = TryConvert::try_convert(value)?;
|
102
102
|
}
|
103
103
|
|
104
|
-
let value: Value = kwargs.delete(
|
104
|
+
let value: Value = kwargs.delete(ruby.to_symbol("special"))?;
|
105
105
|
if !value.is_nil() {
|
106
106
|
token.special = TryConvert::try_convert(value)?;
|
107
107
|
}
|
108
108
|
|
109
109
|
if !kwargs.is_empty() {
|
110
110
|
// TODO improve message
|
111
|
-
return Err(Error::new(
|
111
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
112
112
|
}
|
113
113
|
|
114
114
|
Ok(token)
|
@@ -141,7 +141,7 @@ impl RbAddedToken {
|
|
141
141
|
|
142
142
|
struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
143
143
|
|
144
|
-
impl
|
144
|
+
impl TryConvert for TextInputSequence<'_> {
|
145
145
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
146
146
|
Ok(Self(String::try_convert(ob)?.into()))
|
147
147
|
}
|
@@ -170,7 +170,7 @@ impl From<RbArrayStr> for tk::InputSequence<'_> {
|
|
170
170
|
|
171
171
|
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
172
172
|
|
173
|
-
impl
|
173
|
+
impl TryConvert for PreTokenizedInputSequence<'_> {
|
174
174
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
175
175
|
if let Ok(seq) = RbArrayStr::try_convert(ob) {
|
176
176
|
return Ok(Self(seq.into()));
|
@@ -187,8 +187,10 @@ impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
|
|
187
187
|
|
188
188
|
struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
189
189
|
|
190
|
-
impl
|
190
|
+
impl TryConvert for TextEncodeInput<'_> {
|
191
191
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
192
|
+
let ruby = Ruby::get_with(ob);
|
193
|
+
|
192
194
|
if let Ok(i) = TextInputSequence::try_convert(ob) {
|
193
195
|
return Ok(Self(i.into()));
|
194
196
|
}
|
@@ -204,7 +206,7 @@ impl<'s> TryConvert for TextEncodeInput<'s> {
|
|
204
206
|
}
|
205
207
|
}
|
206
208
|
Err(Error::new(
|
207
|
-
|
209
|
+
ruby.exception_type_error(),
|
208
210
|
"TextEncodeInput must be a string or pair of strings",
|
209
211
|
))
|
210
212
|
}
|
@@ -218,8 +220,10 @@ impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
|
218
220
|
|
219
221
|
struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
220
222
|
|
221
|
-
impl
|
223
|
+
impl TryConvert for PreTokenizedEncodeInput<'_> {
|
222
224
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
225
|
+
let ruby = Ruby::get_with(ob);
|
226
|
+
|
223
227
|
if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
|
224
228
|
return Ok(Self(i.into()));
|
225
229
|
}
|
@@ -237,7 +241,7 @@ impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
|
|
237
241
|
}
|
238
242
|
}
|
239
243
|
Err(Error::new(
|
240
|
-
|
244
|
+
ruby.exception_type_error(),
|
241
245
|
"PreTokenizedEncodeInput must be an array of strings or pair of arrays",
|
242
246
|
))
|
243
247
|
}
|
@@ -351,7 +355,8 @@ impl RbTokenizer {
|
|
351
355
|
}
|
352
356
|
|
353
357
|
pub fn encode_batch(
|
354
|
-
&
|
358
|
+
ruby: &Ruby,
|
359
|
+
rb_self: &Self,
|
355
360
|
input: RArray,
|
356
361
|
is_pretokenized: bool,
|
357
362
|
add_special_tokens: bool,
|
@@ -367,16 +372,16 @@ impl RbTokenizer {
|
|
367
372
|
Ok(input)
|
368
373
|
})
|
369
374
|
.collect::<RbResult<Vec<tk::EncodeInput>>>()?;
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
.map(Into::<RbEncoding>::into)
|
377
|
-
|
378
|
-
|
379
|
-
|
375
|
+
Ok(ruby.ary_from_iter(
|
376
|
+
rb_self
|
377
|
+
.tokenizer
|
378
|
+
.borrow()
|
379
|
+
.encode_batch_char_offsets(input, add_special_tokens)
|
380
|
+
.map(|encodings| {
|
381
|
+
ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into))
|
382
|
+
})
|
383
|
+
.map_err(RbError::from),
|
384
|
+
))
|
380
385
|
}
|
381
386
|
|
382
387
|
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
@@ -453,10 +458,10 @@ impl RbTokenizer {
|
|
453
458
|
}
|
454
459
|
|
455
460
|
// TODO support more kwargs
|
456
|
-
pub fn enable_padding(&
|
461
|
+
pub fn enable_padding(ruby: &Ruby, rb_self: &Self, kwargs: RHash) -> RbResult<()> {
|
457
462
|
let mut params = PaddingParams::default();
|
458
463
|
|
459
|
-
let value: Value = kwargs.delete(
|
464
|
+
let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
|
460
465
|
if !value.is_nil() {
|
461
466
|
let dir_str = String::try_convert(value)?;
|
462
467
|
params.direction = match dir_str.as_str() {
|
@@ -464,34 +469,34 @@ impl RbTokenizer {
|
|
464
469
|
"right" => PaddingDirection::Right,
|
465
470
|
_ => {
|
466
471
|
return Err(Error::new(
|
467
|
-
|
472
|
+
ruby.exception_arg_error(),
|
468
473
|
"The direction value must be 'left' or 'right'",
|
469
474
|
))
|
470
475
|
}
|
471
476
|
}
|
472
477
|
}
|
473
478
|
|
474
|
-
let value: Value = kwargs.delete(
|
479
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_to_multiple_of"))?;
|
475
480
|
if !value.is_nil() {
|
476
481
|
params.pad_to_multiple_of = TryConvert::try_convert(value)?;
|
477
482
|
}
|
478
483
|
|
479
|
-
let value: Value = kwargs.delete(
|
484
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_id"))?;
|
480
485
|
if !value.is_nil() {
|
481
486
|
params.pad_id = TryConvert::try_convert(value)?;
|
482
487
|
}
|
483
488
|
|
484
|
-
let value: Value = kwargs.delete(
|
489
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_type_id"))?;
|
485
490
|
if !value.is_nil() {
|
486
491
|
params.pad_type_id = TryConvert::try_convert(value)?;
|
487
492
|
}
|
488
493
|
|
489
|
-
let value: Value = kwargs.delete(
|
494
|
+
let value: Value = kwargs.delete(ruby.to_symbol("pad_token"))?;
|
490
495
|
if !value.is_nil() {
|
491
496
|
params.pad_token = TryConvert::try_convert(value)?;
|
492
497
|
}
|
493
498
|
|
494
|
-
let value: Value = kwargs.delete(
|
499
|
+
let value: Value = kwargs.delete(ruby.to_symbol("length"))?;
|
495
500
|
if value.is_nil() {
|
496
501
|
params.strategy = PaddingStrategy::BatchLongest;
|
497
502
|
} else {
|
@@ -500,10 +505,10 @@ impl RbTokenizer {
|
|
500
505
|
|
501
506
|
if !kwargs.is_empty() {
|
502
507
|
// TODO improve message
|
503
|
-
return Err(Error::new(
|
508
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
504
509
|
}
|
505
510
|
|
506
|
-
|
511
|
+
rb_self.tokenizer.borrow_mut().with_padding(Some(params));
|
507
512
|
|
508
513
|
Ok(())
|
509
514
|
}
|
@@ -512,12 +517,13 @@ impl RbTokenizer {
|
|
512
517
|
self.tokenizer.borrow_mut().with_padding(None);
|
513
518
|
}
|
514
519
|
|
515
|
-
pub fn padding(&
|
516
|
-
|
520
|
+
pub fn padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
521
|
+
rb_self
|
522
|
+
.tokenizer
|
517
523
|
.borrow()
|
518
524
|
.get_padding()
|
519
525
|
.map_or(Ok(None), |params| {
|
520
|
-
let ret_hash =
|
526
|
+
let ret_hash = ruby.hash_new();
|
521
527
|
|
522
528
|
ret_hash.aset(
|
523
529
|
"length",
|
@@ -536,18 +542,23 @@ impl RbTokenizer {
|
|
536
542
|
})
|
537
543
|
}
|
538
544
|
|
539
|
-
pub fn enable_truncation(
|
545
|
+
pub fn enable_truncation(
|
546
|
+
ruby: &Ruby,
|
547
|
+
rb_self: &Self,
|
548
|
+
max_length: usize,
|
549
|
+
kwargs: RHash,
|
550
|
+
) -> RbResult<()> {
|
540
551
|
let mut params = TruncationParams {
|
541
552
|
max_length,
|
542
553
|
..Default::default()
|
543
554
|
};
|
544
555
|
|
545
|
-
let value: Value = kwargs.delete(
|
556
|
+
let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
|
546
557
|
if !value.is_nil() {
|
547
558
|
params.stride = TryConvert::try_convert(value)?;
|
548
559
|
}
|
549
560
|
|
550
|
-
let value: Value = kwargs.delete(
|
561
|
+
let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
|
551
562
|
if !value.is_nil() {
|
552
563
|
let strategy_str = String::try_convert(value)?;
|
553
564
|
params.strategy = match strategy_str.as_str() {
|
@@ -555,13 +566,13 @@ impl RbTokenizer {
|
|
555
566
|
"only_first" => TruncationStrategy::OnlyFirst,
|
556
567
|
"only_second" => TruncationStrategy::OnlySecond,
|
557
568
|
_ => return Err(Error::new(
|
558
|
-
|
569
|
+
ruby.exception_arg_error(),
|
559
570
|
"The strategy value must be 'longest_first', 'only_first', or 'only_second'",
|
560
571
|
)),
|
561
572
|
}
|
562
573
|
}
|
563
574
|
|
564
|
-
let value: Value = kwargs.delete(
|
575
|
+
let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
|
565
576
|
if !value.is_nil() {
|
566
577
|
let dir_str = String::try_convert(value)?;
|
567
578
|
params.direction = match dir_str.as_str() {
|
@@ -569,7 +580,7 @@ impl RbTokenizer {
|
|
569
580
|
"right" => TruncationDirection::Right,
|
570
581
|
_ => {
|
571
582
|
return Err(Error::new(
|
572
|
-
|
583
|
+
ruby.exception_arg_error(),
|
573
584
|
"The direction value must be 'left' or 'right'",
|
574
585
|
))
|
575
586
|
}
|
@@ -578,12 +589,12 @@ impl RbTokenizer {
|
|
578
589
|
|
579
590
|
if !kwargs.is_empty() {
|
580
591
|
// TODO improve message
|
581
|
-
return Err(Error::new(
|
592
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
582
593
|
}
|
583
594
|
|
584
|
-
if let Err(error_message) =
|
595
|
+
if let Err(error_message) = rb_self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
585
596
|
return Err(Error::new(
|
586
|
-
|
597
|
+
ruby.exception_arg_error(),
|
587
598
|
error_message.to_string(),
|
588
599
|
));
|
589
600
|
}
|
@@ -598,12 +609,13 @@ impl RbTokenizer {
|
|
598
609
|
.expect("Failed to set truncation to `None`! This should never happen");
|
599
610
|
}
|
600
611
|
|
601
|
-
pub fn truncation(&
|
602
|
-
|
612
|
+
pub fn truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
613
|
+
rb_self
|
614
|
+
.tokenizer
|
603
615
|
.borrow()
|
604
616
|
.get_truncation()
|
605
617
|
.map_or(Ok(None), |params| {
|
606
|
-
let ret_hash =
|
618
|
+
let ret_hash = ruby.hash_new();
|
607
619
|
|
608
620
|
ret_hash.aset("max_length", params.max_length)?;
|
609
621
|
ret_hash.aset("stride", params.stride)?;
|
@@ -629,10 +641,10 @@ impl RbTokenizer {
|
|
629
641
|
self.tokenizer.borrow().get_vocab_size(with_added_tokens)
|
630
642
|
}
|
631
643
|
|
632
|
-
pub fn get_added_tokens_decoder(&
|
633
|
-
let sorted_map =
|
644
|
+
pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
|
645
|
+
let sorted_map = ruby.hash_new();
|
634
646
|
|
635
|
-
for (key, value) in
|
647
|
+
for (key, value) in rb_self.tokenizer.borrow().get_added_tokens_decoder() {
|
636
648
|
sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
|
637
649
|
}
|
638
650
|
|