tokenizers 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/Cargo.lock +33 -74
  4. data/README.md +4 -0
  5. data/ext/tokenizers/Cargo.toml +4 -2
  6. data/ext/tokenizers/src/decoders.rs +275 -6
  7. data/ext/tokenizers/src/encoding.rs +78 -3
  8. data/ext/tokenizers/src/error.rs +2 -2
  9. data/ext/tokenizers/src/lib.rs +88 -17
  10. data/ext/tokenizers/src/models.rs +372 -11
  11. data/ext/tokenizers/src/normalizers.rs +435 -7
  12. data/ext/tokenizers/src/pre_tokenizers.rs +470 -6
  13. data/ext/tokenizers/src/processors.rs +210 -0
  14. data/ext/tokenizers/src/tokenizer.rs +448 -20
  15. data/ext/tokenizers/src/trainers.rs +749 -0
  16. data/ext/tokenizers/src/utils/mod.rs +5 -0
  17. data/ext/tokenizers/src/utils/normalization.rs +85 -0
  18. data/ext/tokenizers/src/utils/regex.rs +22 -0
  19. data/lib/tokenizers/char_bpe_tokenizer.rb +11 -8
  20. data/lib/tokenizers/decoders/bpe_decoder.rb +9 -0
  21. data/lib/tokenizers/decoders/ctc.rb +9 -0
  22. data/lib/tokenizers/decoders/metaspace.rb +9 -0
  23. data/lib/tokenizers/decoders/word_piece.rb +9 -0
  24. data/lib/tokenizers/encoding.rb +19 -0
  25. data/lib/tokenizers/from_pretrained.rb +1 -1
  26. data/lib/tokenizers/models/bpe.rb +9 -0
  27. data/lib/tokenizers/models/unigram.rb +9 -0
  28. data/lib/tokenizers/models/word_level.rb +13 -0
  29. data/lib/tokenizers/models/word_piece.rb +9 -0
  30. data/lib/tokenizers/normalizers/bert_normalizer.rb +9 -0
  31. data/lib/tokenizers/normalizers/strip.rb +9 -0
  32. data/lib/tokenizers/pre_tokenizers/byte_level.rb +9 -0
  33. data/lib/tokenizers/pre_tokenizers/digits.rb +9 -0
  34. data/lib/tokenizers/pre_tokenizers/metaspace.rb +9 -0
  35. data/lib/tokenizers/pre_tokenizers/punctuation.rb +9 -0
  36. data/lib/tokenizers/pre_tokenizers/split.rb +9 -0
  37. data/lib/tokenizers/processors/byte_level.rb +9 -0
  38. data/lib/tokenizers/processors/roberta_processing.rb +9 -0
  39. data/lib/tokenizers/processors/template_processing.rb +9 -0
  40. data/lib/tokenizers/tokenizer.rb +45 -0
  41. data/lib/tokenizers/trainers/bpe_trainer.rb +9 -0
  42. data/lib/tokenizers/trainers/unigram_trainer.rb +26 -0
  43. data/lib/tokenizers/trainers/word_level_trainer.rb +9 -0
  44. data/lib/tokenizers/trainers/word_piece_trainer.rb +26 -0
  45. data/lib/tokenizers/version.rb +1 -1
  46. data/lib/tokenizers.rb +49 -7
  47. metadata +32 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 197131371ec438d82623bc0aacb8fe82ba255904e847eeb9259358f38a7063f0
4
- data.tar.gz: 42ef490120e56fbb79d847ec1eb2b0a6b0ca7aa8f2ad90c09d2053d167491350
3
+ metadata.gz: 4ff4d1ad7b56010f603ead7a4794c003c5294e50f1b33de62c8089ddf150d5ad
4
+ data.tar.gz: 295aaabb720971f2ddcc832ab0d5deedf1e0ed8dab03aca96ac1d396b5723de7
5
5
  SHA512:
