tokenizers 0.5.2 → 0.5.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/Cargo.lock +154 -83
- data/ext/tokenizers/Cargo.toml +2 -2
- data/ext/tokenizers/src/decoders.rs +32 -14
- data/ext/tokenizers/src/error.rs +6 -1
- data/ext/tokenizers/src/lib.rs +47 -12
- data/ext/tokenizers/src/models.rs +75 -23
- data/ext/tokenizers/src/normalizers.rs +84 -24
- data/ext/tokenizers/src/pre_tokenizers.rs +121 -42
- data/ext/tokenizers/src/processors.rs +22 -10
- data/ext/tokenizers/src/tokenizer.rs +141 -39
- data/ext/tokenizers/src/trainers.rs +215 -56
- data/ext/tokenizers/src/utils/regex.rs +6 -4
- data/lib/tokenizers/added_token.rb +7 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +1 -0
- metadata +4 -7
data/ext/tokenizers/src/error.rs
CHANGED
@@ -9,9 +9,14 @@ impl RbError {
|
|
9
9
|
pub fn from(e: Box<dyn std::error::Error + Send + Sync>) -> Error {
|
10
10
|
Error::new(error(), e.to_string())
|
11
11
|
}
|
12
|
+
|
13
|
+
pub fn new_err(s: String) -> Error {
|
14
|
+
Error::new(error(), s)
|
15
|
+
}
|
12
16
|
}
|
13
17
|
|
14
|
-
static ERROR: Lazy<ExceptionClass> =
|
18
|
+
static ERROR: Lazy<ExceptionClass> =
|
19
|
+
Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Error").unwrap());
|
15
20
|
|
16
21
|
fn error() -> ExceptionClass {
|
17
22
|
Ruby::get().unwrap().get_inner(&ERROR)
|
data/ext/tokenizers/src/lib.rs
CHANGED
@@ -15,26 +15,39 @@ mod utils;
|
|
15
15
|
|
16
16
|
use encoding::RbEncoding;
|
17
17
|
use error::RbError;
|
18
|
-
use tokenizer::RbTokenizer;
|
18
|
+
use tokenizer::{RbAddedToken, RbTokenizer};
|
19
19
|
use utils::RbRegex;
|
20
20
|
|
21
21
|
use magnus::{function, method, prelude::*, value::Lazy, Error, RModule, Ruby};
|
22
22
|
|
23
23
|
type RbResult<T> = Result<T, Error>;
|
24
24
|
|
25
|
-
static TOKENIZERS: Lazy<RModule> =
|
25
|
+
static TOKENIZERS: Lazy<RModule> =
|
26
|
+
Lazy::new(|ruby| ruby.class_object().const_get("Tokenizers").unwrap());
|
26
27
|
|
27
|
-
static DECODERS: Lazy<RModule> =
|
28
|
+
static DECODERS: Lazy<RModule> =
|
29
|
+
Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Decoders").unwrap());
|
28
30
|
|
29
|
-
static MODELS: Lazy<RModule> =
|
31
|
+
static MODELS: Lazy<RModule> =
|
32
|
+
Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Models").unwrap());
|
30
33
|
|
31
|
-
static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby|
|
34
|
+
static NORMALIZERS: Lazy<RModule> = Lazy::new(|ruby| {
|
35
|
+
ruby.get_inner(&TOKENIZERS)
|
36
|
+
.const_get("Normalizers")
|
37
|
+
.unwrap()
|
38
|
+
});
|
32
39
|
|
33
|
-
static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby|
|
40
|
+
static PRE_TOKENIZERS: Lazy<RModule> = Lazy::new(|ruby| {
|
41
|
+
ruby.get_inner(&TOKENIZERS)
|
42
|
+
.const_get("PreTokenizers")
|
43
|
+
.unwrap()
|
44
|
+
});
|
34
45
|
|
35
|
-
static PROCESSORS: Lazy<RModule> =
|
46
|
+
static PROCESSORS: Lazy<RModule> =
|
47
|
+
Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Processors").unwrap());
|
36
48
|
|
37
|
-
static TRAINERS: Lazy<RModule> =
|
49
|
+
static TRAINERS: Lazy<RModule> =
|
50
|
+
Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Trainers").unwrap());
|
38
51
|
|
39
52
|
#[magnus::init]
|
40
53
|
fn init(ruby: &Ruby) -> RbResult<()> {
|
@@ -56,12 +69,15 @@ fn init(ruby: &Ruby) -> RbResult<()> {
|
|
56
69
|
class.define_method("_decode", method!(RbTokenizer::decode, 2))?;
|
57
70
|
class.define_method("_decode_batch", method!(RbTokenizer::decode_batch, 2))?;
|
58
71
|
class.define_method("model", method!(RbTokenizer::get_model, 0))?;
|
59
|
-
class.define_method("model=", method!(RbTokenizer::set_model,1))?;
|
72
|
+
class.define_method("model=", method!(RbTokenizer::set_model, 1))?;
|
60
73
|
class.define_method("decoder", method!(RbTokenizer::get_decoder, 0))?;
|
61
74
|
class.define_method("decoder=", method!(RbTokenizer::set_decoder, 1))?;
|
62
75
|
class.define_method("pre_tokenizer", method!(RbTokenizer::get_pre_tokenizer, 0))?;
|
63
76
|
class.define_method("pre_tokenizer=", method!(RbTokenizer::set_pre_tokenizer, 1))?;
|
64
|
-
class.define_method(
|
77
|
+
class.define_method(
|
78
|
+
"post_processor",
|
79
|
+
method!(RbTokenizer::get_post_processor, 0),
|
80
|
+
)?;
|
65
81
|
class.define_method(
|
66
82
|
"post_processor=",
|
67
83
|
method!(RbTokenizer::set_post_processor, 1),
|
@@ -73,12 +89,22 @@ fn init(ruby: &Ruby) -> RbResult<()> {
|
|
73
89
|
class.define_method("_enable_padding", method!(RbTokenizer::enable_padding, 1))?;
|
74
90
|
class.define_method("padding", method!(RbTokenizer::padding, 0))?;
|
75
91
|
class.define_method("no_padding", method!(RbTokenizer::no_padding, 0))?;
|
76
|
-
class.define_method(
|
92
|
+
class.define_method(
|
93
|
+
"_enable_truncation",
|
94
|
+
method!(RbTokenizer::enable_truncation, 2),
|
95
|
+
)?;
|
77
96
|
class.define_method("truncation", method!(RbTokenizer::truncation, 0))?;
|
78
97
|
class.define_method("no_truncation", method!(RbTokenizer::no_truncation, 0))?;
|
79
|
-
class.define_method(
|
98
|
+
class.define_method(
|
99
|
+
"num_special_tokens_to_add",
|
100
|
+
method!(RbTokenizer::num_special_tokens_to_add, 1),
|
101
|
+
)?;
|
80
102
|
class.define_method("_vocab", method!(RbTokenizer::vocab, 1))?;
|
81
103
|
class.define_method("_vocab_size", method!(RbTokenizer::vocab_size, 1))?;
|
104
|
+
class.define_method(
|
105
|
+
"added_tokens_decoder",
|
106
|
+
method!(RbTokenizer::get_added_tokens_decoder, 0),
|
107
|
+
)?;
|
82
108
|
class.define_method("_to_s", method!(RbTokenizer::to_str, 1))?;
|
83
109
|
|
84
110
|
let class = module.define_class("Encoding", ruby.class_object())?;
|
@@ -109,6 +135,15 @@ fn init(ruby: &Ruby) -> RbResult<()> {
|
|
109
135
|
let class = module.define_class("Regex", ruby.class_object())?;
|
110
136
|
class.define_singleton_method("new", function!(RbRegex::new, 1))?;
|
111
137
|
|
138
|
+
let class = module.define_class("AddedToken", ruby.class_object())?;
|
139
|
+
class.define_singleton_method("_new", function!(RbAddedToken::new, 2))?;
|
140
|
+
class.define_method("content", method!(RbAddedToken::get_content, 0))?;
|
141
|
+
class.define_method("rstrip", method!(RbAddedToken::get_rstrip, 0))?;
|
142
|
+
class.define_method("lstrip", method!(RbAddedToken::get_lstrip, 0))?;
|
143
|
+
class.define_method("single_word", method!(RbAddedToken::get_single_word, 0))?;
|
144
|
+
class.define_method("normalized", method!(RbAddedToken::get_normalized, 0))?;
|
145
|
+
class.define_method("special", method!(RbAddedToken::get_special, 0))?;
|
146
|
+
|
112
147
|
let models = module.define_module("Models")?;
|
113
148
|
let pre_tokenizers = module.define_module("PreTokenizers")?;
|
114
149
|
let decoders = module.define_module("Decoders")?;
|
@@ -5,18 +5,19 @@ use std::sync::{Arc, RwLock};
|
|
5
5
|
use crate::trainers::RbTrainer;
|
6
6
|
use magnus::prelude::*;
|
7
7
|
use magnus::{
|
8
|
-
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
9
|
-
RClass, RHash, RModule, Ruby, Symbol, TryConvert,
|
8
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType,
|
9
|
+
DataTypeFunctions, Error, Module, Object, RClass, RHash, RModule, Ruby, Symbol, TryConvert,
|
10
|
+
TypedData, Value,
|
10
11
|
};
|
11
12
|
use serde::{Deserialize, Serialize};
|
12
13
|
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
|
13
|
-
use tk::models::ModelWrapper;
|
14
14
|
use tk::models::unigram::Unigram;
|
15
15
|
use tk::models::wordlevel::WordLevel;
|
16
16
|
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
|
17
|
+
use tk::models::ModelWrapper;
|
17
18
|
use tk::{Model, Token};
|
18
19
|
|
19
|
-
use super::{
|
20
|
+
use super::{RbError, RbResult, MODELS};
|
20
21
|
|
21
22
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
22
23
|
pub struct RbModel {
|
@@ -187,7 +188,12 @@ impl RbModel {
|
|
187
188
|
}
|
188
189
|
|
189
190
|
pub fn bpe_set_continuing_subword_prefix(&self, continuing_subword_prefix: Option<String>) {
|
190
|
-
setter!(
|
191
|
+
setter!(
|
192
|
+
self,
|
193
|
+
BPE,
|
194
|
+
continuing_subword_prefix,
|
195
|
+
continuing_subword_prefix
|
196
|
+
);
|
191
197
|
}
|
192
198
|
|
193
199
|
pub fn bpe_end_of_word_suffix(&self) -> Option<String> {
|
@@ -219,7 +225,12 @@ impl RbModel {
|
|
219
225
|
}
|
220
226
|
|
221
227
|
pub fn word_piece_set_continuing_subword_prefix(&self, continuing_subword_prefix: String) {
|
222
|
-
setter!(
|
228
|
+
setter!(
|
229
|
+
self,
|
230
|
+
WordPiece,
|
231
|
+
continuing_subword_prefix,
|
232
|
+
continuing_subword_prefix
|
233
|
+
);
|
223
234
|
}
|
224
235
|
|
225
236
|
pub fn word_piece_max_input_chars_per_word(&self) -> usize {
|
@@ -227,21 +238,34 @@ impl RbModel {
|
|
227
238
|
}
|
228
239
|
|
229
240
|
pub fn word_piece_set_max_input_chars_per_word(&self, max_input_chars_per_word: usize) {
|
230
|
-
setter!(
|
241
|
+
setter!(
|
242
|
+
self,
|
243
|
+
WordPiece,
|
244
|
+
max_input_chars_per_word,
|
245
|
+
max_input_chars_per_word
|
246
|
+
);
|
231
247
|
}
|
232
248
|
}
|
233
249
|
|
234
250
|
pub struct RbUnigram {}
|
235
251
|
|
236
252
|
impl RbUnigram {
|
237
|
-
fn new(
|
253
|
+
fn new(
|
254
|
+
vocab: Option<Vec<(String, f64)>>,
|
255
|
+
unk_id: Option<usize>,
|
256
|
+
byte_fallback: Option<bool>,
|
257
|
+
) -> RbResult<RbModel> {
|
238
258
|
match (vocab, unk_id, byte_fallback) {
|
239
259
|
(Some(vocab), unk_id, byte_fallback) => {
|
240
|
-
let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false))
|
260
|
+
let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false))
|
261
|
+
.map_err(RbError::from)?;
|
241
262
|
Ok(model.into())
|
242
263
|
}
|
243
264
|
(None, None, _) => Ok(Unigram::default().into()),
|
244
|
-
_ => Err(Error::new(
|
265
|
+
_ => Err(Error::new(
|
266
|
+
exception::arg_error(),
|
267
|
+
"`vocab` and `unk_id` must be both specified",
|
268
|
+
)),
|
245
269
|
}
|
246
270
|
}
|
247
271
|
}
|
@@ -249,7 +273,10 @@ impl RbUnigram {
|
|
249
273
|
pub struct RbWordLevel {}
|
250
274
|
|
251
275
|
impl RbWordLevel {
|
252
|
-
pub fn new(
|
276
|
+
pub fn new(
|
277
|
+
vocab: Option<HashMap<String, u32>>,
|
278
|
+
unk_token: Option<String>,
|
279
|
+
) -> RbResult<RbModel> {
|
253
280
|
let mut builder = WordLevel::builder();
|
254
281
|
if let Some(vocab) = vocab {
|
255
282
|
builder = builder.vocab(vocab);
|
@@ -316,15 +343,16 @@ impl RbWordPiece {
|
|
316
343
|
unsafe impl TypedData for RbModel {
|
317
344
|
fn class(ruby: &Ruby) -> RClass {
|
318
345
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
319
|
-
|
320
|
-
|
321
|
-
|
346
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
|
347
|
+
class.undef_default_alloc_func();
|
348
|
+
class
|
322
349
|
});
|
323
350
|
ruby.get_inner(&CLASS)
|
324
351
|
}
|
325
352
|
|
326
353
|
fn data_type() -> &'static DataType {
|
327
|
-
static DATA_TYPE: DataType =
|
354
|
+
static DATA_TYPE: DataType =
|
355
|
+
data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
|
328
356
|
&DATA_TYPE
|
329
357
|
}
|
330
358
|
|
@@ -368,10 +396,22 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
368
396
|
class.define_method("dropout=", method!(RbModel::bpe_set_dropout, 1))?;
|
369
397
|
class.define_method("unk_token", method!(RbModel::bpe_unk_token, 0))?;
|
370
398
|
class.define_method("unk_token=", method!(RbModel::bpe_set_unk_token, 1))?;
|
371
|
-
class.define_method(
|
372
|
-
|
373
|
-
|
374
|
-
|
399
|
+
class.define_method(
|
400
|
+
"continuing_subword_prefix",
|
401
|
+
method!(RbModel::bpe_continuing_subword_prefix, 0),
|
402
|
+
)?;
|
403
|
+
class.define_method(
|
404
|
+
"continuing_subword_prefix=",
|
405
|
+
method!(RbModel::bpe_set_continuing_subword_prefix, 1),
|
406
|
+
)?;
|
407
|
+
class.define_method(
|
408
|
+
"end_of_word_suffix",
|
409
|
+
method!(RbModel::bpe_end_of_word_suffix, 0),
|
410
|
+
)?;
|
411
|
+
class.define_method(
|
412
|
+
"end_of_word_suffix=",
|
413
|
+
method!(RbModel::bpe_set_end_of_word_suffix, 1),
|
414
|
+
)?;
|
375
415
|
class.define_method("fuse_unk", method!(RbModel::bpe_fuse_unk, 0))?;
|
376
416
|
class.define_method("fuse_unk=", method!(RbModel::bpe_set_fuse_unk, 1))?;
|
377
417
|
class.define_method("byte_fallback", method!(RbModel::bpe_byte_fallback, 0))?;
|
@@ -392,10 +432,22 @@ pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
392
432
|
class.define_singleton_method("_from_file", function!(RbWordPiece::from_file, 2))?;
|
393
433
|
class.define_method("unk_token", method!(RbModel::word_piece_unk_token, 0))?;
|
394
434
|
class.define_method("unk_token=", method!(RbModel::word_piece_set_unk_token, 1))?;
|
395
|
-
class.define_method(
|
396
|
-
|
397
|
-
|
398
|
-
|
435
|
+
class.define_method(
|
436
|
+
"continuing_subword_prefix",
|
437
|
+
method!(RbModel::word_piece_continuing_subword_prefix, 0),
|
438
|
+
)?;
|
439
|
+
class.define_method(
|
440
|
+
"continuing_subword_prefix=",
|
441
|
+
method!(RbModel::word_piece_set_continuing_subword_prefix, 1),
|
442
|
+
)?;
|
443
|
+
class.define_method(
|
444
|
+
"max_input_chars_per_word",
|
445
|
+
method!(RbModel::word_piece_max_input_chars_per_word, 0),
|
446
|
+
)?;
|
447
|
+
class.define_method(
|
448
|
+
"max_input_chars_per_word=",
|
449
|
+
method!(RbModel::word_piece_set_max_input_chars_per_word, 1),
|
450
|
+
)?;
|
399
451
|
|
400
452
|
Ok(())
|
401
453
|
}
|
@@ -1,19 +1,19 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
3
|
use magnus::{
|
4
|
-
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module,
|
5
|
-
Ruby, TryConvert, TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module,
|
5
|
+
Object, RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
6
6
|
};
|
7
7
|
use serde::ser::SerializeStruct;
|
8
8
|
use serde::{Deserialize, Serialize, Serializer};
|
9
9
|
use tk::normalizers::{
|
10
|
-
BertNormalizer, Lowercase, Nmt, NormalizerWrapper,
|
11
|
-
NFC, NFD, NFKC, NFKD,
|
10
|
+
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace, Strip,
|
11
|
+
StripAccents, NFC, NFD, NFKC, NFKD,
|
12
12
|
};
|
13
13
|
use tk::{NormalizedString, Normalizer};
|
14
14
|
|
15
15
|
use super::utils::*;
|
16
|
-
use super::{
|
16
|
+
use super::{RbError, RbResult, NORMALIZERS};
|
17
17
|
|
18
18
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
19
19
|
pub struct RbNormalizer {
|
@@ -28,7 +28,9 @@ impl RbNormalizer {
|
|
28
28
|
|
29
29
|
pub fn normalize_str(&self, sequence: String) -> RbResult<String> {
|
30
30
|
let mut normalized = NormalizedString::from(sequence);
|
31
|
-
self.normalizer
|
31
|
+
self.normalizer
|
32
|
+
.normalize(&mut normalized)
|
33
|
+
.map_err(RbError::from)?;
|
32
34
|
Ok(normalized.get().to_owned())
|
33
35
|
}
|
34
36
|
}
|
@@ -43,7 +45,8 @@ macro_rules! getter {
|
|
43
45
|
($self: ident, $variant: ident, $name: ident) => {{
|
44
46
|
if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
|
45
47
|
let wrapper = norm.read().unwrap();
|
46
|
-
if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone()
|
48
|
+
if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone()
|
49
|
+
{
|
47
50
|
o.$name
|
48
51
|
} else {
|
49
52
|
unreachable!()
|
@@ -66,7 +69,6 @@ macro_rules! setter {
|
|
66
69
|
}
|
67
70
|
|
68
71
|
impl RbNormalizer {
|
69
|
-
|
70
72
|
fn bert_clean_text(&self) -> bool {
|
71
73
|
getter!(self, BertNormalizer, clean_text)
|
72
74
|
}
|
@@ -101,7 +103,7 @@ impl RbNormalizer {
|
|
101
103
|
}
|
102
104
|
|
103
105
|
fn bert_set_lowercase(&self, lowercase: bool) {
|
104
|
-
setter!(self, BertNormalizer, lowercase, lowercase)
|
106
|
+
setter!(self, BertNormalizer, lowercase, lowercase);
|
105
107
|
}
|
106
108
|
|
107
109
|
fn prepend_prepend(&self) -> String {
|
@@ -109,7 +111,7 @@ impl RbNormalizer {
|
|
109
111
|
}
|
110
112
|
|
111
113
|
fn prepend_set_prepend(&self, prepend: String) {
|
112
|
-
setter!(self, Prepend, prepend, prepend)
|
114
|
+
setter!(self, Prepend, prepend, prepend);
|
113
115
|
}
|
114
116
|
|
115
117
|
fn strip_left(&self) -> bool {
|
@@ -117,7 +119,7 @@ impl RbNormalizer {
|
|
117
119
|
}
|
118
120
|
|
119
121
|
fn strip_set_left(&self, left: bool) {
|
120
|
-
setter!(self, StripNormalizer, strip_left, left)
|
122
|
+
setter!(self, StripNormalizer, strip_left, left);
|
121
123
|
}
|
122
124
|
|
123
125
|
fn strip_right(&self) -> bool {
|
@@ -125,14 +127,19 @@ impl RbNormalizer {
|
|
125
127
|
}
|
126
128
|
|
127
129
|
fn strip_set_right(&self, right: bool) {
|
128
|
-
setter!(self, StripNormalizer, strip_right, right)
|
130
|
+
setter!(self, StripNormalizer, strip_right, right);
|
129
131
|
}
|
130
132
|
}
|
131
133
|
|
132
134
|
pub struct RbBertNormalizer {}
|
133
135
|
|
134
136
|
impl RbBertNormalizer {
|
135
|
-
pub fn new(
|
137
|
+
pub fn new(
|
138
|
+
clean_text: bool,
|
139
|
+
handle_chinese_chars: bool,
|
140
|
+
strip_accents: Option<bool>,
|
141
|
+
lowercase: bool,
|
142
|
+
) -> RbNormalizer {
|
136
143
|
BertNormalizer::new(clean_text, handle_chinese_chars, strip_accents, lowercase).into()
|
137
144
|
}
|
138
145
|
}
|
@@ -185,11 +192,28 @@ impl RbNmt {
|
|
185
192
|
}
|
186
193
|
}
|
187
194
|
|
195
|
+
pub struct RbPrecompiled {}
|
196
|
+
|
197
|
+
impl RbPrecompiled {
|
198
|
+
pub fn new(precompiled_charsmap: Vec<u8>) -> RbResult<RbNormalizer> {
|
199
|
+
Precompiled::from(&precompiled_charsmap)
|
200
|
+
.map_err(|e| {
|
201
|
+
RbError::new_err(format!(
|
202
|
+
"Error while attempting to build Precompiled normalizer: {}",
|
203
|
+
e
|
204
|
+
))
|
205
|
+
})
|
206
|
+
.map(|v| v.into())
|
207
|
+
}
|
208
|
+
}
|
209
|
+
|
188
210
|
pub struct RbReplace {}
|
189
211
|
|
190
212
|
impl RbReplace {
|
191
213
|
pub fn new(pattern: RbPattern, content: String) -> RbResult<RbNormalizer> {
|
192
|
-
Replace::new(pattern, content)
|
214
|
+
Replace::new(pattern, content)
|
215
|
+
.map(|v| v.into())
|
216
|
+
.map_err(RbError::from)
|
193
217
|
}
|
194
218
|
}
|
195
219
|
|
@@ -222,14 +246,16 @@ pub struct RbSequence {}
|
|
222
246
|
impl RbSequence {
|
223
247
|
fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
|
224
248
|
let mut sequence = Vec::with_capacity(normalizers.len());
|
225
|
-
for n in normalizers
|
249
|
+
for n in normalizers {
|
226
250
|
let normalizer: &RbNormalizer = TryConvert::try_convert(n)?;
|
227
251
|
match &normalizer.normalizer {
|
228
252
|
RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
|
229
253
|
RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
230
254
|
}
|
231
255
|
}
|
232
|
-
Ok(RbNormalizer::new(RbNormalizerTypeWrapper::Sequence(
|
256
|
+
Ok(RbNormalizer::new(RbNormalizerTypeWrapper::Sequence(
|
257
|
+
sequence,
|
258
|
+
)))
|
233
259
|
}
|
234
260
|
}
|
235
261
|
|
@@ -328,7 +354,10 @@ impl Normalizer for RbNormalizerWrapper {
|
|
328
354
|
unsafe impl TypedData for RbNormalizer {
|
329
355
|
fn class(ruby: &Ruby) -> RClass {
|
330
356
|
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
331
|
-
let class: RClass = ruby
|
357
|
+
let class: RClass = ruby
|
358
|
+
.get_inner(&NORMALIZERS)
|
359
|
+
.const_get("Normalizer")
|
360
|
+
.unwrap();
|
332
361
|
class.undef_default_alloc_func();
|
333
362
|
class
|
334
363
|
});
|
@@ -336,7 +365,8 @@ unsafe impl TypedData for RbNormalizer {
|
|
336
365
|
}
|
337
366
|
|
338
367
|
fn data_type() -> &'static DataType {
|
339
|
-
static DATA_TYPE: DataType =
|
368
|
+
static DATA_TYPE: DataType =
|
369
|
+
data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
|
340
370
|
&DATA_TYPE
|
341
371
|
}
|
342
372
|
|
@@ -347,7 +377,10 @@ unsafe impl TypedData for RbNormalizer {
|
|
347
377
|
class
|
348
378
|
});
|
349
379
|
static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
-
let class: RClass = ruby
|
380
|
+
let class: RClass = ruby
|
381
|
+
.get_inner(&NORMALIZERS)
|
382
|
+
.const_get("BertNormalizer")
|
383
|
+
.unwrap();
|
351
384
|
class.undef_default_alloc_func();
|
352
385
|
class
|
353
386
|
});
|
@@ -381,6 +414,14 @@ unsafe impl TypedData for RbNormalizer {
|
|
381
414
|
class.undef_default_alloc_func();
|
382
415
|
class
|
383
416
|
});
|
417
|
+
static PRECOMPILED: Lazy<RClass> = Lazy::new(|ruby| {
|
418
|
+
let class: RClass = ruby
|
419
|
+
.get_inner(&NORMALIZERS)
|
420
|
+
.const_get("Precompiled")
|
421
|
+
.unwrap();
|
422
|
+
class.undef_default_alloc_func();
|
423
|
+
class
|
424
|
+
});
|
384
425
|
static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
|
385
426
|
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
|
386
427
|
class.undef_default_alloc_func();
|
@@ -397,7 +438,10 @@ unsafe impl TypedData for RbNormalizer {
|
|
397
438
|
class
|
398
439
|
});
|
399
440
|
static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
|
400
|
-
let class: RClass = ruby
|
441
|
+
let class: RClass = ruby
|
442
|
+
.get_inner(&NORMALIZERS)
|
443
|
+
.const_get("StripAccents")
|
444
|
+
.unwrap();
|
401
445
|
class.undef_default_alloc_func();
|
402
446
|
class
|
403
447
|
});
|
@@ -412,6 +456,7 @@ unsafe impl TypedData for RbNormalizer {
|
|
412
456
|
NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
|
413
457
|
NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
|
414
458
|
NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
|
459
|
+
NormalizerWrapper::Precompiled(_) => ruby.get_inner(&PRECOMPILED),
|
415
460
|
NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
|
416
461
|
NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
|
417
462
|
NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
|
@@ -434,10 +479,22 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
434
479
|
class.define_singleton_method("_new", function!(RbBertNormalizer::new, 4))?;
|
435
480
|
class.define_method("clean_text", method!(RbNormalizer::bert_clean_text, 0))?;
|
436
481
|
class.define_method("clean_text=", method!(RbNormalizer::bert_set_clean_text, 1))?;
|
437
|
-
class.define_method(
|
438
|
-
|
439
|
-
|
440
|
-
|
482
|
+
class.define_method(
|
483
|
+
"handle_chinese_chars",
|
484
|
+
method!(RbNormalizer::bert_handle_chinese_chars, 0),
|
485
|
+
)?;
|
486
|
+
class.define_method(
|
487
|
+
"handle_chinese_chars=",
|
488
|
+
method!(RbNormalizer::bert_set_handle_chinese_chars, 1),
|
489
|
+
)?;
|
490
|
+
class.define_method(
|
491
|
+
"strip_accents",
|
492
|
+
method!(RbNormalizer::bert_strip_accents, 0),
|
493
|
+
)?;
|
494
|
+
class.define_method(
|
495
|
+
"strip_accents=",
|
496
|
+
method!(RbNormalizer::bert_set_strip_accents, 1),
|
497
|
+
)?;
|
441
498
|
class.define_method("lowercase", method!(RbNormalizer::bert_lowercase, 0))?;
|
442
499
|
class.define_method("lowercase=", method!(RbNormalizer::bert_set_lowercase, 1))?;
|
443
500
|
|
@@ -459,6 +516,9 @@ pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
|
459
516
|
let class = module.define_class("Nmt", normalizer)?;
|
460
517
|
class.define_singleton_method("new", function!(RbNmt::new, 0))?;
|
461
518
|
|
519
|
+
let class = module.define_class("Precompiled", normalizer)?;
|
520
|
+
class.define_singleton_method("new", function!(RbPrecompiled::new, 1))?;
|
521
|
+
|
462
522
|
let class = module.define_class("Replace", normalizer)?;
|
463
523
|
class.define_singleton_method("new", function!(RbReplace::new, 2))?;
|
464
524
|
|