tokenizers 0.5.3 → 0.5.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +154 -83
- data/ext/tokenizers/Cargo.toml +2 -2
- data/ext/tokenizers/src/decoders.rs +32 -14
- data/ext/tokenizers/src/error.rs +6 -1
- data/ext/tokenizers/src/lib.rs +37 -12
- data/ext/tokenizers/src/models.rs +75 -23
- data/ext/tokenizers/src/normalizers.rs +84 -24
- data/ext/tokenizers/src/pre_tokenizers.rs +121 -42
- data/ext/tokenizers/src/processors.rs +22 -10
- data/ext/tokenizers/src/tokenizer.rs +63 -34
- data/ext/tokenizers/src/trainers.rs +215 -56
- data/ext/tokenizers/src/utils/regex.rs +6 -4
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/version.rb +1 -1
- metadata +3 -7
| @@ -1,8 +1,8 @@ | |
| 1 1 | 
             
            use std::sync::{Arc, RwLock};
         | 
| 2 2 |  | 
| 3 3 | 
             
            use magnus::{
         | 
| 4 | 
            -
                data_type_builder, exception, function, method, value::Lazy, Class, DataType, | 
| 5 | 
            -
                RArray, RClass, RModule, Ruby, TryConvert, TypedData,
         | 
| 4 | 
            +
                data_type_builder, exception, function, method, value::Lazy, Class, DataType,
         | 
| 5 | 
            +
                DataTypeFunctions, Error, Module, Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
         | 
| 6 6 | 
             
            };
         | 
| 7 7 |  | 
| 8 8 | 
             
            use serde::ser::SerializeStruct;
         | 
| @@ -22,7 +22,7 @@ use tk::tokenizer::Offsets; | |
| 22 22 | 
             
            use tk::{PreTokenizedString, PreTokenizer};
         | 
| 23 23 |  | 
| 24 24 | 
             
            use super::utils::*;
         | 
