tokenizers 0.2.3 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/Cargo.lock +32 -73
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +3 -1
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +3 -2
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +64 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +437 -23
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +9 -6
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/from_pretrained.rb +1 -1
  25. data/lib/tokenizers/models/bpe.rb +9 -0
  26. data/lib/tokenizers/models/unigram.rb +9 -0
  27. data/lib/tokenizers/models/word_level.rb +13 -0
  28. data/lib/tokenizers/models/word_piece.rb +9 -0
  29. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  30. data/lib/tokenizers/normalizers/strip.rb +9 -0
  31. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  36. data/lib/tokenizers/processors/byte_level.rb +9 -0
  37. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  38. data/lib/tokenizers/processors/template_processing.rb +9 -0
  39. data/lib/tokenizers/tokenizer.rb +40 -7
  40. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  41. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  42. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  43. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  44. data/lib/tokenizers/version.rb +1 -1
  45. data/lib/tokenizers.rb +42 -2
  46. metadata +30 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e4f3cb98cb867df67a1c8a00b56f9ec5f4c6fafa178d760655dafb6735160773
4
- data.tar.gz: 88c420f7a42f56330ce091df7f131878efd552488232282388e69d7a3c4b4aa2
3
+ metadata.gz: 4ff4d1ad7b56010f603ead7a4794c003c5294e50f1b33de62c8089ddf150d5ad
4
+ data.tar.gz: 295aaabb720971f2ddcc832ab0d5deedf1e0ed8dab03aca96ac1d396b5723de7
5
5
  SHA512:
6
- metadata.gz: 8e4746ccdf33dce78dc2b86d847f47f83576ca0d637671f825ad006a53b7ac3374654f7724f1e889618f322f9cfa5081e30083997ee9810eab282b9a8b99f807
7
- data.tar.gz: 5dfe7b502d908f85ae16cfb28ebe1bd2ff51348c31151c7ee531504c00a0315dc22ea76fea963690de8c7390c7adb50d392e39de6db4a22101e91d31de1fa4e8
6
+ metadata.gz: e14207004cddeef40590229ea2c8a9bf54e5c5b75cdbcdd32cd6f23c24feb8544fcabe86fa9bced32cb41f2581ee0df4d36ed2b6a58ef2fc668aa33c270659df
7
+ data.tar.gz: b2bb202c8c37bdd0d14ca64be147e99b224128c1461f56761f1d58d9326b40768e7b903bbbb4c2a0363bd4b1c9ef5a66be53210ad801d5e45c7d86dd0945bd82
data/CHANGELOG.md CHANGED
@@ -1,3 +1,12 @@
1
+ ## 0.3.0 (2022-02-07)
2
+
3
+ - Added support for training tokenizers
4
+ - Added more methods to `Tokenizer`
5
+ - Added `encode_batch` method to `Encoding`
6
+ - Added `pair` argument to `encode` method
7
+ - Changed `encode` method to include special tokens by default
8
+ - Changed how offsets are calculated for strings with multibyte characters
9
+
1
10
  ## 0.2.3 (2022-01-22)
2
11
 
3
12
  - Added `add_special_tokens` option to `encode` method
data/Cargo.lock CHANGED
@@ -50,9 +50,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
50
50
 
51
51
  [[package]]
52
52
  name = "cc"
53
- version = "1.0.78"
53
+ version = "1.0.79"
54
54
  source = "registry+https://github.com/rust-lang/crates.io-index"
55
- checksum = "a20104e2335ce8a659d6dd92a51a767a0c062599c73b343fd152cb401e828c3d"
55
+ checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
56
56
 
57
57
  [[package]]
58
58
  name = "cexpr"
