tokenizers 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/Cargo.lock +160 -96
- data/ext/tokenizers/Cargo.toml +6 -6
- data/ext/tokenizers/src/decoders.rs +149 -39
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +71 -50
- data/ext/tokenizers/src/normalizers.rs +113 -74
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/decoders/strip.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/normalizers/prepend.rb +9 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +4 -2
- metadata +6 -4
@@ -1,20 +1,19 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
use serde::ser::SerializeStruct;
|
9
8
|
use serde::{Deserialize, Serialize, Serializer};
|
10
9
|
use tk::normalizers::{
|
11
|
-
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Strip, StripAccents,
|
10
|
+
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Prepend, Strip, StripAccents,
|
12
11
|
NFC, NFD, NFKC, NFKD,
|
13
12
|
};
|
14
13
|
use tk::{NormalizedString, Normalizer};
|
15
14
|
|
16
15
|
use super::utils::*;
|
17
|
-
use super::{RbError, RbResult};
|
16
|
+
use super::{NORMALIZERS, RbError, RbResult};
|
18
17
|
|
19
18
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
20
19
|
pub struct RbNormalizer {
|
@@ -44,7 +43,7 @@ macro_rules! getter {
|
|
44
43
|
($self: ident, $variant: ident, $name: ident) => {{
|
45
44
|
if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
|
46
45
|
let wrapper = norm.read().unwrap();
|
47
|
-
if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper {
|
46
|
+
if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone() {
|
48
47
|
o.$name
|
49
48
|
} else {
|
50
49
|
unreachable!()
|
@@ -105,6 +104,14 @@ impl RbNormalizer {
|
|
105
104
|
setter!(self, BertNormalizer, lowercase, lowercase)
|
106
105
|
}
|
107
106
|
|
107
|
+
fn prepend_prepend(&self) -> String {
|
108
|
+
getter!(self, Prepend, prepend)
|
109
|
+
}
|
110
|
+
|
111
|
+
fn prepend_set_prepend(&self, prepend: String) {
|
112
|
+
setter!(self, Prepend, prepend, prepend)
|
113
|
+
}
|
114
|
+
|
108
115
|
fn strip_left(&self) -> bool {
|
109
116
|
getter!(self, StripNormalizer, strip_left)
|
110
117
|
}
|
@@ -186,6 +193,14 @@ impl RbReplace {
|
|
186
193
|
}
|
187
194
|
}
|
188
195
|
|
196
|
+
pub struct RbPrepend {}
|
197
|
+
|
198
|
+
impl RbPrepend {
|
199
|
+
pub fn new(prepend: String) -> RbNormalizer {
|
200
|
+
Prepend::new(prepend).into()
|
201
|
+
}
|
202
|
+
}
|
203
|
+
|
189
204
|
pub struct RbStrip {}
|
190
205
|
|
191
206
|
impl RbStrip {
|
@@ -208,7 +223,7 @@ impl RbSequence {
|
|
208
223
|
fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
|
209
224
|
let mut sequence = Vec::with_capacity(normalizers.len());
|
210
225
|
for n in normalizers.each() {
|
211
|
-
let normalizer: &RbNormalizer =
|
226
|
+
let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
|
212
227
|
match &normalizer.normalizer {
|
213
228
|
RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
|
214
229
|
RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
@@ -311,77 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
|
|
311
326
|
}
|
312
327
|
|
313
328
|
unsafe impl TypedData for RbNormalizer {
|
314
|
-
fn class() -> RClass {
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
})
|
329
|
+
fn class(ruby: &Ruby) -> RClass {
|
330
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
331
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
|
332
|
+
class.undef_default_alloc_func();
|
333
|
+
class
|
334
|
+
});
|
335
|
+
ruby.get_inner(&CLASS)
|
320
336
|
}
|
321
337
|
|
322
338
|
fn data_type() -> &'static DataType {
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
339
|
+
static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
|
340
|
+
&DATA_TYPE
|
341
|
+
}
|
342
|
+
|
343
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
344
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
345
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
|
346
|
+
class.undef_default_alloc_func();
|
347
|
+
class
|
348
|
+
});
|
349
|
+
static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
|
355
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
|
356
|
+
class.undef_default_alloc_func();
|
357
|
+
class
|
358
|
+
});
|
359
|
+
static NFD: Lazy<RClass> = Lazy::new(|ruby| {
|
360
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
|
361
|
+
class.undef_default_alloc_func();
|
362
|
+
class
|
363
|
+
});
|
364
|
+
static NFC: Lazy<RClass> = Lazy::new(|ruby| {
|
365
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
|
366
|
+
class.undef_default_alloc_func();
|
367
|
+
class
|
368
|
+
});
|
369
|
+
static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
|
370
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
|
371
|
+
class.undef_default_alloc_func();
|
372
|
+
class
|
373
|
+
});
|
374
|
+
static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
|
375
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
|
376
|
+
class.undef_default_alloc_func();
|
377
|
+
class
|
378
|
+
});
|
379
|
+
static NMT: Lazy<RClass> = Lazy::new(|ruby| {
|
380
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
|
381
|
+
class.undef_default_alloc_func();
|
382
|
+
class
|
383
|
+
});
|
384
|
+
static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
|
385
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
|
386
|
+
class.undef_default_alloc_func();
|
387
|
+
class
|
388
|
+
});
|
389
|
+
static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
|
390
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
|
391
|
+
class.undef_default_alloc_func();
|
392
|
+
class
|
393
|
+
});
|
394
|
+
static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
|
395
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
|
396
|
+
class.undef_default_alloc_func();
|
397
|
+
class
|
398
|
+
});
|
399
|
+
static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
|
400
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
|
401
|
+
class.undef_default_alloc_func();
|
402
|
+
class
|
403
|
+
});
|
327
404
|
match &value.normalizer {
|
328
|
-
RbNormalizerTypeWrapper::Sequence(_seq) =>
|
329
|
-
let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
|
330
|
-
class.undef_alloc_func();
|
331
|
-
class
|
332
|
-
}),
|
405
|
+
RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
333
406
|
RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
334
407
|
RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
|
335
|
-
NormalizerWrapper::BertNormalizer(_) =>
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
NormalizerWrapper::
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
NormalizerWrapper::
|
346
|
-
let class: RClass = crate::normalizers().const_get("NFD").unwrap();
|
347
|
-
class.undef_alloc_func();
|
348
|
-
class
|
349
|
-
}),
|
350
|
-
NormalizerWrapper::NFC(_) => *memoize!(RClass: {
|
351
|
-
let class: RClass = crate::normalizers().const_get("NFC").unwrap();
|
352
|
-
class.undef_alloc_func();
|
353
|
-
class
|
354
|
-
}),
|
355
|
-
NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
|
356
|
-
let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
|
357
|
-
class.undef_alloc_func();
|
358
|
-
class
|
359
|
-
}),
|
360
|
-
NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
|
361
|
-
let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
|
362
|
-
class.undef_alloc_func();
|
363
|
-
class
|
364
|
-
}),
|
365
|
-
NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
|
366
|
-
let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
|
367
|
-
class.undef_alloc_func();
|
368
|
-
class
|
369
|
-
}),
|
370
|
-
NormalizerWrapper::Replace(_) => *memoize!(RClass: {
|
371
|
-
let class: RClass = crate::normalizers().const_get("Replace").unwrap();
|
372
|
-
class.undef_alloc_func();
|
373
|
-
class
|
374
|
-
}),
|
375
|
-
NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
|
376
|
-
let class: RClass = crate::normalizers().const_get("Strip").unwrap();
|
377
|
-
class.undef_alloc_func();
|
378
|
-
class
|
379
|
-
}),
|
380
|
-
NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
|
381
|
-
let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
|
382
|
-
class.undef_alloc_func();
|
383
|
-
class
|
384
|
-
}),
|
408
|
+
NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
|
409
|
+
NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
|
410
|
+
NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
|
411
|
+
NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
|
412
|
+
NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
|
413
|
+
NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
|
414
|
+
NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
|
415
|
+
NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
|
416
|
+
NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
|
417
|
+
NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
|
418
|
+
NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
|
385
419
|
_ => todo!(),
|
386
420
|
},
|
387
421
|
},
|
@@ -389,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
|
|
389
423
|
}
|
390
424
|
}
|
391
425
|
|
392
|
-
pub fn
|
393
|
-
let normalizer = module.define_class("Normalizer",
|
426
|
+
pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
427
|
+
let normalizer = module.define_class("Normalizer", ruby.class_object())?;
|
394
428
|
normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
|
395
429
|
|
396
430
|
let class = module.define_class("Sequence", normalizer)?;
|
@@ -428,6 +462,11 @@ pub fn normalizers(module: &RModule) -> RbResult<()> {
|
|
428
462
|
let class = module.define_class("Replace", normalizer)?;
|
429
463
|
class.define_singleton_method("new", function!(RbReplace::new, 2))?;
|
430
464
|
|
465
|
+
let class = module.define_class("Prepend", normalizer)?;
|
466
|
+
class.define_singleton_method("_new", function!(RbPrepend::new, 1))?;
|
467
|
+
class.define_method("prepend", method!(RbNormalizer::prepend_prepend, 0))?;
|
468
|
+
class.define_method("prepend=", method!(RbNormalizer::prepend_set_prepend, 1))?;
|
469
|
+
|
431
470
|
let class = module.define_class("Strip", normalizer)?;
|
432
471
|
class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
|
433
472
|
class.define_method("left", method!(RbNormalizer::strip_left, 0))?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
RArray, RClass, RModule, TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
+
RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
|
9
8
|
use serde::ser::SerializeStruct;
|
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
|
|
23
22
|
use tk::{PreTokenizedString, PreTokenizer};
|
24
23
|
|
25
24
|
use super::utils::*;
|
26
|
-
use super::{RbError, RbResult};
|
25
|
+
use super::{PRE_TOKENIZERS, RbError, RbResult};
|
27
26
|
|
28
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
29
28
|
pub struct RbPreTokenizer {
|
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
|
|
215
214
|
|
216
215
|
impl RbWhitespace {
|
217
216
|
pub fn new() -> RbPreTokenizer {
|
218
|
-
Whitespace
|
217
|
+
Whitespace.into()
|
219
218
|
}
|
220
219
|
}
|
221
220
|
|
@@ -241,7 +240,7 @@ impl RbSequence {
|
|
241
240
|
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
242
241
|
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
243
242
|
for n in pre_tokenizers.each() {
|
244
|
-
let pretokenizer: &RbPreTokenizer =
|
243
|
+
let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
|
245
244
|
match &pretokenizer.pretok {
|
246
245
|
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
247
246
|
sequence.extend(inner.iter().cloned())
|
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
|
|
346
345
|
}
|
347
346
|
|
348
347
|
unsafe impl TypedData for RbPreTokenizer {
|
349
|
-
fn class() -> RClass {
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
})
|
348
|
+
fn class(ruby: &Ruby) -> RClass {
|
349
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
ruby.get_inner(&CLASS)
|
355
355
|
}
|
356
356
|
|
357
357
|
fn data_type() -> &'static DataType {
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
358
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
|
359
|
+
&DATA_TYPE
|
360
|
+
}
|
361
|
+
|
362
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
363
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
364
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
|
365
|
+
class.undef_default_alloc_func();
|
366
|
+
class
|
367
|
+
});
|
368
|
+
static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
369
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
|
370
|
+
class.undef_default_alloc_func();
|
371
|
+
class
|
372
|
+
});
|
373
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
374
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
|
375
|
+
class.undef_default_alloc_func();
|
376
|
+
class
|
377
|
+
});
|
378
|
+
static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
379
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
|
380
|
+
class.undef_default_alloc_func();
|
381
|
+
class
|
382
|
+
});
|
383
|
+
static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
|
384
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
|
385
|
+
class.undef_default_alloc_func();
|
386
|
+
class
|
387
|
+
});
|
388
|
+
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
389
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
|
390
|
+
class.undef_default_alloc_func();
|
391
|
+
class
|
392
|
+
});
|
393
|
+
static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
|
394
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
|
395
|
+
class.undef_default_alloc_func();
|
396
|
+
class
|
397
|
+
});
|
398
|
+
static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
399
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
|
400
|
+
class.undef_default_alloc_func();
|
401
|
+
class
|
402
|
+
});
|
403
|
+
static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
|
404
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
|
405
|
+
class.undef_default_alloc_func();
|
406
|
+
class
|
407
|
+
});
|
408
|
+
static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
409
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
|
410
|
+
class.undef_default_alloc_func();
|
411
|
+
class
|
412
|
+
});
|
413
|
+
static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
414
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
|
415
|
+
class.undef_default_alloc_func();
|
416
|
+
class
|
417
|
+
});
|
362
418
|
match &value.pretok {
|
363
|
-
RbPreTokenizerTypeWrapper::Sequence(_seq) =>
|
364
|
-
let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
|
365
|
-
class.undef_alloc_func();
|
366
|
-
class
|
367
|
-
}),
|
419
|
+
RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
368
420
|
RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
369
421
|
RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
|
370
|
-
PreTokenizerWrapper::BertPreTokenizer(_) =>
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
PreTokenizerWrapper::
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
|
381
|
-
let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
|
382
|
-
class.undef_alloc_func();
|
383
|
-
class
|
384
|
-
}),
|
385
|
-
PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
|
386
|
-
let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
|
387
|
-
class.undef_alloc_func();
|
388
|
-
class
|
389
|
-
}),
|
390
|
-
PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
|
391
|
-
let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
|
392
|
-
class.undef_alloc_func();
|
393
|
-
class
|
394
|
-
}),
|
395
|
-
PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
|
396
|
-
let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
|
397
|
-
class.undef_alloc_func();
|
398
|
-
class
|
399
|
-
}),
|
400
|
-
PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
|
401
|
-
let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
|
402
|
-
class.undef_alloc_func();
|
403
|
-
class
|
404
|
-
}),
|
405
|
-
PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
|
406
|
-
let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
|
407
|
-
class.undef_alloc_func();
|
408
|
-
class
|
409
|
-
}),
|
410
|
-
PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
|
411
|
-
let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
|
412
|
-
class.undef_alloc_func();
|
413
|
-
class
|
414
|
-
}),
|
415
|
-
PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
|
416
|
-
let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
|
417
|
-
class.undef_alloc_func();
|
418
|
-
class
|
419
|
-
}),
|
422
|
+
PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
|
423
|
+
PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
424
|
+
PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
|
425
|
+
PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
|
426
|
+
PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
|
427
|
+
PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
|
428
|
+
PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
|
429
|
+
PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
|
430
|
+
PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
|
431
|
+
PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
|
420
432
|
_ => todo!(),
|
421
433
|
},
|
422
434
|
},
|
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
424
436
|
}
|
425
437
|
}
|
426
438
|
|
427
|
-
pub fn
|
428
|
-
let pre_tokenizer = module.define_class("PreTokenizer",
|
439
|
+
pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
440
|
+
let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
|
429
441
|
pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
|
430
442
|
|
431
443
|
let class = module.define_class("Sequence", pre_tokenizer)?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::Arc;
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
function,
|
6
|
-
TryConvert, TypedData, Value,
|
4
|
+
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData, Value,
|
7
6
|
};
|
8
7
|
use serde::{Deserialize, Serialize};
|
9
8
|
use tk::processors::bert::BertProcessing;
|
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
|
|
13
12
|
use tk::processors::PostProcessorWrapper;
|
14
13
|
use tk::{Encoding, PostProcessor};
|
15
14
|
|
16
|
-
use super::RbResult;
|
15
|
+
use super::{PROCESSORS, RbResult};
|
17
16
|
|
18
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
19
18
|
pub struct RbPostProcessor {
|
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
|
|
53
52
|
|
54
53
|
impl TryConvert for RbSpecialToken {
|
55
54
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
56
|
-
if let Ok(v) =
|
55
|
+
if let Ok(v) = <(String, u32)>::try_convert(ob) {
|
57
56
|
Ok(Self(v.into()))
|
58
|
-
} else if let Ok(v) =
|
57
|
+
} else if let Ok(v) = <(u32, String)>::try_convert(ob) {
|
59
58
|
Ok(Self(v.into()))
|
60
59
|
} else {
|
61
60
|
todo!()
|
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
|
|
74
73
|
|
75
74
|
impl TryConvert for RbTemplate {
|
76
75
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
77
|
-
if let Ok(s) =
|
76
|
+
if let Ok(s) = String::try_convert(ob) {
|
78
77
|
Ok(Self(
|
79
78
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
80
79
|
))
|
81
|
-
} else if let Ok(s) =
|
80
|
+
} else if let Ok(s) = <Vec<String>>::try_convert(ob) {
|
82
81
|
Ok(Self(
|
83
82
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
84
83
|
))
|
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
|
|
152
151
|
}
|
153
152
|
|
154
153
|
unsafe impl TypedData for RbPostProcessor {
|
155
|
-
fn class() -> RClass {
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
})
|
154
|
+
fn class(ruby: &Ruby) -> RClass {
|
155
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
156
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
|
157
|
+
class.undef_default_alloc_func();
|
158
|
+
class
|
159
|
+
});
|
160
|
+
ruby.get_inner(&CLASS)
|
161
161
|
}
|
162
162
|
|
163
163
|
fn data_type() -> &'static DataType {
|
164
|
-
|
164
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
|
165
|
+
&DATA_TYPE
|
165
166
|
}
|
166
167
|
|
167
|
-
fn class_for(value: &Self) -> RClass {
|
168
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
169
|
+
static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
170
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
|
171
|
+
class.undef_default_alloc_func();
|
172
|
+
class
|
173
|
+
});
|
174
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
175
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
|
176
|
+
class.undef_default_alloc_func();
|
177
|
+
class
|
178
|
+
});
|
179
|
+
static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
180
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
|
181
|
+
class.undef_default_alloc_func();
|
182
|
+
class
|
183
|
+
});
|
184
|
+
static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
185
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
|
186
|
+
class.undef_default_alloc_func();
|
187
|
+
class
|
188
|
+
});
|
168
189
|
match *value.processor {
|
169
|
-
PostProcessorWrapper::Bert(_) =>
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
}),
|
174
|
-
PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
|
175
|
-
let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
|
176
|
-
class.undef_alloc_func();
|
177
|
-
class
|
178
|
-
}),
|
179
|
-
PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
|
180
|
-
let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
|
181
|
-
class.undef_alloc_func();
|
182
|
-
class
|
183
|
-
}),
|
184
|
-
PostProcessorWrapper::Template(_) => *memoize!(RClass: {
|
185
|
-
let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
|
186
|
-
class.undef_alloc_func();
|
187
|
-
class
|
188
|
-
}),
|
190
|
+
PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
|
191
|
+
PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
192
|
+
PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
|
193
|
+
PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
|
189
194
|
_ => todo!(),
|
190
195
|
}
|
191
196
|
}
|
192
197
|
}
|
193
198
|
|
194
|
-
pub fn
|
195
|
-
let post_processor = module.define_class("PostProcessor",
|
199
|
+
pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
200
|
+
let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
|
196
201
|
|
197
202
|
let class = module.define_class("BertProcessing", post_processor)?;
|
198
203
|
class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
|