| 25 | 
            -
            use super::{ | 
| 25 | 
            +
            use super::{RbError, RbResult, PRE_TOKENIZERS};
         | 
| 26 26 |  | 
| 27 27 | 
             
            #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
         | 
| 28 28 | 
             
            pub struct RbPreTokenizer {
         | 
| @@ -34,7 +34,9 @@ impl RbPreTokenizer { | |
| 34 34 | 
             
                fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
         | 
| 35 35 | 
             
                    let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
         | 
| 36 36 |  | 
| 37 | 
            -
                    self.pretok | 
| 37 | 
            +
                    self.pretok
         | 
| 38 | 
            +
                        .pre_tokenize(&mut pretokenized)
         | 
| 39 | 
            +
                        .map_err(RbError::from)?;
         | 
| 38 40 |  | 
| 39 41 | 
             
                    Ok(pretokenized
         | 
| 40 42 | 
             
                        .get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
         | 
| @@ -195,11 +197,7 @@ impl RbDigits { | |
| 195 197 | 
             
            pub struct RbMetaspace {}
         | 
| 196 198 |  | 
| 197 199 | 
             
            impl RbMetaspace {
         | 
| 198 | 
            -
                fn new(
         | 
| 199 | 
            -
                    replacement: char,
         | 
| 200 | 
            -
                    prepend_scheme: String,
         | 
| 201 | 
            -
                    split: bool,
         | 
| 202 | 
            -
                ) -> RbResult<RbPreTokenizer> {
         | 
| 200 | 
            +
                fn new(replacement: char, prepend_scheme: String, split: bool) -> RbResult<RbPreTokenizer> {
         | 
| 203 201 | 
             
                    let prepend_scheme = from_string(prepend_scheme)?;
         | 
| 204 202 | 
             
                    Ok(Metaspace::new(replacement, prepend_scheme, split).into())
         | 
| 205 203 | 
             
                }
         | 
| @@ -216,8 +214,14 @@ impl RbPunctuation { | |
| 216 214 | 
             
            pub struct RbSplit {}
         | 
| 217 215 |  | 
| 218 216 | 
             
            impl RbSplit {
         | 
| 219 | 
            -
                pub fn new( | 
| 220 | 
            -
                     | 
| 217 | 
            +
                pub fn new(
         | 
| 218 | 
            +
                    pattern: RbPattern,
         | 
| 219 | 
            +
                    behavior: RbSplitDelimiterBehavior,
         | 
| 220 | 
            +
                    invert: bool,
         | 
| 221 | 
            +
                ) -> RbResult<RbPreTokenizer> {
         | 
| 222 | 
            +
                    Split::new(pattern, behavior.into(), invert)
         | 
| 223 | 
            +
                        .map(|v| v.into())
         | 
| 224 | 
            +
                        .map_err(RbError::from)
         | 
| 221 225 | 
             
                }
         | 
| 222 226 | 
             
            }
         | 
| 223 227 |  | 
| @@ -258,16 +262,18 @@ pub struct RbSequence {} | |
| 258 262 | 
             
            impl RbSequence {
         | 
| 259 263 | 
             
                fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
         | 
| 260 264 | 
             
                    let mut sequence = Vec::with_capacity(pre_tokenizers.len());
         | 
| 261 | 
            -
                    for n in pre_tokenizers | 
| 265 | 
            +
                    for n in pre_tokenizers {
         | 
| 262 266 | 
             
                        let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n)?;
         | 
| 263 267 | 
             
                        match &pretokenizer.pretok {
         | 
| 264 268 | 
             
                            RbPreTokenizerTypeWrapper::Sequence(inner) => {
         | 
| 265 | 
            -
                                sequence.extend(inner.iter().cloned())
         | 
| 269 | 
            +
                                sequence.extend(inner.iter().cloned());
         | 
| 266 270 | 
             
                            }
         | 
| 267 271 | 
             
                            RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
         | 
| 268 272 | 
             
                        }
         | 
| 269 273 | 
             
                    }
         | 
| 270 | 
            -
                    Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence( | 
| 274 | 
            +
                    Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(
         | 
| 275 | 
            +
                        sequence,
         | 
| 276 | 
            +
                    )))
         | 
| 271 277 | 
             
                }
         | 
| 272 278 | 
             
            }
         | 
| 273 279 |  | 
| @@ -277,10 +283,13 @@ pub(crate) fn from_string(string: String) -> RbResult<PrependScheme> { | |
| 277 283 | 
             
                    "never" => PrependScheme::Never,
         | 
| 278 284 | 
             
                    "always" => PrependScheme::Always,
         | 
| 279 285 | 
             
                    _ => {
         | 
| 280 | 
            -
                        return Err(Error::new( | 
| 281 | 
            -
                             | 
| 282 | 
            -
                             | 
| 283 | 
            -
             | 
| 286 | 
            +
                        return Err(Error::new(
         | 
| 287 | 
            +
                            exception::arg_error(),
         | 
| 288 | 
            +
                            format!(
         | 
| 289 | 
            +
                                "{} is an unknown variant, should be one of ['first', 'never', 'always']",
         | 
| 290 | 
            +
                                string
         | 
| 291 | 
            +
                            ),
         | 
| 292 | 
            +
                        ));
         | 
| 284 293 | 
             
                    }
         | 
| 285 294 | 
             
                };
         | 
| 286 295 | 
             
                Ok(scheme)
         | 
| @@ -381,7 +390,10 @@ impl PreTokenizer for RbPreTokenizerWrapper { | |
| 381 390 | 
             
            unsafe impl TypedData for RbPreTokenizer {
         | 
| 382 391 | 
             
                fn class(ruby: &Ruby) -> RClass {
         | 
| 383 392 | 
             
                    static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 384 | 
            -
                        let class: RClass = ruby | 
| 393 | 
            +
                        let class: RClass = ruby
         | 
| 394 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 395 | 
            +
                            .const_get("PreTokenizer")
         | 
| 396 | 
            +
                            .unwrap();
         | 
| 385 397 | 
             
                        class.undef_default_alloc_func();
         | 
| 386 398 | 
             
                        class
         | 
| 387 399 | 
             
                    });
         | 
| @@ -389,28 +401,41 @@ unsafe impl TypedData for RbPreTokenizer { | |
| 389 401 | 
             
                }
         | 
| 390 402 |  | 
| 391 403 | 
             
                fn data_type() -> &'static DataType {
         | 
| 392 | 
            -
                    static DATA_TYPE: DataType = | 
| 404 | 
            +
                    static DATA_TYPE: DataType =
         | 
| 405 | 
            +
                        data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
         | 
| 393 406 | 
             
                    &DATA_TYPE
         | 
| 394 407 | 
             
                }
         | 
| 395 408 |  | 
| 396 409 | 
             
                fn class_for(ruby: &Ruby, value: &Self) -> RClass {
         | 
| 397 410 | 
             
                    static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 398 | 
            -
                        let class: RClass = ruby | 
| 411 | 
            +
                        let class: RClass = ruby
         | 
| 412 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 413 | 
            +
                            .const_get("Sequence")
         | 
| 414 | 
            +
                            .unwrap();
         | 
| 399 415 | 
             
                        class.undef_default_alloc_func();
         | 
| 400 416 | 
             
                        class
         | 
| 401 417 | 
             
                    });
         | 
| 402 418 | 
             
                    static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 403 | 
            -
                        let class: RClass = ruby | 
| 419 | 
            +
                        let class: RClass = ruby
         | 
| 420 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 421 | 
            +
                            .const_get("BertPreTokenizer")
         | 
| 422 | 
            +
                            .unwrap();
         | 
| 404 423 | 
             
                        class.undef_default_alloc_func();
         | 
| 405 424 | 
             
                        class
         | 
| 406 425 | 
             
                    });
         | 
| 407 426 | 
             
                    static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 408 | 
            -
                        let class: RClass = ruby | 
| 427 | 
            +
                        let class: RClass = ruby
         | 
| 428 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 429 | 
            +
                            .const_get("ByteLevel")
         | 
| 430 | 
            +
                            .unwrap();
         | 
| 409 431 | 
             
                        class.undef_default_alloc_func();
         | 
| 410 432 | 
             
                        class
         | 
| 411 433 | 
             
                    });
         | 
| 412 434 | 
             
                    static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 413 | 
            -
                        let class: RClass = ruby | 
| 435 | 
            +
                        let class: RClass = ruby
         | 
| 436 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 437 | 
            +
                            .const_get("CharDelimiterSplit")
         | 
| 438 | 
            +
                            .unwrap();
         | 
| 414 439 | 
             
                        class.undef_default_alloc_func();
         | 
| 415 440 | 
             
                        class
         | 
| 416 441 | 
             
                    });
         | 
| @@ -420,12 +445,18 @@ unsafe impl TypedData for RbPreTokenizer { | |
| 420 445 | 
             
                        class
         | 
| 421 446 | 
             
                    });
         | 
