tokenizers 0.3.3 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/Cargo.lock +52 -23
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/decoders.rs +72 -61
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +57 -51
- data/ext/tokenizers/src/normalizers.rs +90 -77
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +2 -2
- metadata +3 -3
@@ -3,10 +3,10 @@ use std::path::{Path, PathBuf};
|
|
3
3
|
use std::sync::{Arc, RwLock};
|
4
4
|
|
5
5
|
use crate::trainers::RbTrainer;
|
6
|
-
use magnus::
|
6
|
+
use magnus::prelude::*;
|
7
7
|
use magnus::{
|
8
|
-
exception, function,
|
9
|
-
RClass, RHash, RModule, Symbol, TypedData, Value,
|
8
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
|
9
|
+
RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
|
10
10
|
};
|
11
11
|
use serde::{Deserialize, Serialize};
|
12
12
|
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
|
@@ -16,7 +16,7 @@ use tk::models::wordlevel::WordLevel;
|
|
16
16
|
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
|
17
17
|
use tk::{Model, Token};
|
18
18
|
|
19
|
-
use super::{RbError, RbResult};
|
19
|
+
use super::{MODELS, RbError, RbResult};
|
20
20
|
|
21
21
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
22
22
|
pub struct RbModel {
|
@@ -73,37 +73,37 @@ impl RbBPE {
|
|
73
73
|
fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
74
74
|
let value: Value = kwargs.delete(Symbol::new("cache_capacity"))?;
|
75
75
|
if !value.is_nil() {
|
76
|
-
builder = builder.cache_capacity(
|
76
|
+
builder = builder.cache_capacity(TryConvert::try_convert(value)?);
|
77
77
|
}
|
78
78
|
|
79
79
|
let value: Value = kwargs.delete(Symbol::new("dropout"))?;
|
80
80
|
if !value.is_nil() {
|
81
|
-
builder = builder.dropout(
|
81
|
+
builder = builder.dropout(TryConvert::try_convert(value)?);
|
82
82
|
}
|
83
83
|
|
84
84
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
85
85
|
if !value.is_nil() {
|
86
|
-
builder = builder.unk_token(
|
86
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
87
87
|
}
|
88
88
|
|
89
89
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
90
90
|
if !value.is_nil() {
|
91
|
-
builder = builder.continuing_subword_prefix(
|
91
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
92
92
|
}
|
93
93
|
|
94
94
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
95
95
|
if !value.is_nil() {
|
96
|
-
builder = builder.end_of_word_suffix(
|
96
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
97
97
|
}
|
98
98
|
|
99
99
|
let value: Value = kwargs.delete(Symbol::new("fuse_unk"))?;
|
100
100
|
if !value.is_nil() {
|
101
|
-
builder = builder.fuse_unk(
|
101
|
+
builder = builder.fuse_unk(TryConvert::try_convert(value)?);
|
102
102
|
}
|
103
103
|
|
104
104
|
let value: Value = kwargs.delete(Symbol::new("byte_fallback"))?;
|
105
105
|
if !value.is_nil() {
|
106
|
-
builder = builder.byte_fallback(
|
106
|
+
builder = builder.byte_fallback(TryConvert::try_convert(value)?);
|
107
107
|
}
|
108
108
|
|
109
109
|
if !kwargs.is_empty() {
|
@@ -234,13 +234,13 @@ impl RbModel {
|
|
234
234
|
pub struct RbUnigram {}
|
235
235
|
|
236
236
|
impl RbUnigram {
|
237
|
-
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> RbResult<RbModel> {
|
238
|
-
match (vocab, unk_id) {
|
239
|
-
(Some(vocab), unk_id) => {
|
240
|
-
let model = Unigram::from(vocab, unk_id).map_err(RbError::from)?;
|
237
|
+
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>, byte_fallback: Option<bool>) -> RbResult<RbModel> {
|
238
|
+
match (vocab, unk_id, byte_fallback) {
|
239
|
+
(Some(vocab), unk_id, byte_fallback) => {
|
240
|
+
let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(RbError::from)?;
|
241
241
|
Ok(model.into())
|
242
242
|
}
|
243
|
-
(None, None) => Ok(Unigram::default().into()),
|
243
|
+
(None, None, _) => Ok(Unigram::default().into()),
|
244
244
|
_ => Err(Error::new(exception::arg_error(), "`vocab` and `unk_id` must be both specified")),
|
245
245
|
}
|
246
246
|
}
|
@@ -277,17 +277,17 @@ impl RbWordPiece {
|
|
277
277
|
fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
278
278
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
279
279
|
if !value.is_nil() {
|
280
|
-
builder = builder.unk_token(
|
280
|
+
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
281
281
|
}
|
282
282
|
|
283
283
|
let value: Value = kwargs.delete(Symbol::new("max_input_chars_per_word"))?;
|
284
284
|
if !value.is_nil() {
|
285
|
-
builder = builder.max_input_chars_per_word(
|
285
|
+
builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
|
286
286
|
}
|
287
287
|
|
288
288
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
289
289
|
if !value.is_nil() {
|
290
|
-
builder = builder.continuing_subword_prefix(
|
290
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
291
291
|
}
|
292
292
|
|
293
293
|
if !kwargs.is_empty() {
|
@@ -314,46 +314,52 @@ impl RbWordPiece {
|
|
314
314
|
}
|
315
315
|
|
316
316
|
unsafe impl TypedData for RbModel {
|
317
|
-
fn class() -> RClass {
|
318
|
-
|
319
|
-
let class: RClass =
|
320
|
-
class.
|
317
|
+
fn class(ruby: &Ruby) -> RClass {
|
318
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
319
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Model").unwrap();
|
320
|
+
class.undef_default_alloc_func();
|
321
321
|
class
|
322
|
-
})
|
322
|
+
});
|
323
|
+
ruby.get_inner(&CLASS)
|
323
324
|
}
|
324
325
|
|
325
326
|
fn data_type() -> &'static DataType {
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
327
|
+
static DATA_TYPE: DataType = data_type_builder!(RbModel, "Tokenizers::Models::Model").build();
|
328
|
+
&DATA_TYPE
|
329
|
+
}
|
330
|
+
|
331
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
332
|
+
static BPE: Lazy<RClass> = Lazy::new(|ruby| {
|
333
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("BPE").unwrap();
|
334
|
+
class.undef_default_alloc_func();
|
335
|
+
class
|
336
|
+
});
|
337
|
+
static UNIGRAM: Lazy<RClass> = Lazy::new(|ruby| {
|
338
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("Unigram").unwrap();
|
339
|
+
class.undef_default_alloc_func();
|
340
|
+
class
|
341
|
+
});
|
342
|
+
static WORD_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
343
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordLevel").unwrap();
|
344
|
+
class.undef_default_alloc_func();
|
345
|
+
class
|
346
|
+
});
|
347
|
+
static WORD_PIECE: Lazy<RClass> = Lazy::new(|ruby| {
|
348
|
+
let class: RClass = ruby.get_inner(&MODELS).const_get("WordPiece").unwrap();
|
349
|
+
class.undef_default_alloc_func();
|
350
|
+
class
|
351
|
+
});
|
330
352
|
match *value.model.read().unwrap() {
|
331
|
-
ModelWrapper::BPE(_) =>
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
}),
|
336
|
-
ModelWrapper::Unigram(_) => *memoize!(RClass: {
|
337
|
-
let class: RClass = crate::models().const_get("Unigram").unwrap();
|
338
|
-
class.undef_alloc_func();
|
339
|
-
class
|
340
|
-
}),
|
341
|
-
ModelWrapper::WordLevel(_) => *memoize!(RClass: {
|
342
|
-
let class: RClass = crate::models().const_get("WordLevel").unwrap();
|
343
|
-
class.undef_alloc_func();
|
344
|
-
class
|
345
|
-
}),
|
346
|
-
ModelWrapper::WordPiece(_) => *memoize!(RClass: {
|
347
|
-
let class: RClass = crate::models().const_get("WordPiece").unwrap();
|
348
|
-
class.undef_alloc_func();
|
349
|
-
class
|
350
|
-
}),
|
353
|
+
ModelWrapper::BPE(_) => ruby.get_inner(&BPE),
|
354
|
+
ModelWrapper::Unigram(_) => ruby.get_inner(&UNIGRAM),
|
355
|
+
ModelWrapper::WordLevel(_) => ruby.get_inner(&WORD_LEVEL),
|
356
|
+
ModelWrapper::WordPiece(_) => ruby.get_inner(&WORD_PIECE),
|
351
357
|
}
|
352
358
|
}
|
353
359
|
}
|
354
360
|
|
355
|
-
pub fn
|
356
|
-
let model = module.define_class("Model",
|
361
|
+
pub fn init_models(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
362
|
+
let model = module.define_class("Model", ruby.class_object())?;
|
357
363
|
|
358
364
|
let class = module.define_class("BPE", model)?;
|
359
365
|
class.define_singleton_method("_new", function!(RbBPE::new, 3))?;
|
@@ -372,7 +378,7 @@ pub fn models(module: &RModule) -> RbResult<()> {
|
|
372
378
|
class.define_method("byte_fallback=", method!(RbModel::bpe_set_byte_fallback, 1))?;
|
373
379
|
|
374
380
|
let class = module.define_class("Unigram", model)?;
|
375
|
-
class.define_singleton_method("_new", function!(RbUnigram::new,
|
381
|
+
class.define_singleton_method("_new", function!(RbUnigram::new, 3))?;
|
376
382
|
|
377
383
|
let class = module.define_class("WordLevel", model)?;
|
378
384
|
class.define_singleton_method("_new", function!(RbWordLevel::new, 2))?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object, RArray, RClass, RModule,
|
5
|
+
Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
use serde::ser::SerializeStruct;
|
9
8
|
use serde::{Deserialize, Serialize, Serializer};
|
@@ -14,7 +13,7 @@ use tk::normalizers::{
|
|
14
13
|
use tk::{NormalizedString, Normalizer};
|
15
14
|
|
16
15
|
use super::utils::*;
|
17
|
-
use super::{RbError, RbResult};
|
16
|
+
use super::{NORMALIZERS, RbError, RbResult};
|
18
17
|
|
19
18
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
20
19
|
pub struct RbNormalizer {
|
@@ -224,7 +223,7 @@ impl RbSequence {
|
|
224
223
|
fn new(normalizers: RArray) -> RbResult<RbNormalizer> {
|
225
224
|
let mut sequence = Vec::with_capacity(normalizers.len());
|
226
225
|
for n in normalizers.each() {
|
227
|
-
let normalizer: &RbNormalizer =
|
226
|
+
let normalizer: &RbNormalizer = TryConvert::try_convert(n?)?;
|
228
227
|
match &normalizer.normalizer {
|
229
228
|
RbNormalizerTypeWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
|
230
229
|
RbNormalizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
@@ -327,82 +326,96 @@ impl Normalizer for RbNormalizerWrapper {
|
|
327
326
|
}
|
328
327
|
|
329
328
|
unsafe impl TypedData for RbNormalizer {
|
330
|
-
fn class() -> RClass {
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
})
|
329
|
+
fn class(ruby: &Ruby) -> RClass {
|
330
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
331
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Normalizer").unwrap();
|
332
|
+
class.undef_default_alloc_func();
|
333
|
+
class
|
334
|
+
});
|
335
|
+
ruby.get_inner(&CLASS)
|
336
336
|
}
|
337
337
|
|
338
338
|
fn data_type() -> &'static DataType {
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
339
|
+
static DATA_TYPE: DataType = data_type_builder!(RbNormalizer, "Tokenizers::Normalizers::Normalizer").build();
|
340
|
+
&DATA_TYPE
|
341
|
+
}
|
342
|
+
|
343
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
344
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
345
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Sequence").unwrap();
|
346
|
+
class.undef_default_alloc_func();
|
347
|
+
class
|
348
|
+
});
|
349
|
+
static BERT_NORMALIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("BertNormalizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
static LOWERCASE: Lazy<RClass> = Lazy::new(|ruby| {
|
355
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Lowercase").unwrap();
|
356
|
+
class.undef_default_alloc_func();
|
357
|
+
class
|
358
|
+
});
|
359
|
+
static NFD: Lazy<RClass> = Lazy::new(|ruby| {
|
360
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFD").unwrap();
|
361
|
+
class.undef_default_alloc_func();
|
362
|
+
class
|
363
|
+
});
|
364
|
+
static NFC: Lazy<RClass> = Lazy::new(|ruby| {
|
365
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFC").unwrap();
|
366
|
+
class.undef_default_alloc_func();
|
367
|
+
class
|
368
|
+
});
|
369
|
+
static NFKC: Lazy<RClass> = Lazy::new(|ruby| {
|
370
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKC").unwrap();
|
371
|
+
class.undef_default_alloc_func();
|
372
|
+
class
|
373
|
+
});
|
374
|
+
static NFKD: Lazy<RClass> = Lazy::new(|ruby| {
|
375
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("NFKD").unwrap();
|
376
|
+
class.undef_default_alloc_func();
|
377
|
+
class
|
378
|
+
});
|
379
|
+
static NMT: Lazy<RClass> = Lazy::new(|ruby| {
|
380
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Nmt").unwrap();
|
381
|
+
class.undef_default_alloc_func();
|
382
|
+
class
|
383
|
+
});
|
384
|
+
static REPLACE: Lazy<RClass> = Lazy::new(|ruby| {
|
385
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Replace").unwrap();
|
386
|
+
class.undef_default_alloc_func();
|
387
|
+
class
|
388
|
+
});
|
389
|
+
static PREPEND: Lazy<RClass> = Lazy::new(|ruby| {
|
390
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Prepend").unwrap();
|
391
|
+
class.undef_default_alloc_func();
|
392
|
+
class
|
393
|
+
});
|
394
|
+
static STRIP: Lazy<RClass> = Lazy::new(|ruby| {
|
395
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("Strip").unwrap();
|
396
|
+
class.undef_default_alloc_func();
|
397
|
+
class
|
398
|
+
});
|
399
|
+
static STRIP_ACCENTS: Lazy<RClass> = Lazy::new(|ruby| {
|
400
|
+
let class: RClass = ruby.get_inner(&NORMALIZERS).const_get("StripAccents").unwrap();
|
401
|
+
class.undef_default_alloc_func();
|
402
|
+
class
|
403
|
+
});
|
343
404
|
match &value.normalizer {
|
344
|
-
RbNormalizerTypeWrapper::Sequence(_seq) =>
|
345
|
-
let class: RClass = crate::normalizers().const_get("Sequence").unwrap();
|
346
|
-
class.undef_alloc_func();
|
347
|
-
class
|
348
|
-
}),
|
405
|
+
RbNormalizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
349
406
|
RbNormalizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
350
407
|
RbNormalizerWrapper::Wrapped(wrapped) => match &wrapped {
|
351
|
-
NormalizerWrapper::BertNormalizer(_) =>
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
NormalizerWrapper::
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
NormalizerWrapper::
|
362
|
-
let class: RClass = crate::normalizers().const_get("NFD").unwrap();
|
363
|
-
class.undef_alloc_func();
|
364
|
-
class
|
365
|
-
}),
|
366
|
-
NormalizerWrapper::NFC(_) => *memoize!(RClass: {
|
367
|
-
let class: RClass = crate::normalizers().const_get("NFC").unwrap();
|
368
|
-
class.undef_alloc_func();
|
369
|
-
class
|
370
|
-
}),
|
371
|
-
NormalizerWrapper::NFKC(_) => *memoize!(RClass: {
|
372
|
-
let class: RClass = crate::normalizers().const_get("NFKC").unwrap();
|
373
|
-
class.undef_alloc_func();
|
374
|
-
class
|
375
|
-
}),
|
376
|
-
NormalizerWrapper::NFKD(_) => *memoize!(RClass: {
|
377
|
-
let class: RClass = crate::normalizers().const_get("NFKD").unwrap();
|
378
|
-
class.undef_alloc_func();
|
379
|
-
class
|
380
|
-
}),
|
381
|
-
NormalizerWrapper::Nmt(_) => *memoize!(RClass: {
|
382
|
-
let class: RClass = crate::normalizers().const_get("Nmt").unwrap();
|
383
|
-
class.undef_alloc_func();
|
384
|
-
class
|
385
|
-
}),
|
386
|
-
NormalizerWrapper::Replace(_) => *memoize!(RClass: {
|
387
|
-
let class: RClass = crate::normalizers().const_get("Replace").unwrap();
|
388
|
-
class.undef_alloc_func();
|
389
|
-
class
|
390
|
-
}),
|
391
|
-
NormalizerWrapper::Prepend(_) => *memoize!(RClass: {
|
392
|
-
let class: RClass = crate::normalizers().const_get("Prepend").unwrap();
|
393
|
-
class.undef_alloc_func();
|
394
|
-
class
|
395
|
-
}),
|
396
|
-
NormalizerWrapper::StripNormalizer(_) => *memoize!(RClass: {
|
397
|
-
let class: RClass = crate::normalizers().const_get("Strip").unwrap();
|
398
|
-
class.undef_alloc_func();
|
399
|
-
class
|
400
|
-
}),
|
401
|
-
NormalizerWrapper::StripAccents(_) => *memoize!(RClass: {
|
402
|
-
let class: RClass = crate::normalizers().const_get("StripAccents").unwrap();
|
403
|
-
class.undef_alloc_func();
|
404
|
-
class
|
405
|
-
}),
|
408
|
+
NormalizerWrapper::BertNormalizer(_) => ruby.get_inner(&BERT_NORMALIZER),
|
409
|
+
NormalizerWrapper::Lowercase(_) => ruby.get_inner(&LOWERCASE),
|
410
|
+
NormalizerWrapper::NFD(_) => ruby.get_inner(&NFD),
|
411
|
+
NormalizerWrapper::NFC(_) => ruby.get_inner(&NFC),
|
412
|
+
NormalizerWrapper::NFKC(_) => ruby.get_inner(&NFKC),
|
413
|
+
NormalizerWrapper::NFKD(_) => ruby.get_inner(&NFKD),
|
414
|
+
NormalizerWrapper::Nmt(_) => ruby.get_inner(&NMT),
|
415
|
+
NormalizerWrapper::Replace(_) => ruby.get_inner(&REPLACE),
|
416
|
+
NormalizerWrapper::Prepend(_) => ruby.get_inner(&PREPEND),
|
417
|
+
NormalizerWrapper::StripNormalizer(_) => ruby.get_inner(&STRIP),
|
418
|
+
NormalizerWrapper::StripAccents(_) => ruby.get_inner(&STRIP_ACCENTS),
|
406
419
|
_ => todo!(),
|
407
420
|
},
|
408
421
|
},
|
@@ -410,8 +423,8 @@ unsafe impl TypedData for RbNormalizer {
|
|
410
423
|
}
|
411
424
|
}
|
412
425
|
|
413
|
-
pub fn
|
414
|
-
let normalizer = module.define_class("Normalizer",
|
426
|
+
pub fn init_normalizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
427
|
+
let normalizer = module.define_class("Normalizer", ruby.class_object())?;
|
415
428
|
normalizer.define_method("normalize_str", method!(RbNormalizer::normalize_str, 1))?;
|
416
429
|
|
417
430
|
let class = module.define_class("Sequence", normalizer)?;
|
@@ -1,9 +1,8 @@
|
|
1
1
|
use std::sync::{Arc, RwLock};
|
2
2
|
|
3
|
-
use magnus::typed_data::DataTypeBuilder;
|
4
3
|
use magnus::{
|
5
|
-
|
6
|
-
RArray, RClass, RModule, TypedData,
|
4
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Module, Object,
|
5
|
+
RArray, RClass, RModule, Ruby, TryConvert, TypedData,
|
7
6
|
};
|
8
7
|
|
9
8
|
use serde::ser::SerializeStruct;
|
@@ -23,7 +22,7 @@ use tk::tokenizer::Offsets;
|
|
23
22
|
use tk::{PreTokenizedString, PreTokenizer};
|
24
23
|
|
25
24
|
use super::utils::*;
|
26
|
-
use super::{RbError, RbResult};
|
25
|
+
use super::{PRE_TOKENIZERS, RbError, RbResult};
|
27
26
|
|
28
27
|
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
29
28
|
pub struct RbPreTokenizer {
|
@@ -215,7 +214,7 @@ pub struct RbWhitespace {}
|
|
215
214
|
|
216
215
|
impl RbWhitespace {
|
217
216
|
pub fn new() -> RbPreTokenizer {
|
218
|
-
Whitespace
|
217
|
+
Whitespace.into()
|
219
218
|
}
|
220
219
|
}
|
221
220
|
|
@@ -241,7 +240,7 @@ impl RbSequence {
|
|
241
240
|
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
242
241
|
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
243
242
|
for n in pre_tokenizers.each() {
|
244
|
-
let pretokenizer: &RbPreTokenizer =
|
243
|
+
let pretokenizer: &RbPreTokenizer = TryConvert::try_convert(n?)?;
|
245
244
|
match &pretokenizer.pretok {
|
246
245
|
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
247
246
|
sequence.extend(inner.iter().cloned())
|
@@ -346,77 +345,90 @@ impl PreTokenizer for RbPreTokenizerWrapper {
|
|
346
345
|
}
|
347
346
|
|
348
347
|
unsafe impl TypedData for RbPreTokenizer {
|
349
|
-
fn class() -> RClass {
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
})
|
348
|
+
fn class(ruby: &Ruby) -> RClass {
|
349
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
350
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("PreTokenizer").unwrap();
|
351
|
+
class.undef_default_alloc_func();
|
352
|
+
class
|
353
|
+
});
|
354
|
+
ruby.get_inner(&CLASS)
|
355
355
|
}
|
356
356
|
|
357
357
|
fn data_type() -> &'static DataType {
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
358
|
+
static DATA_TYPE: DataType = data_type_builder!(RbPreTokenizer, "Tokenizers::PreTokenizers::PreTokenizer").build();
|
359
|
+
&DATA_TYPE
|
360
|
+
}
|
361
|
+
|
362
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
363
|
+
static SEQUENCE: Lazy<RClass> = Lazy::new(|ruby| {
|
364
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Sequence").unwrap();
|
365
|
+
class.undef_default_alloc_func();
|
366
|
+
class
|
367
|
+
});
|
368
|
+
static BERT_PRE_TOKENIZER: Lazy<RClass> = Lazy::new(|ruby| {
|
369
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("BertPreTokenizer").unwrap();
|
370
|
+
class.undef_default_alloc_func();
|
371
|
+
class
|
372
|
+
});
|
373
|
+
static BYTE_LEVEL: Lazy<RClass> = Lazy::new(|ruby| {
|
374
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("ByteLevel").unwrap();
|
375
|
+
class.undef_default_alloc_func();
|
376
|
+
class
|
377
|
+
});
|
378
|
+
static CHAR_DELIMITER_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
379
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("CharDelimiterSplit").unwrap();
|
380
|
+
class.undef_default_alloc_func();
|
381
|
+
class
|
382
|
+
});
|
383
|
+
static DIGITS: Lazy<RClass> = Lazy::new(|ruby| {
|
384
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Digits").unwrap();
|
385
|
+
class.undef_default_alloc_func();
|
386
|
+
class
|
387
|
+
});
|
388
|
+
static METASPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
389
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Metaspace").unwrap();
|
390
|
+
class.undef_default_alloc_func();
|
391
|
+
class
|
392
|
+
});
|
393
|
+
static PUNCTUATION: Lazy<RClass> = Lazy::new(|ruby| {
|
394
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Punctuation").unwrap();
|
395
|
+
class.undef_default_alloc_func();
|
396
|
+
class
|
397
|
+
});
|
398
|
+
static SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
399
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Split").unwrap();
|
400
|
+
class.undef_default_alloc_func();
|
401
|
+
class
|
402
|
+
});
|
403
|
+
static UNICODE_SCRIPTS: Lazy<RClass> = Lazy::new(|ruby| {
|
404
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("UnicodeScripts").unwrap();
|
405
|
+
class.undef_default_alloc_func();
|
406
|
+
class
|
407
|
+
});
|
408
|
+
static WHITESPACE: Lazy<RClass> = Lazy::new(|ruby| {
|
409
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("Whitespace").unwrap();
|
410
|
+
class.undef_default_alloc_func();
|
411
|
+
class
|
412
|
+
});
|
413
|
+
static WHITESPACE_SPLIT: Lazy<RClass> = Lazy::new(|ruby| {
|
414
|
+
let class: RClass = ruby.get_inner(&PRE_TOKENIZERS).const_get("WhitespaceSplit").unwrap();
|
415
|
+
class.undef_default_alloc_func();
|
416
|
+
class
|
417
|
+
});
|
362
418
|
match &value.pretok {
|
363
|
-
RbPreTokenizerTypeWrapper::Sequence(_seq) =>
|
364
|
-
let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
|
365
|
-
class.undef_alloc_func();
|
366
|
-
class
|
367
|
-
}),
|
419
|
+
RbPreTokenizerTypeWrapper::Sequence(_seq) => ruby.get_inner(&SEQUENCE),
|
368
420
|
RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
369
421
|
RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
|
370
|
-
PreTokenizerWrapper::BertPreTokenizer(_) =>
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
PreTokenizerWrapper::
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
|
381
|
-
let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
|
382
|
-
class.undef_alloc_func();
|
383
|
-
class
|
384
|
-
}),
|
385
|
-
PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
|
386
|
-
let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
|
387
|
-
class.undef_alloc_func();
|
388
|
-
class
|
389
|
-
}),
|
390
|
-
PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
|
391
|
-
let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
|
392
|
-
class.undef_alloc_func();
|
393
|
-
class
|
394
|
-
}),
|
395
|
-
PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
|
396
|
-
let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
|
397
|
-
class.undef_alloc_func();
|
398
|
-
class
|
399
|
-
}),
|
400
|
-
PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
|
401
|
-
let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
|
402
|
-
class.undef_alloc_func();
|
403
|
-
class
|
404
|
-
}),
|
405
|
-
PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
|
406
|
-
let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
|
407
|
-
class.undef_alloc_func();
|
408
|
-
class
|
409
|
-
}),
|
410
|
-
PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
|
411
|
-
let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
|
412
|
-
class.undef_alloc_func();
|
413
|
-
class
|
414
|
-
}),
|
415
|
-
PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
|
416
|
-
let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
|
417
|
-
class.undef_alloc_func();
|
418
|
-
class
|
419
|
-
}),
|
422
|
+
PreTokenizerWrapper::BertPreTokenizer(_) => ruby.get_inner(&BERT_PRE_TOKENIZER),
|
423
|
+
PreTokenizerWrapper::ByteLevel(_) => ruby.get_inner(&BYTE_LEVEL),
|
424
|
+
PreTokenizerWrapper::Delimiter(_) => ruby.get_inner(&CHAR_DELIMITER_SPLIT),
|
425
|
+
PreTokenizerWrapper::Digits(_) => ruby.get_inner(&DIGITS),
|
426
|
+
PreTokenizerWrapper::Metaspace(_) => ruby.get_inner(&METASPACE),
|
427
|
+
PreTokenizerWrapper::Punctuation(_) => ruby.get_inner(&PUNCTUATION),
|
428
|
+
PreTokenizerWrapper::Split(_) => ruby.get_inner(&SPLIT),
|
429
|
+
PreTokenizerWrapper::UnicodeScripts(_) => ruby.get_inner(&UNICODE_SCRIPTS),
|
430
|
+
PreTokenizerWrapper::Whitespace(_) => ruby.get_inner(&WHITESPACE),
|
431
|
+
PreTokenizerWrapper::WhitespaceSplit(_) => ruby.get_inner(&WHITESPACE_SPLIT),
|
420
432
|
_ => todo!(),
|
421
433
|
},
|
422
434
|
},
|
@@ -424,8 +436,8 @@ unsafe impl TypedData for RbPreTokenizer {
|
|
424
436
|
}
|
425
437
|
}
|
426
438
|
|
427
|
-
pub fn
|
428
|
-
let pre_tokenizer = module.define_class("PreTokenizer",
|
439
|
+
pub fn init_pre_tokenizers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
440
|
+
let pre_tokenizer = module.define_class("PreTokenizer", ruby.class_object())?;
|
429
441
|
pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
|
430
442
|
|
431
443
|
let class = module.define_class("Sequence", pre_tokenizer)?;
|