tokenizers 0.5.3 → 0.5.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +154 -83
- data/ext/tokenizers/Cargo.toml +2 -2
- data/ext/tokenizers/src/decoders.rs +32 -14
- data/ext/tokenizers/src/error.rs +6 -1
- data/ext/tokenizers/src/lib.rs +37 -12
- data/ext/tokenizers/src/models.rs +75 -23
- data/ext/tokenizers/src/normalizers.rs +84 -24
- data/ext/tokenizers/src/pre_tokenizers.rs +121 -42
- data/ext/tokenizers/src/processors.rs +22 -10
- data/ext/tokenizers/src/tokenizer.rs +63 -34
- data/ext/tokenizers/src/trainers.rs +215 -56
- data/ext/tokenizers/src/utils/regex.rs +6 -4
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/version.rb +1 -1
- metadata +3 -7
    
        data/ext/tokenizers/src/lib.rs
    CHANGED
    
    | @@ -22,19 +22,32 @@ use magnus::{function, method, prelude::*, value::Lazy, Error, RModule, Ruby}; | |
| 22 22 |  | 
| 23 23 | 
             
            type RbResult<T> = Result<T, Error>;
         | 
| 24 24 |  | 
| 25 | 
            -
            static TOKENIZERS: Lazy<RModule> = | 
| 25 | 
            +
            static TOKENIZERS: Lazy<RModule> =
         | 
| 26 | 
            +
                Lazy::new(|ruby| ruby.class_object().const_get("Tokenizers").unwrap());
         | 
| 26 27 |  | 
| 27 | 
            -
            static DECODERS: Lazy<RModule> = | 
| 28 | 
            +
            static DECODERS: Lazy<RModule> =
         | 
| 29 | 
            +
                Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Decoders").unwrap());
         | 
| 28 30 |  | 
| 29 | 
            -
            static MODELS: Lazy<RModule> = | 
| 31 | 
            +
            static MODELS: Lazy<RModule> =
         | 
| 32 | 
            +
                Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Models").unwrap());
         | 
| 30 33 |  | 
| 31 | 
            -
            static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby|  | 
| 34 | 
            +
            static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby| {
         | 
| 35 | 
            +
                ruby.get_inner(&TOKENIZERS)
         | 
| 36 | 
            +
                    .const_get("Normalizers")
         | 
| 37 | 
            +
                    .unwrap()
         | 
| 38 | 
            +
            });
         | 
| 32 39 |  | 
| 33 | 
            -
            static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby|  | 
| 40 | 
            +
            static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| {
         | 
| 41 | 
            +
                ruby.get_inner(&TOKENIZERS)
         | 
| 42 | 
            +
                    .const_get("PreTokenizers")
         | 
| 43 | 
            +
                    .unwrap()
         | 
| 44 | 
            +
            });
         | 
| 34 45 |  | 
| 35 | 
            -
            static PROCESSORS: Lazy<RModule> = | 
| 46 | 
            +
            static PROCESSORS: Lazy<RModule> =
         | 
| 47 | 
            +
                Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Processors").unwrap());
         | 
| 36 48 |  | 
| 37 | 
            -
            static TRAINERS: Lazy<RModule> = | 
| 49 | 
            +
            static TRAINERS: Lazy<RModule> =
         | 
| 50 | 
            +
                Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Trainers").unwrap());
         | 
| 38 51 |  | 
| 39 52 | 
             
            #[magnus::init]
         | 
