tokenizers 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/Cargo.lock +160 -96
- data/ext/tokenizers/Cargo.toml +6 -6
- data/ext/tokenizers/src/decoders.rs +149 -39
- data/ext/tokenizers/src/error.rs +5 -3
- data/ext/tokenizers/src/lib.rs +21 -33
- data/ext/tokenizers/src/models.rs +71 -50
- data/ext/tokenizers/src/normalizers.rs +113 -74
- data/ext/tokenizers/src/pre_tokenizers.rs +85 -73
- data/ext/tokenizers/src/processors.rs +43 -38
- data/ext/tokenizers/src/tokenizer.rs +35 -28
- data/ext/tokenizers/src/trainers.rs +82 -80
- data/ext/tokenizers/src/utils/normalization.rs +4 -3
- data/ext/tokenizers/src/utils/regex.rs +5 -3
- data/lib/tokenizers/decoders/strip.rb +9 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/unigram.rb +2 -2
- data/lib/tokenizers/normalizers/prepend.rb +9 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +4 -2
- metadata +6 -4
@@ -2,6 +2,7 @@ use std::cell::RefCell;
|
|
2
2
|
use std::collections::HashMap;
|
3
3
|
use std::path::PathBuf;
|
4
4
|
|
5
|
+
use magnus::prelude::*;
|
5
6
|
use magnus::{exception, Error, RArray, RHash, Symbol, TryConvert, Value};
|
6
7
|
use tk::tokenizer::{
|
7
8
|
Model, PaddingDirection, PaddingParams, PaddingStrategy,
|
@@ -78,7 +79,7 @@ struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
|
78
79
|
|
79
80
|
impl<'s> TryConvert for TextInputSequence<'s> {
|
80
81
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
81
|
-
Ok(Self(
|
82
|
+
Ok(Self(String::try_convert(ob)?.into()))
|
82
83
|
}
|
83
84
|
}
|
84
85
|
|
@@ -92,7 +93,7 @@ struct RbArrayStr(Vec<String>);
|
|
92
93
|
|
93
94
|
impl TryConvert for RbArrayStr {
|
94
95
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
95
|
-
let seq =
|
96
|
+
let seq = <Vec<String>>::try_convert(ob)?;
|
96
97
|
Ok(Self(seq))
|
97
98
|
}
|
98
99
|
}
|
@@ -107,7 +108,7 @@ struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
|
107
108
|
|
108
109
|
impl<'s> TryConvert for PreTokenizedInputSequence<'s> {
|
109
110
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
110
|
-
if let Ok(seq) =
|
111
|
+
if let Ok(seq) = RbArrayStr::try_convert(ob) {
|
111
112
|
return Ok(Self(seq.into()));
|
112
113
|
}
|
113
114
|
todo!()
|
@@ -124,14 +125,14 @@ struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
124
125
|
|
125
126
|
impl<'s> TryConvert for TextEncodeInput<'s> {
|
126
127
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
127
|
-
if let Ok(i) =
|
128
|
+
if let Ok(i) = TextInputSequence::try_convert(ob) {
|
128
129
|
return Ok(Self(i.into()));
|
129
130
|
}
|
130
|
-
if let Ok((i1, i2)) =
|
131
|
+
if let Ok((i1, i2)) = <(TextInputSequence, TextInputSequence)>::try_convert(ob) {
|
131
132
|
return Ok(Self((i1, i2).into()));
|
132
133
|
}
|
133
134
|
// TODO check if this branch is needed
|
134
|
-
if let Ok(arr) =
|
135
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
135
136
|
if arr.len() == 2 {
|
136
137
|
let first = arr.entry::<TextInputSequence>(0).unwrap();
|
137
138
|
let second = arr.entry::<TextInputSequence>(1).unwrap();
|
@@ -155,16 +156,16 @@ struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
|
155
156
|
|
156
157
|
impl<'s> TryConvert for PreTokenizedEncodeInput<'s> {
|
157
158
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
158
|
-
if let Ok(i) =
|
159
|
+
if let Ok(i) = PreTokenizedInputSequence::try_convert(ob) {
|
159
160
|
return Ok(Self(i.into()));
|
160
161
|
}
|
161
162
|
if let Ok((i1, i2)) =
|
162
|
-
|
163
|
+
<(PreTokenizedInputSequence, PreTokenizedInputSequence)>::try_convert(ob)
|
163
164
|
{
|
164
165
|
return Ok(Self((i1, i2).into()));
|
165
166
|
}
|
166
167
|
// TODO check if this branch is needed
|
167
|
-
if let Ok(arr) =
|
168
|
+
if let Ok(arr) = RArray::try_convert(ob) {
|
168
169
|
if arr.len() == 2 {
|
169
170
|
let first = arr.entry::<PreTokenizedInputSequence>(0).unwrap();
|
170
171
|
let second = arr.entry::<PreTokenizedInputSequence>(1).unwrap();
|
@@ -251,16 +252,16 @@ impl RbTokenizer {
|
|
251
252
|
add_special_tokens: bool,
|
252
253
|
) -> RbResult<RbEncoding> {
|
253
254
|
let sequence: tk::InputSequence = if is_pretokenized {
|
254
|
-
|
255
|
+
PreTokenizedInputSequence::try_convert(sequence)?.into()
|
255
256
|
} else {
|
256
|
-
|
257
|
+
TextInputSequence::try_convert(sequence)?.into()
|
257
258
|
};
|
258
259
|
let input = match pair {
|
259
260
|
Some(pair) => {
|
260
261
|
let pair: tk::InputSequence = if is_pretokenized {
|
261
|
-
|
262
|
+
PreTokenizedInputSequence::try_convert(pair)?.into()
|
262
263
|
} else {
|
263
|
-
|
264
|
+
TextInputSequence::try_convert(pair)?.into()
|
264
265
|
};
|
265
266
|
tk::EncodeInput::Dual(sequence, pair)
|
266
267
|
}
|
@@ -284,9 +285,9 @@ impl RbTokenizer {
|
|
284
285
|
.each()
|
285
286
|
.map(|o| {
|
286
287
|
let input: tk::EncodeInput = if is_pretokenized {
|
287
|
-
|
288
|
+
PreTokenizedEncodeInput::try_convert(o?)?.into()
|
288
289
|
} else {
|
289
|
-
|
290
|
+
TextEncodeInput::try_convert(o?)?.into()
|
290
291
|
};
|
291
292
|
Ok(input)
|
292
293
|
})
|
@@ -306,14 +307,15 @@ impl RbTokenizer {
|
|
306
307
|
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> RbResult<String> {
|
307
308
|
self.tokenizer
|
308
309
|
.borrow()
|
309
|
-
.decode(ids, skip_special_tokens)
|
310
|
+
.decode(&ids, skip_special_tokens)
|
310
311
|
.map_err(RbError::from)
|
311
312
|
}
|
312
313
|
|
313
314
|
pub fn decode_batch(&self, sequences: Vec<Vec<u32>>, skip_special_tokens: bool) -> RbResult<Vec<String>> {
|
315
|
+
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
314
316
|
self.tokenizer
|
315
317
|
.borrow()
|
316
|
-
.decode_batch(
|
318
|
+
.decode_batch(&slices, skip_special_tokens)
|
317
319
|
.map_err(RbError::from)
|
318
320
|
}
|
319
321
|
|
@@ -353,7 +355,7 @@ impl RbTokenizer {
|
|
353
355
|
|
354
356
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
355
357
|
if !value.is_nil() {
|
356
|
-
let dir_str
|
358
|
+
let dir_str = String::try_convert(value)?;
|
357
359
|
params.direction = match dir_str.as_str() {
|
358
360
|
"left" => PaddingDirection::Left,
|
359
361
|
"right" => PaddingDirection::Right,
|
@@ -363,29 +365,29 @@ impl RbTokenizer {
|
|
363
365
|
|
364
366
|
let value: Value = kwargs.delete(Symbol::new("pad_to_multiple_of"))?;
|
365
367
|
if !value.is_nil() {
|
366
|
-
params.pad_to_multiple_of =
|
368
|
+
params.pad_to_multiple_of = TryConvert::try_convert(value)?;
|
367
369
|
}
|
368
370
|
|
369
371
|
let value: Value = kwargs.delete(Symbol::new("pad_id"))?;
|
370
372
|
if !value.is_nil() {
|
371
|
-
params.pad_id =
|
373
|
+
params.pad_id = TryConvert::try_convert(value)?;
|
372
374
|
}
|
373
375
|
|
374
376
|
let value: Value = kwargs.delete(Symbol::new("pad_type_id"))?;
|
375
377
|
if !value.is_nil() {
|
376
|
-
params.pad_type_id =
|
378
|
+
params.pad_type_id = TryConvert::try_convert(value)?;
|
377
379
|
}
|
378
380
|
|
379
381
|
let value: Value = kwargs.delete(Symbol::new("pad_token"))?;
|
380
382
|
if !value.is_nil() {
|
381
|
-
params.pad_token =
|
383
|
+
params.pad_token = TryConvert::try_convert(value)?;
|
382
384
|
}
|
383
385
|
|
384
386
|
let value: Value = kwargs.delete(Symbol::new("length"))?;
|
385
387
|
if value.is_nil() {
|
386
388
|
params.strategy = PaddingStrategy::BatchLongest;
|
387
389
|
} else {
|
388
|
-
params.strategy = PaddingStrategy::Fixed(
|
390
|
+
params.strategy = PaddingStrategy::Fixed(TryConvert::try_convert(value)?);
|
389
391
|
}
|
390
392
|
|
391
393
|
if !kwargs.is_empty() {
|
@@ -431,12 +433,12 @@ impl RbTokenizer {
|
|
431
433
|
|
432
434
|
let value: Value = kwargs.delete(Symbol::new("stride"))?;
|
433
435
|
if !value.is_nil() {
|
434
|
-
params.stride =
|
436
|
+
params.stride = TryConvert::try_convert(value)?;
|
435
437
|
}
|
436
438
|
|
437
439
|
let value: Value = kwargs.delete(Symbol::new("strategy"))?;
|
438
440
|
if !value.is_nil() {
|
439
|
-
let strategy_str
|
441
|
+
let strategy_str = String::try_convert(value)?;
|
440
442
|
params.strategy = match strategy_str.as_str() {
|
441
443
|
"longest_first" => TruncationStrategy::LongestFirst,
|
442
444
|
"only_first" => TruncationStrategy::OnlyFirst,
|
@@ -447,7 +449,7 @@ impl RbTokenizer {
|
|
447
449
|
|
448
450
|
let value: Value = kwargs.delete(Symbol::new("direction"))?;
|
449
451
|
if !value.is_nil() {
|
450
|
-
let dir_str
|
452
|
+
let dir_str = String::try_convert(value)?;
|
451
453
|
params.direction = match dir_str.as_str() {
|
452
454
|
"left" => TruncationDirection::Left,
|
453
455
|
"right" => TruncationDirection::Right,
|
@@ -460,13 +462,18 @@ impl RbTokenizer {
|
|
460
462
|
return Err(Error::new(exception::arg_error(), "unknown keyword"));
|
461
463
|
}
|
462
464
|
|
463
|
-
self.tokenizer.borrow_mut().with_truncation(Some(params))
|
465
|
+
if let Err(error_message) = self.tokenizer.borrow_mut().with_truncation(Some(params)) {
|
466
|
+
return Err(Error::new(exception::arg_error(), error_message.to_string()));
|
467
|
+
}
|
464
468
|
|
465
469
|
Ok(())
|
466
470
|
}
|
467
471
|
|
468
472
|
pub fn no_truncation(&self) {
|
469
|
-
self.tokenizer
|
473
|
+
self.tokenizer
|
474
|
+
.borrow_mut()
|
475
|
+
.with_truncation(None)
|
476
|
+
.expect("Failed to set truncation to `None`! This should never happen");
|
470
477
|
}
|
471
478
|
|
472
479
|
pub fn truncation(&self) -> RbResult<Option<RHash>> {
|
@@ -3,16 +3,16 @@ use std::sync::{Arc, RwLock};
|
|
3
3
|
|
4
4
|
use crate::models::RbModel;
|
5
5
|
use crate::tokenizer::RbAddedToken;
|
6
|
-
use magnus::
|
6
|
+
use magnus::prelude::*;
|
7
7
|
use magnus::{
|
8
|
-
exception, function,
|
9
|
-
RArray, RClass, RHash, RModule, Symbol, TypedData, Value,
|
8
|
+
data_type_builder, exception, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error, Module, Object,
|
9
|
+
RArray, RClass, RHash, RModule, Ruby, Symbol, TryConvert, TypedData, Value,
|
10
10
|
};
|
11
11
|
use serde::{Deserialize, Serialize};
|
12
12
|
use tk::models::TrainerWrapper;
|
13
13
|
use tk::Trainer;
|
14
14
|
|
15
|
-
use super::RbResult;
|
15
|
+
use super::{RbResult, TRAINERS};
|
16
16
|
|
17
17
|
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
18
18
|
pub struct RbTrainer {
|
@@ -112,7 +112,7 @@ impl RbTrainer {
|
|
112
112
|
special_tokens
|
113
113
|
.each()
|
114
114
|
.map(|token| {
|
115
|
-
if let Ok(content) =
|
115
|
+
if let Ok(content) = String::try_convert(token?) {
|
116
116
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
117
117
|
} else {
|
118
118
|
todo!()
|
@@ -144,7 +144,7 @@ impl RbTrainer {
|
|
144
144
|
self,
|
145
145
|
BpeTrainer,
|
146
146
|
initial_alphabet,
|
147
|
-
alphabet.into_iter().
|
147
|
+
alphabet.into_iter().collect()
|
148
148
|
);
|
149
149
|
}
|
150
150
|
|
@@ -199,7 +199,7 @@ impl RbTrainer {
|
|
199
199
|
special_tokens
|
200
200
|
.each()
|
201
201
|
.map(|token| {
|
202
|
-
if let Ok(content) =
|
202
|
+
if let Ok(content) = String::try_convert(token?) {
|
203
203
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
204
204
|
} else {
|
205
205
|
todo!()
|
@@ -223,7 +223,7 @@ impl RbTrainer {
|
|
223
223
|
self,
|
224
224
|
UnigramTrainer,
|
225
225
|
initial_alphabet,
|
226
|
-
alphabet.into_iter().
|
226
|
+
alphabet.into_iter().collect()
|
227
227
|
);
|
228
228
|
}
|
229
229
|
|
@@ -270,7 +270,7 @@ impl RbTrainer {
|
|
270
270
|
special_tokens
|
271
271
|
.each()
|
272
272
|
.map(|token| {
|
273
|
-
if let Ok(content) =
|
273
|
+
if let Ok(content) = String::try_convert(token?) {
|
274
274
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
275
275
|
} else {
|
276
276
|
todo!()
|
@@ -324,7 +324,7 @@ impl RbTrainer {
|
|
324
324
|
special_tokens
|
325
325
|
.each()
|
326
326
|
.map(|token| {
|
327
|
-
if let Ok(content) =
|
327
|
+
if let Ok(content) = String::try_convert(token?) {
|
328
328
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
329
329
|
} else {
|
330
330
|
todo!()
|
@@ -356,7 +356,7 @@ impl RbTrainer {
|
|
356
356
|
self,
|
357
357
|
WordPieceTrainer,
|
358
358
|
@set_initial_alphabet,
|
359
|
-
alphabet.into_iter().
|
359
|
+
alphabet.into_iter().collect()
|
360
360
|
);
|
361
361
|
}
|
362
362
|
|
@@ -397,11 +397,10 @@ impl RbBpeTrainer {
|
|
397
397
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
398
398
|
if !value.is_nil() {
|
399
399
|
builder = builder.special_tokens(
|
400
|
-
value
|
401
|
-
.try_convert::<RArray>()?
|
400
|
+
RArray::try_convert(value)?
|
402
401
|
.each()
|
403
402
|
.map(|token| {
|
404
|
-
if let Ok(content) =
|
403
|
+
if let Ok(content) = String::try_convert(token?) {
|
405
404
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
406
405
|
} else {
|
407
406
|
todo!()
|
@@ -413,39 +412,39 @@ impl RbBpeTrainer {
|
|
413
412
|
|
414
413
|
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
415
414
|
if !value.is_nil() {
|
416
|
-
let arr =
|
415
|
+
let arr = <Vec<char>>::try_convert(value)?;
|
417
416
|
let set: HashSet<char> = HashSet::from_iter(arr);
|
418
417
|
builder = builder.initial_alphabet(set);
|
419
418
|
}
|
420
419
|
|
421
420
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
422
421
|
if !value.is_nil() {
|
423
|
-
builder = builder.vocab_size(
|
422
|
+
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
424
423
|
}
|
425
424
|
|
426
425
|
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
427
426
|
if !value.is_nil() {
|
428
|
-
builder = builder.min_frequency(
|
427
|
+
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
429
428
|
}
|
430
429
|
|
431
430
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
432
431
|
if !value.is_nil() {
|
433
|
-
builder = builder.show_progress(
|
432
|
+
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
434
433
|
}
|
435
434
|
|
436
435
|
let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
|
437
436
|
if !value.is_nil() {
|
438
|
-
builder = builder.limit_alphabet(
|
437
|
+
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
439
438
|
}
|
440
439
|
|
441
440
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
442
441
|
if !value.is_nil() {
|
443
|
-
builder = builder.continuing_subword_prefix(
|
442
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
444
443
|
}
|
445
444
|
|
446
445
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
447
446
|
if !value.is_nil() {
|
448
|
-
builder = builder.end_of_word_suffix(
|
447
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
449
448
|
}
|
450
449
|
|
451
450
|
if !kwargs.is_empty() {
|
@@ -466,11 +465,10 @@ impl RbUnigramTrainer {
|
|
466
465
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
467
466
|
if !value.is_nil() {
|
468
467
|
builder.special_tokens(
|
469
|
-
value
|
470
|
-
.try_convert::<RArray>()?
|
468
|
+
RArray::try_convert(value)?
|
471
469
|
.each()
|
472
470
|
.map(|token| {
|
473
|
-
if let Ok(content) =
|
471
|
+
if let Ok(content) = String::try_convert(token?) {
|
474
472
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
475
473
|
} else {
|
476
474
|
todo!()
|
@@ -482,44 +480,44 @@ impl RbUnigramTrainer {
|
|
482
480
|
|
483
481
|
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
484
482
|
if !value.is_nil() {
|
485
|
-
let arr =
|
483
|
+
let arr = <Vec<char>>::try_convert(value)?;
|
486
484
|
let set: HashSet<char> = HashSet::from_iter(arr);
|
487
485
|
builder.initial_alphabet(set);
|
488
486
|
}
|
489
487
|
|
490
488
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
491
489
|
if !value.is_nil() {
|
492
|
-
builder.vocab_size(
|
490
|
+
builder.vocab_size(TryConvert::try_convert(value)?);
|
493
491
|
}
|
494
492
|
|
495
493
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
496
494
|
if !value.is_nil() {
|
497
|
-
builder.show_progress(
|
495
|
+
builder.show_progress(TryConvert::try_convert(value)?);
|
498
496
|
}
|
499
497
|
|
500
498
|
let value: Value = kwargs.delete(Symbol::new("n_sub_iterations"))?;
|
501
499
|
if !value.is_nil() {
|
502
|
-
builder.n_sub_iterations(
|
500
|
+
builder.n_sub_iterations(TryConvert::try_convert(value)?);
|
503
501
|
}
|
504
502
|
|
505
503
|
let value: Value = kwargs.delete(Symbol::new("unk_token"))?;
|
506
504
|
if !value.is_nil() {
|
507
|
-
builder.unk_token(Some(
|
505
|
+
builder.unk_token(Some(TryConvert::try_convert(value)?));
|
508
506
|
}
|
509
507
|
|
510
508
|
let value: Value = kwargs.delete(Symbol::new("max_piece_length"))?;
|
511
509
|
if !value.is_nil() {
|
512
|
-
builder.max_piece_length(
|
510
|
+
builder.max_piece_length(TryConvert::try_convert(value)?);
|
513
511
|
}
|
514
512
|
|
515
513
|
let value: Value = kwargs.delete(Symbol::new("seed_size"))?;
|
516
514
|
if !value.is_nil() {
|
517
|
-
builder.seed_size(
|
515
|
+
builder.seed_size(TryConvert::try_convert(value)?);
|
518
516
|
}
|
519
517
|
|
520
518
|
let value: Value = kwargs.delete(Symbol::new("shrinking_factor"))?;
|
521
519
|
if !value.is_nil() {
|
522
|
-
builder.shrinking_factor(
|
520
|
+
builder.shrinking_factor(TryConvert::try_convert(value)?);
|
523
521
|
}
|
524
522
|
|
525
523
|
if !kwargs.is_empty() {
|
@@ -541,11 +539,10 @@ impl RbWordLevelTrainer {
|
|
541
539
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
542
540
|
if !value.is_nil() {
|
543
541
|
builder.special_tokens(
|
544
|
-
value
|
545
|
-
.try_convert::<RArray>()?
|
542
|
+
RArray::try_convert(value)?
|
546
543
|
.each()
|
547
544
|
.map(|token| {
|
548
|
-
if let Ok(content) =
|
545
|
+
if let Ok(content) = String::try_convert(token?) {
|
549
546
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
550
547
|
} else {
|
551
548
|
todo!()
|
@@ -557,17 +554,17 @@ impl RbWordLevelTrainer {
|
|
557
554
|
|
558
555
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
559
556
|
if !value.is_nil() {
|
560
|
-
builder.vocab_size(
|
557
|
+
builder.vocab_size(TryConvert::try_convert(value)?);
|
561
558
|
}
|
562
559
|
|
563
560
|
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
564
561
|
if !value.is_nil() {
|
565
|
-
builder.min_frequency(
|
562
|
+
builder.min_frequency(TryConvert::try_convert(value)?);
|
566
563
|
}
|
567
564
|
|
568
565
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
569
566
|
if !value.is_nil() {
|
570
|
-
builder.show_progress(
|
567
|
+
builder.show_progress(TryConvert::try_convert(value)?);
|
571
568
|
}
|
572
569
|
|
573
570
|
Ok(builder.build().expect("WordLevelTrainerBuilder cannot fail").into())
|
@@ -583,11 +580,10 @@ impl RbWordPieceTrainer {
|
|
583
580
|
let value: Value = kwargs.delete(Symbol::new("special_tokens"))?;
|
584
581
|
if !value.is_nil() {
|
585
582
|
builder = builder.special_tokens(
|
586
|
-
value
|
587
|
-
.try_convert::<RArray>()?
|
583
|
+
RArray::try_convert(value)?
|
588
584
|
.each()
|
589
585
|
.map(|token| {
|
590
|
-
if let Ok(content) =
|
586
|
+
if let Ok(content) = String::try_convert(token?) {
|
591
587
|
Ok(RbAddedToken::from(content, Some(true)).get_token())
|
592
588
|
} else {
|
593
589
|
todo!()
|
@@ -599,39 +595,39 @@ impl RbWordPieceTrainer {
|
|
599
595
|
|
600
596
|
let value: Value = kwargs.delete(Symbol::new("initial_alphabet"))?;
|
601
597
|
if !value.is_nil() {
|
602
|
-
let arr =
|
598
|
+
let arr = <Vec<char>>::try_convert(value)?;
|
603
599
|
let set: HashSet<char> = HashSet::from_iter(arr);
|
604
600
|
builder = builder.initial_alphabet(set);
|
605
601
|
}
|
606
602
|
|
607
603
|
let value: Value = kwargs.delete(Symbol::new("vocab_size"))?;
|
608
604
|
if !value.is_nil() {
|
609
|
-
builder = builder.vocab_size(
|
605
|
+
builder = builder.vocab_size(TryConvert::try_convert(value)?);
|
610
606
|
}
|
611
607
|
|
612
608
|
let value: Value = kwargs.delete(Symbol::new("min_frequency"))?;
|
613
609
|
if !value.is_nil() {
|
614
|
-
builder = builder.min_frequency(
|
610
|
+
builder = builder.min_frequency(TryConvert::try_convert(value)?);
|
615
611
|
}
|
616
612
|
|
617
613
|
let value: Value = kwargs.delete(Symbol::new("show_progress"))?;
|
618
614
|
if !value.is_nil() {
|
619
|
-
builder = builder.show_progress(
|
615
|
+
builder = builder.show_progress(TryConvert::try_convert(value)?);
|
620
616
|
}
|
621
617
|
|
622
618
|
let value: Value = kwargs.delete(Symbol::new("limit_alphabet"))?;
|
623
619
|
if !value.is_nil() {
|
624
|
-
builder = builder.limit_alphabet(
|
620
|
+
builder = builder.limit_alphabet(TryConvert::try_convert(value)?);
|
625
621
|
}
|
626
622
|
|
627
623
|
let value: Value = kwargs.delete(Symbol::new("continuing_subword_prefix"))?;
|
628
624
|
if !value.is_nil() {
|
629
|
-
builder = builder.continuing_subword_prefix(
|
625
|
+
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
630
626
|
}
|
631
627
|
|
632
628
|
let value: Value = kwargs.delete(Symbol::new("end_of_word_suffix"))?;
|
633
629
|
if !value.is_nil() {
|
634
|
-
builder = builder.end_of_word_suffix(
|
630
|
+
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
635
631
|
}
|
636
632
|
|
637
633
|
if !kwargs.is_empty() {
|
@@ -644,46 +640,52 @@ impl RbWordPieceTrainer {
|
|
644
640
|
}
|
645
641
|
|
646
642
|
unsafe impl TypedData for RbTrainer {
|
647
|
-
fn class() -> RClass {
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
})
|
643
|
+
fn class(ruby: &Ruby) -> RClass {
|
644
|
+
static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
|
645
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("Trainer").unwrap();
|
646
|
+
class.undef_default_alloc_func();
|
647
|
+
class
|
648
|
+
});
|
649
|
+
ruby.get_inner(&CLASS)
|
653
650
|
}
|
654
651
|
|
655
652
|
fn data_type() -> &'static DataType {
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
653
|
+
static DATA_TYPE: DataType = data_type_builder!(RbTrainer, "Tokenizers::Trainers::Trainer").build();
|
654
|
+
&DATA_TYPE
|
655
|
+
}
|
656
|
+
|
657
|
+
fn class_for(ruby: &Ruby, value: &Self) -> RClass {
|
658
|
+
static BPE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
659
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("BpeTrainer").unwrap();
|
660
|
+
class.undef_default_alloc_func();
|
661
|
+
class
|
662
|
+
});
|
663
|
+
static UNIGRAM_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
664
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("UnigramTrainer").unwrap();
|
665
|
+
class.undef_default_alloc_func();
|
666
|
+
class
|
667
|
+
});
|
668
|
+
static WORD_LEVEL_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
669
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordLevelTrainer").unwrap();
|
670
|
+
class.undef_default_alloc_func();
|
671
|
+
class
|
672
|
+
});
|
673
|
+
static WORD_PIECE_TRAINER: Lazy<RClass> = Lazy::new(|ruby| {
|
674
|
+
let class: RClass = ruby.get_inner(&TRAINERS).const_get("WordPieceTrainer").unwrap();
|
675
|
+
class.undef_default_alloc_func();
|
676
|
+
class
|
677
|
+
});
|
660
678
|
match *value.trainer.read().unwrap() {
|
661
|
-
TrainerWrapper::BpeTrainer(_) =>
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
}),
|
666
|
-
TrainerWrapper::UnigramTrainer(_) => *memoize!(RClass: {
|
667
|
-
let class: RClass = crate::trainers().const_get("UnigramTrainer").unwrap();
|
668
|
-
class.undef_alloc_func();
|
669
|
-
class
|
670
|
-
}),
|
671
|
-
TrainerWrapper::WordLevelTrainer(_) => *memoize!(RClass: {
|
672
|
-
let class: RClass = crate::trainers().const_get("WordLevelTrainer").unwrap();
|
673
|
-
class.undef_alloc_func();
|
674
|
-
class
|
675
|
-
}),
|
676
|
-
TrainerWrapper::WordPieceTrainer(_) => *memoize!(RClass: {
|
677
|
-
let class: RClass = crate::trainers().const_get("WordPieceTrainer").unwrap();
|
678
|
-
class.undef_alloc_func();
|
679
|
-
class
|
680
|
-
}),
|
679
|
+
TrainerWrapper::BpeTrainer(_) => ruby.get_inner(&BPE_TRAINER),
|
680
|
+
TrainerWrapper::UnigramTrainer(_) => ruby.get_inner(&UNIGRAM_TRAINER),
|
681
|
+
TrainerWrapper::WordLevelTrainer(_) => ruby.get_inner(&WORD_LEVEL_TRAINER),
|
682
|
+
TrainerWrapper::WordPieceTrainer(_) => ruby.get_inner(&WORD_PIECE_TRAINER),
|
681
683
|
}
|
682
684
|
}
|
683
685
|
}
|
684
686
|
|
685
|
-
pub fn
|
686
|
-
let trainer = module.define_class("Trainer",
|
687
|
+
pub fn init_trainers(ruby: &Ruby, module: &RModule) -> RbResult<()> {
|
688
|
+
let trainer = module.define_class("Trainer", ruby.class_object())?;
|
687
689
|
|
688
690
|
let class = module.define_class("BpeTrainer", trainer)?;
|
689
691
|
class.define_singleton_method("_new", function!(RbBpeTrainer::new, 1))?;
|
@@ -1,5 +1,6 @@
|
|
1
1
|
use super::regex::{regex, RbRegex};
|
2
2
|
use crate::RbResult;
|
3
|
+
use magnus::prelude::*;
|
3
4
|
use magnus::{exception, Error, TryConvert, Value};
|
4
5
|
use tk::normalizer::SplitDelimiterBehavior;
|
5
6
|
use tk::pattern::Pattern;
|
@@ -13,9 +14,9 @@ pub enum RbPattern<'p> {
|
|
13
14
|
impl TryConvert for RbPattern<'_> {
|
14
15
|
fn try_convert(obj: Value) -> RbResult<Self> {
|
15
16
|
if obj.is_kind_of(regex()) {
|
16
|
-
Ok(RbPattern::Regex(
|
17
|
+
Ok(RbPattern::Regex(TryConvert::try_convert(obj)?))
|
17
18
|
} else {
|
18
|
-
Ok(RbPattern::Str(
|
19
|
+
Ok(RbPattern::Str(TryConvert::try_convert(obj)?))
|
19
20
|
}
|
20
21
|
}
|
21
22
|
}
|
@@ -61,7 +62,7 @@ pub struct RbSplitDelimiterBehavior(pub SplitDelimiterBehavior);
|
|
61
62
|
|
62
63
|
impl TryConvert for RbSplitDelimiterBehavior {
|
63
64
|
fn try_convert(obj: Value) -> RbResult<Self> {
|
64
|
-
let s =
|
65
|
+
let s = String::try_convert(obj)?;
|
65
66
|
|
66
67
|
Ok(Self(match s.as_str() {
|
67
68
|
"removed" => Ok(SplitDelimiterBehavior::Removed),
|
@@ -1,6 +1,6 @@
|
|
1
1
|
use onig::Regex;
|
2
|
-
use magnus::{exception,
|
3
|
-
use crate::{
|
2
|
+
use magnus::{exception, prelude::*, value::Lazy, Error, RClass, Ruby};
|
3
|
+
use crate::{RbResult, TOKENIZERS};
|
4
4
|
|
5
5
|
#[magnus::wrap(class = "Tokenizers::Regex")]
|
6
6
|
pub struct RbRegex {
|
@@ -17,6 +17,8 @@ impl RbRegex {
|
|
17
17
|
}
|
18
18
|
}
|
19
19
|
|
20
|
+
static REGEX: Lazy<RClass> = Lazy::new(|ruby| ruby.get_inner(&TOKENIZERS).const_get("Regex").unwrap());
|
21
|
+
|
20
22
|
pub fn regex() -> RClass {
|
21
|
-
|
23
|
+
Ruby::get().unwrap().get_inner(®EX)
|
22
24
|
}
|