tokenizers 0.5.5 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Cargo.lock +124 -53
- data/ext/tokenizers/Cargo.toml +4 -3
- data/ext/tokenizers/src/encoding.rs +10 -8
- data/ext/tokenizers/src/models.rs +37 -24
- data/ext/tokenizers/src/normalizers.rs +1 -2
- data/ext/tokenizers/src/pre_tokenizers.rs +5 -5
- data/ext/tokenizers/src/tokenizer.rs +61 -49
- data/ext/tokenizers/src/trainers.rs +60 -50
- data/ext/tokenizers/src/utils/normalization.rs +3 -2
- data/ext/tokenizers/src/utils/regex.rs +5 -4
- data/lib/tokenizers/from_pretrained.rb +2 -2
- data/lib/tokenizers/trainers/unigram_trainer.rb +10 -9
- data/lib/tokenizers/trainers/word_piece_trainer.rb +10 -9
- data/lib/tokenizers/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0f9e999132cb832793cc26a1cc80462332e71c5c3fc06a90bc652d88f5d850c0
|
4
|
+
data.tar.gz: 0705c9463baed06e0c2fc887514fa983e635f8bb99094bc6b5859ad6adbeaccb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: eb18eb98388d80ad6043d3dac967e40f0c9d7dfaa9edd615dca49f1b42c8ca1acffe5096b3250c5a94f230c17bd660ba1c4c55a16e65294354734e20331017b6
|
7
|
+
data.tar.gz: 8ac036c36f7f43eb050d65bab50b03cb434bd7d02f7abb00b0056b5555b226a9db82bbde56bd3b8d7826491e53343342b273e0ebfa131575fc529d90ed163978
|
data/CHANGELOG.md
CHANGED
data/Cargo.lock
CHANGED
@@ -2,6 +2,20 @@
|
|
2
2
|
# It is not intended for manual editing.
|
3
3
|
version = 3
|
4
4
|
|
5
|
+
[[package]]
|
6
|
+
name = "ahash"
|
7
|
+
version = "0.8.12"
|
8
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
9
|
+
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
|
10
|
+
dependencies = [
|
11
|
+
"cfg-if",
|
12
|
+
"getrandom",
|
13
|
+
"once_cell",
|
14
|
+
"serde",
|
15
|
+
"version_check",
|
16
|
+
"zerocopy",
|
17
|
+
]
|
18
|
+
|
5
19
|
[[package]]
|
6
20
|
name = "aho-corasick"
|
7
21
|
version = "1.1.3"
|
@@ -23,7 +37,7 @@ version = "0.69.5"
|
|
23
37
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
24
38
|
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
25
39
|
dependencies = [
|
26
|
-
"bitflags
|
40
|
+
"bitflags",
|
27
41
|
"cexpr",
|
28
42
|
"clang-sys",
|
29
43
|
"itertools 0.12.1",
|
@@ -37,12 +51,6 @@ dependencies = [
|
|
37
51
|
"syn",
|
38
52
|
]
|
39
53
|
|
40
|
-
[[package]]
|
41
|
-
name = "bitflags"
|
42
|
-
version = "1.3.2"
|
43
|
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
44
|
-
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
45
|
-
|
46
54
|
[[package]]
|
47
55
|
name = "bitflags"
|
48
56
|
version = "2.9.0"
|
@@ -55,6 +63,15 @@ version = "3.17.0"
|
|
55
63
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
56
64
|
checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
|
57
65
|
|
66
|
+
[[package]]
|
67
|
+
name = "castaway"
|
68
|
+
version = "0.2.3"
|
69
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
70
|
+
checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5"
|
71
|
+
dependencies = [
|
72
|
+
"rustversion",
|
73
|
+
]
|
74
|
+
|
58
75
|
[[package]]
|
59
76
|
name = "cc"
|
60
77
|
version = "1.2.21"
|
@@ -90,6 +107,21 @@ dependencies = [
|
|
90
107
|
"libloading",
|
91
108
|
]
|
92
109
|
|
110
|
+
[[package]]
|
111
|
+
name = "compact_str"
|
112
|
+
version = "0.9.0"
|
113
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
114
|
+
checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
|
115
|
+
dependencies = [
|
116
|
+
"castaway",
|
117
|
+
"cfg-if",
|
118
|
+
"itoa",
|
119
|
+
"rustversion",
|
120
|
+
"ryu",
|
121
|
+
"serde",
|
122
|
+
"static_assertions",
|
123
|
+
]
|
124
|
+
|
93
125
|
[[package]]
|
94
126
|
name = "console"
|
95
127
|
version = "0.15.11"
|
@@ -163,6 +195,15 @@ dependencies = [
|
|
163
195
|
"syn",
|
164
196
|
]
|
165
197
|
|
198
|
+
[[package]]
|
199
|
+
name = "dary_heap"
|
200
|
+
version = "0.3.7"
|
201
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
202
|
+
checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
|
203
|
+
dependencies = [
|
204
|
+
"serde",
|
205
|
+
]
|
206
|
+
|
166
207
|
[[package]]
|
167
208
|
name = "derive_builder"
|
168
209
|
version = "0.20.2"
|
@@ -223,12 +264,13 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
|
223
264
|
|
224
265
|
[[package]]
|
225
266
|
name = "getrandom"
|
226
|
-
version = "0.
|
267
|
+
version = "0.3.3"
|
227
268
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
228
|
-
checksum = "
|
269
|
+
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
229
270
|
dependencies = [
|
230
271
|
"cfg-if",
|
231
272
|
"libc",
|
273
|
+
"r-efi",
|
232
274
|
"wasi",
|
233
275
|
]
|
234
276
|
|
@@ -257,15 +299,6 @@ dependencies = [
|
|
257
299
|
"web-time",
|
258
300
|
]
|
259
301
|
|
260
|
-
[[package]]
|
261
|
-
name = "itertools"
|
262
|
-
version = "0.11.0"
|
263
|
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
264
|
-
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
|
265
|
-
dependencies = [
|
266
|
-
"either",
|
267
|
-
]
|
268
|
-
|
269
302
|
[[package]]
|
270
303
|
name = "itertools"
|
271
304
|
version = "0.12.1"
|
@@ -277,9 +310,9 @@ dependencies = [
|
|
277
310
|
|
278
311
|
[[package]]
|
279
312
|
name = "itertools"
|
280
|
-
version = "0.
|
313
|
+
version = "0.14.0"
|
281
314
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
282
|
-
checksum = "
|
315
|
+
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
|
283
316
|
dependencies = [
|
284
317
|
"either",
|
285
318
|
]
|
@@ -352,9 +385,9 @@ checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
|
|
352
385
|
|
353
386
|
[[package]]
|
354
387
|
name = "magnus"
|
355
|
-
version = "0.
|
388
|
+
version = "0.8.0"
|
356
389
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
357
|
-
checksum = "
|
390
|
+
checksum = "3f14d3cc31b2dc4fce6cd447a83c7a7ca2ab8a9f1e535dcb2f796ff972b0e68b"
|
358
391
|
dependencies = [
|
359
392
|
"magnus-macros",
|
360
393
|
"rb-sys",
|
@@ -364,9 +397,9 @@ dependencies = [
|
|
364
397
|
|
365
398
|
[[package]]
|
366
399
|
name = "magnus-macros"
|
367
|
-
version = "0.
|
400
|
+
version = "0.8.0"
|
368
401
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
369
|
-
checksum = "
|
402
|
+
checksum = "47607461fd8e1513cb4f2076c197d8092d921a1ea75bd08af97398f593751892"
|
370
403
|
dependencies = [
|
371
404
|
"proc-macro2",
|
372
405
|
"quote",
|
@@ -430,11 +463,11 @@ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
|
430
463
|
|
431
464
|
[[package]]
|
432
465
|
name = "onig"
|
433
|
-
version = "6.
|
466
|
+
version = "6.5.1"
|
434
467
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
435
|
-
checksum = "
|
468
|
+
checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0"
|
436
469
|
dependencies = [
|
437
|
-
"bitflags
|
470
|
+
"bitflags",
|
438
471
|
"libc",
|
439
472
|
"once_cell",
|
440
473
|
"onig_sys",
|
@@ -442,9 +475,9 @@ dependencies = [
|
|
442
475
|
|
443
476
|
[[package]]
|
444
477
|
name = "onig_sys"
|
445
|
-
version = "69.
|
478
|
+
version = "69.9.1"
|
446
479
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
447
|
-
checksum = "
|
480
|
+
checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc"
|
448
481
|
dependencies = [
|
449
482
|
"cc",
|
450
483
|
"pkg-config",
|
@@ -495,22 +528,27 @@ dependencies = [
|
|
495
528
|
"proc-macro2",
|
496
529
|
]
|
497
530
|
|
531
|
+
[[package]]
|
532
|
+
name = "r-efi"
|
533
|
+
version = "5.3.0"
|
534
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
535
|
+
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
536
|
+
|
498
537
|
[[package]]
|
499
538
|
name = "rand"
|
500
|
-
version = "0.
|
539
|
+
version = "0.9.1"
|
501
540
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
502
|
-
checksum = "
|
541
|
+
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
503
542
|
dependencies = [
|
504
|
-
"libc",
|
505
543
|
"rand_chacha",
|
506
544
|
"rand_core",
|
507
545
|
]
|
508
546
|
|
509
547
|
[[package]]
|
510
548
|
name = "rand_chacha"
|
511
|
-
version = "0.
|
549
|
+
version = "0.9.0"
|
512
550
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
513
|
-
checksum = "
|
551
|
+
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
514
552
|
dependencies = [
|
515
553
|
"ppv-lite86",
|
516
554
|
"rand_core",
|
@@ -518,9 +556,9 @@ dependencies = [
|
|
518
556
|
|
519
557
|
[[package]]
|
520
558
|
name = "rand_core"
|
521
|
-
version = "0.
|
559
|
+
version = "0.9.3"
|
522
560
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
523
|
-
checksum = "
|
561
|
+
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
|
524
562
|
dependencies = [
|
525
563
|
"getrandom",
|
526
564
|
]
|
@@ -537,12 +575,12 @@ dependencies = [
|
|
537
575
|
|
538
576
|
[[package]]
|
539
577
|
name = "rayon-cond"
|
540
|
-
version = "0.
|
578
|
+
version = "0.4.0"
|
541
579
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
542
|
-
checksum = "
|
580
|
+
checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f"
|
543
581
|
dependencies = [
|
544
582
|
"either",
|
545
|
-
"itertools 0.
|
583
|
+
"itertools 0.14.0",
|
546
584
|
"rayon",
|
547
585
|
]
|
548
586
|
|
@@ -558,18 +596,18 @@ dependencies = [
|
|
558
596
|
|
559
597
|
[[package]]
|
560
598
|
name = "rb-sys"
|
561
|
-
version = "0.9.
|
599
|
+
version = "0.9.117"
|
562
600
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
563
|
-
checksum = "
|
601
|
+
checksum = "f900d1ce4629a2ebffaf5de74bd8f9c1188d4c5ed406df02f97e22f77a006f44"
|
564
602
|
dependencies = [
|
565
603
|
"rb-sys-build",
|
566
604
|
]
|
567
605
|
|
568
606
|
[[package]]
|
569
607
|
name = "rb-sys-build"
|
570
|
-
version = "0.9.
|
608
|
+
version = "0.9.117"
|
571
609
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
572
|
-
checksum = "
|
610
|
+
checksum = "ef1e9c857028f631056bcd6d88cec390c751e343ce2223ddb26d23eb4a151d59"
|
573
611
|
dependencies = [
|
574
612
|
"bindgen",
|
575
613
|
"lazy_static",
|
@@ -582,9 +620,9 @@ dependencies = [
|
|
582
620
|
|
583
621
|
[[package]]
|
584
622
|
name = "rb-sys-env"
|
585
|
-
version = "0.
|
623
|
+
version = "0.2.2"
|
586
624
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
587
|
-
checksum = "
|
625
|
+
checksum = "08f8d2924cf136a1315e2b4c7460a39f62ef11ee5d522df9b2750fab55b868b6"
|
588
626
|
|
589
627
|
[[package]]
|
590
628
|
name = "regex"
|
@@ -621,6 +659,12 @@ version = "1.1.0"
|
|
621
659
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
622
660
|
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
623
661
|
|
662
|
+
[[package]]
|
663
|
+
name = "rustversion"
|
664
|
+
version = "1.0.21"
|
665
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
666
|
+
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
|
667
|
+
|
624
668
|
[[package]]
|
625
669
|
name = "ryu"
|
626
670
|
version = "1.0.20"
|
@@ -695,6 +739,12 @@ dependencies = [
|
|
695
739
|
"unicode-segmentation",
|
696
740
|
]
|
697
741
|
|
742
|
+
[[package]]
|
743
|
+
name = "static_assertions"
|
744
|
+
version = "1.1.0"
|
745
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
746
|
+
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
747
|
+
|
698
748
|
[[package]]
|
699
749
|
name = "strsim"
|
700
750
|
version = "0.11.1"
|
@@ -734,27 +784,30 @@ dependencies = [
|
|
734
784
|
|
735
785
|
[[package]]
|
736
786
|
name = "tokenizers"
|
737
|
-
version = "0.
|
787
|
+
version = "0.6.0"
|
738
788
|
dependencies = [
|
789
|
+
"ahash",
|
739
790
|
"magnus",
|
740
791
|
"onig",
|
741
792
|
"serde",
|
742
|
-
"tokenizers 0.
|
793
|
+
"tokenizers 0.22.0",
|
743
794
|
]
|
744
795
|
|
745
796
|
[[package]]
|
746
797
|
name = "tokenizers"
|
747
|
-
version = "0.
|
798
|
+
version = "0.22.0"
|
748
799
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
749
|
-
checksum = "
|
800
|
+
checksum = "af10f51be57162b69d90a15cb226eef12c9e4faecbd5e3ea98a86bfb920b3d71"
|
750
801
|
dependencies = [
|
802
|
+
"ahash",
|
751
803
|
"aho-corasick",
|
804
|
+
"compact_str",
|
805
|
+
"dary_heap",
|
752
806
|
"derive_builder",
|
753
807
|
"esaxx-rs",
|
754
808
|
"getrandom",
|
755
809
|
"indicatif",
|
756
|
-
"itertools 0.
|
757
|
-
"lazy_static",
|
810
|
+
"itertools 0.14.0",
|
758
811
|
"log",
|
759
812
|
"macro_rules_attribute",
|
760
813
|
"monostate",
|
@@ -807,11 +860,20 @@ version = "0.1.1"
|
|
807
860
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
808
861
|
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
809
862
|
|
863
|
+
[[package]]
|
864
|
+
name = "version_check"
|
865
|
+
version = "0.9.5"
|
866
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
867
|
+
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
868
|
+
|
810
869
|
[[package]]
|
811
870
|
name = "wasi"
|
812
|
-
version = "0.
|
871
|
+
version = "0.14.2+wasi-0.2.4"
|
813
872
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
814
|
-
checksum = "
|
873
|
+
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
|
874
|
+
dependencies = [
|
875
|
+
"wit-bindgen-rt",
|
876
|
+
]
|
815
877
|
|
816
878
|
[[package]]
|
817
879
|
name = "wasm-bindgen"
|
@@ -953,6 +1015,15 @@ version = "0.52.6"
|
|
953
1015
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
954
1016
|
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
955
1017
|
|
1018
|
+
[[package]]
|
1019
|
+
name = "wit-bindgen-rt"
|
1020
|
+
version = "0.39.0"
|
1021
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1022
|
+
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
|
1023
|
+
dependencies = [
|
1024
|
+
"bitflags",
|
1025
|
+
]
|
1026
|
+
|
956
1027
|
[[package]]
|
957
1028
|
name = "zerocopy"
|
958
1029
|
version = "0.8.25"
|
data/ext/tokenizers/Cargo.toml
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
[package]
|
2
2
|
name = "tokenizers"
|
3
|
-
version = "0.
|
3
|
+
version = "0.6.0"
|
4
4
|
license = "Apache-2.0"
|
5
5
|
authors = ["Andrew Kane <andrew@ankane.org>"]
|
6
6
|
edition = "2021"
|
@@ -11,11 +11,12 @@ publish = false
|
|
11
11
|
crate-type = ["cdylib"]
|
12
12
|
|
13
13
|
[dependencies]
|
14
|
-
|
14
|
+
ahash = { version = "0.8.11", features = ["serde"] }
|
15
|
+
magnus = "0.8"
|
15
16
|
onig = { version = "6", default-features = false }
|
16
17
|
serde = { version = "1", features = ["rc", "derive"] }
|
17
18
|
|
18
19
|
[dependencies.tokenizers]
|
19
|
-
version = "=0.
|
20
|
+
version = "=0.22.0" # also update in from_pretrained.rb
|
20
21
|
default-features = false
|
21
22
|
features = ["progressbar", "onig", "esaxx_fast"]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
use magnus::RArray;
|
1
|
+
use magnus::{RArray, Ruby};
|
2
2
|
use tk::{Encoding, Offsets};
|
3
3
|
|
4
4
|
#[magnus::wrap(class = "Tokenizers::Encoding")]
|
@@ -50,13 +50,15 @@ impl RbEncoding {
|
|
50
50
|
self.encoding.get_attention_mask().to_vec()
|
51
51
|
}
|
52
52
|
|
53
|
-
pub fn overflowing(&
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
53
|
+
pub fn overflowing(ruby: &Ruby, rb_self: &Self) -> RArray {
|
54
|
+
ruby.ary_from_iter(
|
55
|
+
rb_self
|
56
|
+
.encoding
|
57
|
+
.get_overflowing()
|
58
|
+
.clone()
|
59
|
+
.into_iter()
|
60
|
+
.map(Into::<RbEncoding>::into),
|
61
|
+
)
|
60
62
|
}
|
61
63
|
|
62
64
|
pub fn word_to_tokens(&self, word_index: u32, sequence_index: usize) -> Option<(usize, usize)> {
|
@@ -3,14 +3,14 @@ use std::path::{Path, PathBuf};
|
|
3
3
|
use std::sync::{Arc, RwLock};
|
4
4
|
|
5
5
|
use crate::trainers::RbTrainer;
|
6
|
+
use ahash::AHashMap;
|
6
7
|
use magnus::prelude::*;
|
7
8
|
use magnus::{
|
8
|
-
data_type_builder,
|
9
|
-
|
10
|
-
TypedData, Value,
|
9
|
+
data_type_builder, function, method, value::Lazy, Class, DataType, DataTypeFunctions, Error,
|
10
|
+
Module, Object, RClass, RHash, RModule, Ruby, TryConvert, TypedData, Value,
|
11
11
|
};
|
12
12
|
use serde::{Deserialize, Serialize};
|
13
|
-
use tk::models::bpe::{BpeBuilder, Merges,
|
13
|
+
use tk::models::bpe::{BpeBuilder, Merges, BPE};
|
14
14
|
use tk::models::unigram::Unigram;
|
15
15
|
use tk::models::wordlevel::WordLevel;
|
16
16
|
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
|
@@ -72,52 +72,59 @@ pub struct RbBPE {}
|
|
72
72
|
|
73
73
|
impl RbBPE {
|
74
74
|
fn with_builder(mut builder: BpeBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
75
|
-
let
|
75
|
+
let ruby = Ruby::get().unwrap();
|
76
|
+
|
77
|
+
let value: Value = kwargs.delete(ruby.to_symbol("cache_capacity"))?;
|
76
78
|
if !value.is_nil() {
|
77
79
|
builder = builder.cache_capacity(TryConvert::try_convert(value)?);
|
78
80
|
}
|
79
81
|
|
80
|
-
let value: Value = kwargs.delete(
|
82
|
+
let value: Value = kwargs.delete(ruby.to_symbol("dropout"))?;
|
81
83
|
if !value.is_nil() {
|
82
84
|
builder = builder.dropout(TryConvert::try_convert(value)?);
|
83
85
|
}
|
84
86
|
|
85
|
-
let value: Value = kwargs.delete(
|
87
|
+
let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
|
86
88
|
if !value.is_nil() {
|
87
89
|
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
88
90
|
}
|
89
91
|
|
90
|
-
let value: Value = kwargs.delete(
|
92
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
91
93
|
if !value.is_nil() {
|
92
94
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
93
95
|
}
|
94
96
|
|
95
|
-
let value: Value = kwargs.delete(
|
97
|
+
let value: Value = kwargs.delete(ruby.to_symbol("end_of_word_suffix"))?;
|
96
98
|
if !value.is_nil() {
|
97
99
|
builder = builder.end_of_word_suffix(TryConvert::try_convert(value)?);
|
98
100
|
}
|
99
101
|
|
100
|
-
let value: Value = kwargs.delete(
|
102
|
+
let value: Value = kwargs.delete(ruby.to_symbol("fuse_unk"))?;
|
101
103
|
if !value.is_nil() {
|
102
104
|
builder = builder.fuse_unk(TryConvert::try_convert(value)?);
|
103
105
|
}
|
104
106
|
|
105
|
-
let value: Value = kwargs.delete(
|
107
|
+
let value: Value = kwargs.delete(ruby.to_symbol("byte_fallback"))?;
|
106
108
|
if !value.is_nil() {
|
107
109
|
builder = builder.byte_fallback(TryConvert::try_convert(value)?);
|
108
110
|
}
|
109
111
|
|
110
112
|
if !kwargs.is_empty() {
|
111
113
|
// TODO improve message
|
112
|
-
return Err(Error::new(
|
114
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
113
115
|
}
|
114
116
|
|
115
117
|
builder.build().map(|v| v.into()).map_err(RbError::from)
|
116
118
|
}
|
117
119
|
|
118
|
-
pub fn new(
|
120
|
+
pub fn new(
|
121
|
+
vocab: Option<HashMap<String, u32>>,
|
122
|
+
merges: Option<Merges>,
|
123
|
+
kwargs: RHash,
|
124
|
+
) -> RbResult<RbModel> {
|
119
125
|
let mut builder = BPE::builder();
|
120
126
|
if let (Some(vocab), Some(merges)) = (vocab, merges) {
|
127
|
+
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
|
121
128
|
builder = builder.vocab_and_merges(vocab, merges);
|
122
129
|
}
|
123
130
|
RbBPE::with_builder(builder, kwargs)
|
@@ -125,7 +132,7 @@ impl RbBPE {
|
|
125
132
|
|
126
133
|
pub fn from_file(vocab: String, merges: String, kwargs: RHash) -> RbResult<RbModel> {
|
127
134
|
let (vocab, merges) = BPE::read_file(&vocab, &merges).map_err(RbError::from)?;
|
128
|
-
|
135
|
+
let vocab = vocab.into_iter().collect();
|
129
136
|
RbBPE::new(Some(vocab), Some(merges), kwargs)
|
130
137
|
}
|
131
138
|
}
|
@@ -251,6 +258,7 @@ pub struct RbUnigram {}
|
|
251
258
|
|
252
259
|
impl RbUnigram {
|
253
260
|
fn new(
|
261
|
+
ruby: &Ruby,
|
254
262
|
vocab: Option<Vec<(String, f64)>>,
|
255
263
|
unk_id: Option<usize>,
|
256
264
|
byte_fallback: Option<bool>,
|
@@ -263,7 +271,7 @@ impl RbUnigram {
|
|
263
271
|
}
|
264
272
|
(None, None, _) => Ok(Unigram::default().into()),
|
265
273
|
_ => Err(Error::new(
|
266
|
-
|
274
|
+
ruby.exception_arg_error(),
|
267
275
|
"`vocab` and `unk_id` must be both specified",
|
268
276
|
)),
|
269
277
|
}
|
@@ -279,7 +287,7 @@ impl RbWordLevel {
|
|
279
287
|
) -> RbResult<RbModel> {
|
280
288
|
let mut builder = WordLevel::builder();
|
281
289
|
if let Some(vocab) = vocab {
|
282
|
-
builder = builder.vocab(vocab);
|
290
|
+
builder = builder.vocab(vocab.into_iter().collect());
|
283
291
|
}
|
284
292
|
if let Some(unk_token) = unk_token {
|
285
293
|
builder = builder.unk_token(unk_token);
|
@@ -287,13 +295,15 @@ impl RbWordLevel {
|
|
287
295
|
builder.build().map(|v| v.into()).map_err(RbError::from)
|
288
296
|
}
|
289
297
|
|
290
|
-
pub fn read_file(vocab: String) -> RbResult<
|
291
|
-
WordLevel::read_file(&vocab).map_err(RbError::from)
|
298
|
+
pub fn read_file(vocab: String) -> RbResult<HashMap<String, u32>> {
|
299
|
+
let vocab = WordLevel::read_file(&vocab).map_err(RbError::from)?;
|
300
|
+
let vocab: HashMap<_, _> = vocab.into_iter().collect();
|
301
|
+
Ok(vocab)
|
292
302
|
}
|
293
303
|
|
294
304
|
pub fn from_file(vocab: String, unk_token: Option<String>) -> RbResult<RbModel> {
|
295
305
|
let vocab = WordLevel::read_file(&vocab).map_err(RbError::from)?;
|
296
|
-
|
306
|
+
let vocab = vocab.into_iter().collect();
|
297
307
|
RbWordLevel::new(Some(vocab), unk_token)
|
298
308
|
}
|
299
309
|
}
|
@@ -302,24 +312,26 @@ pub struct RbWordPiece {}
|
|
302
312
|
|
303
313
|
impl RbWordPiece {
|
304
314
|
fn with_builder(mut builder: WordPieceBuilder, kwargs: RHash) -> RbResult<RbModel> {
|
305
|
-
let
|
315
|
+
let ruby = Ruby::get().unwrap();
|
316
|
+
|
317
|
+
let value: Value = kwargs.delete(ruby.to_symbol("unk_token"))?;
|
306
318
|
if !value.is_nil() {
|
307
319
|
builder = builder.unk_token(TryConvert::try_convert(value)?);
|
308
320
|
}
|
309
321
|
|
310
|
-
let value: Value = kwargs.delete(
|
322
|
+
let value: Value = kwargs.delete(ruby.to_symbol("max_input_chars_per_word"))?;
|
311
323
|
if !value.is_nil() {
|
312
324
|
builder = builder.max_input_chars_per_word(TryConvert::try_convert(value)?);
|
313
325
|
}
|
314
326
|
|
315
|
-
let value: Value = kwargs.delete(
|
327
|
+
let value: Value = kwargs.delete(ruby.to_symbol("continuing_subword_prefix"))?;
|
316
328
|
if !value.is_nil() {
|
317
329
|
builder = builder.continuing_subword_prefix(TryConvert::try_convert(value)?);
|
318
330
|
}
|
319
331
|
|
320
332
|
if !kwargs.is_empty() {
|
321
333
|
// TODO improve message
|
322
|
-
return Err(Error::new(
|
334
|
+
return Err(Error::new(ruby.exception_arg_error(), "unknown keyword"));
|
323
335
|
}
|
324
336
|
|
325
337
|
builder.build().map(|v| v.into()).map_err(RbError::from)
|
@@ -328,6 +340,7 @@ impl RbWordPiece {
|
|
328
340
|
pub fn new(vocab: Option<HashMap<String, u32>>, kwargs: RHash) -> RbResult<RbModel> {
|
329
341
|
let mut builder = WordPiece::builder();
|
330
342
|
if let Some(vocab) = vocab {
|
343
|
+
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
|
331
344
|
builder = builder.vocab(vocab);
|
332
345
|
}
|
333
346
|
RbWordPiece::with_builder(builder, kwargs)
|
@@ -336,7 +349,7 @@ impl RbWordPiece {
|
|
336
349
|
pub fn from_file(vocab: String, kwargs: RHash) -> RbResult<RbModel> {
|
337
350
|
let vocab = WordPiece::read_file(&vocab).map_err(RbError::from)?;
|
338
351
|
|
339
|
-
RbWordPiece::new(Some(vocab), kwargs)
|
352
|
+
RbWordPiece::new(Some(vocab.into_iter().collect()), kwargs)
|
340
353
|
}
|
341
354
|
}
|
342
355
|
|
@@ -199,8 +199,7 @@ impl RbPrecompiled {
|
|
199
199
|
Precompiled::from(&precompiled_charsmap)
|
200
200
|
.map_err(|e| {
|
201
201
|
RbError::new_err(format!(
|
202
|
-
"Error while attempting to build Precompiled normalizer: {}"
|
203
|
-
e
|
202
|
+
"Error while attempting to build Precompiled normalizer: {e}"
|
204
203
|
))
|
205
204
|
})
|
206
205
|
.map(|v| v.into())
|