tokenizers 0.2.2 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/Cargo.lock +33 -74
- data/README.md +4 -0
- data/ext/tokenizers/Cargo.toml +4 -2
- data/ext/tokenizers/src/decoders.rs +275 -6
- data/ext/tokenizers/src/encoding.rs +78 -3
- data/ext/tokenizers/src/error.rs +2 -2
- data/ext/tokenizers/src/lib.rs +88 -17
- data/ext/tokenizers/src/models.rs +372 -11
- data/ext/tokenizers/src/normalizers.rs +435 -7
- data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
- data/ext/tokenizers/src/processors.rs +210 -0
- data/ext/tokenizers/src/tokenizer.rs +448 -20
- data/ext/tokenizers/src/trainers.rs +749 -0
- data/ext/tokenizers/src/utils/mod.rs +5 -0
- data/ext/tokenizers/src/utils/normalization.rs +85 -0
- data/ext/tokenizers/src/utils/regex.rs +22 -0
- data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
- data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
- data/lib/tokenizers/decoders/ctc.rb +9 -0
- data/lib/tokenizers/decoders/metaspace.rb +9 -0
- data/lib/tokenizers/decoders/word_piece.rb +9 -0
- data/lib/tokenizers/encoding.rb +19 -0
- data/lib/tokenizers/from_pretrained.rb +1 -1
- data/lib/tokenizers/models/bpe.rb +9 -0
- data/lib/tokenizers/models/unigram.rb +9 -0
- data/lib/tokenizers/models/word_level.rb +13 -0
- data/lib/tokenizers/models/word_piece.rb +9 -0
- data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
- data/lib/tokenizers/normalizers/strip.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
- data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
- data/lib/tokenizers/processors/byte_level.rb +9 -0
- data/lib/tokenizers/processors/roberta_processing.rb +9 -0
- data/lib/tokenizers/processors/template_processing.rb +9 -0
- data/lib/tokenizers/tokenizer.rb +45 -0
- data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
- data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
- data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
- data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
- data/lib/tokenizers/version.rb +1 -1
- data/lib/tokenizers.rb +49 -7
- metadata +32 -3
@@ -1,14 +1,478 @@
|
|
1
|
+
use std::sync::{Arc, RwLock};
|
2
|
+
|
3
|
+
use magnus::typed_data::DataTypeBuilder;
|
4
|
+
use magnus::{
|
5
|
+
function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object,
|
6
|
+
RArray, RClass, RModule, TypedData,
|
7
|
+
};
|
8
|
+
|
9
|
+
use serde::ser::SerializeStruct;
|
10
|
+
use serde::{Deserialize, Serialize, Serializer};
|
11
|
+
|
1
12
|
use tk::pre_tokenizers::bert::BertPreTokenizer;
|
13
|
+
use tk::pre_tokenizers::byte_level::ByteLevel;
|
14
|
+
use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
|
15
|
+
use tk::pre_tokenizers::digits::Digits;
|
16
|
+
use tk::pre_tokenizers::metaspace::Metaspace;
|
17
|
+
use tk::pre_tokenizers::punctuation::Punctuation;
|
18
|
+
use tk::pre_tokenizers::split::Split;
|
19
|
+
use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
|
20
|
+
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
21
|
+
use tk::pre_tokenizers::PreTokenizerWrapper;
|
22
|
+
use tk::tokenizer::Offsets;
|
23
|
+
use tk::{PreTokenizedString, PreTokenizer};
|
24
|
+
|
25
|
+
use super::utils::*;
|
26
|
+
use super::{RbError, RbResult};
|
27
|
+
|
28
|
+
#[derive(DataTypeFunctions, Clone, Serialize, Deserialize)]
|
29
|
+
pub struct RbPreTokenizer {
|
30
|
+
#[serde(flatten)]
|
31
|
+
pub(crate) pretok: RbPreTokenizerTypeWrapper,
|
32
|
+
}
|
33
|
+
|
34
|
+
impl RbPreTokenizer {
|
35
|
+
fn pre_tokenize_str(&self, s: String) -> RbResult<Vec<(String, Offsets)>> {
|
36
|
+
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
|
37
|
+
|
38
|
+
self.pretok.pre_tokenize(&mut pretokenized).map_err(RbError::from)?;
|
39
|
+
|
40
|
+
Ok(pretokenized
|
41
|
+
.get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char)
|
42
|
+
.into_iter()
|
43
|
+
.map(|(s, o, _)| (s.to_owned(), o))
|
44
|
+
.collect())
|
45
|
+
}
|
46
|
+
}
|
2
47
|
|
3
|
-
|
4
|
-
|
5
|
-
|
48
|
+
macro_rules! getter {
|
49
|
+
($self: ident, $variant: ident, $($name: tt)+) => {{
|
50
|
+
if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
|
51
|
+
if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) =
|
52
|
+
*single.read().unwrap() {
|
53
|
+
pretok.$($name)+
|
54
|
+
} else {
|
55
|
+
unreachable!()
|
56
|
+
}
|
57
|
+
} else {
|
58
|
+
unreachable!()
|
59
|
+
}
|
60
|
+
}};
|
6
61
|
}
|
7
62
|
|
63
|
+
macro_rules! setter {
|
64
|
+
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
65
|
+
if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
|
66
|
+
if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
67
|
+
*single.write().unwrap()
|
68
|
+
{
|
69
|
+
pretok.$name = $value;
|
70
|
+
}
|
71
|
+
}
|
72
|
+
}};
|
73
|
+
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
|
74
|
+
if let RbPreTokenizerTypeWrapper::Single(ref single) = &$self.pretok {
|
75
|
+
if let RbPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
76
|
+
*single.write().unwrap()
|
77
|
+
{
|
78
|
+
pretok.$name($value);
|
79
|
+
}
|
80
|
+
}
|
81
|
+
}};
|
82
|
+
}
|
83
|
+
|
84
|
+
impl RbPreTokenizer {
|
85
|
+
#[allow(dead_code)]
|
86
|
+
pub(crate) fn new(pretok: RbPreTokenizerTypeWrapper) -> Self {
|
87
|
+
RbPreTokenizer { pretok }
|
88
|
+
}
|
89
|
+
|
90
|
+
fn byte_level_add_prefix_space(&self) -> bool {
|
91
|
+
getter!(self, ByteLevel, add_prefix_space)
|
92
|
+
}
|
93
|
+
|
94
|
+
fn byte_level_set_add_prefix_space(&self, add_prefix_space: bool) {
|
95
|
+
setter!(self, ByteLevel, add_prefix_space, add_prefix_space);
|
96
|
+
}
|
97
|
+
|
98
|
+
fn byte_level_use_regex(&self) -> bool {
|
99
|
+
getter!(self, ByteLevel, use_regex)
|
100
|
+
}
|
101
|
+
|
102
|
+
fn byte_level_set_use_regex(&self, use_regex: bool) {
|
103
|
+
setter!(self, ByteLevel, use_regex, use_regex);
|
104
|
+
}
|
105
|
+
|
106
|
+
fn char_delimiter_split_delimiter(&self) -> String {
|
107
|
+
getter!(self, Delimiter, delimiter.to_string())
|
108
|
+
}
|
109
|
+
|
110
|
+
fn char_delimiter_split_set_delimiter(&self, delimiter: char) {
|
111
|
+
setter!(self, Delimiter, delimiter, delimiter);
|
112
|
+
}
|
113
|
+
|
114
|
+
fn digits_individual_digits(&self) -> bool {
|
115
|
+
getter!(self, Digits, individual_digits)
|
116
|
+
}
|
117
|
+
|
118
|
+
fn digits_set_individual_digits(&self, individual_digits: bool) {
|
119
|
+
setter!(self, Digits, individual_digits, individual_digits);
|
120
|
+
}
|
121
|
+
|
122
|
+
fn metaspace_add_prefix_space(&self) -> bool {
|
123
|
+
getter!(self, Metaspace, add_prefix_space)
|
124
|
+
}
|
125
|
+
|
126
|
+
fn metaspace_set_add_prefix_space(&self, add_prefix_space: bool) {
|
127
|
+
setter!(self, Metaspace, add_prefix_space, add_prefix_space);
|
128
|
+
}
|
129
|
+
|
130
|
+
fn metaspace_replacement(&self) -> String {
|
131
|
+
getter!(self, Metaspace, get_replacement().to_string())
|
132
|
+
}
|
133
|
+
|
134
|
+
fn metaspace_set_replacement(&self, replacement: char) {
|
135
|
+
setter!(self, Metaspace, @set_replacement, replacement);
|
136
|
+
}
|
137
|
+
}
|
138
|
+
|
139
|
+
impl PreTokenizer for RbPreTokenizer {
|
140
|
+
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
|
141
|
+
self.pretok.pre_tokenize(normalized)
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
pub struct RbByteLevel {}
|
146
|
+
|
147
|
+
impl RbByteLevel {
|
148
|
+
pub fn new(add_prefix_space: bool, use_regex: bool) -> RbPreTokenizer {
|
149
|
+
ByteLevel::default()
|
150
|
+
.add_prefix_space(add_prefix_space)
|
151
|
+
.use_regex(use_regex)
|
152
|
+
.into()
|
153
|
+
}
|
154
|
+
|
155
|
+
fn alphabet() -> Vec<String> {
|
156
|
+
ByteLevel::alphabet()
|
157
|
+
.into_iter()
|
158
|
+
.map(|c| c.to_string())
|
159
|
+
.collect()
|
160
|
+
}
|
161
|
+
}
|
162
|
+
|
163
|
+
pub struct RbCharDelimiterSplit {}
|
164
|
+
|
165
|
+
impl RbCharDelimiterSplit {
|
166
|
+
pub fn new(delimiter: char) -> RbPreTokenizer {
|
167
|
+
CharDelimiterSplit::new(delimiter).into()
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
pub struct RbDigits {}
|
172
|
+
|
173
|
+
impl RbDigits {
|
174
|
+
fn new(individual_digits: bool) -> RbPreTokenizer {
|
175
|
+
Digits::new(individual_digits).into()
|
176
|
+
}
|
177
|
+
}
|
178
|
+
|
179
|
+
pub struct RbMetaspace {}
|
180
|
+
|
181
|
+
impl RbMetaspace {
|
182
|
+
fn new(
|
183
|
+
replacement: char,
|
184
|
+
add_prefix_space: bool,
|
185
|
+
) -> RbPreTokenizer {
|
186
|
+
Metaspace::new(replacement, add_prefix_space).into()
|
187
|
+
}
|
188
|
+
}
|
189
|
+
|
190
|
+
pub struct RbPunctuation {}
|
191
|
+
|
192
|
+
impl RbPunctuation {
|
193
|
+
pub fn new(behavior: RbSplitDelimiterBehavior) -> RbResult<RbPreTokenizer> {
|
194
|
+
Ok(Punctuation::new(behavior.into()).into())
|
195
|
+
}
|
196
|
+
}
|
197
|
+
|
198
|
+
pub struct RbSplit {}
|
199
|
+
|
200
|
+
impl RbSplit {
|
201
|
+
pub fn new(pattern: RbPattern, behavior: RbSplitDelimiterBehavior, invert: bool) -> RbResult<RbPreTokenizer> {
|
202
|
+
Split::new(pattern, behavior.into(), invert).map(|v| v.into()).map_err(RbError::from)
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
pub struct RbUnicodeScripts {}
|
207
|
+
|
208
|
+
impl RbUnicodeScripts {
|
209
|
+
pub fn new() -> RbPreTokenizer {
|
210
|
+
UnicodeScripts::new().into()
|
211
|
+
}
|
212
|
+
}
|
213
|
+
|
214
|
+
pub struct RbWhitespace {}
|
215
|
+
|
216
|
+
impl RbWhitespace {
|
217
|
+
pub fn new() -> RbPreTokenizer {
|
218
|
+
Whitespace::default().into()
|
219
|
+
}
|
220
|
+
}
|
221
|
+
|
222
|
+
pub struct RbWhitespaceSplit {}
|
223
|
+
|
224
|
+
impl RbWhitespaceSplit {
|
225
|
+
pub fn new() -> RbPreTokenizer {
|
226
|
+
WhitespaceSplit.into()
|
227
|
+
}
|
228
|
+
}
|
229
|
+
|
230
|
+
pub struct RbBertPreTokenizer {}
|
231
|
+
|
8
232
|
impl RbBertPreTokenizer {
|
9
|
-
pub fn new() ->
|
10
|
-
|
11
|
-
|
233
|
+
pub fn new() -> RbPreTokenizer {
|
234
|
+
BertPreTokenizer.into()
|
235
|
+
}
|
236
|
+
}
|
237
|
+
|
238
|
+
pub struct RbSequence {}
|
239
|
+
|
240
|
+
impl RbSequence {
|
241
|
+
fn new(pre_tokenizers: RArray) -> RbResult<RbPreTokenizer> {
|
242
|
+
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
|
243
|
+
for n in pre_tokenizers.each() {
|
244
|
+
let pretokenizer: &RbPreTokenizer = n?.try_convert()?;
|
245
|
+
match &pretokenizer.pretok {
|
246
|
+
RbPreTokenizerTypeWrapper::Sequence(inner) => {
|
247
|
+
sequence.extend(inner.iter().cloned())
|
248
|
+
}
|
249
|
+
RbPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
250
|
+
}
|
251
|
+
}
|
252
|
+
Ok(RbPreTokenizer::new(RbPreTokenizerTypeWrapper::Sequence(sequence)))
|
253
|
+
}
|
254
|
+
}
|
255
|
+
|
256
|
+
#[derive(Clone, Deserialize)]
|
257
|
+
#[serde(untagged)]
|
258
|
+
pub(crate) enum RbPreTokenizerWrapper {
|
259
|
+
// Custom(CustomPreTokenizer),
|
260
|
+
Wrapped(PreTokenizerWrapper),
|
261
|
+
}
|
262
|
+
|
263
|
+
impl Serialize for RbPreTokenizerWrapper {
|
264
|
+
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
|
265
|
+
where
|
266
|
+
S: Serializer,
|
267
|
+
{
|
268
|
+
match self {
|
269
|
+
RbPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
|
270
|
+
// RbPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
|
271
|
+
}
|
272
|
+
}
|
273
|
+
}
|
274
|
+
|
275
|
+
#[derive(Clone, Deserialize)]
|
276
|
+
#[serde(untagged)]
|
277
|
+
pub(crate) enum RbPreTokenizerTypeWrapper {
|
278
|
+
Sequence(Vec<Arc<RwLock<RbPreTokenizerWrapper>>>),
|
279
|
+
Single(Arc<RwLock<RbPreTokenizerWrapper>>),
|
280
|
+
}
|
281
|
+
|
282
|
+
impl Serialize for RbPreTokenizerTypeWrapper {
|
283
|
+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
284
|
+
where
|
285
|
+
S: Serializer,
|
286
|
+
{
|
287
|
+
match self {
|
288
|
+
RbPreTokenizerTypeWrapper::Sequence(seq) => {
|
289
|
+
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
290
|
+
ser.serialize_field("type", "Sequence")?;
|
291
|
+
ser.serialize_field("pretokenizers", seq)?;
|
292
|
+
ser.end()
|
293
|
+
}
|
294
|
+
RbPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer),
|
295
|
+
}
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
impl<I> From<I> for RbPreTokenizerWrapper
|
300
|
+
where
|
301
|
+
I: Into<PreTokenizerWrapper>,
|
302
|
+
{
|
303
|
+
fn from(pretok: I) -> Self {
|
304
|
+
RbPreTokenizerWrapper::Wrapped(pretok.into())
|
305
|
+
}
|
306
|
+
}
|
307
|
+
|
308
|
+
impl<I> From<I> for RbPreTokenizerTypeWrapper
|
309
|
+
where
|
310
|
+
I: Into<RbPreTokenizerWrapper>,
|
311
|
+
{
|
312
|
+
fn from(pretok: I) -> Self {
|
313
|
+
RbPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into())))
|
314
|
+
}
|
315
|
+
}
|
316
|
+
|
317
|
+
impl<I> From<I> for RbPreTokenizer
|
318
|
+
where
|
319
|
+
I: Into<PreTokenizerWrapper>,
|
320
|
+
{
|
321
|
+
fn from(pretok: I) -> Self {
|
322
|
+
RbPreTokenizer {
|
323
|
+
pretok: pretok.into().into(),
|
324
|
+
}
|
325
|
+
}
|
326
|
+
}
|
327
|
+
|
328
|
+
impl PreTokenizer for RbPreTokenizerTypeWrapper {
|
329
|
+
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
330
|
+
match self {
|
331
|
+
RbPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok),
|
332
|
+
RbPreTokenizerTypeWrapper::Sequence(inner) => inner
|
333
|
+
.iter()
|
334
|
+
.try_for_each(|n| n.read().unwrap().pre_tokenize(pretok)),
|
12
335
|
}
|
13
336
|
}
|
14
337
|
}
|
338
|
+
|
339
|
+
impl PreTokenizer for RbPreTokenizerWrapper {
|
340
|
+
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
341
|
+
match self {
|
342
|
+
RbPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok),
|
343
|
+
// RbPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok),
|
344
|
+
}
|
345
|
+
}
|
346
|
+
}
|
347
|
+
|
348
|
+
unsafe impl TypedData for RbPreTokenizer {
|
349
|
+
fn class() -> RClass {
|
350
|
+
*memoize!(RClass: {
|
351
|
+
let class: RClass = crate::pre_tokenizers().const_get("PreTokenizer").unwrap();
|
352
|
+
class.undef_alloc_func();
|
353
|
+
class
|
354
|
+
})
|
355
|
+
}
|
356
|
+
|
357
|
+
fn data_type() -> &'static DataType {
|
358
|
+
memoize!(DataType: DataTypeBuilder::<RbPreTokenizer>::new("Tokenizers::PreTokenizers::PreTokenizer").build())
|
359
|
+
}
|
360
|
+
|
361
|
+
fn class_for(value: &Self) -> RClass {
|
362
|
+
match &value.pretok {
|
363
|
+
RbPreTokenizerTypeWrapper::Sequence(_seq) => *memoize!(RClass: {
|
364
|
+
let class: RClass = crate::pre_tokenizers().const_get("Sequence").unwrap();
|
365
|
+
class.undef_alloc_func();
|
366
|
+
class
|
367
|
+
}),
|
368
|
+
RbPreTokenizerTypeWrapper::Single(inner) => match &*inner.read().unwrap() {
|
369
|
+
RbPreTokenizerWrapper::Wrapped(wrapped) => match &wrapped {
|
370
|
+
PreTokenizerWrapper::BertPreTokenizer(_) => *memoize!(RClass: {
|
371
|
+
let class: RClass = crate::pre_tokenizers().const_get("BertPreTokenizer").unwrap();
|
372
|
+
class.undef_alloc_func();
|
373
|
+
class
|
374
|
+
}),
|
375
|
+
PreTokenizerWrapper::ByteLevel(_) => *memoize!(RClass: {
|
376
|
+
let class: RClass = crate::pre_tokenizers().const_get("ByteLevel").unwrap();
|
377
|
+
class.undef_alloc_func();
|
378
|
+
class
|
379
|
+
}),
|
380
|
+
PreTokenizerWrapper::Delimiter(_) => *memoize!(RClass: {
|
381
|
+
let class: RClass = crate::pre_tokenizers().const_get("CharDelimiterSplit").unwrap();
|
382
|
+
class.undef_alloc_func();
|
383
|
+
class
|
384
|
+
}),
|
385
|
+
PreTokenizerWrapper::Digits(_) => *memoize!(RClass: {
|
386
|
+
let class: RClass = crate::pre_tokenizers().const_get("Digits").unwrap();
|
387
|
+
class.undef_alloc_func();
|
388
|
+
class
|
389
|
+
}),
|
390
|
+
PreTokenizerWrapper::Metaspace(_) => *memoize!(RClass: {
|
391
|
+
let class: RClass = crate::pre_tokenizers().const_get("Metaspace").unwrap();
|
392
|
+
class.undef_alloc_func();
|
393
|
+
class
|
394
|
+
}),
|
395
|
+
PreTokenizerWrapper::Punctuation(_) => *memoize!(RClass: {
|
396
|
+
let class: RClass = crate::pre_tokenizers().const_get("Punctuation").unwrap();
|
397
|
+
class.undef_alloc_func();
|
398
|
+
class
|
399
|
+
}),
|
400
|
+
PreTokenizerWrapper::Split(_) => *memoize!(RClass: {
|
401
|
+
let class: RClass = crate::pre_tokenizers().const_get("Split").unwrap();
|
402
|
+
class.undef_alloc_func();
|
403
|
+
class
|
404
|
+
}),
|
405
|
+
PreTokenizerWrapper::UnicodeScripts(_) => *memoize!(RClass: {
|
406
|
+
let class: RClass = crate::pre_tokenizers().const_get("UnicodeScripts").unwrap();
|
407
|
+
class.undef_alloc_func();
|
408
|
+
class
|
409
|
+
}),
|
410
|
+
PreTokenizerWrapper::Whitespace(_) => *memoize!(RClass: {
|
411
|
+
let class: RClass = crate::pre_tokenizers().const_get("Whitespace").unwrap();
|
412
|
+
class.undef_alloc_func();
|
413
|
+
class
|
414
|
+
}),
|
415
|
+
PreTokenizerWrapper::WhitespaceSplit(_) => *memoize!(RClass: {
|
416
|
+
let class: RClass = crate::pre_tokenizers().const_get("WhitespaceSplit").unwrap();
|
417
|
+
class.undef_alloc_func();
|
418
|
+
class
|
419
|
+
}),
|
420
|
+
_ => todo!(),
|
421
|
+
},
|
422
|
+
},
|
423
|
+
}
|
424
|
+
}
|
425
|
+
}
|
426
|
+
|
427
|
+
pub fn pre_tokenizers(module: &RModule) -> RbResult<()> {
|
428
|
+
let pre_tokenizer = module.define_class("PreTokenizer", Default::default())?;
|
429
|
+
pre_tokenizer.define_method("pre_tokenize_str", method!(RbPreTokenizer::pre_tokenize_str, 1))?;
|
430
|
+
|
431
|
+
let class = module.define_class("Sequence", pre_tokenizer)?;
|
432
|
+
class.define_singleton_method("new", function!(RbSequence::new, 1))?;
|
433
|
+
|
434
|
+
let class = module.define_class("BertPreTokenizer", pre_tokenizer)?;
|
435
|
+
class.define_singleton_method("new", function!(RbBertPreTokenizer::new, 0))?;
|
436
|
+
|
437
|
+
let class = module.define_class("ByteLevel", pre_tokenizer)?;
|
438
|
+
class.define_singleton_method("_new", function!(RbByteLevel::new, 2))?;
|
439
|
+
class.define_singleton_method("alphabet", function!(RbByteLevel::alphabet, 0))?;
|
440
|
+
class.define_method("add_prefix_space", method!(RbPreTokenizer::byte_level_add_prefix_space, 0))?;
|
441
|
+
class.define_method("add_prefix_space=", method!(RbPreTokenizer::byte_level_set_add_prefix_space, 1))?;
|
442
|
+
class.define_method("use_regex", method!(RbPreTokenizer::byte_level_use_regex, 0))?;
|
443
|
+
class.define_method("use_regex=", method!(RbPreTokenizer::byte_level_set_use_regex, 1))?;
|
444
|
+
|
445
|
+
let class = module.define_class("CharDelimiterSplit", pre_tokenizer)?;
|
446
|
+
class.define_singleton_method("new", function!(RbCharDelimiterSplit::new, 1))?;
|
447
|
+
class.define_method("delimiter", method!(RbPreTokenizer::char_delimiter_split_delimiter, 0))?;
|
448
|
+
class.define_method("delimiter=", method!(RbPreTokenizer::char_delimiter_split_set_delimiter, 1))?;
|
449
|
+
|
450
|
+
let class = module.define_class("Digits", pre_tokenizer)?;
|
451
|
+
class.define_singleton_method("_new", function!(RbDigits::new, 1))?;
|
452
|
+
class.define_method("individual_digits", method!(RbPreTokenizer::digits_individual_digits, 0))?;
|
453
|
+
class.define_method("individual_digits=", method!(RbPreTokenizer::digits_set_individual_digits, 1))?;
|
454
|
+
|
455
|
+
let class = module.define_class("Metaspace", pre_tokenizer)?;
|
456
|
+
class.define_singleton_method("_new", function!(RbMetaspace::new, 2))?;
|
457
|
+
class.define_method("add_prefix_space", method!(RbPreTokenizer::metaspace_add_prefix_space, 0))?;
|
458
|
+
class.define_method("add_prefix_space=", method!(RbPreTokenizer::metaspace_set_add_prefix_space, 1))?;
|
459
|
+
class.define_method("replacement", method!(RbPreTokenizer::metaspace_replacement, 0))?;
|
460
|
+
class.define_method("replacement=", method!(RbPreTokenizer::metaspace_set_replacement, 1))?;
|
461
|
+
|
462
|
+
let class = module.define_class("Punctuation", pre_tokenizer)?;
|
463
|
+
class.define_singleton_method("_new", function!(RbPunctuation::new, 1))?;
|
464
|
+
|
465
|
+
let class = module.define_class("Split", pre_tokenizer)?;
|
466
|
+
class.define_singleton_method("_new", function!(RbSplit::new, 3))?;
|
467
|
+
|
468
|
+
let class = module.define_class("UnicodeScripts", pre_tokenizer)?;
|
469
|
+
class.define_singleton_method("new", function!(RbUnicodeScripts::new, 0))?;
|
470
|
+
|
471
|
+
let class = module.define_class("Whitespace", pre_tokenizer)?;
|
472
|
+
class.define_singleton_method("new", function!(RbWhitespace::new, 0))?;
|
473
|
+
|
474
|
+
let class = module.define_class("WhitespaceSplit", pre_tokenizer)?;
|
475
|
+
class.define_singleton_method("new", function!(RbWhitespaceSplit::new, 0))?;
|
476
|
+
|
477
|
+
Ok(())
|
478
|
+
}
|
@@ -0,0 +1,210 @@
|
|
1
|
+
use std::sync::Arc;
|
2
|
+
|
3
|
+
use magnus::typed_data::DataTypeBuilder;
|
4
|
+
use magnus::{
|
5
|
+
function, memoize, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
|
6
|
+
TryConvert, TypedData, Value,
|
7
|
+
};
|
8
|
+
use serde::{Deserialize, Serialize};
|
9
|
+
use tk::processors::bert::BertProcessing;
|
10
|
+
use tk::processors::byte_level::ByteLevel;
|
11
|
+
use tk::processors::roberta::RobertaProcessing;
|
12
|
+
use tk::processors::template::{SpecialToken, Template};
|
13
|
+
use tk::processors::PostProcessorWrapper;
|
14
|
+
use tk::{Encoding, PostProcessor};
|
15
|
+
|
16
|
+
use super::RbResult;
|
17
|
+
|
18
|
+
#[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
|
19
|
+
pub struct RbPostProcessor {
|
20
|
+
#[serde(flatten)]
|
21
|
+
pub processor: Arc<PostProcessorWrapper>,
|
22
|
+
}
|
23
|
+
|
24
|
+
impl RbPostProcessor {
|
25
|
+
pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
|
26
|
+
RbPostProcessor { processor }
|
27
|
+
}
|
28
|
+
}
|
29
|
+
|
30
|
+
impl PostProcessor for RbPostProcessor {
|
31
|
+
fn added_tokens(&self, is_pair: bool) -> usize {
|
32
|
+
self.processor.added_tokens(is_pair)
|
33
|
+
}
|
34
|
+
|
35
|
+
fn process_encodings(
|
36
|
+
&self,
|
37
|
+
encodings: Vec<Encoding>,
|
38
|
+
add_special_tokens: bool,
|
39
|
+
) -> tk::Result<Vec<Encoding>> {
|
40
|
+
self.processor
|
41
|
+
.process_encodings(encodings, add_special_tokens)
|
42
|
+
}
|
43
|
+
}
|
44
|
+
|
45
|
+
#[derive(Clone, Debug)]
|
46
|
+
pub struct RbSpecialToken(SpecialToken);
|
47
|
+
|
48
|
+
impl From<RbSpecialToken> for SpecialToken {
|
49
|
+
fn from(v: RbSpecialToken) -> Self {
|
50
|
+
v.0
|
51
|
+
}
|
52
|
+
}
|
53
|
+
|
54
|
+
impl TryConvert for RbSpecialToken {
|
55
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
56
|
+
if let Ok(v) = ob.try_convert::<(String, u32)>() {
|
57
|
+
Ok(Self(v.into()))
|
58
|
+
} else if let Ok(v) = ob.try_convert::<(u32, String)>() {
|
59
|
+
Ok(Self(v.into()))
|
60
|
+
} else {
|
61
|
+
todo!()
|
62
|
+
}
|
63
|
+
}
|
64
|
+
}
|
65
|
+
|
66
|
+
#[derive(Clone, Debug)]
|
67
|
+
pub struct RbTemplate(Template);
|
68
|
+
|
69
|
+
impl From<RbTemplate> for Template {
|
70
|
+
fn from(v: RbTemplate) -> Self {
|
71
|
+
v.0
|
72
|
+
}
|
73
|
+
}
|
74
|
+
|
75
|
+
impl TryConvert for RbTemplate {
|
76
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
77
|
+
if let Ok(s) = ob.try_convert::<String>() {
|
78
|
+
Ok(Self(
|
79
|
+
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
80
|
+
))
|
81
|
+
} else if let Ok(s) = ob.try_convert::<Vec<String>>() {
|
82
|
+
Ok(Self(
|
83
|
+
s.try_into().unwrap(), //.map_err(RbError::from)?,
|
84
|
+
))
|
85
|
+
} else {
|
86
|
+
todo!()
|
87
|
+
}
|
88
|
+
}
|
89
|
+
}
|
90
|
+
|
91
|
+
pub struct RbBertProcessing {}
|
92
|
+
|
93
|
+
impl RbBertProcessing {
|
94
|
+
pub fn new(sep: (String, u32), cls: (String, u32)) -> RbPostProcessor {
|
95
|
+
RbPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into()))
|
96
|
+
}
|
97
|
+
}
|
98
|
+
|
99
|
+
pub struct RbByteLevel {}
|
100
|
+
|
101
|
+
impl RbByteLevel {
|
102
|
+
pub fn new(trim_offsets: Option<bool>) -> RbPostProcessor {
|
103
|
+
let mut byte_level = ByteLevel::default();
|
104
|
+
|
105
|
+
if let Some(to) = trim_offsets {
|
106
|
+
byte_level = byte_level.trim_offsets(to);
|
107
|
+
}
|
108
|
+
RbPostProcessor::new(Arc::new(byte_level.into()))
|
109
|
+
}
|
110
|
+
|
111
|
+
}
|
112
|
+
|
113
|
+
pub struct RbRobertaProcessing {}
|
114
|
+
|
115
|
+
impl RbRobertaProcessing {
|
116
|
+
fn new(
|
117
|
+
sep: (String, u32),
|
118
|
+
cls: (String, u32),
|
119
|
+
trim_offsets: bool,
|
120
|
+
add_prefix_space: bool,
|
121
|
+
) -> RbPostProcessor {
|
122
|
+
let proc = RobertaProcessing::new(sep, cls)
|
123
|
+
.trim_offsets(trim_offsets)
|
124
|
+
.add_prefix_space(add_prefix_space);
|
125
|
+
RbPostProcessor::new(Arc::new(proc.into()))
|
126
|
+
}
|
127
|
+
}
|
128
|
+
|
129
|
+
pub struct RbTemplateProcessing {}
|
130
|
+
|
131
|
+
impl RbTemplateProcessing {
|
132
|
+
pub fn new(
|
133
|
+
single: Option<RbTemplate>,
|
134
|
+
pair: Option<RbTemplate>,
|
135
|
+
special_tokens: Option<Vec<(String, u32)>>,
|
136
|
+
) -> RbResult<RbPostProcessor> {
|
137
|
+
let mut builder = tk::processors::template::TemplateProcessing::builder();
|
138
|
+
|
139
|
+
if let Some(seq) = single {
|
140
|
+
builder.single(seq.into());
|
141
|
+
}
|
142
|
+
if let Some(seq) = pair {
|
143
|
+
builder.pair(seq.into());
|
144
|
+
}
|
145
|
+
if let Some(sp) = special_tokens {
|
146
|
+
builder.special_tokens(sp);
|
147
|
+
}
|
148
|
+
let processor = builder.build().unwrap(); //.map_err(RbError::from)?;
|
149
|
+
|
150
|
+
Ok(RbPostProcessor::new(Arc::new(processor.into())))
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
unsafe impl TypedData for RbPostProcessor {
|
155
|
+
fn class() -> RClass {
|
156
|
+
*memoize!(RClass: {
|
157
|
+
let class: RClass = crate::processors().const_get("PostProcessor").unwrap();
|
158
|
+
class.undef_alloc_func();
|
159
|
+
class
|
160
|
+
})
|
161
|
+
}
|
162
|
+
|
163
|
+
fn data_type() -> &'static DataType {
|
164
|
+
memoize!(DataType: DataTypeBuilder::<RbPostProcessor>::new("Tokenizers::Processors::PostProcessor").build())
|
165
|
+
}
|
166
|
+
|
167
|
+
fn class_for(value: &Self) -> RClass {
|
168
|
+
match *value.processor {
|
169
|
+
PostProcessorWrapper::Bert(_) => *memoize!(RClass: {
|
170
|
+
let class: RClass = crate::processors().const_get("BertProcessing").unwrap();
|
171
|
+
class.undef_alloc_func();
|
172
|
+
class
|
173
|
+
}),
|
174
|
+
PostProcessorWrapper::ByteLevel(_) => *memoize!(RClass: {
|
175
|
+
let class: RClass = crate::processors().const_get("ByteLevel").unwrap();
|
176
|
+
class.undef_alloc_func();
|
177
|
+
class
|
178
|
+
}),
|
179
|
+
PostProcessorWrapper::Roberta(_) => *memoize!(RClass: {
|
180
|
+
let class: RClass = crate::processors().const_get("RobertaProcessing").unwrap();
|
181
|
+
class.undef_alloc_func();
|
182
|
+
class
|
183
|
+
}),
|
184
|
+
PostProcessorWrapper::Template(_) => *memoize!(RClass: {
|
185
|
+
let class: RClass = crate::processors().const_get("TemplateProcessing").unwrap();
|
186
|
+
class.undef_alloc_func();
|
187
|
+
class
|
188
|
+
}),
|
189
|
+
_ => todo!(),
|
190
|
+
}
|
191
|
+
}
|
192
|
+
}
|
193
|
+
|
194
|
+
pub fn processors(module: &RModule) -> RbResult<()> {
|
195
|
+
let post_processor = module.define_class("PostProcessor", Default::default())?;
|
196
|
+
|
197
|
+
let class = module.define_class("BertProcessing", post_processor)?;
|
198
|
+
class.define_singleton_method("new", function!(RbBertProcessing::new, 2))?;
|
199
|
+
|
200
|
+
let class = module.define_class("ByteLevel", post_processor)?;
|
201
|
+
class.define_singleton_method("_new", function!(RbByteLevel::new, 1))?;
|
202
|
+
|
203
|
+
let class = module.define_class("RobertaProcessing", post_processor)?;
|
204
|
+
class.define_singleton_method("_new", function!(RbRobertaProcessing::new, 4))?;
|
205
|
+
|
206
|
+
let class = module.define_class("TemplateProcessing", post_processor)?;
|
207
|
+
class.define_singleton_method("_new", function!(RbTemplateProcessing::new, 3))?;
|
208
|
+
|
209
|
+
Ok(())
|
210
|
+
}
|