tokenizers 0.3.3 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +52 -23
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/decoders.rs +72 -61
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +57 -51
- data/ext/tokenizers/src/normalizers.rs +90 -77
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +2 -2
- metadata +3 -3
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
|
|
3
3
|
use std::sync::{Arc, RwLock};
|
4
4
|
|
5
5
|
use crate::trainers::RbTrainer;
|
6
|
-
use magnus::
|
6
|
+
use magnus::prelude::*;
|
7
7
|
use magnus::{
|
8
|
-
exception, function,
|
9
|
-
RClass, RHash, RModule, Symbol, TypedData, Value,
|
8
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
|
9
|
+
RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
|
10
10
|
};
|
11
11
|
use serde::{Deserialize, Serialize};
|
12
12
|
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
|
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
|
|
16
16
|
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
|
17
17
|
use tk::{Model, Token};
|
18
18
|
|
19
|
-
use super::{RbError, RbResult};
|
19
|
+
use super::{MODELS, RbError, RbResult};
|
20
20
|
|
21
21
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
22
22
|
pub struct RbModel {
|
@@ -73,37 +73,37 @@ impl RbBPE {
|
|
73
73
|
fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
74
74
|
let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
|
75
75
|
if !value.is_nil() {
|
76
|
-
builder = builder.cache_capacity(
|
76
|
+
builder = builder.cache_capacity(TryConvert::try_convert(value)?);
|
77
77
|
}
|
78
78
|
|
79
79
|
let value: Value = kwargs.delete(Symbol::new("dropout"))?;
|
80
80
|
if !value.is_nil() {
|
81
|
-
builder = builder.dropout(
|
81
|
+
builder = builder.dropout(TryConvert::try_convert(value)?);
|
82
82
|
}
|
83
83
|
|
84
84
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
85
85
|
if !value.is_nil() {
|
86
|
-
builder = builder.unk_token(
|
86
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
87
87
|
}
|
88
88
|
|
89
89
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
90
90
|
if !value.is_nil() {
|
91
|
-
builder = builder.continuing_subword_prefix(
|
91
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
92
92
|
}
|
93
93
|
|
94
94
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
95
95
|
if !value.is_nil() {
|
96
|
-
builder = builder.end_of_word_suffix(
|
96
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
97
97
|
}
|
98
98
|
|
99
99
|
let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
|
100
100
|
if !value.is_nil() {
|
101
|
-
builder = builder.fuse_unk(
|
101
|
+
builder = builder.fuse_unk(TryConvert::try_convert(value)?);
|
102
102
|
}
|
103
103
|
|
104
104
|
let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
|
105
105
|
if !value.is_nil() {
|
106
|
-
builder = builder.byte_fallback(
|
106
|
+
builder = builder.byte_fallback(TryConvert::try_convert(value)?);
|
107
107
|
}
|
108
108
|
|
109
109
|
if !kwargs.is_empty() {
|
@@ -234,13 +234,13 @@ impl RbModel {
|
|
234
234
|
pub struct RbUnigram {}
|
235
235
|
|
236
236
|
impl RbUnigram {
|
237
|
-
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
|
238
|
-
match (vocab, unk_id) {
|
239
|
-
(Some(vocab), unk_id) => {
|
240
|
-
let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
|
237
|
+
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
|
238
|
+
match (vocab, unk_id, byte_fallback) {
|
239
|
+
(Some(vocab), unk_id, byte_fallback) => {
|
240
|
+
let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
|
241
241
|
Ok(model.into())
|
242
242
|
}
|
243
|
-
(None, None) => Ok(Unigram::default().into()),
|
243
|
+
(None, None, _) => Ok(Unigram::default().into()),
|
244
244
|
_ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
|
245
245
|
}
|
246
246
|
}
|
@@ -277,17 +277,17 @@ impl RbWordPiece {
|
|
277
277
|
fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
278
278
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
279
279
|
if !value.is_nil() {
|
280
|
-
builder = builder.unk_token(
|
280
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
281
281
|
}
|
282
282
|
|
283
283
|
let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
|
284
284
|
if !value.is_nil() {
|
285
|
-
builder = builder.max_input_chars_per_word(
|
285
|
+
builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
|
286
286
|
}
|
287
287
|
|
288
288
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
289
289
|
if !value.is_nil() {
|
290
|
-
builder = builder.continuing_subword_prefix(
|
290
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
291
291
|
}
|
292
292
|
|
293
293
|
if !kwargs.is_empty() {
|
@@ -314,46 +314,52 @@ impl RbWordPiece {
|
|
314
314
|
}
|
315
315
|
|
316
316
|
unsafe impl TypedData for RbModel {
|
317
|
-
fn class() -> RClass {
|
318
|
-
|
319
|
-
let class: RClass =
|
320
|
-
class.
|
317
|
+
fn class(ruby: &Ruby) -> RClass {
|
318
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
319
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
|
320
|
+
class.undef_default_alloc_func();
|
321
321
|
class
|
322
|
-
})
|
322
|
+
});
|
323
|
+
ruby.get_inner(&CLASS)
|
323
324
|
}
|
324
325
|
|
325
326
|
fn data_type() -> &'static DataType {
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
327
|
+
static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
|
328
|
+
&DATA_TYPE
|
329
|
+
}
|
330
|
+
|
331
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
332
|
+
static BPE: Lazy<RClass> = Lazy::new(|ruby| {
|
333
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
|
334
|
+
class.undef_default_alloc_func();
|
335
|
+
class
|
336
|
+
});
|
337
|
+
static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
|
338
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
|
339
|
+
class.undef_default_alloc_func();
|
340
|
+
class
|
341
|
+
});
|
342
|
+
static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
343
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
|
344
|
+
class.undef_default_alloc_func();
|
345
|
+
class
|
346
|
+
});
|
347
|
+
static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
|
348
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
|
349
|
+
class.undef_default_alloc_func();
|
350
|
+
class
|
351
|
+
});
|
330
352
|
match *value.model.read().unwrap() {
|
331
|
-
ModelWrapper::BPE(_) =>
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
}),
|
336
|
-
ModelWrapper::Unigram(_) => *memoize!(RClass: {
|
337
|
-
let class: RClass = crate::models().const_get("Unigram").unwrap();
|
338
|
-
class.undef_alloc_func();
|
339
|
-
class
|
340
|
-
}),
|
341
|
-
ModelWrapper::WordLevel(_) => *memoize!(RClass: {
|
342
|
-
let class: RClass = crate::models().const_get("WordLevel").unwrap();
|
343
|
-
class.undef_alloc_func();
|
344
|
-
class
|
345
|
-
}),
|
346
|
-
ModelWrapper::WordPiece(_) => *memoize!(RClass: {
|
347
|
-
let class: RClass = crate::models().const_get("WordPiece").unwrap();
|
348
|
-
class.undef_alloc_func();
|
349
|
-
class
|
350
|
-
}),
|
353
|
+
ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
|
354
|
+
ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
|
355
|
+
ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
|
356
|
+
ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
|
351
357
|
}
|
352
358
|
}
|
353
359
|
}
|
354
360
|
|
355
|
-
pub fn
|
356
|
-
let model = module.define_class("Model",
|
361
|
+
pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
362
|
+
let model = module.define_class("Model", ruby.class_object())?;
|
357
363
|
|
358
364
|
let class = module.define_class("BPE", model)?;
|
359
365
|
class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
|
@@ -372,7 +378,7 @@ pub fn models(module: &RModule) -> RbResult<()> {
|
|
372
378
|
class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
|
373
379
|
|
374
380
|
let class = module.define_class("Unigram", model)?;
|
375
|
-
class.define_singleton_method("_new", function!(RbUnigram::new,
|
381
|
+
class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
|
376
382
|
|
377
383
|
let class = module.define_class("WordLevel", model)?;
|
378
384
|
class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
use serde::ser::SerializeStruct;
|
9
8
|
use serde::{Deserialize, Serialize, Serializer};
|
@@ -14,7 +13,7 @@ use tk::normalizers::{
|
|
14
13
|
use tk::{NormalizedString, Normalizer};
|
15
14
|
|
16
15
|
use super::utils::*;
|
17
|
-
use super::{RbError, RbResult};
|
16
|
+
use super::{NORMALIZERS, RbError, RbResult};
|
18
17
|
|
19
18
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
20
19
|
pub struct RbNormalizer {
|
@@ -224,7 +223,7 @@ impl RbSequence {
|
|
224
223
|
fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
|
225
224
|
let mut sequence = Vec::with_capacity(normalizers.len());
|
226
225
|
for n in normalizers.each() {
|
227
|
-
let normalizer: &RbNormalizer =
|
226
|
+
let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
|
228
227
|
match &normalizer.normalizer {
|
229
228
|
RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
|
230
229
|
RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
@@ -327,82 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
|
|
327
326
|
}
|
328
327
|
|
329
328
|
unsafe impl TypedData for RbNormalizer {
|
330
|
-
fn class() -> RClass {
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
})
|
329
|
+
fn class(ruby: &Ruby) -> RClass {
|
330
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
331
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
|
332
|
+
class.undef_default_alloc_func();
|
333
|
+
class
|
334
|
+
});
|
335
|
+
ruby.get_inner(&CLASS)
|
336
336
|
}
|
337
337
|
|
338
338
|
fn data_type() -> &'static DataType {
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
339
|
+
static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
|
340
|
+
&DATA_TYPE
|
341
|
+
}
|
342
|
+
|
343
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
344
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
345
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
|
346
|
+
class.undef_default_alloc_func();
|
347
|
+
class
|
348
|
+
});
|
349
|
+
static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
|
355
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
|
356
|
+
class.undef_default_alloc_func();
|
357
|
+
class
|
358
|
+
});
|
359
|
+
static NFD: Lazy<RClass> = Lazy::new(|ruby| {
|
360
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
|
361
|
+
class.undef_default_alloc_func();
|
362
|
+
class
|
363
|
+
});
|
364
|
+
static NFC: Lazy<RClass> = Lazy::new(|ruby| {
|
365
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
|
366
|
+
class.undef_default_alloc_func();
|
367
|
+
class
|
368
|
+
});
|
369
|
+
static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
|
370
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
|
371
|
+
class.undef_default_alloc_func();
|
372
|
+
class
|
373
|
+
});
|
374
|
+
static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
|
375
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
|
376
|
+
class.undef_default_alloc_func();
|
377
|
+
class
|
378
|
+
});
|
379
|
+
static NMT: Lazy<RClass> = Lazy::new(|ruby| {
|
380
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
|
381
|
+
class.undef_default_alloc_func();
|
382
|
+
class
|
383
|
+
});
|
384
|
+
static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
|
385
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
|
386
|
+
class.undef_default_alloc_func();
|
387
|
+
class
|
388
|
+
});
|
389
|
+
static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
|
390
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
|
391
|
+
class.undef_default_alloc_func();
|
392
|
+
class
|
393
|
+
});
|
394
|
+
static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
|
395
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
|
396
|
+
class.undef_default_alloc_func();
|
397
|
+
class
|
398
|
+
});
|
399
|
+
static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
|
400
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
|
401
|
+
class.undef_default_alloc_func();
|
402
|
+
class
|
403
|
+
});
|
343
404
|
match &value.normalizer {
|
344
|
-
RbNormalizerTypeWrapper::Sequence(_seq) =>
|
345
|
-
let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
|
346
|
-
class.undef_alloc_func();
|
347
|
-
class
|
348
|
-
}),
|
405
|
+
RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
349
406
|
RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
350
407
|
RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
|
351
|
-
NormalizerWrapper::BertNormalizer(_) =>
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
NormalizerWrapper::
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
NormalizerWrapper::
|
362
|
-
let class: RClass = crate::normalizers().const_get("NFD").unwrap();
|
363
|
-
class.undef_alloc_func();
|
364
|
-
class
|
365
|
-
}),
|
366
|
-
NormalizerWrapper::NFC(_) => *memoize!(RClass: {
|
367
|
-
let class: RClass = crate::normalizers().const_get("NFC").unwrap();
|
368
|
-
class.undef_alloc_func();
|
369
|
-
class
|
370
|
-
}),
|
371
|
-
NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
|
372
|
-
let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
|
373
|
-
class.undef_alloc_func();
|
374
|
-
class
|
375
|
-
}),
|
376
|
-
NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
|
377
|
-
let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
|
378
|
-
class.undef_alloc_func();
|
379
|
-
class
|
380
|
-
}),
|
381
|
-
NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
|
382
|
-
let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
|
383
|
-
class.undef_alloc_func();
|
384
|
-
class
|
385
|
-
}),
|
386
|
-
NormalizerWrapper::Replace(_) => *memoize!(RClass: {
|
387
|
-
let class: RClass = crate::normalizers().const_get("Replace").unwrap();
|
388
|
-
class.undef_alloc_func();
|
389
|
-
class
|
390
|
-
}),
|
391
|
-
NormalizerWrapper::Prepend(_) => *memoize!(RClass: {
|
392
|
-
let class: RClass = crate::normalizers().const_get("Prepend").unwrap();
|
393
|
-
class.undef_alloc_func();
|
394
|
-
class
|
395
|
-
}),
|
396
|
-
NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
|
397
|
-
let class: RClass = crate::normalizers().const_get("Strip").unwrap();
|
398
|
-
class.undef_alloc_func();
|
399
|
-
class
|
400
|
-
}),
|
401
|
-
NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
|
402
|
-
let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
|
403
|
-
class.undef_alloc_func();
|
404
|
-
class
|
405
|
-
}),
|
408
|
+
NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
|
409
|
+
NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
|
410
|
+
NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
|
411
|
+
NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
|
412
|
+
NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
|
413
|
+
NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
|
414
|
+
NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
|
415
|
+
NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
|
416
|
+
NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
|
417
|
+
NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
|
418
|
+
NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
|
406
419
|
_ => todo!(),
|
407
420
|
},
|
408
421
|
},
|
@@ -410,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
|
|
410
423
|
}
|
411
424
|
}
|
412
425
|
|
413
|
-
pub fn
|
414
|
-
let normalizer = module.define_class("Normalizer",
|
426
|
+
pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
427
|
+
let normalizer = module.define_class("Normalizer", ruby.class_object())?;
|
415
428
|
normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
|
416
429
|
|
417
430
|
let class = module.define_class("Sequence", normalizer)?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
RArray, RClass, RModule, TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
+
RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
|
9
8
|
use serde::ser::SerializeStruct;
|
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
|
|
23
22
|
use tk::{PreTokenizedString, PreTokenizer};
|
24
23
|
|
25
24
|
use super::utils::*;
|
26
|
-
use super::{RbError, RbResult};
|
25
|
+
use super::{PRE_TOKENIZERS, RbError, RbResult};
|
27
26
|
|
28
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
29
28
|
pub struct RbPreTokenizer {
|
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
|
|
215
214
|
|
216
215
|
impl RbWhitespace {
|
217
216
|
pub fn new() -> RbPreTokenizer {
|
218
|
-
Whitespace
|
217
|
+
Whitespace.into()
|
219
218
|
}
|
220
219
|
}
|
221
220
|
|
@@ -241,7 +240,7 @@ impl RbSequence {
|
|
241
240
|
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
242
241
|
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
243
242
|
for n in pre_tokenizers.each() {
|
244
|
-
let pretokenizer: &RbPreTokenizer =
|
243
|
+
let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
|
245
244
|
match &pretokenizer.pretok {
|
246
245
|
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
247
246
|
sequence.extend(inner.iter().cloned())
|
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
|
|
346
345
|
}
|
347
346
|
|
348
347
|
unsafe impl TypedData for RbPreTokenizer {
|
349
|
-
fn class() -> RClass {
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
})
|
348
|
+
fn class(ruby: &Ruby) -> RClass {
|
349
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
ruby.get_inner(&CLASS)
|
355
355
|
}
|
356
356
|
|
357
357
|
fn data_type() -> &'static DataType {
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
358
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
|
359
|
+
&DATA_TYPE
|
360
|
+
}
|
361
|
+
|
362
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
363
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
364
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
|
365
|
+
class.undef_default_alloc_func();
|
366
|
+
class
|
367
|
+
});
|
368
|
+
static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
369
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
|
370
|
+
class.undef_default_alloc_func();
|
371
|
+
class
|
372
|
+
});
|
373
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
374
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
|
375
|
+
class.undef_default_alloc_func();
|
376
|
+
class
|
377
|
+
});
|
378
|
+
static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
379
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
|
380
|
+
class.undef_default_alloc_func();
|
381
|
+
class
|
382
|
+
});
|
383
|
+
static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
|
384
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
|
385
|
+
class.undef_default_alloc_func();
|
386
|
+
class
|
387
|
+
});
|
388
|
+
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
389
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
|
390
|
+
class.undef_default_alloc_func();
|
391
|
+
class
|
392
|
+
});
|
393
|
+
static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
|
394
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
|
395
|
+
class.undef_default_alloc_func();
|
396
|
+
class
|
397
|
+
});
|
398
|
+
static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
399
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
|
400
|
+
class.undef_default_alloc_func();
|
401
|
+
class
|
402
|
+
});
|
403
|
+
static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
|
404
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
|
405
|
+
class.undef_default_alloc_func();
|
406
|
+
class
|
407
|
+
});
|
408
|
+
static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
409
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
|
410
|
+
class.undef_default_alloc_func();
|
411
|
+
class
|
412
|
+
});
|
413
|
+
static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
414
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
|
415
|
+
class.undef_default_alloc_func();
|
416
|
+
class
|
417
|
+
});
|
362
418
|
match &value.pretok {
|
363
|
-
RbPreTokenizerTypeWrapper::Sequence(_seq) =>
|
364
|
-
let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
|
365
|
-
class.undef_alloc_func();
|
366
|
-
class
|
367
|
-
}),
|
419
|
+
RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
368
420
|
RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
369
421
|
RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
|
370
|
-
PreTokenizerWrapper::BertPreTokenizer(_) =>
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
PreTokenizerWrapper::
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
|
381
|
-
let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
|
382
|
-
class.undef_alloc_func();
|
383
|
-
class
|
384
|
-
}),
|
385
|
-
PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
|
386
|
-
let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
|
387
|
-
class.undef_alloc_func();
|
388
|
-
class
|
389
|
-
}),
|
390
|
-
PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
|
391
|
-
let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
|
392
|
-
class.undef_alloc_func();
|
393
|
-
class
|
394
|
-
}),
|
395
|
-
PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
|
396
|
-
let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
|
397
|
-
class.undef_alloc_func();
|
398
|
-
class
|
399
|
-
}),
|
400
|
-
PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
|
401
|
-
let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
|
402
|
-
class.undef_alloc_func();
|
403
|
-
class
|
404
|
-
}),
|
405
|
-
PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
|
406
|
-
let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
|
407
|
-
class.undef_alloc_func();
|
408
|
-
class
|
409
|
-
}),
|
410
|
-
PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
|
411
|
-
let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
|
412
|
-
class.undef_alloc_func();
|
413
|
-
class
|
414
|
-
}),
|
415
|
-
PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
|
416
|
-
let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
|
417
|
-
class.undef_alloc_func();
|
418
|
-
class
|
419
|
-
}),
|
422
|
+
PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
|
423
|
+
PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
424
|
+
PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
|
425
|
+
PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
|
426
|
+
PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
|
427
|
+
PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
|
428
|
+
PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
|
429
|
+
PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
|
430
|
+
PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
|
431
|
+
PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
|
420
432
|
_ => todo!(),
|
421
433
|
},
|
422
434
|
},
|
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
424
436
|
}
|
425
437
|
}
|
426
438
|
|
427
|
-
pub fn
|
428
|
-
let pre_tokenizer = module.define_class("PreTokenizer",
|
439
|
+
pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
440
|
+
let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
|
429
441
|
pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
|
430
442
|
|
431
443
|
let class = module.define_class("Sequence", pre_tokenizer)?;
|