@@ -138,9 +138,9 @@ dependencies = [
138
138
 
139
139
  [[package]]
140
140
  name = "darling"
141
- version = "0.14.2"
141
+ version = "0.14.3"
142
142
  source = "registry+https://github.com/rust-lang/crates.io-index"
143
- checksum = "b0dd3cd20dc6b5a876612a6e5accfe7f3dd883db6d07acfbf14c128f61550dfa"
143
+ checksum = "c0808e1bd8671fb44a113a14e13497557533369847788fa2ae912b6ebfce9fa8"
144
144
  dependencies = [
145
145
  "darling_core",
146
146
  "darling_macro",
@@ -148,9 +148,9 @@ dependencies = [
148
148
 
149
149
  [[package]]
150
150
  name = "darling_core"
151
- version = "0.14.2"
151
+ version = "0.14.3"
152
152
  source = "registry+https://github.com/rust-lang/crates.io-index"
153
- checksum = "a784d2ccaf7c98501746bf0be29b2022ba41fd62a2e622af997a03e9f972859f"
153
+ checksum = "001d80444f28e193f30c2f293455da62dcf9a6b29918a4253152ae2b1de592cb"
154
154
  dependencies = [
155
155
  "fnv",
156
156
  "ident_case",
@@ -162,9 +162,9 @@ dependencies = [
162
162
 
163
163
  [[package]]
164
164
  name = "darling_macro"
165
- version = "0.14.2"
165
+ version = "0.14.3"
166
166
  source = "registry+https://github.com/rust-lang/crates.io-index"
167
- checksum = "7618812407e9402654622dd402b0a89dff9ba93badd6540781526117b92aab7e"
167
+ checksum = "b36230598a2d5de7ec1c6f51f72d8a99a9208daff41de2084d06e3fd3ea56685"
168
168
  dependencies = [
169
169
  "darling_core",
170
170
  "quote",
@@ -202,31 +202,11 @@ dependencies = [
202
202
  "syn",
203
203
  ]
204
204
 
205
- [[package]]
206
- name = "dirs"
207
- version = "3.0.2"
208
- source = "registry+https://github.com/rust-lang/crates.io-index"
209
- checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309"
210
- dependencies = [
211
- "dirs-sys",
212
- ]
213
-
214
- [[package]]
215
- name = "dirs-sys"
216
- version = "0.3.7"
217
- source = "registry+https://github.com/rust-lang/crates.io-index"
218
- checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6"
219
- dependencies = [
220
- "libc",
221
- "redox_users",
222
- "winapi",
223
- ]
224
-
225
205
  [[package]]
226
206
  name = "either"
227
- version = "1.8.0"
207
+ version = "1.8.1"
228
208
  source = "registry+https://github.com/rust-lang/crates.io-index"
229
- checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797"
209
+ checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91"
230
210
 
231
211
  [[package]]
232
212
  name = "encode_unicode"
@@ -372,9 +352,8 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
372
352
 
373
353
  [[package]]
374
354
  name = "magnus"
375
- version = "0.4.4"
376
- source = "registry+https://github.com/rust-lang/crates.io-index"
377
- checksum = "fc87660cd7daa49fddbfd524c836de54d5c927d520cd163f43700c5087c57d6c"
355
+ version = "0.5.0"
356
+ source = "git+https://github.com/matsadler/magnus#eda735faa7e03da2443eaf2c4058a184917d6b87"
378
357
  dependencies = [
379
358
  "magnus-macros",
380
359
  "rb-sys",
@@ -384,8 +363,7 @@ dependencies = [
384
363
  [[package]]
385
364
  name = "magnus-macros"
386
365
  version = "0.3.0"
387
- source = "registry+https://github.com/rust-lang/crates.io-index"
388
- checksum = "206cb23bfeea05180c97522ef6a3e52a4eb17b0ed2f30ee3ca9c4f994d2378ae"
366
+ source = "git+https://github.com/matsadler/magnus#eda735faa7e03da2443eaf2c4058a184917d6b87"
389
367
  dependencies = [
390
368
  "proc-macro2",
391
369
  "quote",
@@ -415,9 +393,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
415
393
 
416
394
  [[package]]
417
395
  name = "nom"
418
- version = "7.1.2"
396
+ version = "7.1.3"
419
397
  source = "registry+https://github.com/rust-lang/crates.io-index"
420
- checksum = "e5507769c4919c998e69e49c839d9dc6e693ede4cc4290d6ad8b41d4f09c548c"
398
+ checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
421
399
  dependencies = [
422
400
  "memchr",
423
401
  "minimal-lexical",
@@ -493,9 +471,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
493
471
 
494
472
  [[package]]
495
473
  name = "proc-macro2"
496
- version = "1.0.49"
474
+ version = "1.0.51"
497
475
  source = "registry+https://github.com/rust-lang/crates.io-index"
498
- checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5"
476
+ checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
499
477
  dependencies = [
500
478
  "unicode-ident",
501
479
  ]
@@ -562,9 +540,9 @@ dependencies = [
562
540
 
563
541
  [[package]]
564
542
  name = "rayon-core"
565
- version = "1.10.1"
543
+ version = "1.10.2"
566
544
  source = "registry+https://github.com/rust-lang/crates.io-index"
567
- checksum = "cac410af5d00ab6884528b4ab69d1e8e146e8d471201800fa1b4524126de6ad3"
545
+ checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b"
568
546
  dependencies = [
569
547
  "crossbeam-channel",
570
548
  "crossbeam-deque",
@@ -574,18 +552,18 @@ dependencies = [
574
552
 
575
553
  [[package]]
576
554
  name = "rb-sys"
577
- version = "0.9.56"
555
+ version = "0.9.64"
578
556
  source = "registry+https://github.com/rust-lang/crates.io-index"
579
- checksum = "ef82428221475c6f9e7893fe30b88d45ac86bdb12e58e7c92055ba4bceb78a69"
557
+ checksum = "cc8945662df8083245deda89e236647173cc7ad750f481ddcd7bbfd3afe3fa5e"
580
558
  dependencies = [
581
559
  "rb-sys-build",
582
560
  ]
583
561
 
584
562
  [[package]]
585
563
  name = "rb-sys-build"
586
- version = "0.9.56"
564
+ version = "0.9.64"
587
565
  source = "registry+https://github.com/rust-lang/crates.io-index"
588
- checksum = "950bfc239d2e7704576abe4d37b008876bbfd70a99196a188c5caeae2ba7344a"
566
+ checksum = "ae8c3cdf9edc3908ee1555b7a1bca58ee1b499439b32cd1c1ec3e66736a8df48"
589
567
  dependencies = [
590
568
  "bindgen",
591
569
  "regex",
@@ -594,29 +572,9 @@ dependencies = [
594
572
 
595
573
  [[package]]
596
574
  name = "rb-sys-env"
597
- version = "0.1.1"
598
- source = "registry+https://github.com/rust-lang/crates.io-index"
599
- checksum = "74c38752410925faeb82c400c06ba2fd9ee6aa8f719dd33994c9e53f5242d25f"
600
-
601
- [[package]]
602
- name = "redox_syscall"
603
- version = "0.2.16"
575
+ version = "0.1.2"
604
576
  source = "registry+https://github.com/rust-lang/crates.io-index"
605
- checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
606
- dependencies = [
607
- "bitflags",
608
- ]
609
-
610
- [[package]]
611
- name = "redox_users"
612
- version = "0.4.3"
613
- source = "registry+https://github.com/rust-lang/crates.io-index"
614
- checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
615
- dependencies = [
616
- "getrandom",
617
- "redox_syscall",
618
- "thiserror",
619
- ]
577
+ checksum = "a35802679f07360454b418a5d1735c89716bde01d35b1560fc953c1415a0b3bb"
620
578
 
621
579
  [[package]]
622
580
  name = "regex"
@@ -675,9 +633,9 @@ dependencies = [
675
633
 
676
634
  [[package]]
677
635
  name = "serde_json"
678
- version = "1.0.91"
636
+ version = "1.0.92"
679
637
  source = "registry+https://github.com/rust-lang/crates.io-index"
680
- checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883"
638
+ checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a"
681
639
  dependencies = [
682
640
  "itoa",
683
641
  "ryu",
@@ -756,17 +714,18 @@ name = "tokenizers"
756
714
  version = "0.2.3"
757
715
  dependencies = [
758
716
  "magnus",
717
+ "onig",
718
+ "serde",
759
719
  "tokenizers 0.13.2",
760
720
  ]
761
721
 
762
722
  [[package]]
763
723
  name = "tokenizers"
764
724
  version = "0.13.2"
765
- source = "git+https://github.com/huggingface/tokenizers#fe4ae7dc38be11a5c93ae703816c869f993c21ab"
725
+ source = "git+https://github.com/huggingface/tokenizers#fa66caf0abff16bae2213658ffa3e969c5445750"
766
726
  dependencies = [
767
727
  "aho-corasick",
768
728
  "derive_builder",
769
- "dirs",
770
729
  "esaxx-rs",
771
730
  "getrandom",
772
731
  "indicatif",
@@ -807,9 +766,9 @@ dependencies = [
807
766
 
808
767
  [[package]]
809
768
  name = "unicode-segmentation"
810
- version = "1.10.0"
769
+ version = "1.10.1"
811
770
  source = "registry+https://github.com/rust-lang/crates.io-index"
812
- checksum = "0fdbf052a0783de01e944a6ce7a8cb939e295b1e7be835a1112c3b9a7f047a5a"
771
+ checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
813
772
 
814
773
  [[package]]
815
774
  name = "unicode-width"
data/README.md CHANGED
@@ -40,6 +40,10 @@ Load a tokenizer from files
40
40
  tokenizer = Tokenizers::CharBPETokenizer.new("vocab.json", "merges.txt")
41
41
  ```
42
42
 
43
+ ## Training
44
+
45
+ Check out the [Quicktour](https://huggingface.co/docs/tokenizers/quicktour) and equivalent [Ruby code](https://github.com/ankane/tokenizers-ruby/blob/master/test/quicktour_test.rb#L8)
46
+
43
47
  ## History
44
48
 
45
49
  View the [changelog](https://github.com/ankane/tokenizers-ruby/blob/master/CHANGELOG.md)
@@ -10,7 +10,9 @@ publish = false
10
10
  crate-type = ["cdylib"]
11
11
 
12
12
  [dependencies]
13
- magnus = "0.4"
13
+ magnus = { git = "https://github.com/matsadler/magnus" }
14
+ onig = { version = "6.0", default-features = false }
15
+ serde = { version = "1.0", features = ["rc", "derive"] }
14
16
 
15
17
  [dependencies.tokenizers]
16
18
  version = "0.13.2" # also update in from_pretrained.rb
@@ -1,14 +1,283 @@
1
+ use std::sync::{Arc, RwLock};
2
+
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
+ TypedData,
7
+ };
8
+ use serde::{Deserialize, Serialize};
1
9
  use tk::decoders::bpe::BPEDecoder;
10
+ use tk::decoders::byte_level::ByteLevel;
11
+ use tk::decoders::ctc::CTC;
12
+ use tk::decoders::metaspace::Metaspace;
13
+ use tk::decoders::wordpiece::WordPiece;
14
+ use tk::decoders::DecoderWrapper;
15
+ use tk::Decoder;
16
+
17
+ use super::RbResult;
18
+
19
+ #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
20
+ pub struct RbDecoder {
21
+ #[serde(flatten)]
22
+ pub(crate) decoder: RbDecoderWrapper,
23
+ }
24
+
25
+ impl Decoder for RbDecoder {
26
+ fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
27
+ self.decoder.decode_chain(tokens)
28
+ }
29
+ }
30
+
31
+ macro_rules! getter {
32
+ ($self: ident, $variant: ident, $($name: tt)+) => {{
33
+ let decoder = &$self.decoder;
34
+ let RbDecoderWrapper::Wrapped(ref wrap) = decoder;
35
+ if let DecoderWrapper::$variant(ref dec) = *wrap.read().unwrap() {
36
+ dec.$($name)+
37
+ } else {
38
+ unreachable!()
39
+ }
40
+ }};
41
+ }
2
42
 
3
- #[magnus::wrap(class = "Tokenizers::BPEDecoder")]
4
- pub struct RbBPEDecoder {
5
- pub decoder: BPEDecoder,
43
+ macro_rules! setter {
44
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
45
+ let decoder = &$self.decoder;
46
+ let RbDecoderWrapper::Wrapped(ref wrap) = decoder;
47
+ if let DecoderWrapper::$variant(ref mut dec) = *wrap.write().unwrap() {
48
+ dec.$name = $value;
49
+ }
50
+ }};
51
+ ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
52
+ let decoder = &$self.decoder;
53
+ let RbDecoderWrapper::Wrapped(ref wrap) = decoder;
54
+ if let DecoderWrapper::$variant(ref mut dec) = *wrap.write().unwrap() {
55
+ dec.$name($value);
56
+ }
57
+ }};
6
58
  }
59
+ impl RbDecoder {
60
+ pub fn bpe_suffix(&self) -> String {
61
+ getter!(self, BPE, suffix.clone())
62
+ }
63
+
64
+ pub fn bpe_set_suffix(&self, suffix: String) {
65
+ setter!(self, BPE, suffix, suffix);
66
+ }
67
+
68
+ pub fn ctc_cleanup(&self) -> bool {
69
+ getter!(self, CTC, cleanup)
70
+ }
71
+
72
+ pub fn ctc_set_cleanup(&self, cleanup: bool) {
73
+ setter!(self, CTC, cleanup, cleanup);
74
+ }
75
+
76
+ pub fn ctc_pad_token(&self) -> String {
77
+ getter!(self, CTC, pad_token.clone())
78
+ }
79
+
80
+ pub fn ctc_set_pad_token(&self, pad_token: String) {
81
+ setter!(self, CTC, pad_token, pad_token);
82
+ }
83
+
84
+ pub fn ctc_word_delimiter_token(&self) -> String {
85
+ getter!(self, CTC, word_delimiter_token.clone())
86
+ }
87
+
88
+ pub fn ctc_set_word_delimiter_token(&self, word_delimiter_token: String) {
89
+ setter!(self, CTC, word_delimiter_token, word_delimiter_token);
90
+ }
91
+
92
+ pub fn metaspace_replacement(&self) -> char {
93
+ getter!(self, Metaspace, get_replacement().clone())
94
+ }
95
+
96
+ pub fn metaspace_set_replacement(&self, replacement: char) {
97
+ setter!(self, Metaspace, @set_replacement, replacement);
98
+ }
99
+
100
+ pub fn metaspace_add_prefix_space(&self) -> bool {
101
+ getter!(self, Metaspace, add_prefix_space)
102
+ }
103
+
104
+ pub fn metaspace_set_add_prefix_space(&self, add_prefix_space: bool) {
105
+ setter!(self, Metaspace, add_prefix_space, add_prefix_space);
106
+ }
107
+
108
+ pub fn word_piece_cleanup(&self) -> bool {
109
+ getter!(self, WordPiece, cleanup)
110
+ }
111
+
112
+ pub fn word_piece_set_cleanup(&self, cleanup: bool) {
113
+ setter!(self, WordPiece, cleanup, cleanup);
114
+ }
115
+
116
+ pub fn word_piece_prefix(&self) -> String {
117
+ getter!(self, WordPiece, prefix.clone())
118
+ }
119
+
120
+ pub fn word_piece_set_prefix(&self, prefix: String) {
121
+ setter!(self, WordPiece, prefix, prefix);
122
+ }
123
+ }
124
+
125
+ pub struct RbBPEDecoder {}
7
126
 
8
127
  impl RbBPEDecoder {
9
- pub fn new() -> Self {
10
- RbBPEDecoder {
11
- decoder: BPEDecoder::default(),
128
+ pub fn new(suffix: String) -> RbDecoder {
129
+ BPEDecoder::new(suffix).into()
130
+ }
131
+ }
132
+
133
+ pub struct RbByteLevelDecoder {}
134
+
135
+ impl RbByteLevelDecoder {
136
+ pub fn new() -> RbDecoder {
137
+ ByteLevel::default().into()
138
+ }
139
+ }
140
+
141
+ pub struct RbCTC {}
142
+
143
+ impl RbCTC {
144
+ pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> RbDecoder {
145
+ CTC::new(pad_token, word_delimiter_token, cleanup).into()
146
+ }
147
+ }
148
+
149
+ pub struct RbMetaspaceDecoder {}
150
+
151
+ impl RbMetaspaceDecoder {
152
+ pub fn new(replacement: char, add_prefix_space: bool) -> RbDecoder {
153
+ Metaspace::new(replacement, add_prefix_space).into()
154
+ }
155
+ }
156
+
157
+ pub struct RbWordPieceDecoder {}
158
+
159
+ impl RbWordPieceDecoder {
160
+ pub fn new(prefix: String, cleanup: bool) -> RbDecoder {
161
+ WordPiece::new(prefix, cleanup).into()
162
+ }
163
+ }
164
+
165
+ #[derive(Clone, Deserialize, Serialize)]
166
+ #[serde(untagged)]
167
+ pub(crate) enum RbDecoderWrapper {
168
+ // Custom(Arc<RwLock<CustomDecoder>>),
169
+ Wrapped(Arc<RwLock<DecoderWrapper>>),
170
+ }
171
+
172
+ impl<I> From<I> for RbDecoderWrapper
173
+ where
174
+ I: Into<DecoderWrapper>,
175
+ {
176
+ fn from(norm: I) -> Self {
177
+ RbDecoderWrapper::Wrapped(Arc::new(RwLock::new(norm.into())))
178
+ }
179
+ }
180
+
181
+ impl<I> From<I> for RbDecoder
182
+ where
183
+ I: Into<DecoderWrapper>,
184
+ {
185
+ fn from(dec: I) -> Self {
186
+ RbDecoder {
187
+ decoder: dec.into().into(),
188
+ }
189
+ }
190
+ }
191
+
192
+ impl Decoder for RbDecoderWrapper {
193
+ fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
194
+ match self {
195
+ RbDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens),
196
+ // RbDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens),
197
+ }
198
+ }
199
+ }
200
+
201
+ unsafe impl TypedData for RbDecoder {
202
+ fn class() -> RClass {
203
+ *memoize!(RClass: {
204
+ let class: RClass = crate::decoders().const_get("Decoder").unwrap();
205
+ class.undef_alloc_func();
206
+ class
207
+ })
208
+ }
209
+
210
+ fn data_type() -> &'static DataType {
211
+ memoize!(DataType: DataTypeBuilder::<RbDecoder>::new("Tokenizers::Decoders::Decoder").build())
212
+ }
213
+
214
+ fn class_for(value: &Self) -> RClass {
215
+ match &value.decoder {
216
+ RbDecoderWrapper::Wrapped(inner) => match *inner.read().unwrap() {
217
+ DecoderWrapper::BPE(_) => *memoize!(RClass: {
218
+ let class: RClass = crate::decoders().const_get("BPEDecoder").unwrap();
219
+ class.undef_alloc_func();
220
+ class
221
+ }),
222
+ DecoderWrapper::ByteLevel(_) => *memoize!(RClass: {
223
+ let class: RClass = crate::decoders().const_get("ByteLevel").unwrap();
224
+ class.undef_alloc_func();
225
+ class
226
+ }),
227
+ DecoderWrapper::CTC(_) => *memoize!(RClass: {
228
+ let class: RClass = crate::decoders().const_get("CTC").unwrap();
229
+ class.undef_alloc_func();
230
+ class
231
+ }),
232
+ DecoderWrapper::Metaspace(_) => *memoize!(RClass: {
233
+ let class: RClass = crate::decoders().const_get("Metaspace").unwrap();
234
+ class.undef_alloc_func();
235
+ class
236
+ }),
237
+ DecoderWrapper::WordPiece(_) => *memoize!(RClass: {
238
+ let class: RClass = crate::decoders().const_get("WordPiece").unwrap();
239
+ class.undef_alloc_func();
240
+ class
241
+ }),
242
+ _ => todo!(),
243
+ },
12
244
  }
13
245
  }
14
246
  }
247
+
248
+ pub fn decoders(module: &RModule) -> RbResult<()> {
249
+ let decoder = module.define_class("Decoder", Default::default())?;
250
+
251
+ let class = module.define_class("BPEDecoder", decoder)?;
252
+ class.define_singleton_method("_new", function!(RbBPEDecoder::new, 1))?;
253
+ class.define_method("suffix", method!(RbDecoder::bpe_suffix, 0))?;
254
+ class.define_method("suffix=", method!(RbDecoder::bpe_set_suffix, 1))?;
255
+
256
+ let class = module.define_class("ByteLevel", decoder)?;
257
+ class.define_singleton_method("new", function!(RbByteLevelDecoder::new, 0))?;
258
+
259
+ let class = module.define_class("CTC", decoder)?;
260
+ class.define_singleton_method("_new", function!(RbCTC::new, 3))?;
261
+ class.define_method("cleanup", method!(RbDecoder::ctc_cleanup, 0))?;
262
+ class.define_method("cleanup=", method!(RbDecoder::ctc_set_cleanup, 1))?;
263
+ class.define_method("pad_token", method!(RbDecoder::ctc_pad_token, 0))?;
264
+ class.define_method("pad_token=", method!(RbDecoder::ctc_set_pad_token, 1))?;
265
+ class.define_method("word_delimiter_token", method!(RbDecoder::ctc_word_delimiter_token, 0))?;
266
+ class.define_method("word_delimiter_token=", method!(RbDecoder::ctc_set_word_delimiter_token, 1))?;
267
+
268
+ let class = module.define_class("Metaspace", decoder)?;
269
+ class.define_singleton_method("_new", function!(RbMetaspaceDecoder::new, 2))?;
270
+ class.define_method("add_prefix_space", method!(RbDecoder::metaspace_add_prefix_space, 0))?;
271
+ class.define_method("add_prefix_space=", method!(RbDecoder::metaspace_set_add_prefix_space, 1))?;
272
+ class.define_method("replacement", method!(RbDecoder::metaspace_replacement, 0))?;
273
+ class.define_method("replacement=", method!(RbDecoder::metaspace_set_replacement, 1))?;
274
+
275
+ let class = module.define_class("WordPiece", decoder)?;
276
+ class.define_singleton_method("_new", function!(RbWordPieceDecoder::new, 2))?;
277
+ class.define_method("cleanup", method!(RbDecoder::word_piece_cleanup, 0))?;
278
+ class.define_method("cleanup=", method!(RbDecoder::word_piece_set_cleanup, 1))?;
279
+ class.define_method("prefix", method!(RbDecoder::word_piece_prefix, 0))?;
280
+ class.define_method("prefix=", method!(RbDecoder::word_piece_set_prefix, 1))?;
281
+
282
+ Ok(())
283
+ }
@@ -1,3 +1,4 @@
1
+ use magnus::RArray;
1
2
  use tk::{Encoding, Offsets};
2
3
 
3
4
  #[magnus::wrap(class = "Tokenizers::Encoding")]
@@ -49,12 +50,12 @@ impl RbEncoding {
49
50
  self.encoding.get_attention_mask().to_vec()
50
51
  }
51
52
 
52
- pub fn overflowing(&self) -> Vec<Self> {
53
+ pub fn overflowing(&self) -> RArray {
53
54
  self.encoding
54
55
  .get_overflowing()
55
56
  .clone()
56
57
  .into_iter()
57
- .map(|e| e.into())
58
+ .map(Into::<RbEncoding>::into)
58
59
  .collect()
59
60
  }
60
61
 
@@ -1,4 +1,4 @@
1
- use magnus::{exception, memoize, Error, ExceptionClass, Module};
1
+ use magnus::{memoize, Error, ExceptionClass, Module};
2
2
 
3
3
  use super::module;
4
4
 
@@ -12,5 +12,5 @@ impl RbError {
12
12
  }
13
13
 
14
14
  fn error() -> ExceptionClass {
15
- *memoize!(ExceptionClass: module().define_error("Error", exception::standard_error()).unwrap())
15
+ *memoize!(ExceptionClass: module().const_get("Error").unwrap())
16
16
  }