tokenizers 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/Cargo.lock +160 -96
- data/ext/tokenizers/Cargo.toml +6 -6
- data/ext/tokenizers/src/decoders.rs +149 -39
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +71 -50
- data/ext/tokenizers/src/normalizers.rs +113 -74
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/decoders/strip.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/normalizers/prepend.rb +9 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +4 -2
- metadata +6 -4
@@ -2,6 +2,7 @@ use std::cell::RefCell;
|
|
2
2
|
use std::collections::HashMap;
|
3
3
|
use std::path::PathBuf;
|
4
4
|
|
5
|
+
use magnus::prelude::*;
|
5
6
|
use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
|
6
7
|
use tk::tokenizer::{
|
7
8
|
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
|
78
79
|
|
79
80
|
impl<'s> TryConvert for TextInputSequence<'s> {
|
80
81
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
81
|
-
Ok(Self(
|
82
|
+
Ok(Self(String::try_convert(ob)?.into()))
|
82
83
|
}
|
83
84
|
}
|
84
85
|
|
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
|
|
92
93
|
|
93
94
|
impl TryConvert for RbArrayStr {
|
94
95
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
95
|
-
let seq =
|
96
|
+
let seq = <Vec<String>>::try_convert(ob)?;
|
96
97
|
Ok(Self(seq))
|
97
98
|
}
|
98
99
|
}
|
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
|
107
108
|
|
108
109
|
impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
|
109
110
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
110
|
-
if let Ok(seq) =
|
111
|
+
if let Ok(seq) = RbArrayStr::try_convert(ob) {
|
111
112
|
return Ok(Self(seq.into()));
|
112
113
|
}
|
113
114
|
todo!()
|
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
124
125
|
|
125
126
|
impl<'s> TryConvert for TextEncodeInput<'s> {
|
126
127
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
127
|
-
if let Ok(i) =
|
128
|
+
if let Ok(i) = TextInputSequence::try_convert(ob) {
|
128
129
|
return Ok(Self(i.into()));
|
129
130
|
}
|
130
|
-
if let Ok((i1, i2)) =
|
131
|
+
if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
|
131
132
|
return Ok(Self((i1, i2).into()));
|
132
133
|
}
|
133
134
|
// TODO check if this branch is needed
|
134
|
-
if let Ok(arr) =
|
135
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
135
136
|
if arr.len() == 2 {
|
136
137
|
let first = arr.entry::<TextInputSequence>(0).unwrap();
|
137
138
|
let second = arr.entry::<TextInputSequence>(1).unwrap();
|
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
155
156
|
|
156
157
|
impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
|
157
158
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
158
|
-
if let Ok(i) =
|
159
|
+
if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
|
159
160
|
return Ok(Self(i.into()));
|
160
161
|
}
|
161
162
|
if let Ok((i1, i2)) =
|
162
|
-
|
163
|
+
<(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
|
163
164
|
{
|
164
165
|
return Ok(Self((i1, i2).into()));
|
165
166
|
}
|
166
167
|
// TODO check if this branch is needed
|
167
|
-
if let Ok(arr) =
|
168
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
168
169
|
if arr.len() == 2 {
|
169
170
|
let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
|
170
171
|
let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
|
@@ -251,16 +252,16 @@ impl RbTokenizer {
|
|
251
252
|
add_special_tokens: bool,
|
252
253
|
) -> RbResult<RbEncoding> {
|
253
254
|
let sequence: tk::InputSequence = if is_pretokenized {
|
254
|
-
|
255
|
+
PreTokenizedInputSequence::try_convert(sequence)?.into()
|
255
256
|
} else {
|
256
|
-
|
257
|
+
TextInputSequence::try_convert(sequence)?.into()
|
257
258
|
};
|
258
259
|
let input = match pair {
|
259
260
|
Some(pair) => {
|
260
261
|
let pair: tk::InputSequence = if is_pretokenized {
|
261
|
-
|
262
|
+
PreTokenizedInputSequence::try_convert(pair)?.into()
|
262
263
|
} else {
|
263
|
-
|
264
|
+
TextInputSequence::try_convert(pair)?.into()
|
264
265
|
};
|
265
266
|
tk::EncodeInput::Dual(sequence, pair)
|
266
267
|
}
|
@@ -284,9 +285,9 @@ impl RbTokenizer {
|
|
284
285
|
.each()
|
285
286
|
.map(|o| {
|
286
287
|
let input: tk::EncodeInput = if is_pretokenized {
|
287
|
-
|
288
|
+
PreTokenizedEncodeInput::try_convert(o?)?.into()
|
288
289
|
} else {
|
289
|
-
|
290
|
+
TextEncodeInput::try_convert(o?)?.into()
|
290
291
|
};
|
291
292
|
Ok(input)
|
292
293
|
})
|
@@ -306,14 +307,15 @@ impl RbTokenizer {
|
|
306
307
|
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
307
308
|
self.tokenizer
|
308
309
|
.borrow()
|
309
|
-
.decode(ids, skip_special_tokens)
|
310
|
+
.decode(&ids, skip_special_tokens)
|
310
311
|
.map_err(RbError::from)
|
311
312
|
}
|
312
313
|
|
313
314
|
pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
|
315
|
+
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
314
316
|
self.tokenizer
|
315
317
|
.borrow()
|
316
|
-
.decode_batch(
|
318
|
+
.decode_batch(&slices, skip_special_tokens)
|
317
319
|
.map_err(RbError::from)
|
318
320
|
}
|
319
321
|
|
@@ -353,7 +355,7 @@ impl RbTokenizer {
|
|
353
355
|
|
354
356
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
355
357
|
if !value.is_nil() {
|
356
|
-
let dir_str
|
358
|
+
let dir_str = String::try_convert(value)?;
|
357
359
|
params.direction = match dir_str.as_str() {
|
358
360
|
"left" => PaddingDirection::Left,
|
359
361
|
"right" => PaddingDirection::Right,
|
@@ -363,29 +365,29 @@ impl RbTokenizer {
|
|
363
365
|
|
364
366
|
let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
|
365
367
|
if !value.is_nil() {
|
366
|
-
params.pad_to_multiple_of =
|
368
|
+
params.pad_to_multiple_of = TryConvert::try_convert(value)?;
|
367
369
|
}
|
368
370
|
|
369
371
|
let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
|
370
372
|
if !value.is_nil() {
|
371
|
-
params.pad_id =
|
373
|
+
params.pad_id = TryConvert::try_convert(value)?;
|
372
374
|
}
|
373
375
|
|
374
376
|
let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
|
375
377
|
if !value.is_nil() {
|
376
|
-
params.pad_type_id =
|
378
|
+
params.pad_type_id = TryConvert::try_convert(value)?;
|
377
379
|
}
|
378
380
|
|
379
381
|
let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
|
380
382
|
if !value.is_nil() {
|
381
|
-
params.pad_token =
|
383
|
+
params.pad_token = TryConvert::try_convert(value)?;
|
382
384
|
}
|
383
385
|
|
384
386
|
let value: Value = kwargs.delete(Symbol::new("length"))?;
|
385
387
|
if value.is_nil() {
|
386
388
|
params.strategy = PaddingStrategy::BatchLongest;
|
387
389
|
} else {
|
388
|
-
params.strategy = PaddingStrategy::Fixed(
|
390
|
+
params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
|
389
391
|
}
|
390
392
|
|
391
393
|
if !kwargs.is_empty() {
|
@@ -431,12 +433,12 @@ impl RbTokenizer {
|
|
431
433
|
|
432
434
|
let value: Value = kwargs.delete(Symbol::new("stride"))?;
|
433
435
|
if !value.is_nil() {
|
434
|
-
params.stride =
|
436
|
+
params.stride = TryConvert::try_convert(value)?;
|
435
437
|
}
|
436
438
|
|
437
439
|
let value: Value = kwargs.delete(Symbol::new("strategy"))?;
|
438
440
|
if !value.is_nil() {
|
439
|
-
let strategy_str
|
441
|
+
let strategy_str = String::try_convert(value)?;
|
440
442
|
params.strategy = match strategy_str.as_str() {
|
441
443
|
"longest_first" => TruncationStrategy::LongestFirst,
|
442
444
|
"only_first" => TruncationStrategy::OnlyFirst,
|
@@ -447,7 +449,7 @@ impl RbTokenizer {
|
|
447
449
|
|
448
450
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
449
451
|
if !value.is_nil() {
|
450
|
-
let dir_str
|
452
|
+
let dir_str = String::try_convert(value)?;
|
451
453
|
params.direction = match dir_str.as_str() {
|
452
454
|
"left" => TruncationDirection::Left,
|
453
455
|
"right" => TruncationDirection::Right,
|
@@ -460,13 +462,18 @@ impl RbTokenizer {
|
|
460
462
|
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
461
463
|
}
|
462
464
|
|
463
|
-
self.tokenizer.borrow_mut().with_truncation(Some(params))
|
465
|
+
if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
466
|
+
return Err(Error::new(exception::arg_error(), error_message.to_string()));
|
467
|
+
}
|
464
468
|
|
465
469
|
Ok(())
|
466
470
|
}
|
467
471
|
|
468
472
|
pub fn no_truncation(&self) {
|
469
|
-
self.tokenizer
|
473
|
+
self.tokenizer
|
474
|
+
.borrow_mut()
|
475
|
+
.with_truncation(None)
|
476
|
+
.expect("Failed to set truncation to `None`! This should never happen");
|
470
477
|
}
|
471
478
|
|
472
479
|
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|
@@ -3,16 +3,16 @@ use std::sync::{Arc, RwLock};
|
|
3
3
|
|
4
4
|
use crate::models::RbModel;
|
5
5
|
use crate::tokenizer::RbAddedToken;
|
6
|
-
use magnus::
|
6
|
+
use magnus::prelude::*;
|
7
7
|
use magnus::{
|
8
|
-
exception, function,
|
9
|
-
RArray, RClass, RHash, RModule, Symbol, TypedData, Value,
|
8
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
|
9
|
+
RArray, RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
|
10
10
|
};
|
11
11
|
use serde::{Deserialize, Serialize};
|
12
12
|
use tk::models::TrainerWrapper;
|
13
13
|
use tk::Trainer;
|
14
14
|
|
15
|
-
use super::RbResult;
|
15
|
+
use super::{RbResult, TRAINERS};
|
16
16
|
|
17
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
18
18
|
pub struct RbTrainer {
|
@@ -112,7 +112,7 @@ impl RbTrainer {
|
|
112
112
|
special_tokens
|
113
113
|
.each()
|
114
114
|
.map(|token| {
|
115
|
-
if let Ok(content) =
|
115
|
+
if let Ok(content) = String::try_convert(token?) {
|
116
116
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
117
117
|
} else {
|
118
118
|
todo!()
|
@@ -144,7 +144,7 @@ impl RbTrainer {
|
|
144
144
|
self,
|
145
145
|
BpeTrainer,
|
146
146
|
initial_alphabet,
|
147
|
-
alphabet.into_iter().
|
147
|
+
alphabet.into_iter().collect()
|
148
148
|
);
|
149
149
|
}
|
150
150
|
|
@@ -199,7 +199,7 @@ impl RbTrainer {
|
|
199
199
|
special_tokens
|
200
200
|
.each()
|
201
201
|
.map(|token| {
|
202
|
-
if let Ok(content) =
|
202
|
+
if let Ok(content) = String::try_convert(token?) {
|
203
203
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
204
204
|
} else {
|
205
205
|
todo!()
|
@@ -223,7 +223,7 @@ impl RbTrainer {
|
|
223
223
|
self,
|
224
224
|
UnigramTrainer,
|
225
225
|
initial_alphabet,
|
226
|
-
alphabet.into_iter().
|
226
|
+
alphabet.into_iter().collect()
|
227
227
|
);
|
228
228
|
}
|
229
229
|
|
@@ -270,7 +270,7 @@ impl RbTrainer {
|
|
270
270
|
special_tokens
|
271
271
|
.each()
|
272
272
|
.map(|token| {
|
273
|
-
if let Ok(content) =
|
273
|
+
if let Ok(content) = String::try_convert(token?) {
|
274
274
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
275
275
|
} else {
|
276
276
|
todo!()
|
@@ -324,7 +324,7 @@ impl RbTrainer {
|
|
324
324
|
special_tokens
|
325
325
|
.each()
|
326
326
|
.map(|token| {
|
327
|
-
if let Ok(content) =
|
327
|
+
if let Ok(content) = String::try_convert(token?) {
|
328
328
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
329
329
|
} else {
|
330
330
|
todo!()
|
@@ -356,7 +356,7 @@ impl RbTrainer {
|
|
356
356
|
self,
|
357
357
|
WordPieceTrainer,
|
358
358
|
@set_initial_alphabet,
|
359
|
-
alphabet.into_iter().
|
359
|
+
alphabet.into_iter().collect()
|
360
360
|
);
|
361
361
|
}
|
362
362
|
|
@@ -397,11 +397,10 @@ impl RbBpeTrainer {
|
|
397
397
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
398
398
|
if !value.is_nil() {
|
399
399
|
builder = builder.special_tokens(
|
400
|
-
value
|
401
|
-
.try_convert::<RArray>()?
|
400
|
+
RArray::try_convert(value)?
|
402
401
|
.each()
|
403
402
|
.map(|token| {
|
404
|
-
if let Ok(content) =
|
403
|
+
if let Ok(content) = String::try_convert(token?) {
|
405
404
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
406
405
|
} else {
|
407
406
|
todo!()
|
@@ -413,39 +412,39 @@ impl RbBpeTrainer {
|
|
413
412
|
|
414
413
|
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
415
414
|
if !value.is_nil() {
|
416
|
-
let arr =
|
415
|
+
let arr = <Vec<char>>::try_convert(value)?;
|
417
416
|
let set: HashSet<char> = HashSet::from_iter(arr);
|
418
417
|
builder = builder.initial_alphabet(set);
|
419
418
|
}
|
420
419
|
|
421
420
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
422
421
|
if !value.is_nil() {
|
423
|
-
builder = builder.vocab_size(
|
422
|
+
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
424
423
|
}
|
425
424
|
|
426
425
|
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
427
426
|
if !value.is_nil() {
|
428
|
-
builder = builder.min_frequency(
|
427
|
+
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
429
428
|
}
|
430
429
|
|
431
430
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
432
431
|
if !value.is_nil() {
|
433
|
-
builder = builder.show_progress(
|
432
|
+
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
434
433
|
}
|
435
434
|
|
436
435
|
let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
|
437
436
|
if !value.is_nil() {
|
438
|
-
builder = builder.limit_alphabet(
|
437
|
+
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
439
438
|
}
|
440
439
|
|
441
440
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
442
441
|
if !value.is_nil() {
|
443
|
-
builder = builder.continuing_subword_prefix(
|
442
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
444
443
|
}
|
445
444
|
|
446
445
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
447
446
|
if !value.is_nil() {
|
448
|
-
builder = builder.end_of_word_suffix(
|
447
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
449
448
|
}
|
450
449
|
|
451
450
|
if !kwargs.is_empty() {
|
@@ -466,11 +465,10 @@ impl RbUnigramTrainer {
|
|
466
465
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
467
466
|
if !value.is_nil() {
|
468
467
|
builder.special_tokens(
|
469
|
-
value
|
470
|
-
.try_convert::<RArray>()?
|
468
|
+
RArray::try_convert(value)?
|
471
469
|
.each()
|
472
470
|
.map(|token| {
|
473
|
-
if let Ok(content) =
|
471
|
+
if let Ok(content) = String::try_convert(token?) {
|
474
472
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
475
473
|
} else {
|
476
474
|
todo!()
|
@@ -482,44 +480,44 @@ impl RbUnigramTrainer {
|
|
482
480
|
|
483
481
|
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
484
482
|
if !value.is_nil() {
|
485
|
-
let arr =
|
483
|
+
let arr = <Vec<char>>::try_convert(value)?;
|
486
484
|
let set: HashSet<char> = HashSet::from_iter(arr);
|
487
485
|
builder.initial_alphabet(set);
|
488
486
|
}
|
489
487
|
|
490
488
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
491
489
|
if !value.is_nil() {
|
492
|
-
builder.vocab_size(
|
490
|
+
builder.vocab_size(TryConvert::try_convert(value)?);
|
493
491
|
}
|
494
492
|
|
495
493
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
496
494
|
if !value.is_nil() {
|
497
|
-
builder.show_progress(
|
495
|
+
builder.show_progress(TryConvert::try_convert(value)?);
|
498
496
|
}
|
499
497
|
|
500
498
|
let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
|
501
499
|
if !value.is_nil() {
|
502
|
-
builder.n_sub_iterations(
|
500
|
+
builder.n_sub_iterations(TryConvert::try_convert(value)?);
|
503
501
|
}
|
504
502
|
|
505
503
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
506
504
|
if !value.is_nil() {
|
507
|
-
builder.unk_token(Some(
|
505
|
+
builder.unk_token(Some(TryConvert::try_convert(value)?));
|
508
506
|
}
|
509
507
|
|
510
508
|
let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
|
511
509
|
if !value.is_nil() {
|
512
|
-
builder.max_piece_length(
|
510
|
+
builder.max_piece_length(TryConvert::try_convert(value)?);
|
513
511
|
}
|
514
512
|
|
515
513
|
let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
|
516
514
|
if !value.is_nil() {
|
517
|
-
builder.seed_size(
|
515
|
+
builder.seed_size(TryConvert::try_convert(value)?);
|
518
516
|
}
|
519
517
|
|
520
518
|
let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
|
521
519
|
if !value.is_nil() {
|
522
|
-
builder.shrinking_factor(
|
520
|
+
builder.shrinking_factor(TryConvert::try_convert(value)?);
|
523
521
|
}
|
524
522
|
|
525
523
|
if !kwargs.is_empty() {
|
@@ -541,11 +539,10 @@ impl RbWordLevelTrainer {
|
|
541
539
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
542
540
|
if !value.is_nil() {
|
543
541
|
builder.special_tokens(
|
544
|
-
value
|
545
|
-
.try_convert::<RArray>()?
|
542
|
+
RArray::try_convert(value)?
|
546
543
|
.each()
|
547
544
|
.map(|token| {
|
548
|
-
if let Ok(content) =
|
545
|
+
if let Ok(content) = String::try_convert(token?) {
|
549
546
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
550
547
|
} else {
|
551
548
|
todo!()
|
@@ -557,17 +554,17 @@ impl RbWordLevelTrainer {
|
|
557
554
|
|
558
555
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
559
556
|
if !value.is_nil() {
|
560
|
-
builder.vocab_size(
|
557
|
+
builder.vocab_size(TryConvert::try_convert(value)?);
|
561
558
|
}
|
562
559
|
|
563
560
|
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
564
561
|
if !value.is_nil() {
|
565
|
-
builder.min_frequency(
|
562
|
+
builder.min_frequency(TryConvert::try_convert(value)?);
|
566
563
|
}
|
567
564
|
|
568
565
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
569
566
|
if !value.is_nil() {
|
570
|
-
builder.show_progress(
|
567
|
+
builder.show_progress(TryConvert::try_convert(value)?);
|
571
568
|
}
|
572
569
|
|
573
570
|
Ok(builder.build().expect("WordLevelTrainerBuilder cannot fail").into())
|
@@ -583,11 +580,10 @@ impl RbWordPieceTrainer {
|
|
583
580
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
584
581
|
if !value.is_nil() {
|
585
582
|
builder = builder.special_tokens(
|
586
|
-
value
|
587
|
-
.try_convert::<RArray>()?
|
583
|
+
RArray::try_convert(value)?
|
588
584
|
.each()
|
589
585
|
.map(|token| {
|
590
|
-
if let Ok(content) =
|
586
|
+
if let Ok(content) = String::try_convert(token?) {
|
591
587
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
592
588
|
} else {
|
593
589
|
todo!()
|
@@ -599,39 +595,39 @@ impl RbWordPieceTrainer {
|
|
599
595
|
|
600
596
|
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
601
597
|
if !value.is_nil() {
|
602
|
-
let arr =
|
598
|
+
let arr = <Vec<char>>::try_convert(value)?;
|
603
599
|
let set: HashSet<char> = HashSet::from_iter(arr);
|
604
600
|
builder = builder.initial_alphabet(set);
|
605
601
|
}
|
606
602
|
|
607
603
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
608
604
|
if !value.is_nil() {
|
609
|
-
builder = builder.vocab_size(
|
605
|
+
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
610
606
|
}
|
611
607
|
|
612
608
|
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
613
609
|
if !value.is_nil() {
|
614
|
-
builder = builder.min_frequency(
|
610
|
+
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
615
611
|
}
|
616
612
|
|
617
613
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
618
614
|
if !value.is_nil() {
|
619
|
-
builder = builder.show_progress(
|
615
|
+
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
620
616
|
}
|
621
617
|
|
622
618
|
let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
|
623
619
|
if !value.is_nil() {
|
624
|
-
builder = builder.limit_alphabet(
|
620
|
+
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
625
621
|
}
|
626
622
|
|
627
623
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
628
624
|
if !value.is_nil() {
|
629
|
-
builder = builder.continuing_subword_prefix(
|
625
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
630
626
|
}
|
631
627
|
|
632
628
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
633
629
|
if !value.is_nil() {
|
634
|
-
builder = builder.end_of_word_suffix(
|
630
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
635
631
|
}
|
636
632
|
|
637
633
|
if !kwargs.is_empty() {
|
@@ -644,46 +640,52 @@ impl RbWordPieceTrainer {
|
|
644
640
|
}
|
645
641
|
|
646
642
|
unsafe impl TypedData for RbTrainer {
|
647
|
-
fn class() -> RClass {
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
})
|
643
|
+
fn class(ruby: &Ruby) -> RClass {
|
644
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
645
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("Trainer").unwrap();
|
646
|
+
class.undef_default_alloc_func();
|
647
|
+
class
|
648
|
+
});
|
649
|
+
ruby.get_inner(&CLASS)
|
653
650
|
}
|
654
651
|
|
655
652
|
fn data_type() -> &'static DataType {
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
653
|
+
static DATA_TYPE: DataType = data_type_builder!(RbTrainer, "Tokenizers::Trainers::Trainer").build();
|
654
|
+
&DATA_TYPE
|
655
|
+
}
|
656
|
+
|
657
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
658
|
+
static BPE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
659
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("BpeTrainer").unwrap();
|
660
|
+
class.undef_default_alloc_func();
|
661
|
+
class
|
662
|
+
});
|
663
|
+
static UNIGRAM_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
664
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("UnigramTrainer").unwrap();
|
665
|
+
class.undef_default_alloc_func();
|
666
|
+
class
|
667
|
+
});
|
668
|
+
static WORD_LEVEL_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
669
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordLevelTrainer").unwrap();
|
670
|
+
class.undef_default_alloc_func();
|
671
|
+
class
|
672
|
+
});
|
673
|
+
static WORD_PIECE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
674
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordPieceTrainer").unwrap();
|
675
|
+
class.undef_default_alloc_func();
|
676
|
+
class
|
677
|
+
});
|
660
678
|
match *value.trainer.read().unwrap() {
|
661
|
-
TrainerWrapper::BpeTrainer(_) =>
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
}),
|
666
|
-
TrainerWrapper::UnigramTrainer(_) => *memoize!(RClass: {
|
667
|
-
let class: RClass = crate::trainers().const_get("UnigramTrainer").unwrap();
|
668
|
-
class.undef_alloc_func();
|
669
|
-
class
|
670
|
-
}),
|
671
|
-
TrainerWrapper::WordLevelTrainer(_) => *memoize!(RClass: {
|
672
|
-
let class: RClass = crate::trainers().const_get("WordLevelTrainer").unwrap();
|
673
|
-
class.undef_alloc_func();
|
674
|
-
class
|
675
|
-
}),
|
676
|
-
TrainerWrapper::WordPieceTrainer(_) => *memoize!(RClass: {
|
677
|
-
let class: RClass = crate::trainers().const_get("WordPieceTrainer").unwrap();
|
678
|
-
class.undef_alloc_func();
|
679
|
-
class
|
680
|
-
}),
|
679
|
+
TrainerWrapper::BpeTrainer(_) => ruby.get_inner(&BPE_TRAINER),
|
680
|
+
TrainerWrapper::UnigramTrainer(_) => ruby.get_inner(&UNIGRAM_TRAINER),
|
681
|
+
TrainerWrapper::WordLevelTrainer(_) => ruby.get_inner(&WORD_LEVEL_TRAINER),
|
682
|
+
TrainerWrapper::WordPieceTrainer(_) => ruby.get_inner(&WORD_PIECE_TRAINER),
|
681
683
|
}
|
682
684
|
}
|
683
685
|
}
|
684
686
|
|
685
|
-
pub fn
|
686
|
-
let trainer = module.define_class("Trainer",
|
687
|
+
pub fn init_trainers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
688
|
+
let trainer = module.define_class("Trainer", ruby.class_object())?;
|
687
689
|
|
688
690
|
let class = module.define_class("BpeTrainer", trainer)?;
|
689
691
|
class.define_singleton_method("_new", function!(RbBpeTrainer::new, 1))?;
|
@@ -1,5 +1,6 @@
|
|
1
1
|
use super::regex::{regex, RbRegex};
|
2
2
|
use crate::RbResult;
|
3
|
+
use magnus::prelude::*;
|
3
4
|
use magnus::{exception, Error, TryConvert, Value};
|
4
5
|
use tk::normalizer::SplitDelimiterBehavior;
|
5
6
|
use tk::pattern::Pattern;
|
@@ -13,9 +14,9 @@ pub enum RbPattern<'p> {
|
|
13
14
|
impl TryConvert for RbPattern<'_> {
|
14
15
|
fn try_convert(obj: Value) -> RbResult<Self> {
|
15
16
|
if obj.is_kind_of(regex()) {
|
16
|
-
Ok(RbPattern::Regex(
|
17
|
+
Ok(RbPattern::Regex(TryConvert::try_convert(obj)?))
|
17
18
|
} else {
|
18
|
-
Ok(RbPattern::Str(
|
19
|
+
Ok(RbPattern::Str(TryConvert::try_convert(obj)?))
|
19
20
|
}
|
20
21
|
}
|
21
22
|
}
|
@@ -61,7 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
|
|
61
62
|
|
62
63
|
impl TryConvert for RbSplitDelimiterBehavior {
|
63
64
|
fn try_convert(obj: Value) -> RbResult<Self> {
|
64
|
-
let s =
|
65
|
+
let s = String::try_convert(obj)?;
|
65
66
|
|
66
67
|
Ok(Self(match s.as_str() {
|
67
68
|
"removed" => Ok(SplitDelimiterBehavior::Removed),
|
@@ -1,6 +1,6 @@
|
|
1
1
|
use onig::Regex;
|
2
|
-
use magnus::{exception,
|
3
|
-
use crate::{
|
2
|
+
use magnus::{exception, prelude::*, value::Lazy, Error, RClass, Ruby};
|
3
|
+
use crate::{RbResult, TOKENIZERS};
|
4
4
|
|
5
5
|
#[magnus::wrap(class = "Tokenizers::Regex")]
|
6
6
|
pub struct RbRegex {
|
@@ -17,6 +17,8 @@ impl RbRegex {
|
|
17
17
|
}
|
18
18
|
}
|
19
19
|
|
20
|
+
static REGEX: Lazy<RClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Regex").unwrap());
|
21
|
+
|
20
22
|
pub fn regex() -> RClass {
|
21
|
-
|
23
|
+
Ruby::get().unwrap().get_inner(®EX)
|
22
24
|
}
|