| 422 447 | 
             
                    static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 423 | 
            -
                        let class: RClass = ruby | 
| 448 | 
            +
                        let class: RClass = ruby
         | 
| 449 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 450 | 
            +
                            .const_get("Metaspace")
         | 
| 451 | 
            +
                            .unwrap();
         | 
| 424 452 | 
             
                        class.undef_default_alloc_func();
         | 
| 425 453 | 
             
                        class
         | 
| 426 454 | 
             
                    });
         | 
| 427 455 | 
             
                    static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 428 | 
            -
                        let class: RClass = ruby | 
| 456 | 
            +
                        let class: RClass = ruby
         | 
| 457 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 458 | 
            +
                            .const_get("Punctuation")
         | 
| 459 | 
            +
                            .unwrap();
         | 
| 429 460 | 
             
                        class.undef_default_alloc_func();
         | 
| 430 461 | 
             
                        class
         | 
| 431 462 | 
             
                    });
         | 
| @@ -435,17 +466,26 @@ unsafe impl TypedData for RbPreTokenizer { | |
| 435 466 | 
             
                        class
         | 
| 436 467 | 
             
                    });
         | 
| 437 468 | 
             
                    static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 438 | 
            -
                        let class: RClass = ruby | 
| 469 | 
            +
                        let class: RClass = ruby
         | 
| 470 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 471 | 
            +
                            .const_get("UnicodeScripts")
         | 
| 472 | 
            +
                            .unwrap();
         | 
| 439 473 | 
             
                        class.undef_default_alloc_func();
         | 
| 440 474 | 
             
                        class
         | 
| 441 475 | 
             
                    });
         | 
| 442 476 | 
             
                    static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 443 | 
            -
                        let class: RClass = ruby | 
| 477 | 
            +
                        let class: RClass = ruby
         | 
| 478 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 479 | 
            +
                            .const_get("Whitespace")
         | 
| 480 | 
            +
                            .unwrap();
         | 
| 444 481 | 
             
                        class.undef_default_alloc_func();
         | 
| 445 482 | 
             
                        class
         | 
| 446 483 | 
             
                    });
         | 