6
- metadata.gz: 0a21b4811cc9e31565209eb514e55d6b22302c350371a76205aeb3b67cf94ea6dabf85074cebd48c65f9eca56e8e750b83a1df841807e53afb1275961bca50ce
7
- data.tar.gz: 222bb9d759e3a2cc00ad7a4950c821fdbad1bbf6d4413f237bcf9cdc0698c2011022890b3f306be6df3d70b05abd446ad43066851ffa6c27387ddf3191f7557d
6
+ metadata.gz: e14207004cddeef40590229ea2c8a9bf54e5c5b75cdbcdd32cd6f23c24feb8544fcabe86fa9bced32cb41f2581ee0df4d36ed2b6a58ef2fc668aa33c270659df
7
+ data.tar.gz: b2bb202c8c37bdd0d14ca64be147e99b224128c1461f56761f1d58d9326b40768e7b903bbbb4c2a0363bd4b1c9ef5a66be53210ad801d5e45c7d86dd0945bd82
data/CHANGELOG.md CHANGED
@@ -1,3 +1,19 @@
1
+ ## 0.3.0 (2022-02-07)
2
+
3
+ - Added support for training tokenizers
4
+ - Added more methods to `Tokenizer`
5
+ - Added `encode_batch` method to `Encoding`
6
+ - Added `pair` argument to `encode` method
7
+ - Changed `encode` method to include special tokens by default
8
+ - Changed how offsets are calculated for strings with multibyte characters
9
+
10
+ ## 0.2.3 (2022-01-22)
11
+
12
+ - Added `add_special_tokens` option to `encode` method
13
+ - Added warning about `encode` method including special tokens by default in 0.3.0
14
+ - Added more methods to `Encoding`
15
+ - Fixed error with precompiled gem on Mac ARM
16
+
1
17
  ## 0.2.2 (2022-01-15)
2
18
 
3
19
  - Added precompiled gem for Linux ARM
data/Cargo.lock CHANGED
@@ -50,9 +50,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
50
50
 
51
51
  [[package]]
52
52
  name = "cc"
53
- version = "1.0.78"
53
+ version = "1.0.79"
54
54
  source = "registry+https://github.com/rust-lang/crates.io-index"
55
- checksum = "a20104e2335ce8a659d6dd92a51a767a0c062599c73b343fd152cb401e828c3d"
55
+ checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
56
56
 
57
57
  [[package]]
58
58
  name = "cexpr"
