tokenizers 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
data/src/lib.rs DELETED
@@ -1,290 +0,0 @@
1
- #[macro_use]
2
- extern crate rutie;
3
-
4
- use rutie::{AnyException, AnyObject, Array, Integer, Module, Object, RString, VerifiedObject, VM};
5
- use tokenizers::decoders::bpe::BPEDecoder;
6
- use tokenizers::models::bpe::BPE;
7
- use tokenizers::normalizers::BertNormalizer;
8
- use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
9
- use tokenizers::tokenizer::Tokenizer;
10
- use tokenizers::{decoders, AddedToken, Encoding};
11
-
12
- pub const VERSION: &str = env!("CARGO_PKG_VERSION");
13
-
14
- wrappable_struct!(Tokenizer, TokenizerWrapper, TOKENIZER_WRAPPER);
15
- wrappable_struct!(BPE, BPEWrapper, BPE_WRAPPER);
16
- wrappable_struct!(Encoding, EncodingWrapper, ENCODING_WRAPPER);
17
- wrappable_struct!(BPEDecoder, BPEDecoderWrapper, BPE_DECODER_WRAPPER);
18
- wrappable_struct!(BertPreTokenizer, BertPreTokenizerWrapper, BERT_PRE_TOKENIZER_WRAPPER);
19
- wrappable_struct!(BertNormalizer, BertNormalizerWrapper, BERT_NORMALIZER_WRAPPER);
20
-
21
- module!(rbTokenizers);
22
-
23
- class!(rbBPE);
24
- class!(rbTokenizer);
25
- class!(rbEncoding);
26
- class!(rbBPEDecoder);
27
- class!(rbBertPreTokenizer);
28
- class!(rbBertNormalizer);
29
-
30
- fn unwrap_object<T>(res: Result<T, AnyException>) -> T {
31
- res.map_err(VM::raise_ex).unwrap()
32
- }
33
-
34
- fn unwrap_optional<T>(res: Result<AnyObject, AnyException>) -> Option<T>
35
- where
36
- T: VerifiedObject,
37
- {
38
- let x = unwrap_object(res);
39
- if x.is_nil() {
40
- None
41
- } else {
42
- Some(unwrap_object(x.try_convert_to::<T>()))
43
- }
44
- }
45
-
46
- fn handle_error<T>(res: Result<T, Box<dyn std::error::Error + Send + Sync>>) -> T {
47
- match res {
48
- Ok(x) => x,
49
- Err(e) => {
50
- VM::raise(
51
- Module::from_existing("Tokenizers").get_nested_class("Error"),
52
- &e.to_string(),
53
- );
54
- unreachable!()
55
- }
56
- }
57
- }
58
-
59
- methods!(
60
- rbTokenizers,
61
- _rtself,
62
-
63
- fn tokenizers_from_pretrained(identifier: RString, revision: RString, auth_token: AnyObject) -> AnyObject {
64
- let identifier = unwrap_object(identifier);
65
- let revision = unwrap_object(revision);
66
- let auth_token: Option<RString> = unwrap_optional(auth_token);
67
-
68
- let params = tokenizers::FromPretrainedParameters {
69
- revision: revision.to_string(),
70
- auth_token: auth_token.map(|x| x.to_string()),
71
- user_agent: [("bindings", "Ruby"), ("version", VERSION)]
72
- .iter()
73
- .map(|(k, v)| (k.to_string(), v.to_string()))
74
- .collect(),
75
- };
76
-
77
- let tokenizer = handle_error(Tokenizer::from_pretrained(identifier.to_string(), Some(params)));
78
- Module::from_existing("Tokenizers")
79
- .get_nested_class("Tokenizer")
80
- .wrap_data(tokenizer, &*TOKENIZER_WRAPPER)
81
- }
82
- );
83
-
84
- methods!(
85
- rbBPE,
86
- _rtself,
87
-
88
- fn bpe_new(vocab: RString, merges: RString) -> AnyObject {
89
- let vocab = unwrap_object(vocab);
90
- let merges = unwrap_object(merges);
91
-
92
- let bpe = handle_error(BPE::from_file(&vocab.to_string(), &merges.to_string())
93
- .unk_token("<unk>".into())
94
- .end_of_word_suffix("</w>".into())
95
- .build());
96
-
97
- Module::from_existing("Tokenizers")
98
- .get_nested_class("BPE")
99
- .wrap_data(bpe, &*BPE_WRAPPER)
100
- }
101
- );
102
-
103
- methods!(
104
- rbTokenizer,
105
- _rtself,
106
-
107
- fn tokenizer_new(model: AnyObject) -> AnyObject {
108
- let model = unwrap_object(model);
109
-
110
- // TODO support any model
111
- let model = model.get_data(&*BPE_WRAPPER).clone();
112
-
113
- let mut tokenizer = Tokenizer::new(model);
114
-
115
- Module::from_existing("Tokenizers")
116
- .get_nested_class("Tokenizer")
117
- .wrap_data(tokenizer, &*TOKENIZER_WRAPPER)
118
- }
119
- );
120
-
121
- methods!(
122
- rbTokenizer,
123
- rtself,
124
-
125
- fn tokenizer_add_special_tokens(tokens: Array) -> rbTokenizer {
126
- let tokenizer = rtself.get_data_mut(&*TOKENIZER_WRAPPER);
127
- let tokens = unwrap_object(tokens);
128
-
129
- let mut vec = Vec::new();
130
- for token in tokens.into_iter() {
131
- vec.push(AddedToken::from(unwrap_object(token.try_convert_to::<RString>()).to_string(), true));
132
- }
133
- tokenizer.add_special_tokens(&vec);
134
- rtself
135
- }
136
-
137
- fn tokenizer_encode(text: RString) -> AnyObject {
138
- let tokenizer = rtself.get_data(&*TOKENIZER_WRAPPER);
139
- let text = unwrap_object(text);
140
-
141
- let encoding = handle_error(tokenizer.encode(text.to_string(), false));
142
- Module::from_existing("Tokenizers")
143
- .get_nested_class("Encoding")
144
- .wrap_data(encoding, &*ENCODING_WRAPPER)
145
- }
146
-
147
- fn tokenizer_decode(ids: Array) -> RString {
148
- let tokenizer = rtself.get_data(&*TOKENIZER_WRAPPER);
149
- let ids = unwrap_object(ids);
150
-
151
- let mut vec = Vec::new();
152
- for item in ids.into_iter() {
153
- vec.push(unwrap_object(item.try_convert_to::<Integer>()).into());
154
- }
155
- let s = handle_error(tokenizer.decode(vec, true));
156
- RString::new_utf8(&s)
157
- }
158
-
159
- fn tokenizer_decoder_set(decoder: AnyObject) -> AnyObject {
160
- let tokenizer = rtself.get_data_mut(&*TOKENIZER_WRAPPER);
161
- let decoder = unwrap_object(decoder);
162
-
163
- tokenizer.with_decoder(decoder.get_data(&*BPE_DECODER_WRAPPER).clone());
164
- decoder
165
- }
166
-
167
- fn tokenizer_pre_tokenizer_set(pre_tokenizer: AnyObject) -> AnyObject {
168
- let tokenizer = rtself.get_data_mut(&*TOKENIZER_WRAPPER);
169
- let pre_tokenizer = unwrap_object(pre_tokenizer);
170
-
171
- tokenizer.with_pre_tokenizer(*pre_tokenizer.get_data(&*BERT_PRE_TOKENIZER_WRAPPER));
172
- pre_tokenizer
173
- }
174
-
175
- fn tokenizer_normalizer_set(normalizer: AnyObject) -> AnyObject {
176
- let tokenizer = rtself.get_data_mut(&*TOKENIZER_WRAPPER);
177
- let normalizer = unwrap_object(normalizer);
178
-
179
- tokenizer.with_normalizer(*normalizer.get_data(&*BERT_NORMALIZER_WRAPPER));
180
- normalizer
181
- }
182
- );
183
-
184
- methods!(
185
- rbEncoding,
186
- rtself,
187
-
188
- fn encoding_ids() -> Array {
189
- let encoding = rtself.get_data(&*ENCODING_WRAPPER);
190
-
191
- let mut array = Array::new();
192
- for x in encoding.get_ids() {
193
- array.push(Integer::from(*x));
194
- }
195
- array
196
- }
197
-
198
- fn encoding_tokens() -> Array {
199
- let encoding = rtself.get_data(&*ENCODING_WRAPPER);
200
-
201
- let mut array = Array::new();
202
- for x in encoding.get_tokens() {
203
- array.push(RString::new_utf8(x));
204
- }
205
- array
206
- }
207
- );
208
-
209
- methods!(
210
- rbBPEDecoder,
211
- _rtself,
212
-
213
- fn bpe_decoder_new() -> AnyObject {
214
- let decoder = decoders::bpe::BPEDecoder::default();
215
- Module::from_existing("Tokenizers")
216
- .get_nested_class("BPEDecoder")
217
- .wrap_data(decoder, &*BPE_DECODER_WRAPPER)
218
- }
219
- );
220
-
221
- methods!(
222
- rbBertPreTokenizer,
223
- _rtself,
224
-
225
- fn bert_pre_tokenizer_new() -> AnyObject {
226
- let pre_tokenizer = BertPreTokenizer;
227
- Module::from_existing("Tokenizers")
228
- .get_nested_class("BertPreTokenizer")
229
- .wrap_data(pre_tokenizer, &*BERT_PRE_TOKENIZER_WRAPPER)
230
- }
231
- );
232
-
233
- methods!(
234
- rbBertNormalizer,
235
- _rtself,
236
-
237
- fn bert_normalizer_new() -> AnyObject {
238
- let normalizer = BertNormalizer::default();
239
- Module::from_existing("Tokenizers")
240
- .get_nested_class("BertNormalizer")
241
- .wrap_data(normalizer, &*BERT_NORMALIZER_WRAPPER)
242
- }
243
- );
244
-
245
- #[allow(non_snake_case)]
246
- #[no_mangle]
247
- pub extern "C" fn Init_ext() {
248
- let mut m = Module::new("Tokenizers");
249
-
250
- m.define(|klass| {
251
- klass.def_self("_from_pretrained", tokenizers_from_pretrained);
252
- klass.define_nested_class("BPE", None);
253
- klass.define_nested_class("Tokenizer", None);
254
- klass.define_nested_class("Encoding", None);
255
- klass.define_nested_class("BPEDecoder", None);
256
- klass.define_nested_class("BertPreTokenizer", None);
257
- klass.define_nested_class("BertNormalizer", None);
258
- });
259
-
260
- m.get_nested_class("BPE").define(|klass| {
261
- klass.def_self("new", bpe_new);
262
- });
263
-
264
- m.get_nested_class("Tokenizer").define(|klass| {
265
- klass.def_self("new", tokenizer_new);
266
- klass.def("add_special_tokens", tokenizer_add_special_tokens);
267
- klass.def("encode", tokenizer_encode);
268
- klass.def("decode", tokenizer_decode);
269
- klass.def("decoder=", tokenizer_decoder_set);
270
- klass.def("pre_tokenizer=", tokenizer_pre_tokenizer_set);
271
- klass.def("normalizer=", tokenizer_normalizer_set);
272
- });
273
-
274
- m.get_nested_class("Encoding").define(|klass| {
275
- klass.def("ids", encoding_ids);
276
- klass.def("tokens", encoding_tokens);
277
- });
278
-
279
- m.get_nested_class("BPEDecoder").define(|klass| {
280
- klass.def_self("new", bpe_decoder_new);
281
- });
282
-
283
- m.get_nested_class("BertPreTokenizer").define(|klass| {
284
- klass.def_self("new", bert_pre_tokenizer_new);
285
- });
286
-
287
- m.get_nested_class("BertNormalizer").define(|klass| {
288
- klass.def_self("new", bert_normalizer_new);
289
- });
290
- }