| 40 53 | 
             
            fn init(ruby: &Ruby) -> RbResult<()> {
         | 
| @@ -56,12 +69,15 @@ fn init(ruby: &Ruby) -> RbResult<()> { | |
| 56 69 | 
             
                class.define_method("_decode", method!(RbTokenizer::decode, 2))?;
         | 
| 57 70 | 
             
                class.define_method("_decode_batch", method!(RbTokenizer::decode_batch, 2))?;
         | 
| 58 71 | 
             
                class.define_method("model", method!(RbTokenizer::get_model, 0))?;
         | 
| 59 | 
            -
                class.define_method("model=", method!(RbTokenizer::set_model,1))?;
         | 
| 72 | 
            +
                class.define_method("model=", method!(RbTokenizer::set_model, 1))?;
         | 
| 60 73 | 
             
                class.define_method("decoder", method!(RbTokenizer::get_decoder, 0))?;
         | 
| 61 74 | 
             
                class.define_method("decoder=", method!(RbTokenizer::set_decoder, 1))?;
         | 
| 62 75 | 
             
                class.define_method("pre_tokenizer", method!(RbTokenizer::get_pre_tokenizer, 0))?;
         | 
| 63 76 | 
             
                class.define_method("pre_tokenizer=", method!(RbTokenizer::set_pre_tokenizer, 1))?;
         | 
| 64 | 
            -
                class.define_method( | 
| 77 | 
            +
                class.define_method(
         | 
| 78 | 
            +
                    "post_processor",
         | 
| 79 | 
            +
                    method!(RbTokenizer::get_post_processor, 0),
         | 
| 80 | 
            +
                )?;
         | 
| 65 81 | 
             
                class.define_method(
         | 
| 66 82 | 
             
                    "post_processor=",
         | 
| 67 83 | 
             
                    method!(RbTokenizer::set_post_processor, 1),
         | 
| @@ -73,13 +89,22 @@ fn init(ruby: &Ruby) -> RbResult<()> { | |
| 73 89 | 
             
                class.define_method("_enable_padding", method!(RbTokenizer::enable_padding, 1))?;
         | 
| 74 90 | 
             
                class.define_method("padding", method!(RbTokenizer::padding, 0))?;
         | 
| 75 91 | 
             
                class.define_method("no_padding", method!(RbTokenizer::no_padding, 0))?;
         | 
| 76 | 
            -
                class.define_method( | 
| 92 | 
            +
                class.define_method(
         | 
| 93 | 
            +
                    "_enable_truncation",
         | 
| 94 | 
            +
                    method!(RbTokenizer::enable_truncation, 2),
         | 
| 95 | 
            +
                )?;
         | 
| 77 96 | 
             
                class.define_method("truncation", method!(RbTokenizer::truncation, 0))?;
         | 
| 78 97 | 
             
                class.define_method("no_truncation", method!(RbTokenizer::no_truncation, 0))?;
         | 
| 79 | 
            -
                class.define_method( | 
| 98 | 
            +
                class.define_method(
         | 
| 99 | 
            +
                    "num_special_tokens_to_add",
         | 
| 100 | 
            +
                    method!(RbTokenizer::num_special_tokens_to_add, 1),
         | 
| 101 | 
            +
                )?;
         | 
| 80 102 | 
             
                class.define_method("_vocab", method!(RbTokenizer::vocab, 1))?;
         | 
| 81 103 | 
             
                class.define_method("_vocab_size", method!(RbTokenizer::vocab_size, 1))?;
         | 
| 82 | 
            -
                class.define_method( | 
| 104 | 
            +
                class.define_method(
         | 
| 105 | 
            +
                    "added_tokens_decoder",
         | 
| 106 | 
            +
                    method!(RbTokenizer::get_added_tokens_decoder, 0),
         | 
| 107 | 
            +
                )?;
         | 
| 83 108 | 
             
                class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
         | 
| 84 109 |  | 
| 85 110 | 
             
                let class = module.define_class("Encoding", ruby.class_object())?;
         | 
| @@ -5,18 +5,19 @@ use std::sync::{Arc, RwLock}; | |
| 5 5 | 
             
            use crate::trainers::RbTrainer;
         | 
| 6 6 | 
             
            use magnus::prelude::*;
         | 
| 7 7 | 
             
            use magnus::{
         | 
| 8 | 
            -
                data_type_builder, exception, function, method, value::Lazy, Class, DataType, | 
| 9 | 
            -
                RClass, RHash, RModule, Ruby, Symbol, TryConvert, | 
| 8 | 
            +
                data_type_builder, exception, function, method, value::Lazy, Class, DataType,
         | 
| 9 | 
            +
                DataTypeFunctions, Error, Module, Object, RClass, RHash, RModule, Ruby, Symbol, TryConvert,
         | 
| 10 | 
            +
                TypedData, Value,
         | 
| 10 11 | 
             
            };
         | 
| 11 12 | 
             
            use serde::{Deserialize, Serialize};
         | 
| 12 13 | 
             
            use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
         | 
| 13 | 
            -
            use tk::models::ModelWrapper;
         | 
| 14 14 | 
             
            use tk::models::unigram::Unigram;
         | 
| 15 15 | 
             
            use tk::models::wordlevel::WordLevel;
         | 
| 16 16 | 
             
            use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
         | 
| 17 | 
            +
            use tk::models::ModelWrapper;
         | 
| 17 18 | 
             
            use tk::{Model, Token};
         | 
| 18 19 |  | 
| 19 | 
            -
            use super::{ | 
| 20 | 
            +
            use super::{RbError, RbResult, MODELS};
         | 
| 20 21 |  | 
| 21 22 | 
             
            #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
         | 
| 22 23 | 
             
            pub struct RbModel {
         | 
| @@ -187,7 +188,12 @@ impl RbModel { | |
| 187 188 | 
             
                }
         | 
| 188 189 |  | 
| 189 190 | 
             
                pub fn bpe_set_continuing_subword_prefix(&self, continuing_subword_prefix: Option<String>) {
         | 
| 190 | 
            -
                    setter!( | 
| 191 | 
            +
                    setter!(
         | 
| 192 | 
            +
                        self,
         | 
| 193 | 
            +
                        BPE,
         | 
| 194 | 
            +
                        continuing_subword_prefix,
         | 
| 195 | 
            +
                        continuing_subword_prefix
         | 
| 196 | 
            +
                    );
         | 
| 191 197 | 
             
                }
         | 
| 192 198 |  | 
| 193 199 | 
             
                pub fn bpe_end_of_word_suffix(&self) -> Option<String> {
         | 
| @@ -219,7 +225,12 @@ impl RbModel { | |
| 219 225 | 
             
                }
         | 
| 220 226 |  | 
| 221 227 | 
             
                pub fn word_piece_set_continuing_subword_prefix(&self, continuing_subword_prefix: String) {
         | 
| 222 | 
            -
                    setter!( | 
| 228 | 
            +
                    setter!(
         | 
| 229 | 
            +
                        self,
         | 
| 230 | 
            +
                        WordPiece,
         | 
| 231 | 
            +
                        continuing_subword_prefix,
         | 
| 232 | 
            +
                        continuing_subword_prefix
         | 
| 233 | 
            +
                    );
         | 
| 223 234 | 
             
                }
         | 
| 224 235 |  | 
| 225 236 | 
             
                pub fn word_piece_max_input_chars_per_word(&self) -> usize {
         | 
| @@ -227,21 +238,34 @@ impl RbModel { | |
| 227 238 | 
             
                }
         | 
| 228 239 |  | 
| 229 240 | 
             
                pub fn word_piece_set_max_input_chars_per_word(&self, max_input_chars_per_word: usize) {
         | 
| 230 | 
            -
                    setter!( | 
| 241 | 
            +
                    setter!(
         | 
| 242 | 
            +
                        self,
         | 
| 243 | 
            +
                        WordPiece,
         | 
| 244 | 
            +
                        max_input_chars_per_word,
         | 
| 245 | 
            +
                        max_input_chars_per_word
         | 
| 246 | 
            +
                    );
         | 
| 231 247 | 
             
                }
         | 
| 232 248 | 
             
            }
         | 
| 233 249 |  | 
| 234 250 | 
             
            pub struct RbUnigram {}
         | 
| 235 251 |  | 
| 236 252 | 
             
            impl RbUnigram {
         | 
| 237 | 
            -
                fn new( | 
| 253 | 
            +
                fn new(
         | 
| 254 | 
            +
                    vocab: Option<Vec<(String, f64)>>,
         | 
| 255 | 
            +
                    unk_id: Option<usize>,
         | 
| 256 | 
            +
                    byte_fallback: Option<bool>,
         | 
| 257 | 
            +
                ) -> RbResult<RbModel> {
         | 
| 238 258 | 
             
                    match (vocab, unk_id, byte_fallback) {
         | 
| 239 259 | 
             
                        (Some(vocab), unk_id, byte_fallback) => {
         | 
| 240 | 
            -
                            let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)) | 
| 260 | 
            +
                            let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false))
         | 
| 261 | 
            +
                                .map_err(RbError::from)?;
         | 
| 241 262 | 
             
                            Ok(model.into())
         | 
| 242 263 | 
             
                        }
         | 
| 243 264 | 
             
                        (None, None, _) => Ok(Unigram::default().into()),
         | 
| 244 | 
            -
                        _ => Err(Error::new( | 
| 265 | 
            +
                        _ => Err(Error::new(
         | 
| 266 | 
            +
                            exception::arg_error(),
         | 
| 267 | 
            +
                            "`vocab` and `unk_id` must be both specified",
         | 
| 268 | 
            +
                        )),
         | 
| 245 269 | 
             
                    }
         | 
| 246 270 | 
             
                }
         | 
| 247 271 | 
             
            }
         | 
| @@ -249,7 +273,10 @@ impl RbUnigram { | |
| 249 273 | 
             
            pub struct RbWordLevel {}
         | 
| 250 274 |  | 
| 251 275 | 
             
            impl RbWordLevel {
         | 
| 252 | 
            -
                pub fn new( | 
| 276 | 
            +
                pub fn new(
         | 
| 277 | 
            +
                    vocab: Option<HashMap<String, u32>>,
         | 
| 278 | 
            +
                    unk_token: Option<String>,
         | 
| 279 | 
            +
                ) -> RbResult<RbModel> {
         | 
| 253 280 | 
             
                    let mut builder = WordLevel::builder();
         | 
| 254 281 | 
             
                    if let Some(vocab) = vocab {
         | 
| 255 282 | 
             
                        builder = builder.vocab(vocab);
         | 
| @@ -316,15 +343,16 @@ impl RbWordPiece { | |
| 316 343 | 
             
            unsafe impl TypedData for RbModel {
         | 
| 317 344 | 
             
                fn class(ruby: &Ruby) -> RClass {
         | 
| 318 345 | 
             
                    static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 319 | 
            -
             | 
| 320 | 
            -
             | 
| 321 | 
            -
             | 
| 346 | 
            +
                        let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
         | 
| 347 | 
            +
                        class.undef_default_alloc_func();
         | 
| 348 | 
            +
                        class
         | 
| 322 349 | 
             
                    });
         | 
| 323 350 | 
             
                    ruby.get_inner(&CLASS)
         | 
| 324 351 | 
             
                }
         | 
| 325 352 |  | 
| 326 353 | 
             
                fn data_type() -> &'static DataType {
         | 
| 327 | 
            -
                    static DATA_TYPE: DataType = | 
| 354 | 
            +
                    static DATA_TYPE: DataType =
         | 
| 355 | 
            +
                        data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
         | 
| 328 356 | 
             
                    &DATA_TYPE
         | 
| 329 357 | 
             
                }
         | 
| 330 358 |  | 
| @@ -368,10 +396,22 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> { | |
| 368 396 | 
             
                class.define_method("dropout=", method!(RbModel::bpe_set_dropout, 1))?;
         | 
| 369 397 | 
             
                class.define_method("unk_token", method!(RbModel::bpe_unk_token, 0))?;
         | 
| 370 398 | 
             
                class.define_method("unk_token=", method!(RbModel::bpe_set_unk_token, 1))?;
         | 
| 371 | 
            -
                class.define_method( | 
| 372 | 
            -
             | 
| 373 | 
            -
             | 
| 374 | 
            -
                 | 
| 399 | 
            +
                class.define_method(
         | 
| 400 | 
            +
                    "continuing_subword_prefix",
         | 
| 401 | 
            +
                    method!(RbModel::bpe_continuing_subword_prefix, 0),
         | 
| 402 | 
            +
                )?;
         | 
| 403 | 
            +
                class.define_method(
         | 
| 404 | 
            +
                    "continuing_subword_prefix=",
         | 
| 405 | 
            +
                    method!(RbModel::bpe_set_continuing_subword_prefix, 1),
         | 
| 406 | 
            +
                )?;
         | 
| 407 | 
            +
                class.define_method(
         | 
| 408 | 
            +
                    "end_of_word_suffix",
         | 
| 409 | 
            +
                    method!(RbModel::bpe_end_of_word_suffix, 0),
         | 
| 410 | 
            +
                )?;
         | 
| 411 | 
            +
                class.define_method(
         | 
| 412 | 
            +
                    "end_of_word_suffix=",
         | 
| 413 | 
            +
                    method!(RbModel::bpe_set_end_of_word_suffix, 1),
         | 
| 414 | 
            +
                )?;
         | 
| 375 415 | 
             
                class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
         | 
| 376 416 | 
             
                class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
         | 
| 377 417 | 
             
                class.define_method("byte_fallback", method!(RbModel::bpe_byte_fallback, 0))?;
         | 
| @@ -392,10 +432,22 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> { | |
| 392 432 | 
             
                class.define_singleton_method("_from_file", function!(RbWordPiece::from_file, 2))?;
         | 
| 393 433 | 
             
                class.define_method("unk_token", method!(RbModel::word_piece_unk_token, 0))?;
         | 
| 394 434 | 
             
                class.define_method("unk_token=", method!(RbModel::word_piece_set_unk_token, 1))?;
         | 
| 395 | 
            -
                class.define_method( | 
| 396 | 
            -
             | 
| 397 | 
            -
             | 
| 398 | 
            -
                 | 
| 435 | 
            +
                class.define_method(
         | 
| 436 | 
            +
                    "continuing_subword_prefix",
         | 
| 437 | 
            +
                    method!(RbModel::word_piece_continuing_subword_prefix, 0),
         | 
| 438 | 
            +
                )?;
         | 
| 439 | 
            +
                class.define_method(
         | 
| 440 | 
            +
                    "continuing_subword_prefix=",
         | 
| 441 | 
            +
                    method!(RbModel::word_piece_set_continuing_subword_prefix, 1),
         | 
| 442 | 
            +
                )?;
         | 
| 443 | 
            +
                class.define_method(
         | 
| 444 | 
            +
                    "max_input_chars_per_word",
         | 
| 445 | 
            +
                    method!(RbModel::word_piece_max_input_chars_per_word, 0),
         | 
| 446 | 
            +
                )?;
         | 
| 447 | 
            +
                class.define_method(
         | 
| 448 | 
            +
                    "max_input_chars_per_word=",
         | 
| 449 | 
            +
                    method!(RbModel::word_piece_set_max_input_chars_per_word, 1),
         | 
| 450 | 
            +
                )?;
         | 
| 399 451 |  | 
| 400 452 | 
             
                Ok(())
         | 
| 401 453 | 
             
            }
         | 
| @@ -1,19 +1,19 @@ | |
| 1 1 | 
             
            use std::sync::{Arc, RwLock};
         | 
| 2 2 |  | 
| 3 3 | 
             
            use magnus::{
         | 
| 4 | 
            -
                data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, | 
| 5 | 
            -
                Ruby, TryConvert, TypedData,
         | 
| 4 | 
            +
                data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module,
         | 
| 5 | 
            +
                Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
         | 
| 6 6 | 
             
            };
         | 
| 7 7 | 
             
            use serde::ser::SerializeStruct;
         | 
| 8 8 | 
             
            use serde::{Deserialize, Serialize, Serializer};
         | 
| 9 9 | 
             
            use tk::normalizers::{
         | 
| 10 | 
            -
                BertNormalizer, Lowercase, Nmt, NormalizerWrapper,  | 
| 11 | 
            -
                NFC, NFD, NFKC, NFKD,
         | 
| 10 | 
            +
                BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace, Strip,
         | 
| 11 | 
            +
                StripAccents, NFC, NFD, NFKC, NFKD,
         | 
| 12 12 | 
             
            };
         | 
| 13 13 | 
             
            use tk::{NormalizedString, Normalizer};
         | 
| 14 14 |  | 
| 15 15 | 
             
            use super::utils::*;
         | 
| 16 | 
            -
            use super::{ | 
| 16 | 
            +
            use super::{RbError, RbResult, NORMALIZERS};
         | 
| 17 17 |  | 
| 18 18 | 
             
            #[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
         | 
| 19 19 | 
             
            pub struct RbNormalizer {
         | 
| @@ -28,7 +28,9 @@ impl RbNormalizer { | |
| 28 28 |  | 
| 29 29 | 
             
                pub fn normalize_str(&self, sequence: String) -> RbResult<String> {
         | 
| 30 30 | 
             
                    let mut normalized = NormalizedString::from(sequence);
         | 
| 31 | 
            -
                    self.normalizer | 
| 31 | 
            +
                    self.normalizer
         | 
| 32 | 
            +
                        .normalize(&mut normalized)
         | 
| 33 | 
            +
                        .map_err(RbError::from)?;
         | 
| 32 34 | 
             
                    Ok(normalized.get().to_owned())
         | 
| 33 35 | 
             
                }
         | 
| 34 36 | 
             
            }
         | 
| @@ -43,7 +45,8 @@ macro_rules! getter { | |
| 43 45 | 
             
                ($self: ident, $variant: ident, $name: ident) => {{
         | 
| 44 46 | 
             
                    if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
         | 
| 45 47 | 
             
                        let wrapper = norm.read().unwrap();
         | 
| 46 | 
            -
                        if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone() | 
| 48 | 
            +
                        if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone()
         | 
| 49 | 
            +
                        {
         | 
| 47 50 | 
             
                            o.$name
         | 
| 48 51 | 
             
                        } else {
         | 
| 49 52 | 
             
                            unreachable!()
         | 
| @@ -66,7 +69,6 @@ macro_rules! setter { | |
| 66 69 | 
             
            }
         | 
| 67 70 |  | 
| 68 71 | 
             
            impl RbNormalizer {
         | 
| 69 | 
            -
             | 
| 70 72 | 
             
                fn bert_clean_text(&self) -> bool {
         | 
| 71 73 | 
             
                    getter!(self, BertNormalizer, clean_text)
         | 
| 72 74 | 
             
                }
         | 
| @@ -101,7 +103,7 @@ impl RbNormalizer { | |
| 101 103 | 
             
                }
         | 
| 102 104 |  | 
| 103 105 | 
             
                fn bert_set_lowercase(&self, lowercase: bool) {
         | 
| 104 | 
            -
                    setter!(self, BertNormalizer, lowercase, lowercase)
         | 
| 106 | 
            +
                    setter!(self, BertNormalizer, lowercase, lowercase);
         | 
| 105 107 | 
             
                }
         | 
| 106 108 |  | 
| 107 109 | 
             
                fn prepend_prepend(&self) -> String {
         | 
| @@ -109,7 +111,7 @@ impl RbNormalizer { | |
| 109 111 | 
             
                }
         | 
| 110 112 |  | 
| 111 113 | 
             
                fn prepend_set_prepend(&self, prepend: String) {
         | 
| 112 | 
            -
                    setter!(self, Prepend, prepend, prepend)
         | 
| 114 | 
            +
                    setter!(self, Prepend, prepend, prepend);
         | 
| 113 115 | 
             
                }
         | 
| 114 116 |  | 
| 115 117 | 
             
                fn strip_left(&self) -> bool {
         | 
| @@ -117,7 +119,7 @@ impl RbNormalizer { | |
| 117 119 | 
             
                }
         | 
| 118 120 |  | 
| 119 121 | 
             
                fn strip_set_left(&self, left: bool) {
         | 
| 120 | 
            -
                    setter!(self, StripNormalizer, strip_left, left)
         | 
| 122 | 
            +
                    setter!(self, StripNormalizer, strip_left, left);
         | 
| 121 123 | 
             
                }
         | 
| 122 124 |  | 
| 123 125 | 
             
                fn strip_right(&self) -> bool {
         | 
| @@ -125,14 +127,19 @@ impl RbNormalizer { | |
| 125 127 | 
             
                }
         | 
| 126 128 |  | 
| 127 129 | 
             
                fn strip_set_right(&self, right: bool) {
         | 
| 128 | 
            -
                    setter!(self, StripNormalizer, strip_right, right)
         | 
| 130 | 
            +
                    setter!(self, StripNormalizer, strip_right, right);
         | 
| 129 131 | 
             
                }
         | 
| 130 132 | 
             
            }
         | 
| 131 133 |  | 
| 132 134 | 
             
            pub struct RbBertNormalizer {}
         | 
| 133 135 |  | 
| 134 136 | 
             
            impl RbBertNormalizer {
         | 
| 135 | 
            -
                pub fn new( | 
| 137 | 
            +
                pub fn new(
         | 
| 138 | 
            +
                    clean_text: bool,
         | 
| 139 | 
            +
                    handle_chinese_chars: bool,
         | 
| 140 | 
            +
                    strip_accents: Option<bool>,
         | 
| 141 | 
            +
                    lowercase: bool,
         | 
| 142 | 
            +
                ) -> RbNormalizer {
         | 
| 136 143 | 
             
                    BertNormalizer::new(clean_text, handle_chinese_chars, strip_accents, lowercase).into()
         | 
| 137 144 | 
             
                }
         | 
| 138 145 | 
             
            }
         | 
| @@ -185,11 +192,28 @@ impl RbNmt { | |
| 185 192 | 
             
                }
         | 
| 186 193 | 
             
            }
         | 
| 187 194 |  | 
| 195 | 
            +
            pub struct RbPrecompiled {}
         | 
| 196 | 
            +
             | 
| 197 | 
            +
            impl RbPrecompiled {
         | 
| 198 | 
            +
                pub fn new(precompiled_charsmap: Vec<u8>) -> RbResult<RbNormalizer> {
         | 
| 199 | 
            +
                    Precompiled::from(&precompiled_charsmap)
         | 
| 200 | 
            +
                        .map_err(|e| {
         | 
| 201 | 
            +
                            RbError::new_err(format!(
         | 
| 202 | 
            +
                                "Error while attempting to build Precompiled normalizer: {}",
         | 
| 203 | 
            +
                                e
         | 
| 204 | 
            +
                            ))
         | 
| 205 | 
            +
                        })
         | 
| 206 | 
            +
                        .map(|v| v.into())
         | 
| 207 | 
            +
                }
         | 
| 208 | 
            +
            }
         | 
| 209 | 
            +
             | 
| 188 210 | 
             
            pub struct RbReplace {}
         | 
| 189 211 |  | 
| 190 212 | 
             
            impl RbReplace {
         | 
| 191 213 | 
             
                pub fn new(pattern: RbPattern, content: String) -> RbResult<RbNormalizer> {
         | 
| 192 | 
            -
                    Replace::new(pattern, content) | 
| 214 | 
            +
                    Replace::new(pattern, content)
         | 
| 215 | 
            +
                        .map(|v| v.into())
         | 
| 216 | 
            +
                        .map_err(RbError::from)
         | 
| 193 217 | 
             
                }
         | 
| 194 218 | 
             
            }
         | 
| 195 219 |  | 
| @@ -222,14 +246,16 @@ pub struct RbSequence {} | |
| 222 246 | 
             
            impl RbSequence {
         | 
| 223 247 | 
             
                fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
         | 
| 224 248 | 
             
                    let mut sequence = Vec::with_capacity(normalizers.len());
         | 
| 225 | 
            -
                    for n in normalizers | 
| 249 | 
            +
                    for n in normalizers {
         | 
| 226 250 | 
             
                        let normalizer: &RbNormalizer = TryConvert::try_convert(n)?;
         | 
| 227 251 | 
             
                        match &normalizer.normalizer {
         | 
| 228 252 | 
             
                            RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
         | 
| 229 253 | 
             
                            RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
         | 
| 230 254 | 
             
                        }
         | 
| 231 255 | 
             
                    }
         | 
| 232 | 
            -
                    Ok(RbNormalizer::new(RbNormalizerTypeWrapper::Sequence( | 
| 256 | 
            +
                    Ok(RbNormalizer::new(RbNormalizerTypeWrapper::Sequence(
         | 
| 257 | 
            +
                        sequence,
         | 
| 258 | 
            +
                    )))
         | 
| 233 259 | 
             
                }
         | 
| 234 260 | 
             
            }
         | 
| 235 261 |  | 
| @@ -328,7 +354,10 @@ impl Normalizer for RbNormalizerWrapper { | |
| 328 354 | 
             
            unsafe impl TypedData for RbNormalizer {
         | 
| 329 355 | 
             
                fn class(ruby: &Ruby) -> RClass {
         | 
| 330 356 | 
             
                    static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 331 | 
            -
                        let class: RClass = ruby | 
| 357 | 
            +
                        let class: RClass = ruby
         | 
| 358 | 
            +
                            .get_inner(&NORMALIZERS)
         | 
| 359 | 
            +
                            .const_get("Normalizer")
         | 
| 360 | 
            +
                            .unwrap();
         | 
| 332 361 | 
             
                        class.undef_default_alloc_func();
         | 
| 333 362 | 
             
                        class
         | 
| 334 363 | 
             
                    });
         | 
| @@ -336,7 +365,8 @@ unsafe impl TypedData for RbNormalizer { | |
| 336 365 | 
             
                }
         | 
| 337 366 |  | 
| 338 367 | 
             
                fn data_type() -> &'static DataType {
         | 
| 339 | 
            -
                    static DATA_TYPE: DataType = | 
| 368 | 
            +
                    static DATA_TYPE: DataType =
         | 
| 369 | 
            +
                        data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
         | 
| 340 370 | 
             
                    &DATA_TYPE
         | 
| 341 371 | 
             
                }
         | 
| 342 372 |  | 
| @@ -347,7 +377,10 @@ unsafe impl TypedData for RbNormalizer { | |
| 347 377 | 
             
                        class
         | 
| 348 378 | 
             
                    });
         | 
| 349 379 | 
             
                    static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 350 | 
            -
                        let class: RClass = ruby | 
| 380 | 
            +
                        let class: RClass = ruby
         | 
| 381 | 
            +
                            .get_inner(&NORMALIZERS)
         | 
| 382 | 
            +
                            .const_get("BertNormalizer")
         | 
| 383 | 
            +
                            .unwrap();
         | 
| 351 384 | 
             
                        class.undef_default_alloc_func();
         | 
| 352 385 | 
             
                        class
         | 
| 353 386 | 
             
                    });
         | 
| @@ -381,6 +414,14 @@ unsafe impl TypedData for RbNormalizer { | |
| 381 414 | 
             
                        class.undef_default_alloc_func();
         | 
| 382 415 | 
             
                        class
         | 
| 383 416 | 
             
                    });
         | 
| 417 | 
            +
                    static PRECOMPILED: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 418 | 
            +
                        let class: RClass = ruby
         | 
| 419 | 
            +
                            .get_inner(&NORMALIZERS)
         | 
| 420 | 
            +
                            .const_get("Precompiled")
         | 
| 421 | 
            +
                            .unwrap();
         | 
| 422 | 
            +
                        class.undef_default_alloc_func();
         | 
| 423 | 
            +
                        class
         | 
| 424 | 
            +
                    });
         | 
| 384 425 | 
             
                    static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 385 426 | 
             
                        let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
         | 
| 386 427 | 
             
                        class.undef_default_alloc_func();
         | 
| @@ -397,7 +438,10 @@ unsafe impl TypedData for RbNormalizer { | |
| 397 438 | 
             
                        class
         | 
| 398 439 | 
             
                    });
         | 
| 399 440 | 
             
                    static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
         | 
| 400 | 
            -
                        let class: RClass = ruby | 
| 441 | 
            +
                        let class: RClass = ruby
         | 
| 442 | 
            +
                            .get_inner(&NORMALIZERS)
         | 
| 443 | 
            +
                            .const_get("StripAccents")
         | 
| 444 | 
            +
                            .unwrap();
         | 
| 401 445 | 
             
                        class.undef_default_alloc_func();
         | 
| 402 446 | 
             
                        class
         | 
| 403 447 | 
             
                    });
         | 
| @@ -412,6 +456,7 @@ unsafe impl TypedData for RbNormalizer { | |
| 412 456 | 
             
                                NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
         | 
| 413 457 | 
             
                                NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
         | 
| 414 458 | 
             
                                NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
         | 
| 459 | 
            +
                                NormalizerWrapper::Precompiled(_) => ruby.get_inner(&PRECOMPILED),
         | 
| 415 460 | 
             
                                NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
         | 
| 416 461 | 
             
                                NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
         | 
| 417 462 | 
             
                                NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
         | 
| @@ -434,10 +479,22 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> { | |
| 434 479 | 
             
                class.define_singleton_method("_new", function!(RbBertNormalizer::new, 4))?;
         | 
| 435 480 | 
             
                class.define_method("clean_text", method!(RbNormalizer::bert_clean_text, 0))?;
         | 
| 436 481 | 
             
                class.define_method("clean_text=", method!(RbNormalizer::bert_set_clean_text, 1))?;
         | 
| 437 | 
            -
                class.define_method( | 
| 438 | 
            -
             | 
| 439 | 
            -
             | 
| 440 | 
            -
                 | 
| 482 | 
            +
                class.define_method(
         | 
| 483 | 
            +
                    "handle_chinese_chars",
         | 
| 484 | 
            +
                    method!(RbNormalizer::bert_handle_chinese_chars, 0),
         | 
| 485 | 
            +
                )?;
         | 
| 486 | 
            +
                class.define_method(
         | 
| 487 | 
            +
                    "handle_chinese_chars=",
         | 
| 488 | 
            +
                    method!(RbNormalizer::bert_set_handle_chinese_chars, 1),
         | 
| 489 | 
            +
                )?;
         | 
| 490 | 
            +
                class.define_method(
         | 
| 491 | 
            +
                    "strip_accents",
         | 
| 492 | 
            +
                    method!(RbNormalizer::bert_strip_accents, 0),
         | 
| 493 | 
            +
                )?;
         | 
| 494 | 
            +
                class.define_method(
         | 
| 495 | 
            +
                    "strip_accents=",
         | 
| 496 | 
            +
                    method!(RbNormalizer::bert_set_strip_accents, 1),
         | 
| 497 | 
            +
                )?;
         | 
| 441 498 | 
             
                class.define_method("lowercase", method!(RbNormalizer::bert_lowercase, 0))?;
         | 
| 442 499 | 
             
                class.define_method("lowercase=", method!(RbNormalizer::bert_set_lowercase, 1))?;
         | 
| 443 500 |  | 
| @@ -459,6 +516,9 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> { | |
| 459 516 | 
             
                let class = module.define_class("Nmt", normalizer)?;
         | 
| 460 517 | 
             
                class.define_singleton_method("new", function!(RbNmt::new, 0))?;
         | 
| 461 518 |  | 
| 519 | 
            +
                let class = module.define_class("Precompiled", normalizer)?;
         | 
| 520 | 
            +
                class.define_singleton_method("new", function!(RbPrecompiled::new, 1))?;
         | 
| 521 | 
            +
             | 
| 462 522 | 
             
                let class = module.define_class("Replace", normalizer)?;
         | 
| 463 523 | 
             
                class.define_singleton_method("new", function!(RbReplace::new, 2))?;
         | 
| 464 524 |  |