tokenizers 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/Cargo.lock +160 -96
- data/ext/tokenizers/Cargo.toml +6 -6
- data/ext/tokenizers/src/decoders.rs +149 -39
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +71 -50
- data/ext/tokenizers/src/normalizers.rs +113 -74
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/decoders/strip.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/normalizers/prepend.rb +9 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +4 -2
- metadata +6 -4
@@ -1,20 +1,19 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
use serde::ser::SerializeStruct;
|
9
8
|
use serde::{Deserialize, Serialize, Serializer};
|
10
9
|
use tk::normalizers::{
|
11
|
-
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Strip, StripAccents,
|
10
|
+
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Replace, Prepend, Strip, StripAccents,
|
12
11
|
NFC, NFD, NFKC, NFKD,
|
13
12
|
};
|
14
13
|
use tk::{NormalizedString, Normalizer};
|
15
14
|
|
16
15
|
use super::utils::*;
|
17
|
-
use super::{RbError, RbResult};
|
16
|
+
use super::{NORMALIZERS, RbError, RbResult};
|
18
17
|
|
19
18
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
20
19
|
pub struct RbNormalizer {
|
@@ -44,7 +43,7 @@ macro_rules! getter {
|
|
44
43
|
($self: ident, $variant: ident, $name: ident) => {{
|
45
44
|
if let RbNormalizerTypeWrapper::Single(ref norm) = &$self.normalizer {
|
46
45
|
let wrapper = norm.read().unwrap();
|
47
|
-
if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper {
|
46
|
+
if let RbNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone() {
|
48
47
|
o.$name
|
49
48
|
} else {
|
50
49
|
unreachable!()
|
@@ -105,6 +104,14 @@ impl RbNormalizer {
|
|
105
104
|
setter!(self, BertNormalizer, lowercase, lowercase)
|
106
105
|
}
|
107
106
|
|
107
|
+
fn prepend_prepend(&self) -> String {
|
108
|
+
getter!(self, Prepend, prepend)
|
109
|
+
}
|
110
|
+
|
111
|
+
fn prepend_set_prepend(&self, prepend: String) {
|
112
|
+
setter!(self, Prepend, prepend, prepend)
|
113
|
+
}
|
114
|
+
|
108
115
|
fn strip_left(&self) -> bool {
|
109
116
|
getter!(self, StripNormalizer, strip_left)
|
110
117
|
}
|
@@ -186,6 +193,14 @@ impl RbReplace {
|
|
186
193
|
}
|
187
194
|
}
|
188
195
|
|
196
|
+
pub struct RbPrepend {}
|
197
|
+
|
198
|
+
impl RbPrepend {
|
199
|
+
pub fn new(prepend: String) -> RbNormalizer {
|
200
|
+
Prepend::new(prepend).into()
|
201
|
+
}
|
202
|
+
}
|
203
|
+
|
189
204
|
pub struct RbStrip {}
|
190
205
|
|
191
206
|
impl RbStrip {
|
@@ -208,7 +223,7 @@ impl RbSequence {
|
|
208
223
|
fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
|
209
224
|
let mut sequence = Vec::with_capacity(normalizers.len());
|
210
225
|
for n in normalizers.each() {
|
211
|
-
let normalizer: &RbNormalizer =
|
226
|
+
let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
|
212
227
|
match &normalizer.normalizer {
|
213
228
|
RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
|
214
229
|
RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
@@ -311,77 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
|
|
311
326
|
}
|
312
327
|
|
313
328
|
unsafe impl TypedData for RbNormalizer {
|
314
|
-
fn class() -> RClass {
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
})
|
329
|
+
fn class(ruby: &Ruby) -> RClass {
|
330
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
331
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
|
332
|
+
class.undef_default_alloc_func();
|
333
|
+
class
|
334
|
+
});
|
335
|
+
ruby.get_inner(&CLASS)
|
320
336
|
}
|
321
337
|
|
322
338
|
fn data_type() -> &'static DataType {
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
339
|
+
static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
|
340
|
+
&DATA_TYPE
|
341
|
+
}
|
342
|
+
|
343
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
344
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
345
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
|
346
|
+
class.undef_default_alloc_func();
|
347
|
+
class
|
348
|
+
});
|
349
|
+
static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
|
355
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
|
356
|
+
class.undef_default_alloc_func();
|
357
|
+
class
|
358
|
+
});
|
359
|
+
static NFD: Lazy<RClass> = Lazy::new(|ruby| {
|
360
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
|
361
|
+
class.undef_default_alloc_func();
|
362
|
+
class
|
363
|
+
});
|
364
|
+
static NFC: Lazy<RClass> = Lazy::new(|ruby| {
|
365
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
|
366
|
+
class.undef_default_alloc_func();
|
367
|
+
class
|
368
|
+
});
|
369
|
+
static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
|
370
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
|
371
|
+
class.undef_default_alloc_func();
|
372
|
+
class
|
373
|
+
});
|
374
|
+
static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
|
375
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
|
376
|
+
class.undef_default_alloc_func();
|
377
|
+
class
|
378
|
+
});
|
379
|
+
static NMT: Lazy<RClass> = Lazy::new(|ruby| {
|
380
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
|
381
|
+
class.undef_default_alloc_func();
|
382
|
+
class
|
383
|
+
});
|
384
|
+
static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
|
385
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
|
386
|
+
class.undef_default_alloc_func();
|
387
|
+
class
|
388
|
+
});
|
389
|
+
static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
|
390
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
|
391
|
+
class.undef_default_alloc_func();
|
392
|
+
class
|
393
|
+
});
|
394
|
+
static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
|
395
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
|
396
|
+
class.undef_default_alloc_func();
|
397
|
+
class
|
398
|
+
});
|
399
|
+
static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
|
400
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
|
401
|
+
class.undef_default_alloc_func();
|
402
|
+
class
|
403
|
+
});
|
327
404
|
match &value.normalizer {
|
328
|
-
RbNormalizerTypeWrapper::Sequence(_seq) =>
|
329
|
-
let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
|
330
|
-
class.undef_alloc_func();
|
331
|
-
class
|
332
|
-
}),
|
405
|
+
RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
333
406
|
RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
334
407
|
RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
|
335
|
-
NormalizerWrapper::BertNormalizer(_) =>
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
NormalizerWrapper::
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
NormalizerWrapper::
|
346
|
-
let class: RClass = crate::normalizers().const_get("NFD").unwrap();
|
347
|
-
class.undef_alloc_func();
|
348
|
-
class
|
349
|
-
}),
|
350
|
-
NormalizerWrapper::NFC(_) => *memoize!(RClass: {
|
351
|
-
let class: RClass = crate::normalizers().const_get("NFC").unwrap();
|
352
|
-
class.undef_alloc_func();
|
353
|
-
class
|
354
|
-
}),
|
355
|
-
NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
|
356
|
-
let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
|
357
|
-
class.undef_alloc_func();
|
358
|
-
class
|
359
|
-
}),
|
360
|
-
NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
|
361
|
-
let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
|
362
|
-
class.undef_alloc_func();
|
363
|
-
class
|
364
|
-
}),
|
365
|
-
NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
|
366
|
-
let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
|
367
|
-
class.undef_alloc_func();
|
368
|
-
class
|
369
|
-
}),
|
370
|
-
NormalizerWrapper::Replace(_) => *memoize!(RClass: {
|
371
|
-
let class: RClass = crate::normalizers().const_get("Replace").unwrap();
|
372
|
-
class.undef_alloc_func();
|
373
|
-
class
|
374
|
-
}),
|
375
|
-
NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
|
376
|
-
let class: RClass = crate::normalizers().const_get("Strip").unwrap();
|
377
|
-
class.undef_alloc_func();
|
378
|
-
class
|
379
|
-
}),
|
380
|
-
NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
|
381
|
-
let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
|
382
|
-
class.undef_alloc_func();
|
383
|
-
class
|
384
|
-
}),
|
408
|
+
NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
|
409
|
+
NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
|
410
|
+
NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
|
411
|
+
NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
|
412
|
+
NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
|
413
|
+
NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
|
414
|
+
NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
|
415
|
+
NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
|
416
|
+
NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
|
417
|
+
NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
|
418
|
+
NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
|
385
419
|
_ => todo!(),
|
386
420
|
},
|
387
421
|
},
|
@@ -389,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
|
|
389
423
|
}
|
390
424
|
}
|
391
425
|
|
392
|
-
pub fn
|
393
|
-
let normalizer = module.define_class("Normalizer",
|
426
|
+
pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
427
|
+
let normalizer = module.define_class("Normalizer", ruby.class_object())?;
|
394
428
|
normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
|
395
429
|
|
396
430
|
let class = module.define_class("Sequence", normalizer)?;
|
@@ -428,6 +462,11 @@ pub fn normalizers(module: &RModule) -> RbResult<()> {
|
|
428
462
|
let class = module.define_class("Replace", normalizer)?;
|
429
463
|
class.define_singleton_method("new", function!(RbReplace::new, 2))?;
|
430
464
|
|
465
|
+
let class = module.define_class("Prepend", normalizer)?;
|
466
|
+
class.define_singleton_method("_new", function!(RbPrepend::new, 1))?;
|
467
|
+
class.define_method("prepend", method!(RbNormalizer::prepend_prepend, 0))?;
|
468
|
+
class.define_method("prepend=", method!(RbNormalizer::prepend_set_prepend, 1))?;
|
469
|
+
|
431
470
|
let class = module.define_class("Strip", normalizer)?;
|
432
471
|
class.define_singleton_method("_new", function!(RbStrip::new, 2))?;
|
433
472
|
class.define_method("left", method!(RbNormalizer::strip_left, 0))?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
RArray, RClass, RModule, TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
+
RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
|
9
8
|
use serde::ser::SerializeStruct;
|
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
|
|
23
22
|
use tk::{PreTokenizedString, PreTokenizer};
|
24
23
|
|
25
24
|
use super::utils::*;
|
26
|
-
use super::{RbError, RbResult};
|
25
|
+
use super::{PRE_TOKENIZERS, RbError, RbResult};
|
27
26
|
|
28
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
29
28
|
pub struct RbPreTokenizer {
|
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
|
|
215
214
|
|
216
215
|
impl RbWhitespace {
|
217
216
|
pub fn new() -> RbPreTokenizer {
|
218
|
-
Whitespace
|
217
|
+
Whitespace.into()
|
219
218
|
}
|
220
219
|
}
|
221
220
|
|
@@ -241,7 +240,7 @@ impl RbSequence {
|
|
241
240
|
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
242
241
|
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
243
242
|
for n in pre_tokenizers.each() {
|
244
|
-
let pretokenizer: &RbPreTokenizer =
|
243
|
+
let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
|
245
244
|
match &pretokenizer.pretok {
|
246
245
|
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
247
246
|
sequence.extend(inner.iter().cloned())
|
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
|
|
346
345
|
}
|
347
346
|
|
348
347
|
unsafe impl TypedData for RbPreTokenizer {
|
349
|
-
fn class() -> RClass {
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
})
|
348
|
+
fn class(ruby: &Ruby) -> RClass {
|
349
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
ruby.get_inner(&CLASS)
|
355
355
|
}
|
356
356
|
|
357
357
|
fn data_type() -> &'static DataType {
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
358
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
|
359
|
+
&DATA_TYPE
|
360
|
+
}
|
361
|
+
|
362
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
363
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
364
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
|
365
|
+
class.undef_default_alloc_func();
|
366
|
+
class
|
367
|
+
});
|
368
|
+
static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
369
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
|
370
|
+
class.undef_default_alloc_func();
|
371
|
+
class
|
372
|
+
});
|
373
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
374
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
|
375
|
+
class.undef_default_alloc_func();
|
376
|
+
class
|
377
|
+
});
|
378
|
+
static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
379
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
|
380
|
+
class.undef_default_alloc_func();
|
381
|
+
class
|
382
|
+
});
|
383
|
+
static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
|
384
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
|
385
|
+
class.undef_default_alloc_func();
|
386
|
+
class
|
387
|
+
});
|
388
|
+
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
389
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
|
390
|
+
class.undef_default_alloc_func();
|
391
|
+
class
|
392
|
+
});
|
393
|
+
static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
|
394
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
|
395
|
+
class.undef_default_alloc_func();
|
396
|
+
class
|
397
|
+
});
|
398
|
+
static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
399
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
|
400
|
+
class.undef_default_alloc_func();
|
401
|
+
class
|
402
|
+
});
|
403
|
+
static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
|
404
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
|
405
|
+
class.undef_default_alloc_func();
|
406
|
+
class
|
407
|
+
});
|
408
|
+
static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
409
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
|
410
|
+
class.undef_default_alloc_func();
|
411
|
+
class
|
412
|
+
});
|
413
|
+
static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
414
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
|
415
|
+
class.undef_default_alloc_func();
|
416
|
+
class
|
417
|
+
});
|
362
418
|
match &value.pretok {
|
363
|
-
RbPreTokenizerTypeWrapper::Sequence(_seq) =>
|
364
|
-
let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
|
365
|
-
class.undef_alloc_func();
|
366
|
-
class
|
367
|
-
}),
|
419
|
+
RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
368
420
|
RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
369
421
|
RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
|
370
|
-
PreTokenizerWrapper::BertPreTokenizer(_) =>
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
PreTokenizerWrapper::
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
|
381
|
-
let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
|
382
|
-
class.undef_alloc_func();
|
383
|
-
class
|
384
|
-
}),
|
385
|
-
PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
|
386
|
-
let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
|
387
|
-
class.undef_alloc_func();
|
388
|
-
class
|
389
|
-
}),
|
390
|
-
PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
|
391
|
-
let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
|
392
|
-
class.undef_alloc_func();
|
393
|
-
class
|
394
|
-
}),
|
395
|
-
PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
|
396
|
-
let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
|
397
|
-
class.undef_alloc_func();
|
398
|
-
class
|
399
|
-
}),
|
400
|
-
PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
|
401
|
-
let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
|
402
|
-
class.undef_alloc_func();
|
403
|
-
class
|
404
|
-
}),
|
405
|
-
PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
|
406
|
-
let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
|
407
|
-
class.undef_alloc_func();
|
408
|
-
class
|
409
|
-
}),
|
410
|
-
PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
|
411
|
-
let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
|
412
|
-
class.undef_alloc_func();
|
413
|
-
class
|
414
|
-
}),
|
415
|
-
PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
|
416
|
-
let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
|
417
|
-
class.undef_alloc_func();
|
418
|
-
class
|
419
|
-
}),
|
422
|
+
PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
|
423
|
+
PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
424
|
+
PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
|
425
|
+
PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
|
426
|
+
PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
|
427
|
+
PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
|
428
|
+
PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
|
429
|
+
PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
|
430
|
+
PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
|
431
|
+
PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
|
420
432
|
_ => todo!(),
|
421
433
|
},
|
422
434
|
},
|
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
424
436
|
}
|
425
437
|
}
|
426
438
|
|
427
|
-
pub fn
|
428
|
-
let pre_tokenizer = module.define_class("PreTokenizer",
|
439
|
+
pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
440
|
+
let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
|
429
441
|
pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
|
430
442
|
|
431
443
|
let class = module.define_class("Sequence", pre_tokenizer)?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::Arc;
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
function,
|
6
|
-
TryConvert, TypedData, Value,
|
4
|
+
data_type_builder, function, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData, Value,
|
7
6
|
};
|
8
7
|
use serde::{Deserialize, Serialize};
|
9
8
|
use tk::processors::bert::BertProcessing;
|
@@ -13,7 +12,7 @@ use tk::processors::template::{SpecialToken, Template};
|
|
13
12
|
use tk::processors::PostProcessorWrapper;
|
14
13
|
use tk::{Encoding, PostProcessor};
|
15
14
|
|
16
|
-
use super::RbResult;
|
15
|
+
use super::{PROCESSORS, RbResult};
|
17
16
|
|
18
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
19
18
|
pub struct RbPostProcessor {
|
@@ -53,9 +52,9 @@ impl From<RbSpecialToken> for SpecialToken {
|
|
53
52
|
|
54
53
|
impl TryConvert for RbSpecialToken {
|
55
54
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
56
|
-
if let Ok(v) =
|
55
|
+
if let Ok(v) = <(String, u32)>::try_convert(ob) {
|
57
56
|
Ok(Self(v.into()))
|
58
|
-
} else if let Ok(v) =
|
57
|
+
} else if let Ok(v) = <(u32, String)>::try_convert(ob) {
|
59
58
|
Ok(Self(v.into()))
|
60
59
|
} else {
|
61
60
|
todo!()
|
@@ -74,11 +73,11 @@ impl From<RbTemplate> for Template {
|
|
74
73
|
|
75
74
|
impl TryConvert for RbTemplate {
|
76
75
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
77
|
-
if let Ok(s) =
|
76
|
+
if let Ok(s) = String::try_convert(ob) {
|
78
77
|
Ok(Self(
|
79
78
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
80
79
|
))
|
81
|
-
} else if let Ok(s) =
|
80
|
+
} else if let Ok(s) = <Vec<String>>::try_convert(ob) {
|
82
81
|
Ok(Self(
|
83
82
|
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
84
83
|
))
|
@@ -152,47 +151,53 @@ impl RbTemplateProcessing {
|
|
152
151
|
}
|
153
152
|
|
154
153
|
unsafe impl TypedData for RbPostProcessor {
|
155
|
-
fn class() -> RClass {
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
})
|
154
|
+
fn class(ruby: &Ruby) -> RClass {
|
155
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
156
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("PostProcessor").unwrap();
|
157
|
+
class.undef_default_alloc_func();
|
158
|
+
class
|
159
|
+
});
|
160
|
+
ruby.get_inner(&CLASS)
|
161
161
|
}
|
162
162
|
|
163
163
|
fn data_type() -> &'static DataType {
|
164
|
-
|
164
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPostProcessor, "Tokenizers::Processors::PostProcessor").build();
|
165
|
+
&DATA_TYPE
|
165
166
|
}
|
166
167
|
|
167
|
-
fn class_for(value: &Self) -> RClass {
|
168
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
169
|
+
static BERT_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
170
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("BertProcessing").unwrap();
|
171
|
+
class.undef_default_alloc_func();
|
172
|
+
class
|
173
|
+
});
|
174
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
175
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("ByteLevel").unwrap();
|
176
|
+
class.undef_default_alloc_func();
|
177
|
+
class
|
178
|
+
});
|
179
|
+
static ROBERTA_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
180
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("RobertaProcessing").unwrap();
|
181
|
+
class.undef_default_alloc_func();
|
182
|
+
class
|
183
|
+
});
|
184
|
+
static TEMPLATE_PROCESSING: Lazy<RClass> = Lazy::new(|ruby| {
|
185
|
+
let class: RClass = ruby.get_inner(&PROCESSORS).const_get("TemplateProcessing").unwrap();
|
186
|
+
class.undef_default_alloc_func();
|
187
|
+
class
|
188
|
+
});
|
168
189
|
match *value.processor {
|
169
|
-
PostProcessorWrapper::Bert(_) =>
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
}),
|
174
|
-
PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
|
175
|
-
let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
|
176
|
-
class.undef_alloc_func();
|
177
|
-
class
|
178
|
-
}),
|
179
|
-
PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
|
180
|
-
let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
|
181
|
-
class.undef_alloc_func();
|
182
|
-
class
|
183
|
-
}),
|
184
|
-
PostProcessorWrapper::Template(_) => *memoize!(RClass: {
|
185
|
-
let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
|
186
|
-
class.undef_alloc_func();
|
187
|
-
class
|
188
|
-
}),
|
190
|
+
PostProcessorWrapper::Bert(_) => ruby.get_inner(&BERT_PROCESSING),
|
191
|
+
PostProcessorWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
192
|
+
PostProcessorWrapper::Roberta(_) => ruby.get_inner(&ROBERTA_PROCESSING),
|
193
|
+
PostProcessorWrapper::Template(_) => ruby.get_inner(&TEMPLATE_PROCESSING),
|
189
194
|
_ => todo!(),
|
190
195
|
}
|
191
196
|
}
|
192
197
|
}
|
193
198
|
|
194
|
-
pub fn
|
195
|
-
let post_processor = module.define_class("PostProcessor",
|
199
|
+
pub fn init_processors(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
200
|
+
let post_processor = module.define_class("PostProcessor", ruby.class_object())?;
|
196
201
|
|
197
202
|
let class = module.define_class("BertProcessing", post_processor)?;
|
198
203
|
class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
|