| 447 484 | 
             
                    static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 448 | 
            -
                        let class: RClass = ruby | 
| 485 | 
            +
                        let class: RClass = ruby
         | 
| 486 | 
            +
                            .get_inner(&PRE_TOKENIZERS)
         | 
| 487 | 
            +
                            .const_get("WhitespaceSplit")
         | 
| 488 | 
            +
                            .unwrap();
         | 
| 449 489 | 
             
                        class.undef_default_alloc_func();
         | 
| 450 490 | 
             
                        class
         | 
| 451 491 | 
             
                    });
         | 
| @@ -472,7 +512,10 @@ unsafe impl TypedData for RbPreTokenizer { | |
| 472 512 |  | 
| 473 513 | 
             
            pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
         | 
| 474 514 | 
             
                let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
         | 
| 475 | 
            -
                pre_tokenizer.define_method( | 
| 515 | 
            +
                pre_tokenizer.define_method(
         | 
| 516 | 
            +
                    "pre_tokenize_str",
         | 
| 517 | 
            +
                    method!(RbPreTokenizer::pre_tokenize_str, 1),
         | 
| 518 | 
            +
                )?;
         | 
| 476 519 |  | 
| 477 520 | 
             
                let class = module.define_class("Sequence", pre_tokenizer)?;
         | 
| 478 521 | 
             
                class.define_singleton_method("new", function!(RbSequence::new, 1))?;
         | 
| @@ -483,27 +526,63 @@ pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> { | |
| 483 526 | 
             
                let class = module.define_class("ByteLevel", pre_tokenizer)?;
         | 
| 484 527 | 
             
                class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
         | 
| 485 528 | 
             
                class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
         | 
| 486 | 
            -
                class.define_method( | 
| 487 | 
            -
             | 
| 488 | 
            -
             | 
| 489 | 
            -
                 | 
| 529 | 
            +
                class.define_method(
         | 
| 530 | 
            +
                    "add_prefix_space",
         | 
| 531 | 
            +
                    method!(RbPreTokenizer::byte_level_add_prefix_space, 0),
         | 
| 532 | 
            +
                )?;
         | 
| 533 | 
            +
                class.define_method(
         | 
| 534 | 
            +
                    "add_prefix_space=",
         | 
| 535 | 
            +
                    method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1),
         | 
| 536 | 
            +
                )?;
         | 
| 537 | 
            +
                class.define_method(
         | 
| 538 | 
            +
                    "use_regex",
         | 
| 539 | 
            +
                    method!(RbPreTokenizer::byte_level_use_regex, 0),
         | 
| 540 | 
            +
                )?;
         | 
| 541 | 
            +
                class.define_method(
         | 
| 542 | 
            +
                    "use_regex=",
         | 
| 543 | 
            +
                    method!(RbPreTokenizer::byte_level_set_use_regex, 1),
         | 
| 544 | 
            +
                )?;
         | 
| 490 545 |  | 
| 491 546 | 
             
                let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
         | 
| 492 547 | 
             
                class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
         | 
| 493 | 
            -
                class.define_method( | 
| 494 | 
            -
             | 
| 548 | 
            +
                class.define_method(
         | 
| 549 | 
            +
                    "delimiter",
         | 
| 550 | 
            +
                    method!(RbPreTokenizer::char_delimiter_split_delimiter, 0),
         | 
| 551 | 
            +
                )?;
         | 
| 552 | 
            +
                class.define_method(
         | 
| 553 | 
            +
                    "delimiter=",
         | 
| 554 | 
            +
                    method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1),
         | 
| 555 | 
            +
                )?;
         | 
| 495 556 |  | 
| 496 557 | 
             
                let class = module.define_class("Digits", pre_tokenizer)?;
         | 
| 497 558 | 
             
                class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
         | 
| 498 | 
            -
                class.define_method( | 
| 499 | 
            -
             | 
| 559 | 
            +
                class.define_method(
         | 
| 560 | 
            +
                    "individual_digits",
         | 
| 561 | 
            +
                    method!(RbPreTokenizer::digits_individual_digits, 0),
         | 
| 562 | 
            +
                )?;
         | 
| 563 | 
            +
                class.define_method(
         | 
| 564 | 
            +
                    "individual_digits=",
         | 
| 565 | 
            +
                    method!(RbPreTokenizer::digits_set_individual_digits, 1),
         | 
| 566 | 
            +
                )?;
         | 
| 500 567 |  | 
| 501 568 | 
             
                let class = module.define_class("Metaspace", pre_tokenizer)?;
         | 
| 502 569 | 
             
                class.define_singleton_method("_new", function!(RbMetaspace::new, 3))?;
         | 
| 503 | 
            -
                class.define_method( | 
| 504 | 
            -
             | 
| 505 | 
            -
             | 
| 506 | 
            -
                 | 
| 570 | 
            +
                class.define_method(
         | 
| 571 | 
            +
                    "prepend_scheme",
         | 
| 572 | 
            +
                    method!(RbPreTokenizer::metaspace_prepend_scheme, 0),
         | 
| 573 | 
            +
                )?;
         | 
| 574 | 
            +
                class.define_method(
         | 
| 575 | 
            +
                    "prepend_scheme=",
         | 
| 576 | 
            +
                    method!(RbPreTokenizer::metaspace_set_prepend_scheme, 1),
         | 
| 577 | 
            +
                )?;
         | 
| 578 | 
            +
                class.define_method(
         | 
| 579 | 
            +
                    "replacement",
         | 
| 580 | 
            +
                    method!(RbPreTokenizer::metaspace_replacement, 0),
         | 
| 581 | 
            +
                )?;
         | 
| 582 | 
            +
                class.define_method(
         | 
| 583 | 
            +
                    "replacement=",
         | 
| 584 | 
            +
                    method!(RbPreTokenizer::metaspace_set_replacement, 1),
         | 
| 585 | 
            +
                )?;
         | 
| 507 586 | 
             
                class.define_method("split", method!(RbPreTokenizer::metaspace_split, 0))?;
         | 
| 508 587 | 
             
                class.define_method("split=", method!(RbPreTokenizer::metaspace_set_split, 1))?;
         | 
| 509 588 |  | 
| @@ -1,8 +1,8 @@ | |
| 1 1 | 
             
            use std::sync::Arc;
         | 
| 2 2 |  | 
| 3 3 | 
             
            use magnus::{
         | 
| 4 | 
            -
                data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, | 
| 5 | 
            -
                Ruby, TryConvert, TypedData, Value,
         | 
| 4 | 
            +
                data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
         | 
| 5 | 
            +
                RClass, RModule, Ruby, TryConvert, TypedData, Value,
         | 
| 6 6 | 
             
            };
         | 
| 7 7 | 
             
            use serde::{Deserialize, Serialize};
         | 
| 8 8 | 
             
            use tk::processors::bert::BertProcessing;
         | 
| @@ -12,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template}; | |
| 12 12 | 
             
            use tk::processors::PostProcessorWrapper;
         | 
| 13 13 | 
             
            use tk::{Encoding, PostProcessor};
         | 
| 14 14 |  | 
| 15 | 
            -
            use super::{ | 
| 15 | 
            +
            use super::{RbResult, PROCESSORS};
         | 
| 16 16 |  | 
| 17 17 | 
             
            #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
         | 
| 18 18 | 
             
            pub struct RbPostProcessor {
         | 
| @@ -106,7 +106,6 @@ impl RbByteLevel { | |
| 106 106 | 
             
                    }
         | 
| 107 107 | 
             
                    RbPostProcessor::new(Arc::new(byte_level.into()))
         | 
| 108 108 | 
             
                }
         | 
| 109 | 
            -
             | 
| 110 109 | 
             
            }
         | 
| 111 110 |  | 
| 112 111 | 
             
            pub struct RbRobertaProcessing {}
         | 
| @@ -117,7 +116,7 @@ impl RbRobertaProcessing { | |
| 117 116 | 
             
                    cls: (String, u32),
         | 
| 118 117 | 
             
                    trim_offsets: bool,
         | 
| 119 118 | 
             
                    add_prefix_space: bool,
         | 
| 120 | 
            -
                ) -> | 
| 119 | 
            +
                ) -> RbPostProcessor {
         | 
| 121 120 | 
             
                    let proc = RobertaProcessing::new(sep, cls)
         | 
| 122 121 | 
             
                        .trim_offsets(trim_offsets)
         | 
| 123 122 | 
             
                        .add_prefix_space(add_prefix_space);
         | 
| @@ -153,7 +152,10 @@ impl RbTemplateProcessing { | |
| 153 152 | 
             
            unsafe impl TypedData for RbPostProcessor {
         | 
| 154 153 | 
             
                fn class(ruby: &Ruby) -> RClass {
         | 
| 155 154 | 
             
                    static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 156 | 
            -
                        let class: RClass = ruby | 
| 155 | 
            +
                        let class: RClass = ruby
         | 
| 156 | 
            +
                            .get_inner(&PROCESSORS)
         | 
| 157 | 
            +
                            .const_get("PostProcessor")
         | 
| 158 | 
            +
                            .unwrap();
         | 
| 157 159 | 
             
                        class.undef_default_alloc_func();
         | 
| 158 160 | 
             
                        class
         | 
| 159 161 | 
             
                    });
         | 
| @@ -161,13 +163,17 @@ unsafe impl TypedData for RbPostProcessor { | |
| 161 163 | 
             
                }
         | 
| 162 164 |  | 
| 163 165 | 
             
                fn data_type() -> &'static DataType {
         | 
| 164 | 
            -
                    static DATA_TYPE: DataType = | 
| 166 | 
            +
                    static DATA_TYPE: DataType =
         | 
| 167 | 
            +
                        data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
         | 
| 165 168 | 
             
                    &DATA_TYPE
         | 
| 166 169 | 
             
                }
         | 
| 167 170 |  | 
| 168 171 | 
             
                fn class_for(ruby: &Ruby, value: &Self) -> RClass {
         | 
| 169 172 | 
             
                    static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 170 | 
            -
                        let class: RClass = ruby | 
| 173 | 
            +
                        let class: RClass = ruby
         | 
| 174 | 
            +
                            .get_inner(&PROCESSORS)
         | 
| 175 | 
            +
                            .const_get("BertProcessing")
         | 
| 176 | 
            +
                            .unwrap();
         | 
| 171 177 | 
             
                        class.undef_default_alloc_func();
         | 
| 172 178 | 
             
                        class
         | 
| 173 179 | 
             
                    });
         | 
| @@ -177,12 +183,18 @@ unsafe impl TypedData for RbPostProcessor { | |
| 177 183 | 
             
                        class
         | 
| 178 184 | 
             
                    });
         | 
| 179 185 | 
             
                    static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 180 | 
            -
                        let class: RClass = ruby | 
| 186 | 
            +
                        let class: RClass = ruby
         | 
| 187 | 
            +
                            .get_inner(&PROCESSORS)
         | 
| 188 | 
            +
                            .const_get("RobertaProcessing")
         | 
| 189 | 
            +
                            .unwrap();
         | 
| 181 190 | 
             
                        class.undef_default_alloc_func();
         | 
| 182 191 | 
             
                        class
         | 
| 183 192 | 
             
                    });
         | 