@@ -138,9 +138,9 @@ dependencies = [
138
138
 
139
139
  [[package]]
140
140
  name = "darling"
141
- version = "0.14.2"
141
+ version = "0.14.3"
142
142
  source = "registry+https://github.com/rust-lang/crates.io-index"
143
- checksum = "b0dd3cd20dc6b5a876612a6e5accfe7f3dd883db6d07acfbf14c128f61550dfa"
143
+ checksum = "c0808e1bd8671fb44a113a14e13497557533369847788fa2ae912b6ebfce9fa8"
144
144
  dependencies = [
145
145
  "darling_core",
146
146
  "darling_macro",
@@ -148,9 +148,9 @@ dependencies = [
148
148
 
149
149
  [[package]]
150
150
  name = "darling_core"
151
- version = "0.14.2"
151
+ version = "0.14.3"
152
152
  source = "registry+https://github.com/rust-lang/crates.io-index"
153
- checksum = "a784d2ccaf7c98501746bf0be29b2022ba41fd62a2e622af997a03e9f972859f"
153
+ checksum = "001d80444f28e193f30c2f293455da62dcf9a6b29918a4253152ae2b1de592cb"
154
154
  dependencies = [
155
155
  "fnv",
156
156
  "ident_case",
@@ -162,9 +162,9 @@ dependencies = [
162
162
 
163
163
  [[package]]
164
164
  name = "darling_macro"
165
- version = "0.14.2"
165
+ version = "0.14.3"
166
166
  source = "registry+https://github.com/rust-lang/crates.io-index"
167
- checksum = "7618812407e9402654622dd402b0a89dff9ba93badd6540781526117b92aab7e"
167
+ checksum = "b36230598a2d5de7ec1c6f51f72d8a99a9208daff41de2084d06e3fd3ea56685"
168
168
  dependencies = [
169
169
  "darling_core",
170
170
  "quote",
@@ -202,31 +202,11 @@ dependencies = [
202
202
  "syn",
203
203
  ]
204
204
 
205
- [[package]]
206
- name = "dirs"
207
- version = "3.0.2"
208
- source = "registry+https://github.com/rust-lang/crates.io-index"
209
- checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309"
210
- dependencies = [
211
- "dirs-sys",
212
- ]
213
-
214
- [[package]]
215
- name = "dirs-sys"
216
- version = "0.3.7"
217
- source = "registry+https://github.com/rust-lang/crates.io-index"
218
- checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6"
219
- dependencies = [
220
- "libc",
221
- "redox_users",
222
- "winapi",
223
- ]
224
-
225
205
  [[package]]
226
206
  name = "either"
227
- version = "1.8.0"
207
+ version = "1.8.1"
228
208
  source = "registry+https://github.com/rust-lang/crates.io-index"
229
- checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797"
209
+ checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91"
230
210
 
231
211
  [[package]]
232
212
  name = "encode_unicode"
@@ -372,9 +352,8 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
372
352
 
373
353
  [[package]]
374
354
  name = "magnus"
375
- version = "0.4.4"
376
- source = "registry+https://github.com/rust-lang/crates.io-index"
377
- checksum = "fc87660cd7daa49fddbfd524c836de54d5c927d520cd163f43700c5087c57d6c"
355
+ version = "0.5.0"
356
+ source = "git+https://github.com/matsadler/magnus#eda735faa7e03da2443eaf2c4058a184917d6b87"
378
357
  dependencies = [
379
358
  "magnus-macros",
380
359
  "rb-sys",
@@ -384,8 +363,7 @@ dependencies = [
384
363
  [[package]]
385
364
  name = "magnus-macros"
386
365
  version = "0.3.0"
387
- source = "registry+https://github.com/rust-lang/crates.io-index"
388
- checksum = "206cb23bfeea05180c97522ef6a3e52a4eb17b0ed2f30ee3ca9c4f994d2378ae"
366
+ source = "git+https://github.com/matsadler/magnus#eda735faa7e03da2443eaf2c4058a184917d6b87"
389
367
  dependencies = [
390
368
  "proc-macro2",
391
369
  "quote",
@@ -415,9 +393,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
415
393
 
416
394
  [[package]]
417
395
  name = "nom"
418
- version = "7.1.2"
396
+ version = "7.1.3"
419
397
  source = "registry+https://github.com/rust-lang/crates.io-index"
420
- checksum = "e5507769c4919c998e69e49c839d9dc6e693ede4cc4290d6ad8b41d4f09c548c"
398
+ checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
421
399
  dependencies = [
422
400
  "memchr",
423
401
  "minimal-lexical",
@@ -493,9 +471,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
493
471
 
494
472
  [[package]]
495
473
  name = "proc-macro2"
496
- version = "1.0.49"
474
+ version = "1.0.51"
497
475
  source = "registry+https://github.com/rust-lang/crates.io-index"
498
- checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5"
476
+ checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
499
477
  dependencies = [
500
478
  "unicode-ident",
501
479
  ]
@@ -562,9 +540,9 @@ dependencies = [
562
540
 
563
541
  [[package]]
564
542
  name = "rayon-core"
565
- version = "1.10.1"
543
+ version = "1.10.2"
566
544
  source = "registry+https://github.com/rust-lang/crates.io-index"
567
- checksum = "cac410af5d00ab6884528b4ab69d1e8e146e8d471201800fa1b4524126de6ad3"
545
+ checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b"
568
546
  dependencies = [
569
547
  "crossbeam-channel",
570
548
  "crossbeam-deque",
@@ -574,18 +552,18 @@ dependencies = [
574
552
 
575
553
  [[package]]
576
554
  name = "rb-sys"
577
- version = "0.9.56"
555
+ version = "0.9.64"
578
556
  source = "registry+https://github.com/rust-lang/crates.io-index"
579
- checksum = "ef82428221475c6f9e7893fe30b88d45ac86bdb12e58e7c92055ba4bceb78a69"
557
+ checksum = "cc8945662df8083245deda89e236647173cc7ad750f481ddcd7bbfd3afe3fa5e"
580
558
  dependencies = [
581
559
  "rb-sys-build",
582
560
  ]
583
561
 
584
562
  [[package]]
585
563
  name = "rb-sys-build"
586
- version = "0.9.56"
564
+ version = "0.9.64"
587
565
  source = "registry+https://github.com/rust-lang/crates.io-index"
588
- checksum = "950bfc239d2e7704576abe4d37b008876bbfd70a99196a188c5caeae2ba7344a"
566
+ checksum = "ae8c3cdf9edc3908ee1555b7a1bca58ee1b499439b32cd1c1ec3e66736a8df48"
589
567
  dependencies = [
590
568
  "bindgen",
591
569
  "regex",
@@ -594,29 +572,9 @@ dependencies = [
594
572
 
595
573
  [[package]]
596
574
  name = "rb-sys-env"
597
- version = "0.1.1"
598
- source = "registry+https://github.com/rust-lang/crates.io-index"
599
- checksum = "74c38752410925faeb82c400c06ba2fd9ee6aa8f719dd33994c9e53f5242d25f"
600
-
601
- [[package]]
602
- name = "redox_syscall"
603
- version = "0.2.16"
575
+ version = "0.1.2"
604
576
  source = "registry+https://github.com/rust-lang/crates.io-index"
605
- checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
606
- dependencies = [
607
- "bitflags",
608
- ]
609
-
610
- [[package]]
611
- name = "redox_users"
612
- version = "0.4.3"
613
- source = "registry+https://github.com/rust-lang/crates.io-index"
614
- checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
615
- dependencies = [
616
- "getrandom",
617
- "redox_syscall",
618
- "thiserror",
619
- ]
577
+ checksum = "a35802679f07360454b418a5d1735c89716bde01d35b1560fc953c1415a0b3bb"
620
578
 
621
579
  [[package]]
622
580
  name = "regex"
@@ -675,9 +633,9 @@ dependencies = [
675
633
 
676
634
  [[package]]
677
635
  name = "serde_json"
678
- version = "1.0.91"
636
+ version = "1.0.92"
679
637
  source = "registry+https://github.com/rust-lang/crates.io-index"
680
- checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883"
638
+ checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a"
681
639
  dependencies = [
682
640
  "itoa",
683
641
  "ryu",
@@ -753,20 +711,21 @@ dependencies = [
753
711
 
754
712
  [[package]]
755
713
  name = "tokenizers"
756
- version = "0.2.2"
714
+ version = "0.2.3"
757
715
  dependencies = [
758
716
  "magnus",
717
+ "onig",
718
+ "serde",
759
719
  "tokenizers 0.13.2",
760
720
  ]
761
721
 
762
722
  [[package]]
763
723
  name = "tokenizers"
764
724
  version = "0.13.2"
765
- source = "git+https://github.com/huggingface/tokenizers#fe4ae7dc38be11a5c93ae703816c869f993c21ab"
725
+ source = "git+https://github.com/huggingface/tokenizers#fa66caf0abff16bae2213658ffa3e969c5445750"
766
726
  dependencies = [
767
727
  "aho-corasick",
768
728
  "derive_builder",
769
- "dirs",
770
729
  "esaxx-rs",
771
730
  "getrandom",
772
731
  "indicatif",
@@ -807,9 +766,9 @@ dependencies = [
807
766
 
808
767
  [[package]]
809
768
  name = "unicode-segmentation"
810
- version = "1.10.0"
769
+ version = "1.10.1"
811
770
  source = "registry+https://github.com/rust-lang/crates.io-index"
812
- checksum = "0fdbf052a0783de01e944a6ce7a8cb939e295b1e7be835a1112c3b9a7f047a5a"
771
+ checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
813
772
 
814
773
  [[package]]
815
774
  name = "unicode-width"
data/README.md CHANGED
@@ -40,6 +40,10 @@ Load a tokenizer from files
40
40
  tokenizer = Tokenizers::CharBPETokenizer.new("vocab.json", "merges.txt")
41
41
  ```
42
42
 
43
+ ## Training
44
+
45
+ Check out the [Quicktour](https://huggingface.co/docs/tokenizers/quicktour) and equivalent [Ruby code](https://github.com/ankane/tokenizers-ruby/blob/master/test/quicktour_test.rb#L8)
46
+
43
47
  ## History
44
48
 
45
49
  View the [changelog](https://github.com/ankane/tokenizers-ruby/blob/master/CHANGELOG.md)
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "tokenizers"
3
- version = "0.2.2"
3
+ version = "0.2.3"
4
4
  license = "Apache-2.0"
5
5
  authors = ["Andrew Kane <andrew@ankane.org>"]
6
6
  edition = "2021"
@@ -10,7 +10,9 @@ publish = false
10
10
  crate-type = ["cdylib"]
11
11
 
12
12
  [dependencies]
13
- magnus = "0.4"
13
+ magnus = { git = "https://github.com/matsadler/magnus" }
14
+ onig = { version = "6.0", default-features = false }
15
+ serde = { version = "1.0", features = ["rc", "derive"] }
14
16
 
15
17
  [dependencies.tokenizers]
16
18
  version = "0.13.2" # also update in from_pretrained.rb
@@ -1,14 +1,283 @@
1
+ use std::sync::{Arc, RwLock};
2
+
3
+ use magnus::typed_data::DataTypeBuilder;
4
+ use magnus::{
5
+ function, memoize, method, Class, DataType, DataTypeFunctions, Module, Object, RClass, RModule,
6
+ TypedData,
7
+ };
8
+ use serde::{Deserialize, Serialize};
1
9
  use tk::decoders::bpe::BPEDecoder;
10
+ use tk::decoders::byte_level::ByteLevel;
11
+ use tk::decoders::ctc::CTC;
12
+ use tk::decoders::metaspace::Metaspace;
13
+ use tk::decoders::wordpiece::WordPiece;
14
+ use tk::decoders::DecoderWrapper;
15
+ use tk::Decoder;
16
+
17
+ use super::RbResult;
18
+
19
+ #[derive(DataTypeFunctions, Clone, Deserialize, Serialize)]
20
+ pub struct RbDecoder {
21
+ #[serde(flatten)]
22
+ pub(crate) decoder: RbDecoderWrapper,
23
+ }
24
+
25
+ impl Decoder for RbDecoder {
26
+ fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
27
+ self.decoder.decode_chain(tokens)
28
+ }
29
+ }
30
+
31
+ macro_rules! getter {
32
+ ($self: ident, $variant: ident, $($name: tt)+) => {{
33
+ let decoder = &$self.decoder;
34
+ let RbDecoderWrapper::Wrapped(ref wrap) = decoder;
35
+ if let DecoderWrapper::$variant(ref dec) = *wrap.read().unwrap() {
36
+ dec.$($name)+
37
+ } else {
38
+ unreachable!()
39
+ }
40
+ }};
41
+ }
2
42
 
3
- #[magnus::wrap(class = "Tokenizers::BPEDecoder")]
4
- pub struct RbBPEDecoder {
5
- pub decoder: BPEDecoder,
43
+ macro_rules! setter {
44
+ ($self: ident, $variant: ident, $name: ident, $value: expr) => {{
45
+ let decoder = &$self.decoder;
46
+ let RbDecoderWrapper::Wrapped(ref wrap) = decoder;
47
+ if let DecoderWrapper::$variant(ref mut dec) = *wrap.write().unwrap() {
48
+ dec.$name = $value;
49
+ }
50
+ }};
51
+ ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
52
+ let decoder = &$self.decoder;
53
+ let RbDecoderWrapper::Wrapped(ref wrap) = decoder;
54
+ if let DecoderWrapper::$variant(ref mut dec) = *wrap.write().unwrap() {
55
+ dec.$name($value);
56
+ }
57
+ }};
6
58
  }
59
+ impl RbDecoder {
60
+ pub fn bpe_suffix(&self) -> String {
61
+ getter!(self, BPE, suffix.clone())
62
+ }
63
+
64
+ pub fn bpe_set_suffix(&self, suffix: String) {
65
+ setter!(self, BPE, suffix, suffix);
66
+ }
67
+
68
+ pub fn ctc_cleanup(&self) -> bool {
69
+ getter!(self, CTC, cleanup)
70
+ }
71
+
72
+ pub fn ctc_set_cleanup(&self, cleanup: bool) {
73
+ setter!(self, CTC, cleanup, cleanup);
74
+ }
75
+
76
+ pub fn ctc_pad_token(&self) -> String {
77
+ getter!(self, CTC, pad_token.clone())
78
+ }
79
+
80
+ pub fn ctc_set_pad_token(&self, pad_token: String) {
81
+ setter!(self, CTC, pad_token, pad_token);
82
+ }
83
+
84
+ pub fn ctc_word_delimiter_token(&self) -> String {
85
+ getter!(self, CTC, word_delimiter_token.clone())
86
+ }
87
+
88
+ pub fn ctc_set_word_delimiter_token(&self, word_delimiter_token: String) {
89
+ setter!(self, CTC, word_delimiter_token, word_delimiter_token);
90
+ }
91
+
92
+ pub fn metaspace_replacement(&self) -> char {
93
+ getter!(self, Metaspace, get_replacement().clone())
94
+ }
95
+
96
+ pub fn metaspace_set_replacement(&self, replacement: char) {
97
+ setter!(self, Metaspace, @set_replacement, replacement);
98
+ }
99
+
100
+ pub fn metaspace_add_prefix_space(&self) -> bool {
101
+ getter!(self, Metaspace, add_prefix_space)
102
+ }
103
+
104
+ pub fn metaspace_set_add_prefix_space(&self, add_prefix_space: bool) {
105
+ setter!(self, Metaspace, add_prefix_space, add_prefix_space);
106
+ }
107
+
108
+ pub fn word_piece_cleanup(&self) -> bool {
109
+ getter!(self, WordPiece, cleanup)
110
+ }
111
+
112
+ pub fn word_piece_set_cleanup(&self, cleanup: bool) {
113
+ setter!(self, WordPiece, cleanup, cleanup);
114
+ }
115
+
116
+ pub fn word_piece_prefix(&self) -> String {
117
+ getter!(self, WordPiece, prefix.clone())
118
+ }
119
+
120
+ pub fn word_piece_set_prefix(&self, prefix: String) {
121
+ setter!(self, WordPiece, prefix, prefix);
122
+ }
123
+ }
124
+
125
+ pub struct RbBPEDecoder {}
7
126
 
8
127
  impl RbBPEDecoder {
9
- pub fn new() -> Self {
10
- RbBPEDecoder {
11
- decoder: BPEDecoder::default(),
128
+ pub fn new(suffix: String) -> RbDecoder {
129
+ BPEDecoder::new(suffix).into()
130
+ }
131
+ }
132
+
133
+ pub struct RbByteLevelDecoder {}
134
+
135
+ impl RbByteLevelDecoder {
136
+ pub fn new() -> RbDecoder {
137
+ ByteLevel::default().into()
138
+ }
139
+ }
140
+
141
+ pub struct RbCTC {}
142
+
143
+ impl RbCTC {
144
+ pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> RbDecoder {
145
+ CTC::new(pad_token, word_delimiter_token, cleanup).into()
146
+ }
147
+ }
148
+
149
+ pub struct RbMetaspaceDecoder {}
150
+
151
+ impl RbMetaspaceDecoder {
152
+ pub fn new(replacement: char, add_prefix_space: bool) -> RbDecoder {
153
+ Metaspace::new(replacement, add_prefix_space).into()
154
+ }
155
+ }
156
+
157
+ pub struct RbWordPieceDecoder {}
158
+
159
+ impl RbWordPieceDecoder {
160
+ pub fn new(prefix: String, cleanup: bool) -> RbDecoder {
161
+ WordPiece::new(prefix, cleanup).into()
162
+ }
163
+ }
164
+
165
+ #[derive(Clone, Deserialize, Serialize)]
166
+ #[serde(untagged)]
167
+ pub(crate) enum RbDecoderWrapper {
168
+ // Custom(Arc<RwLock<CustomDecoder>>),
169
+ Wrapped(Arc<RwLock<DecoderWrapper>>),
170
+ }
171
+
172
+ impl<I> From<I> for RbDecoderWrapper
173
+ where
174
+ I: Into<DecoderWrapper>,
175
+ {
176
+ fn from(norm: I) -> Self {
177
+ RbDecoderWrapper::Wrapped(Arc::new(RwLock::new(norm.into())))
178
+ }
179
+ }
180
+
181
+ impl<I> From<I> for RbDecoder
182
+ where
183
+ I: Into<DecoderWrapper>,
184
+ {
185
+ fn from(dec: I) -> Self {
186
+ RbDecoder {
187
+ decoder: dec.into().into(),
188
+ }
189
+ }
190
+ }
191
+
192
+ impl Decoder for RbDecoderWrapper {
193
+ fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
194
+ match self {
195
+ RbDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens),
196
+ // RbDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens),
197
+ }
198
+ }
199
+ }
200
+
201
+ unsafe impl TypedData for RbDecoder {
202
+ fn class() -> RClass {
203
+ *memoize!(RClass: {
204
+ let class: RClass = crate::decoders().const_get("Decoder").unwrap();
205
+ class.undef_alloc_func();
206
+ class
207
+ })
208
+ }
209
+
210
+ fn data_type() -> &'static DataType {
211
+ memoize!(DataType: DataTypeBuilder::<RbDecoder>::new("Tokenizers::Decoders::Decoder").build())
212
+ }
213
+
214
+ fn class_for(value: &Self) -> RClass {
215
+ match &value.decoder {
216
+ RbDecoderWrapper::Wrapped(inner) => match *inner.read().unwrap() {
217
+ DecoderWrapper::BPE(_) => *memoize!(RClass: {
218
+ let class: RClass = crate::decoders().const_get("BPEDecoder").unwrap();
219
+ class.undef_alloc_func();
220
+ class
221
+ }),
222
+ DecoderWrapper::ByteLevel(_) => *memoize!(RClass: {
223
+ let class: RClass = crate::decoders().const_get("ByteLevel").unwrap();
224
+ class.undef_alloc_func();
225
+ class
226
+ }),
227
+ DecoderWrapper::CTC(_) => *memoize!(RClass: {
228
+ let class: RClass = crate::decoders().const_get("CTC").unwrap();
229
+ class.undef_alloc_func();
230
+ class
231
+ }),
232
+ DecoderWrapper::Metaspace(_) => *memoize!(RClass: {
233
+ let class: RClass = crate::decoders().const_get("Metaspace").unwrap();
234
+ class.undef_alloc_func();
235
+ class
236
+ }),
237
+ DecoderWrapper::WordPiece(_) => *memoize!(RClass: {
238
+ let class: RClass = crate::decoders().const_get("WordPiece").unwrap();
239
+ class.undef_alloc_func();
240
+ class
241
+ }),
242
+ _ => todo!(),
243
+ },
12
244
  }
13
245
  }
14
246
  }
247
+
248
+ pub fn decoders(module: &RModule) -> RbResult<()> {
249
+ let decoder = module.define_class("Decoder", Default::default())?;
250
+
251
+ let class = module.define_class("BPEDecoder", decoder)?;
252
+ class.define_singleton_method("_new", function!(RbBPEDecoder::new, 1))?;
253
+ class.define_method("suffix", method!(RbDecoder::bpe_suffix, 0))?;
254
+ class.define_method("suffix=", method!(RbDecoder::bpe_set_suffix, 1))?;
255
+
256
+ let class = module.define_class("ByteLevel", decoder)?;
257
+ class.define_singleton_method("new", function!(RbByteLevelDecoder::new, 0))?;
258
+
259
+ let class = module.define_class("CTC", decoder)?;
260
+ class.define_singleton_method("_new", function!(RbCTC::new, 3))?;
261
+ class.define_method("cleanup", method!(RbDecoder::ctc_cleanup, 0))?;
262
+ class.define_method("cleanup=", method!(RbDecoder::ctc_set_cleanup, 1))?;
263
+ class.define_method("pad_token", method!(RbDecoder::ctc_pad_token, 0))?;
264
+ class.define_method("pad_token=", method!(RbDecoder::ctc_set_pad_token, 1))?;
265
+ class.define_method("word_delimiter_token", method!(RbDecoder::ctc_word_delimiter_token, 0))?;
266
+ class.define_method("word_delimiter_token=", method!(RbDecoder::ctc_set_word_delimiter_token, 1))?;
267
+
268
+ let class = module.define_class("Metaspace", decoder)?;
269
+ class.define_singleton_method("_new", function!(RbMetaspaceDecoder::new, 2))?;
270
+ class.define_method("add_prefix_space", method!(RbDecoder::metaspace_add_prefix_space, 0))?;
271
+ class.define_method("add_prefix_space=", method!(RbDecoder::metaspace_set_add_prefix_space, 1))?;
272
+ class.define_method("replacement", method!(RbDecoder::metaspace_replacement, 0))?;
273
+ class.define_method("replacement=", method!(RbDecoder::metaspace_set_replacement, 1))?;
274
+
275
+ let class = module.define_class("WordPiece", decoder)?;
276
+ class.define_singleton_method("_new", function!(RbWordPieceDecoder::new, 2))?;
277
+ class.define_method("cleanup", method!(RbDecoder::word_piece_cleanup, 0))?;
278
+ class.define_method("cleanup=", method!(RbDecoder::word_piece_set_cleanup, 1))?;
279
+ class.define_method("prefix", method!(RbDecoder::word_piece_prefix, 0))?;
280
+ class.define_method("prefix=", method!(RbDecoder::word_piece_set_prefix, 1))?;
281
+
282
+ Ok(())
283
+ }
@@ -1,16 +1,91 @@
1
- use tk::Encoding;
1
+ use magnus::RArray;
2
+ use tk::{Encoding, Offsets};
2
3
 
3
4
  #[magnus::wrap(class = "Tokenizers::Encoding")]
5
+ #[repr(transparent)]
4
6
  pub struct RbEncoding {
5
7
  pub encoding: Encoding,
6
8
  }
7
9
 
10
+ impl From<Encoding> for RbEncoding {
11
+ fn from(v: Encoding) -> Self {
12
+ Self { encoding: v }
13
+ }
14
+ }
15
+
8
16
  impl RbEncoding {
17
+ pub fn n_sequences(&self) -> usize {
18
+ self.encoding.n_sequences()
19
+ }
20
+
9
21
  pub fn ids(&self) -> Vec<u32> {
10
- self.encoding.get_ids().into()
22
+ self.encoding.get_ids().to_vec()
11
23
  }
12
24
 
13
25
  pub fn tokens(&self) -> Vec<String> {
14
- self.encoding.get_tokens().into()
26
+ self.encoding.get_tokens().to_vec()
27
+ }
28
+
29
+ pub fn word_ids(&self) -> Vec<Option<u32>> {
30
+ self.encoding.get_word_ids().to_vec()
31
+ }
32
+
33
+ pub fn sequence_ids(&self) -> Vec<Option<usize>> {
34
+ self.encoding.get_sequence_ids()
35
+ }
36
+
37
+ pub fn type_ids(&self) -> Vec<u32> {
38
+ self.encoding.get_type_ids().to_vec()
39
+ }
40
+
41
+ pub fn offsets(&self) -> Vec<(usize, usize)> {
42
+ self.encoding.get_offsets().to_vec()
43
+ }
44
+
45
+ pub fn special_tokens_mask(&self) -> Vec<u32> {
46
+ self.encoding.get_special_tokens_mask().to_vec()
47
+ }
48
+
49
+ pub fn attention_mask(&self) -> Vec<u32> {
50
+ self.encoding.get_attention_mask().to_vec()
51
+ }
52
+
53
+ pub fn overflowing(&self) -> RArray {
54
+ self.encoding
55
+ .get_overflowing()
56
+ .clone()
57
+ .into_iter()
58
+ .map(Into::<RbEncoding>::into)
59
+ .collect()
60
+ }
61
+
62
+ pub fn word_to_tokens(&self, word_index: u32, sequence_index: usize) -> Option<(usize, usize)> {
63
+ self.encoding.word_to_tokens(word_index, sequence_index)
64
+ }
65
+
66
+ pub fn word_to_chars(&self, word_index: u32, sequence_index: usize) -> Option<Offsets> {
67
+ self.encoding.word_to_chars(word_index, sequence_index)
68
+ }
69
+
70
+ pub fn token_to_sequence(&self, token_index: usize) -> Option<usize> {
71
+ self.encoding.token_to_sequence(token_index)
72
+ }
73
+
74
+ pub fn token_to_chars(&self, token_index: usize) -> Option<Offsets> {
75
+ let (_, offsets) = self.encoding.token_to_chars(token_index)?;
76
+ Some(offsets)
77
+ }
78
+
79
+ pub fn token_to_word(&self, token_index: usize) -> Option<u32> {
80
+ let (_, word_idx) = self.encoding.token_to_word(token_index)?;
81
+ Some(word_idx)
82
+ }
83
+
84
+ pub fn char_to_token(&self, char_pos: usize, sequence_index: usize) -> Option<usize> {
85
+ self.encoding.char_to_token(char_pos, sequence_index)
86
+ }
87
+
88
+ pub fn char_to_word(&self, char_pos: usize, sequence_index: usize) -> Option<u32> {
89
+ self.encoding.char_to_word(char_pos, sequence_index)
15
90
  }
16
91
  }