tokenizers 0.6.4 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/Cargo.lock +21 -22
- data/ext/tokenizers/Cargo.toml +3 -2
- data/ext/tokenizers/src/decoders.rs +31 -28
- data/ext/tokenizers/src/encoding.rs +42 -11
- data/ext/tokenizers/src/error.rs +10 -5
- data/ext/tokenizers/src/lib.rs +4 -91
- data/ext/tokenizers/src/models.rs +21 -21
- data/ext/tokenizers/src/normalizers.rs +15 -15
- data/ext/tokenizers/src/pre_tokenizers.rs +15 -15
- data/ext/tokenizers/src/processors.rs +145 -15
- data/ext/tokenizers/src/ruby.rs +51 -0
- data/ext/tokenizers/src/tokenizer.rs +381 -244
- data/ext/tokenizers/src/trainers.rs +55 -49
- data/ext/tokenizers/src/utils/normalization.rs +2 -1
- data/ext/tokenizers/src/utils/regex.rs +2 -2
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/processors/sequence.rb +9 -0
- data/lib/tokenizers/tokenizer.rb +4 -0
- data/lib/tokenizers/version.rb +1 -1
- metadata +4 -2
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
use std::cell::RefCell;
|
|
2
1
|
use std::collections::HashMap;
|
|
3
2
|
use std::path::PathBuf;
|
|
4
3
|
use std::str::FromStr;
|
|
4
|
+
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
|
|
5
5
|
|
|
6
6
|
use magnus::prelude::*;
|
|
7
|
-
use magnus::{Error, RArray, RHash, RString, Ruby, TryConvert, Value};
|
|
7
|
+
use magnus::{function, method, Error, RArray, RHash, RModule, RString, Ruby, TryConvert, Value};
|
|
8
8
|
use tk::tokenizer::{
|
|
9
9
|
Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
|
|
10
10
|
TruncationParams, TruncationStrategy,
|
|
@@ -19,6 +19,7 @@ use super::models::RbModel;
|
|
|
19
19
|
use super::normalizers::RbNormalizer;
|
|
20
20
|
use super::pre_tokenizers::RbPreTokenizer;
|
|
21
21
|
use super::processors::RbPostProcessor;
|
|
22
|
+
use super::ruby::GvlExt;
|
|
22
23
|
use super::trainers::RbTrainer;
|
|
23
24
|
use super::{RbError, RbResult};
|
|
24
25
|
|
|
@@ -72,7 +73,7 @@ impl From<tk::AddedToken> for RbAddedToken {
|
|
|
72
73
|
lstrip: Some(token.lstrip),
|
|
73
74
|
rstrip: Some(token.rstrip),
|
|
74
75
|
normalized: Some(token.normalized),
|
|
75
|
-
special:
|
|
76
|
+
special: token.special,
|
|
76
77
|
}
|
|
77
78
|
}
|
|
78
79
|
}
|
|
@@ -200,8 +201,8 @@ impl TryConvert for TextEncodeInput<'_> {
|
|
|
200
201
|
// TODO check if this branch is needed
|
|
201
202
|
if let Ok(arr) = RArray::try_convert(ob) {
|
|
202
203
|
if arr.len() == 2 {
|
|
203
|
-
let first = arr.entry::<TextInputSequence>(0)
|
|
204
|
-
let second = arr.entry::<TextInputSequence>(1)
|
|
204
|
+
let first = arr.entry::<TextInputSequence>(0)?;
|
|
205
|
+
let second = arr.entry::<TextInputSequence>(1)?;
|
|
205
206
|
return Ok(Self((first, second).into()));
|
|
206
207
|
}
|
|
207
208
|
}
|
|
@@ -235,8 +236,8 @@ impl TryConvert for PreTokenizedEncodeInput<'_> {
|
|
|
235
236
|
// TODO check if this branch is needed
|
|
236
237
|
if let Ok(arr) = RArray::try_convert(ob) {
|
|
237
238
|
if arr.len() == 2 {
|
|
238
|
-
let first = arr.entry::<PreTokenizedInputSequence>(0)
|
|
239
|
-
let second = arr.entry::<PreTokenizedInputSequence>(1)
|
|
239
|
+
let first = arr.entry::<PreTokenizedInputSequence>(0)?;
|
|
240
|
+
let second = arr.entry::<PreTokenizedInputSequence>(1)?;
|
|
240
241
|
return Ok(Self((first, second).into()));
|
|
241
242
|
}
|
|
242
243
|
}
|
|
@@ -257,202 +258,170 @@ type Tokenizer = TokenizerImpl<RbModel, RbNormalizer, RbPreTokenizer, RbPostProc
|
|
|
257
258
|
|
|
258
259
|
#[magnus::wrap(class = "Tokenizers::Tokenizer")]
|
|
259
260
|
pub struct RbTokenizer {
|
|
260
|
-
tokenizer:
|
|
261
|
+
tokenizer: Arc<RwLock<Tokenizer>>,
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
impl Clone for RbTokenizer {
|
|
265
|
+
fn clone(&self) -> Self {
|
|
266
|
+
RbTokenizer {
|
|
267
|
+
tokenizer: Arc::clone(&self.tokenizer),
|
|
268
|
+
}
|
|
269
|
+
}
|
|
261
270
|
}
|
|
262
271
|
|
|
263
272
|
impl RbTokenizer {
|
|
264
273
|
pub fn new(tokenizer: Tokenizer) -> Self {
|
|
265
274
|
Self {
|
|
266
|
-
tokenizer:
|
|
275
|
+
tokenizer: Arc::new(RwLock::new(tokenizer)),
|
|
267
276
|
}
|
|
268
277
|
}
|
|
269
278
|
|
|
279
|
+
/// Acquire the inner tokenizer for reading; surfaces lock poisoning as a
|
|
280
|
+
/// `PyException` instead of panicking.
|
|
281
|
+
pub(crate) fn read_inner(&self) -> RbResult<RwLockReadGuard<'_, Tokenizer>> {
|
|
282
|
+
self.tokenizer
|
|
283
|
+
.read()
|
|
284
|
+
.map_err(|_| RbError::new_err("Tokenizer RwLock is poisoned"))
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
/// Acquire the inner tokenizer for writing; surfaces lock poisoning as a
|
|
288
|
+
/// `PyException` instead of panicking.
|
|
289
|
+
pub(crate) fn write_inner(&self) -> RbResult<RwLockWriteGuard<'_, Tokenizer>> {
|
|
290
|
+
self.tokenizer
|
|
291
|
+
.write()
|
|
292
|
+
.map_err(|_| RbError::new_err("Tokenizer RwLock is poisoned"))
|
|
293
|
+
}
|
|
294
|
+
|
|
270
295
|
pub fn from_model(model: &RbModel) -> Self {
|
|
271
296
|
RbTokenizer::new(TokenizerImpl::new(model.clone()))
|
|
272
297
|
}
|
|
273
298
|
|
|
274
299
|
pub fn from_str(json: RString) -> RbResult<Self> {
|
|
275
|
-
Tokenizer::from_str(unsafe { json.as_str()? })
|
|
276
|
-
|
|
277
|
-
tokenizer: RefCell::new(v),
|
|
278
|
-
})
|
|
279
|
-
.map_err(RbError::from)
|
|
300
|
+
let tokenizer = Tokenizer::from_str(unsafe { json.as_str()? }).map_err(RbError::from);
|
|
301
|
+
Ok(Self::new(tokenizer?))
|
|
280
302
|
}
|
|
281
303
|
|
|
282
304
|
pub fn from_file(path: PathBuf) -> RbResult<Self> {
|
|
283
|
-
Tokenizer::from_file(path)
|
|
284
|
-
|
|
285
|
-
tokenizer: RefCell::new(v),
|
|
286
|
-
})
|
|
287
|
-
.map_err(RbError::from)
|
|
305
|
+
let tokenizer = Tokenizer::from_file(path).map_err(RbError::from);
|
|
306
|
+
Ok(Self::new(tokenizer?))
|
|
288
307
|
}
|
|
289
308
|
|
|
290
309
|
pub fn to_str(&self, pretty: bool) -> RbResult<String> {
|
|
291
|
-
self.
|
|
292
|
-
.borrow()
|
|
293
|
-
.to_string(pretty)
|
|
294
|
-
.map_err(RbError::from)
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
|
|
298
|
-
let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
|
|
299
|
-
self.tokenizer.borrow_mut().add_special_tokens(&tokens)
|
|
300
|
-
}
|
|
301
|
-
|
|
302
|
-
pub fn train(&self, files: Vec<String>, trainer: Option<&RbTrainer>) -> RbResult<()> {
|
|
303
|
-
let mut trainer = trainer.map_or_else(
|
|
304
|
-
|| self.tokenizer.borrow().get_model().get_trainer(),
|
|
305
|
-
|t| t.clone(),
|
|
306
|
-
);
|
|
307
|
-
self.tokenizer
|
|
308
|
-
.borrow_mut()
|
|
309
|
-
.train_from_files(&mut trainer, files)
|
|
310
|
-
.map(|_| {})
|
|
311
|
-
.map_err(RbError::from)
|
|
310
|
+
self.read_inner()?.to_string(pretty).map_err(RbError::from)
|
|
312
311
|
}
|
|
313
312
|
|
|
314
313
|
pub fn save(&self, path: String, pretty: bool) -> RbResult<()> {
|
|
315
|
-
self.
|
|
316
|
-
.borrow()
|
|
314
|
+
self.read_inner()?
|
|
317
315
|
.save(&path, pretty)
|
|
318
316
|
.map_err(RbError::from)
|
|
319
317
|
}
|
|
320
318
|
|
|
321
|
-
pub fn
|
|
322
|
-
|
|
323
|
-
|
|
319
|
+
pub fn num_special_tokens_to_add(&self, is_pair: bool) -> RbResult<usize> {
|
|
320
|
+
Ok(self
|
|
321
|
+
.read_inner()?
|
|
322
|
+
.get_post_processor()
|
|
323
|
+
.map_or(0, |p| p.added_tokens(is_pair)))
|
|
324
324
|
}
|
|
325
325
|
|
|
326
|
-
pub fn
|
|
327
|
-
|
|
328
|
-
sequence: Value,
|
|
329
|
-
pair: Option<Value>,
|
|
330
|
-
is_pretokenized: bool,
|
|
331
|
-
add_special_tokens: bool,
|
|
332
|
-
) -> RbResult<RbEncoding> {
|
|
333
|
-
let sequence: tk::InputSequence = if is_pretokenized {
|
|
334
|
-
PreTokenizedInputSequence::try_convert(sequence)?.into()
|
|
335
|
-
} else {
|
|
336
|
-
TextInputSequence::try_convert(sequence)?.into()
|
|
337
|
-
};
|
|
338
|
-
let input = match pair {
|
|
339
|
-
Some(pair) => {
|
|
340
|
-
let pair: tk::InputSequence = if is_pretokenized {
|
|
341
|
-
PreTokenizedInputSequence::try_convert(pair)?.into()
|
|
342
|
-
} else {
|
|
343
|
-
TextInputSequence::try_convert(pair)?.into()
|
|
344
|
-
};
|
|
345
|
-
tk::EncodeInput::Dual(sequence, pair)
|
|
346
|
-
}
|
|
347
|
-
None => tk::EncodeInput::Single(sequence),
|
|
348
|
-
};
|
|
349
|
-
|
|
350
|
-
self.tokenizer
|
|
351
|
-
.borrow()
|
|
352
|
-
.encode_char_offsets(input, add_special_tokens)
|
|
353
|
-
.map(|v| RbEncoding { encoding: v })
|
|
354
|
-
.map_err(RbError::from)
|
|
326
|
+
pub fn get_vocab(&self, with_added_tokens: bool) -> RbResult<HashMap<String, u32>> {
|
|
327
|
+
Ok(self.read_inner()?.get_vocab(with_added_tokens))
|
|
355
328
|
}
|
|
356
329
|
|
|
357
|
-
pub fn
|
|
358
|
-
ruby
|
|
359
|
-
rb_self: &Self,
|
|
360
|
-
input: RArray,
|
|
361
|
-
is_pretokenized: bool,
|
|
362
|
-
add_special_tokens: bool,
|
|
363
|
-
) -> RbResult<RArray> {
|
|
364
|
-
let input: Vec<tk::EncodeInput> = input
|
|
365
|
-
.into_iter()
|
|
366
|
-
.map(|o| {
|
|
367
|
-
let input: tk::EncodeInput = if is_pretokenized {
|
|
368
|
-
PreTokenizedEncodeInput::try_convert(o)?.into()
|
|
369
|
-
} else {
|
|
370
|
-
TextEncodeInput::try_convert(o)?.into()
|
|
371
|
-
};
|
|
372
|
-
Ok(input)
|
|
373
|
-
})
|
|
374
|
-
.collect::<RbResult<Vec<tk::EncodeInput>>>()?;
|
|
375
|
-
rb_self
|
|
376
|
-
.tokenizer
|
|
377
|
-
.borrow()
|
|
378
|
-
.encode_batch_char_offsets(input, add_special_tokens)
|
|
379
|
-
.map(|encodings| {
|
|
380
|
-
ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into))
|
|
381
|
-
})
|
|
382
|
-
.map_err(RbError::from)
|
|
383
|
-
}
|
|
330
|
+
pub fn get_added_tokens_decoder(ruby: &Ruby, rb_self: &Self) -> RbResult<RHash> {
|
|
331
|
+
let sorted_map = ruby.hash_new();
|
|
384
332
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
.decode(&ids, skip_special_tokens)
|
|
389
|
-
.map_err(RbError::from)
|
|
390
|
-
}
|
|
333
|
+
for (key, value) in rb_self.read_inner()?.get_added_tokens_decoder() {
|
|
334
|
+
sorted_map.aset::<u32, RbAddedToken>(key, value.into())?;
|
|
335
|
+
}
|
|
391
336
|
|
|
392
|
-
|
|
393
|
-
&self,
|
|
394
|
-
sequences: Vec<Vec<u32>>,
|
|
395
|
-
skip_special_tokens: bool,
|
|
396
|
-
) -> RbResult<Vec<String>> {
|
|
397
|
-
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
|
398
|
-
self.tokenizer
|
|
399
|
-
.borrow()
|
|
400
|
-
.decode_batch(&slices, skip_special_tokens)
|
|
401
|
-
.map_err(RbError::from)
|
|
337
|
+
Ok(sorted_map)
|
|
402
338
|
}
|
|
403
339
|
|
|
404
|
-
pub fn
|
|
405
|
-
self.
|
|
340
|
+
pub fn get_vocab_size(&self, with_added_tokens: bool) -> RbResult<usize> {
|
|
341
|
+
Ok(self.read_inner()?.get_vocab_size(with_added_tokens))
|
|
406
342
|
}
|
|
407
343
|
|
|
408
|
-
pub fn
|
|
409
|
-
|
|
410
|
-
|
|
344
|
+
pub fn enable_truncation(
|
|
345
|
+
ruby: &Ruby,
|
|
346
|
+
rb_self: &Self,
|
|
347
|
+
max_length: usize,
|
|
348
|
+
kwargs: RHash,
|
|
349
|
+
) -> RbResult<()> {
|
|
350
|
+
let mut params = TruncationParams {
|
|
351
|
+
max_length,
|
|
352
|
+
..Default::default()
|
|
353
|
+
};
|
|
411
354
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
355
|
+
let value: Value = kwargs.delete(ruby.to_symbol("stride"))?;
|
|
356
|
+
if !value.is_nil() {
|
|
357
|
+
params.stride = TryConvert::try_convert(value)?;
|
|
358
|
+
}
|
|
415
359
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
360
|
+
let value: Value = kwargs.delete(ruby.to_symbol("strategy"))?;
|
|
361
|
+
if !value.is_nil() {
|
|
362
|
+
let strategy_str = String::try_convert(value)?;
|
|
363
|
+
params.strategy = match strategy_str.as_str() {
|
|
364
|
+
"longest_first" => TruncationStrategy::LongestFirst,
|
|
365
|
+
"only_first" => TruncationStrategy::OnlyFirst,
|
|
366
|
+
"only_second" => TruncationStrategy::OnlySecond,
|
|
367
|
+
_ => return Err(Error::new(
|
|
368
|
+
ruby.exception_arg_error(),
|
|
369
|
+
"The strategy value must be 'longest_first', 'only_first', or 'only_second'",
|
|
370
|
+
)),
|
|
371
|
+
}
|
|
372
|
+
}
|
|
419
373
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
374
|
+
let value: Value = kwargs.delete(ruby.to_symbol("direction"))?;
|
|
375
|
+
if !value.is_nil() {
|
|
376
|
+
let dir_str = String::try_convert(value)?;
|
|
377
|
+
params.direction = match dir_str.as_str() {
|
|
378
|
+
"left" => TruncationDirection::Left,
|
|
379
|
+
"right" => TruncationDirection::Right,
|
|
380
|
+
_ => {
|
|
381
|
+
return Err(Error::new(
|
|
382
|
+
ruby.exception_arg_error(),
|
|
383
|
+
"The direction value must be 'left' or 'right'",
|
|
384
|
+
))
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
}
|
|
423
388
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
.
|
|
427
|
-
|
|
428
|
-
}
|
|
389
|
+
if !kwargs.is_empty() {
|
|
390
|
+
// TODO improve message
|
|
391
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
|
392
|
+
}
|
|
429
393
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
394
|
+
if let Err(error_message) = rb_self.write_inner()?.with_truncation(Some(params)) {
|
|
395
|
+
return Err(Error::new(
|
|
396
|
+
ruby.exception_arg_error(),
|
|
397
|
+
error_message.to_string(),
|
|
398
|
+
));
|
|
399
|
+
}
|
|
433
400
|
|
|
434
|
-
|
|
435
|
-
self.tokenizer
|
|
436
|
-
.borrow_mut()
|
|
437
|
-
.with_post_processor(processor.cloned());
|
|
401
|
+
Ok(())
|
|
438
402
|
}
|
|
439
403
|
|
|
440
|
-
pub fn
|
|
441
|
-
self.
|
|
404
|
+
pub fn no_truncation(&self) -> RbResult<()> {
|
|
405
|
+
self.write_inner()?
|
|
406
|
+
.with_truncation(None)
|
|
407
|
+
.expect("Failed to set truncation to `None`! This should never happen");
|
|
408
|
+
Ok(())
|
|
442
409
|
}
|
|
443
410
|
|
|
444
|
-
pub fn
|
|
445
|
-
|
|
446
|
-
.
|
|
447
|
-
.
|
|
448
|
-
|
|
411
|
+
pub fn get_truncation(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
|
412
|
+
rb_self
|
|
413
|
+
.read_inner()?
|
|
414
|
+
.get_truncation()
|
|
415
|
+
.map_or(Ok(None), |params| {
|
|
416
|
+
let ret_hash = ruby.hash_new();
|
|
449
417
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
418
|
+
ret_hash.aset("max_length", params.max_length)?;
|
|
419
|
+
ret_hash.aset("stride", params.stride)?;
|
|
420
|
+
ret_hash.aset("strategy", params.strategy.as_ref())?;
|
|
421
|
+
ret_hash.aset("direction", params.direction.as_ref())?;
|
|
453
422
|
|
|
454
|
-
|
|
455
|
-
|
|
423
|
+
Ok(Some(ret_hash))
|
|
424
|
+
})
|
|
456
425
|
}
|
|
457
426
|
|
|
458
427
|
// TODO support more kwargs
|
|
@@ -506,19 +475,19 @@ impl RbTokenizer {
|
|
|
506
475
|
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
|
507
476
|
}
|
|
508
477
|
|
|
509
|
-
rb_self.
|
|
478
|
+
rb_self.write_inner()?.with_padding(Some(params));
|
|
510
479
|
|
|
511
480
|
Ok(())
|
|
512
481
|
}
|
|
513
482
|
|
|
514
|
-
pub fn no_padding(&self) {
|
|
515
|
-
self.
|
|
483
|
+
pub fn no_padding(&self) -> RbResult<()> {
|
|
484
|
+
self.write_inner()?.with_padding(None);
|
|
485
|
+
Ok(())
|
|
516
486
|
}
|
|
517
487
|
|
|
518
|
-
pub fn
|
|
488
|
+
pub fn get_padding(ruby: &Ruby, rb_self: &Self) -> RbResult<Option<RHash>> {
|
|
519
489
|
rb_self
|
|
520
|
-
.
|
|
521
|
-
.borrow()
|
|
490
|
+
.read_inner()?
|
|
522
491
|
.get_padding()
|
|
523
492
|
.map_or(Ok(None), |params| {
|
|
524
493
|
let ret_hash = ruby.hash_new();
|
|
@@ -540,112 +509,280 @@ impl RbTokenizer {
|
|
|
540
509
|
})
|
|
541
510
|
}
|
|
542
511
|
|
|
543
|
-
pub fn
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
512
|
+
pub fn encode(
|
|
513
|
+
&self,
|
|
514
|
+
sequence: Value,
|
|
515
|
+
pair: Option<Value>,
|
|
516
|
+
is_pretokenized: bool,
|
|
517
|
+
add_special_tokens: bool,
|
|
518
|
+
) -> RbResult<RbEncoding> {
|
|
519
|
+
let sequence: tk::InputSequence = if is_pretokenized {
|
|
520
|
+
PreTokenizedInputSequence::try_convert(sequence)?.into()
|
|
521
|
+
} else {
|
|
522
|
+
TextInputSequence::try_convert(sequence)?.into()
|
|
523
|
+
};
|
|
524
|
+
let input = match pair {
|
|
525
|
+
Some(pair) => {
|
|
526
|
+
let pair: tk::InputSequence = if is_pretokenized {
|
|
527
|
+
PreTokenizedInputSequence::try_convert(pair)?.into()
|
|
528
|
+
} else {
|
|
529
|
+
TextInputSequence::try_convert(pair)?.into()
|
|
530
|
+
};
|
|
531
|
+
tk::EncodeInput::Dual(sequence, pair)
|
|
532
|
+
}
|
|
533
|
+
None => tk::EncodeInput::Single(sequence),
|
|
552
534
|
};
|
|
553
535
|
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
536
|
+
self.read_inner()?
|
|
537
|
+
.encode_char_offsets(input, add_special_tokens)
|
|
538
|
+
.map(|v| RbEncoding { encoding: v })
|
|
539
|
+
.map_err(RbError::from)
|
|
540
|
+
}
|
|
558
541
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
542
|
+
pub fn encode_batch(
|
|
543
|
+
ruby: &Ruby,
|
|
544
|
+
rb_self: &Self,
|
|
545
|
+
input: RArray,
|
|
546
|
+
is_pretokenized: bool,
|
|
547
|
+
add_special_tokens: bool,
|
|
548
|
+
) -> RbResult<RArray> {
|
|
549
|
+
let input: Vec<tk::EncodeInput> = input
|
|
550
|
+
.into_iter()
|
|
551
|
+
.map(|o| {
|
|
552
|
+
let input: tk::EncodeInput = if is_pretokenized {
|
|
553
|
+
PreTokenizedEncodeInput::try_convert(o)?.into()
|
|
554
|
+
} else {
|
|
555
|
+
TextEncodeInput::try_convert(o)?.into()
|
|
556
|
+
};
|
|
557
|
+
Ok(input)
|
|
558
|
+
})
|
|
559
|
+
.collect::<RbResult<Vec<tk::EncodeInput>>>()?;
|
|
560
|
+
ruby.detach(|| {
|
|
561
|
+
rb_self
|
|
562
|
+
.tokenizer
|
|
563
|
+
.read()
|
|
564
|
+
.unwrap()
|
|
565
|
+
.encode_batch_char_offsets(input, add_special_tokens)
|
|
566
|
+
})
|
|
567
|
+
.map(|encodings| ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into)))
|
|
568
|
+
.map_err(RbError::from)
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
pub fn encode_batch_fast(
|
|
572
|
+
ruby: &Ruby,
|
|
573
|
+
rb_self: &Self,
|
|
574
|
+
input: RArray,
|
|
575
|
+
is_pretokenized: bool,
|
|
576
|
+
add_special_tokens: bool,
|
|
577
|
+
) -> RbResult<RArray> {
|
|
578
|
+
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
|
|
579
|
+
for item in input {
|
|
580
|
+
let item: tk::EncodeInput = if is_pretokenized {
|
|
581
|
+
PreTokenizedEncodeInput::try_convert(item)?.into()
|
|
582
|
+
} else {
|
|
583
|
+
TextEncodeInput::try_convert(item)?.into()
|
|
584
|
+
};
|
|
585
|
+
items.push(item);
|
|
586
|
+
}
|
|
587
|
+
ruby.detach(|| {
|
|
588
|
+
rb_self
|
|
589
|
+
.tokenizer
|
|
590
|
+
.read()
|
|
591
|
+
.unwrap()
|
|
592
|
+
.encode_batch_fast(items, add_special_tokens)
|
|
593
|
+
})
|
|
594
|
+
.map(|encodings| ruby.ary_from_iter(encodings.into_iter().map(Into::<RbEncoding>::into)))
|
|
595
|
+
.map_err(RbError::from)
|
|
596
|
+
}
|
|
572
597
|
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
"right" => TruncationDirection::Right,
|
|
579
|
-
_ => {
|
|
580
|
-
return Err(Error::new(
|
|
581
|
-
ruby.exception_arg_error(),
|
|
582
|
-
"The direction value must be 'left' or 'right'",
|
|
583
|
-
))
|
|
584
|
-
}
|
|
585
|
-
}
|
|
586
|
-
}
|
|
598
|
+
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
|
599
|
+
self.read_inner()?
|
|
600
|
+
.decode(&ids, skip_special_tokens)
|
|
601
|
+
.map_err(RbError::from)
|
|
602
|
+
}
|
|
587
603
|
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
604
|
+
pub fn decode_batch(
|
|
605
|
+
ruby: &Ruby,
|
|
606
|
+
rb_self: &Self,
|
|
607
|
+
sequences: Vec<Vec<u32>>,
|
|
608
|
+
skip_special_tokens: bool,
|
|
609
|
+
) -> RbResult<Vec<String>> {
|
|
610
|
+
ruby.detach(|| {
|
|
611
|
+
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
|
612
|
+
rb_self
|
|
613
|
+
.tokenizer
|
|
614
|
+
.read()
|
|
615
|
+
.unwrap()
|
|
616
|
+
.decode_batch(&slices, skip_special_tokens)
|
|
617
|
+
})
|
|
618
|
+
.map_err(RbError::from)
|
|
619
|
+
}
|
|
592
620
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
621
|
+
pub fn token_to_id(&self, token: String) -> RbResult<Option<u32>> {
|
|
622
|
+
Ok(self.read_inner()?.token_to_id(&token))
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
pub fn id_to_token(&self, id: u32) -> RbResult<Option<String>> {
|
|
626
|
+
Ok(self.read_inner()?.id_to_token(id))
|
|
627
|
+
}
|
|
599
628
|
|
|
629
|
+
pub fn set_encode_special_tokens(&self, value: bool) -> RbResult<()> {
|
|
630
|
+
self.write_inner()?.set_encode_special_tokens(value);
|
|
600
631
|
Ok(())
|
|
601
632
|
}
|
|
602
633
|
|
|
603
|
-
pub fn
|
|
604
|
-
self.
|
|
605
|
-
.borrow_mut()
|
|
606
|
-
.with_truncation(None)
|
|
607
|
-
.expect("Failed to set truncation to `None`! This should never happen");
|
|
634
|
+
pub fn get_encode_special_tokens(&self) -> RbResult<bool> {
|
|
635
|
+
Ok(self.read_inner()?.get_encode_special_tokens())
|
|
608
636
|
}
|
|
609
637
|
|
|
610
|
-
pub fn
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
.
|
|
614
|
-
.
|
|
615
|
-
|
|
616
|
-
let ret_hash = ruby.hash_new();
|
|
638
|
+
pub fn add_tokens(&self, tokens: Vec<String>) -> RbResult<usize> {
|
|
639
|
+
let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
|
|
640
|
+
self.write_inner()?
|
|
641
|
+
.add_tokens(tokens)
|
|
642
|
+
.map_err(RbError::from)
|
|
643
|
+
}
|
|
617
644
|
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
645
|
+
pub fn add_special_tokens(&self, tokens: Vec<String>) -> RbResult<usize> {
|
|
646
|
+
let tokens: Vec<AddedToken> = tokens.iter().map(|t| AddedToken::from(t, true)).collect();
|
|
647
|
+
self.write_inner()?
|
|
648
|
+
.add_special_tokens(tokens)
|
|
649
|
+
.map_err(RbError::from)
|
|
650
|
+
}
|
|
622
651
|
|
|
623
|
-
|
|
624
|
-
|
|
652
|
+
pub fn train(&self, files: Vec<String>, trainer: Option<&RbTrainer>) -> RbResult<()> {
|
|
653
|
+
let mut trainer = match trainer {
|
|
654
|
+
Some(t) => t.clone(),
|
|
655
|
+
None => self.read_inner()?.get_model().get_trainer(),
|
|
656
|
+
};
|
|
657
|
+
self.write_inner()?
|
|
658
|
+
.train_from_files(&mut trainer, files)
|
|
659
|
+
.map(|_| {})
|
|
660
|
+
.map_err(RbError::from)
|
|
625
661
|
}
|
|
626
662
|
|
|
627
|
-
pub fn
|
|
628
|
-
self.
|
|
629
|
-
.borrow()
|
|
630
|
-
.get_post_processor()
|
|
631
|
-
.map_or(0, |p| p.added_tokens(is_pair))
|
|
663
|
+
pub fn get_model(&self) -> RbResult<RbModel> {
|
|
664
|
+
Ok(self.read_inner()?.get_model().clone())
|
|
632
665
|
}
|
|
633
666
|
|
|
634
|
-
pub fn
|
|
635
|
-
self.
|
|
667
|
+
pub fn set_model(&self, model: &RbModel) -> RbResult<()> {
|
|
668
|
+
self.write_inner()?.with_model(model.clone());
|
|
669
|
+
Ok(())
|
|
636
670
|
}
|
|
637
671
|
|
|
638
|
-
pub fn
|
|
639
|
-
self.
|
|
672
|
+
pub fn get_normalizer(&self) -> RbResult<Option<RbNormalizer>> {
|
|
673
|
+
Ok(self.read_inner()?.get_normalizer().cloned())
|
|
640
674
|
}
|
|
641
675
|
|
|
642
|
-
pub fn
|
|
643
|
-
|
|
676
|
+
pub fn set_normalizer(&self, normalizer: Option<&RbNormalizer>) -> RbResult<()> {
|
|
677
|
+
self.write_inner()?
|
|
678
|
+
.with_normalizer(normalizer.cloned())
|
|
679
|
+
.map(|_| ())
|
|
680
|
+
.map_err(RbError::from)
|
|
681
|
+
}
|
|
644
682
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
683
|
+
pub fn get_pre_tokenizer(&self) -> RbResult<Option<RbPreTokenizer>> {
|
|
684
|
+
Ok(self.read_inner()?.get_pre_tokenizer().cloned())
|
|
685
|
+
}
|
|
648
686
|
|
|
649
|
-
|
|
687
|
+
pub fn set_pre_tokenizer(&self, pretok: Option<&RbPreTokenizer>) -> RbResult<()> {
|
|
688
|
+
self.write_inner()?.with_pre_tokenizer(pretok.cloned());
|
|
689
|
+
Ok(())
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
pub fn get_post_processor(&self) -> RbResult<Option<RbPostProcessor>> {
|
|
693
|
+
Ok(self.read_inner()?.get_post_processor().cloned())
|
|
650
694
|
}
|
|
695
|
+
|
|
696
|
+
pub fn set_post_processor(&self, processor: Option<&RbPostProcessor>) -> RbResult<()> {
|
|
697
|
+
self.write_inner()?.with_post_processor(processor.cloned());
|
|
698
|
+
Ok(())
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
pub fn get_decoder(&self) -> RbResult<Option<RbDecoder>> {
|
|
702
|
+
Ok(self.read_inner()?.get_decoder().cloned())
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
pub fn set_decoder(&self, decoder: Option<&RbDecoder>) -> RbResult<()> {
|
|
706
|
+
self.write_inner()?.with_decoder(decoder.cloned());
|
|
707
|
+
Ok(())
|
|
708
|
+
}
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
pub fn init_tokenizer(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
712
|
+
let class = module.define_class("Tokenizer", ruby.class_object())?;
|
|
713
|
+
class.define_singleton_method("new", function!(RbTokenizer::from_model, 1))?;
|
|
714
|
+
class.define_singleton_method("from_str", function!(RbTokenizer::from_str, 1))?;
|
|
715
|
+
class.define_singleton_method("from_file", function!(RbTokenizer::from_file, 1))?;
|
|
716
|
+
class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
|
|
717
|
+
class.define_method("_save", method!(RbTokenizer::save, 2))?;
|
|
718
|
+
class.define_method(
|
|
719
|
+
"num_special_tokens_to_add",
|
|
720
|
+
method!(RbTokenizer::num_special_tokens_to_add, 1),
|
|
721
|
+
)?;
|
|
722
|
+
class.define_method("_vocab", method!(RbTokenizer::get_vocab, 1))?;
|
|
723
|
+
class.define_method(
|
|
724
|
+
"added_tokens_decoder",
|
|
725
|
+
method!(RbTokenizer::get_added_tokens_decoder, 0),
|
|
726
|
+
)?;
|
|
727
|
+
class.define_method("_vocab_size", method!(RbTokenizer::get_vocab_size, 1))?;
|
|
728
|
+
class.define_method(
|
|
729
|
+
"_enable_truncation",
|
|
730
|
+
method!(RbTokenizer::enable_truncation, 2),
|
|
731
|
+
)?;
|
|
732
|
+
class.define_method("no_truncation", method!(RbTokenizer::no_truncation, 0))?;
|
|
733
|
+
class.define_method("truncation", method!(RbTokenizer::get_truncation, 0))?;
|
|
734
|
+
class.define_method("_enable_padding", method!(RbTokenizer::enable_padding, 1))?;
|
|
735
|
+
class.define_method("no_padding", method!(RbTokenizer::no_padding, 0))?;
|
|
736
|
+
class.define_method("padding", method!(RbTokenizer::get_padding, 0))?;
|
|
737
|
+
class.define_method("_encode", method!(RbTokenizer::encode, 4))?;
|
|
738
|
+
class.define_method("_encode_batch", method!(RbTokenizer::encode_batch, 3))?;
|
|
739
|
+
class.define_method(
|
|
740
|
+
"_encode_batch_fast",
|
|
741
|
+
method!(RbTokenizer::encode_batch_fast, 3),
|
|
742
|
+
)?;
|
|
743
|
+
class.define_method("_decode", method!(RbTokenizer::decode, 2))?;
|
|
744
|
+
class.define_method("_decode_batch", method!(RbTokenizer::decode_batch, 2))?;
|
|
745
|
+
class.define_method("token_to_id", method!(RbTokenizer::token_to_id, 1))?;
|
|
746
|
+
class.define_method("id_to_token", method!(RbTokenizer::id_to_token, 1))?;
|
|
747
|
+
class.define_method(
|
|
748
|
+
"encode_special_tokens=",
|
|
749
|
+
method!(RbTokenizer::set_encode_special_tokens, 1),
|
|
750
|
+
)?;
|
|
751
|
+
class.define_method(
|
|
752
|
+
"encode_special_tokens",
|
|
753
|
+
method!(RbTokenizer::get_encode_special_tokens, 0),
|
|
754
|
+
)?;
|
|
755
|
+
class.define_method("add_tokens", method!(RbTokenizer::add_tokens, 1))?;
|
|
756
|
+
class.define_method(
|
|
757
|
+
"add_special_tokens",
|
|
758
|
+
method!(RbTokenizer::add_special_tokens, 1),
|
|
759
|
+
)?;
|
|
760
|
+
class.define_method("train", method!(RbTokenizer::train, 2))?;
|
|
761
|
+
class.define_method("model", method!(RbTokenizer::get_model, 0))?;
|
|
762
|
+
class.define_method("model=", method!(RbTokenizer::set_model, 1))?;
|
|
763
|
+
class.define_method("normalizer", method!(RbTokenizer::get_normalizer, 0))?;
|
|
764
|
+
class.define_method("normalizer=", method!(RbTokenizer::set_normalizer, 1))?;
|
|
765
|
+
class.define_method("pre_tokenizer", method!(RbTokenizer::get_pre_tokenizer, 0))?;
|
|
766
|
+
class.define_method("pre_tokenizer=", method!(RbTokenizer::set_pre_tokenizer, 1))?;
|
|
767
|
+
class.define_method(
|
|
768
|
+
"post_processor",
|
|
769
|
+
method!(RbTokenizer::get_post_processor, 0),
|
|
770
|
+
)?;
|
|
771
|
+
class.define_method(
|
|
772
|
+
"post_processor=",
|
|
773
|
+
method!(RbTokenizer::set_post_processor, 1),
|
|
774
|
+
)?;
|
|
775
|
+
class.define_method("decoder", method!(RbTokenizer::get_decoder, 0))?;
|
|
776
|
+
class.define_method("decoder=", method!(RbTokenizer::set_decoder, 1))?;
|
|
777
|
+
|
|
778
|
+
let class = module.define_class("AddedToken", ruby.class_object())?;
|
|
779
|
+
class.define_singleton_method("_new", function!(RbAddedToken::new, 2))?;
|
|
780
|
+
class.define_method("content", method!(RbAddedToken::get_content, 0))?;
|
|
781
|
+
class.define_method("rstrip", method!(RbAddedToken::get_rstrip, 0))?;
|
|
782
|
+
class.define_method("lstrip", method!(RbAddedToken::get_lstrip, 0))?;
|
|
783
|
+
class.define_method("single_word", method!(RbAddedToken::get_single_word, 0))?;
|
|
784
|
+
class.define_method("normalized", method!(RbAddedToken::get_normalized, 0))?;
|
|
785
|
+
class.define_method("special", method!(RbAddedToken::get_special, 0))?;
|
|
786
|
+
|
|
787
|
+
Ok(())
|
|
651
788
|
}
|