| 184 193 | 
             
                    static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 185 | 
            -
                        let class: RClass = ruby | 
| 194 | 
            +
                        let class: RClass = ruby
         | 
| 195 | 
            +
                            .get_inner(&PROCESSORS)
         | 
| 196 | 
            +
                            .const_get("TemplateProcessing")
         | 
| 197 | 
            +
                            .unwrap();
         | 
| 186 198 | 
             
                        class.undef_default_alloc_func();
         | 
| 187 199 | 
             
                        class
         | 
| 188 200 | 
             
                    });
         | 
| @@ -6,8 +6,8 @@ use std::str::FromStr; | |
| 6 6 | 
             
            use magnus::prelude::*;
         | 
| 7 7 | 
             
            use magnus::{exception, Error, RArray, RHash, RString, Symbol, TryConvert, Value};
         | 
| 8 8 | 
             
            use tk::tokenizer::{
         | 
| 9 | 
            -
                Model, PaddingDirection, PaddingParams, PaddingStrategy,
         | 
| 10 | 
            -
                 | 
| 9 | 
            +
                Model, PaddingDirection, PaddingParams, PaddingStrategy, TokenizerImpl, TruncationDirection,
         | 
| 10 | 
            +
                TruncationParams, TruncationStrategy,
         | 
| 11 11 | 
             
            };
         | 
| 12 12 | 
             
            use tk::AddedToken;
         | 
| 13 13 |  | 
| @@ -284,7 +284,10 @@ impl RbTokenizer { | |
| 284 284 | 
             
                }
         | 
| 285 285 |  | 
| 286 286 | 
             
                pub fn to_str(&self, pretty: bool) -> RbResult<String> {
         | 
| 287 | 
            -
                    self.tokenizer | 
| 287 | 
            +
                    self.tokenizer
         | 
| 288 | 
            +
                        .borrow()
         | 
| 289 | 
            +
                        .to_string(pretty)
         | 
| 290 | 
            +
                        .map_err(RbError::from)
         | 
| 288 291 | 
             
                }
         | 
| 289 292 |  | 
| 290 293 | 
             
                pub fn add_special_tokens(&self, tokens: Vec<String>) -> usize {
         | 
| @@ -383,7 +386,11 @@ impl RbTokenizer { | |
| 383 386 | 
             
                        .map_err(RbError::from)
         | 
| 384 387 | 
             
                }
         | 
| 385 388 |  | 
| 386 | 
            -
                pub fn decode_batch( | 
| 389 | 
            +
                pub fn decode_batch(
         | 
| 390 | 
            +
                    &self,
         | 
| 391 | 
            +
                    sequences: Vec<Vec<u32>>,
         | 
| 392 | 
            +
                    skip_special_tokens: bool,
         | 
| 393 | 
            +
                ) -> RbResult<Vec<String>> {
         | 
| 387 394 | 
             
                    let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
         | 
| 388 395 | 
             
                    self.tokenizer
         | 
| 389 396 | 
             
                        .borrow()
         | 
| @@ -455,7 +462,12 @@ impl RbTokenizer { | |
| 455 462 | 
             
                        params.direction = match dir_str.as_str() {
         | 
| 456 463 | 
             
                            "left" => PaddingDirection::Left,
         | 
| 457 464 | 
             
                            "right" => PaddingDirection::Right,
         | 
| 458 | 
            -
                            _ =>  | 
| 465 | 
            +
                            _ => {
         | 
| 466 | 
            +
                                return Err(Error::new(
         | 
| 467 | 
            +
                                    exception::arg_error(),
         | 
| 468 | 
            +
                                    "The direction value must be 'left' or 'right'",
         | 
| 469 | 
            +
                                ))
         | 
| 470 | 
            +
                            }
         | 
| 459 471 | 
             
                        }
         | 
| 460 472 | 
             
                    }
         | 
| 461 473 |  | 
| @@ -501,24 +513,27 @@ impl RbTokenizer { | |
| 501 513 | 
             
                }
         | 
| 502 514 |  | 
| 503 515 | 
             
                pub fn padding(&self) -> RbResult<Option<RHash>> {
         | 
| 504 | 
            -
                    self.tokenizer | 
| 505 | 
            -
                         | 
| 506 | 
            -
             | 
| 507 | 
            -
                         | 
| 508 | 
            -
                             | 
| 509 | 
            -
             | 
| 510 | 
            -
             | 
| 511 | 
            -
                                 | 
| 512 | 
            -
             | 
| 513 | 
            -
             | 
| 514 | 
            -
             | 
| 515 | 
            -
             | 
| 516 | 
            -
             | 
| 517 | 
            -
             | 
| 518 | 
            -
             | 
| 519 | 
            -
             | 
| 520 | 
            -
             | 
| 521 | 
            -
             | 
| 516 | 
            +
                    self.tokenizer
         | 
| 517 | 
            +
                        .borrow()
         | 
| 518 | 
            +
                        .get_padding()
         | 
| 519 | 
            +
                        .map_or(Ok(None), |params| {
         | 
| 520 | 
            +
                            let ret_hash = RHash::new();
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                            ret_hash.aset(
         | 
| 523 | 
            +
                                "length",
         | 
| 524 | 
            +
                                match params.strategy {
         | 
| 525 | 
            +
                                    tk::PaddingStrategy::BatchLongest => None,
         | 
| 526 | 
            +
                                    tk::PaddingStrategy::Fixed(size) => Some(size),
         | 
| 527 | 
            +
                                },
         | 
| 528 | 
            +
                            )?;
         | 
| 529 | 
            +
                            ret_hash.aset("pad_to_multiple_of", params.pad_to_multiple_of)?;
         | 
| 530 | 
            +
                            ret_hash.aset("pad_id", params.pad_id)?;
         | 
| 531 | 
            +
                            ret_hash.aset("pad_token", &*params.pad_token)?;
         | 
| 532 | 
            +
                            ret_hash.aset("pad_type_id", params.pad_type_id)?;
         | 
| 533 | 
            +
                            ret_hash.aset("direction", params.direction.as_ref())?;
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                            Ok(Some(ret_hash))
         | 
| 536 | 
            +
                        })
         | 
| 522 537 | 
             
                }
         | 
| 523 538 |  | 
| 524 539 | 
             
                pub fn enable_truncation(&self, max_length: usize, kwargs: RHash) -> RbResult<()> {
         | 
| @@ -539,7 +554,10 @@ impl RbTokenizer { | |
| 539 554 | 
             
                            "longest_first" => TruncationStrategy::LongestFirst,
         | 
| 540 555 | 
             
                            "only_first" => TruncationStrategy::OnlyFirst,
         | 
| 541 556 | 
             
                            "only_second" => TruncationStrategy::OnlySecond,
         | 
| 542 | 
            -
                            _ => return Err(Error::new( | 
| 557 | 
            +
                            _ => return Err(Error::new(
         | 
| 558 | 
            +
                                exception::arg_error(),
         | 
| 559 | 
            +
                                "The strategy value must be 'longest_first', 'only_first', or 'only_second'",
         | 
| 560 | 
            +
                            )),
         | 
| 543 561 | 
             
                        }
         | 
| 544 562 | 
             
                    }
         | 
| 545 563 |  | 
| @@ -549,7 +567,12 @@ impl RbTokenizer { | |
| 549 567 | 
             
                        params.direction = match dir_str.as_str() {
         | 
| 550 568 | 
             
                            "left" => TruncationDirection::Left,
         | 
| 551 569 | 
             
                            "right" => TruncationDirection::Right,
         | 
| 552 | 
            -
                            _ =>  | 
| 570 | 
            +
                            _ => {
         | 
| 571 | 
            +
                                return Err(Error::new(
         | 
| 572 | 
            +
                                    exception::arg_error(),
         | 
| 573 | 
            +
                                    "The direction value must be 'left' or 'right'",
         | 
| 574 | 
            +
                                ))
         | 
| 575 | 
            +
                            }
         | 
| 553 576 | 
             
                        }
         | 
| 554 577 | 
             
                    }
         | 
| 555 578 |  | 
| @@ -559,7 +582,10 @@ impl RbTokenizer { | |
| 559 582 | 
             
                    }
         | 
| 560 583 |  | 
| 561 584 | 
             
                    if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
         | 
| 562 | 
            -
                        return Err(Error::new( | 
| 585 | 
            +
                        return Err(Error::new(
         | 
| 586 | 
            +
                            exception::arg_error(),
         | 
| 587 | 
            +
                            error_message.to_string(),
         | 
| 588 | 
            +
                        ));
         | 
| 563 589 | 
             
                    }
         | 
| 564 590 |  | 
| 565 591 | 
             
                    Ok(())
         | 
| @@ -573,16 +599,19 @@ impl RbTokenizer { | |
| 573 599 | 
             
                }
         | 
| 574 600 |  | 
| 575 601 | 
             
                pub fn truncation(&self) -> RbResult<Option<RHash>> {
         | 
| 576 | 
            -
                    self.tokenizer | 
| 577 | 
            -
                         | 
| 602 | 
            +
                    self.tokenizer
         | 
| 603 | 
            +
                        .borrow()
         | 
| 604 | 
            +
                        .get_truncation()
         | 
| 605 | 
            +
                        .map_or(Ok(None), |params| {
         | 
| 606 | 
            +
                            let ret_hash = RHash::new();
         | 
| 578 607 |  | 
| 579 | 
            -
             | 
| 580 | 
            -
             | 
| 581 | 
            -
             | 
| 582 | 
            -
             | 
| 608 | 
            +
                            ret_hash.aset("max_length", params.max_length)?;
         | 
| 609 | 
            +
                            ret_hash.aset("stride", params.stride)?;
         | 
| 610 | 
            +
                            ret_hash.aset("strategy", params.strategy.as_ref())?;
         | 
| 611 | 
            +
                            ret_hash.aset("direction", params.direction.as_ref())?;
         | 
| 583 612 |  | 
| 584 | 
            -
             | 
| 585 | 
            -
             | 
| 613 | 
            +
                            Ok(Some(ret_hash))
         | 
| 614 | 
            +
                        })
         | 
| 586 615 | 
             
                }
         | 
| 587 616 |  | 
| 588 617 | 
             
                pub fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
         |