youtokentome 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/ext/youtokentome/ext.cpp +135 -0
- data/ext/youtokentome/extconf.rb +12 -0
- data/lib/youtokentome.rb +10 -0
- data/lib/youtokentome/bpe.rb +54 -0
- data/lib/youtokentome/ext.bundle +0 -0
- data/lib/youtokentome/version.rb +3 -0
- data/vendor/YouTokenToMe/LICENSE +19 -0
- data/vendor/YouTokenToMe/README.md +304 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/bpe.cpp +2185 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/bpe.h +86 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/third_party/LICENSE +23 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/third_party/flat_hash_map.h +1502 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utf8.cpp +134 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utf8.h +23 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utils.cpp +119 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utils.h +105 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/yttm.pyx +182 -0
- metadata +133 -0
@@ -0,0 +1,134 @@
|
|
1
|
+
#include "utf8.h"
|
2
|
+
#include <cassert>
|
3
|
+
#include <iostream>
|
4
|
+
#include <string>
|
5
|
+
#include <vector>
|
6
|
+
#include "utils.h"
|
7
|
+
|
8
|
+
namespace vkcom {
|
9
|
+
|
10
|
+
using std::string;
|
11
|
+
using std::vector;
|
12
|
+
|
13
|
+
|
14
|
+
bool check_byte(char x) { return (static_cast<uint8_t>(x) & 0xc0u) == 0x80u; }
|
15
|
+
|
16
|
+
bool check_codepoint(uint32_t x) {
|
17
|
+
return (x < 0xd800) || (0xdfff < x && x < 0x110000);
|
18
|
+
}
|
19
|
+
|
20
|
+
uint64_t utf_length(char ch) {
|
21
|
+
if ((static_cast<uint8_t>(ch) & 0x80u) == 0) {
|
22
|
+
return 1;
|
23
|
+
}
|
24
|
+
if ((static_cast<uint8_t>(ch) & 0xe0u) == 0xc0) {
|
25
|
+
return 2;
|
26
|
+
}
|
27
|
+
if ((static_cast<uint8_t>(ch) & 0xf0u) == 0xe0) {
|
28
|
+
return 3;
|
29
|
+
}
|
30
|
+
if ((static_cast<uint8_t>(ch) & 0xf8u) == 0xf0) {
|
31
|
+
return 4;
|
32
|
+
}
|
33
|
+
// Invalid utf-8
|
34
|
+
return 0;
|
35
|
+
}
|
36
|
+
|
37
|
+
uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) {
|
38
|
+
uint64_t length = utf_length(begin[0]);
|
39
|
+
if (length == 1) {
|
40
|
+
*utf8_len = 1;
|
41
|
+
return static_cast<uint8_t>(begin[0]);
|
42
|
+
}
|
43
|
+
uint32_t code_point = 0;
|
44
|
+
if (size >= 2 && length == 2 && check_byte(begin[1])) {
|
45
|
+
code_point += (static_cast<uint8_t>(begin[0]) & 0x1fu) << 6u;
|
46
|
+
code_point += (static_cast<uint8_t>(begin[1]) & 0x3fu);
|
47
|
+
if (code_point >= 0x0080 && check_codepoint(code_point)) {
|
48
|
+
*utf8_len = 2;
|
49
|
+
return code_point;
|
50
|
+
}
|
51
|
+
} else if (size >= 3 && length == 3 && check_byte(begin[1]) &&
|
52
|
+
check_byte(begin[2])) {
|
53
|
+
code_point += (static_cast<uint8_t>(begin[0]) & 0x0fu) << 12u;
|
54
|
+
code_point += (static_cast<uint8_t>(begin[1]) & 0x3fu) << 6u;
|
55
|
+
code_point += (static_cast<uint8_t>(begin[2]) & 0x3fu);
|
56
|
+
if (code_point >= 0x0800 && check_codepoint(code_point)) {
|
57
|
+
*utf8_len = 3;
|
58
|
+
return code_point;
|
59
|
+
}
|
60
|
+
} else if (size >= 4 && length == 4 && check_byte(begin[1]) &&
|
61
|
+
check_byte(begin[2]) && check_byte(begin[3])) {
|
62
|
+
code_point += (static_cast<uint8_t>(begin[0]) & 0x07u) << 18u;
|
63
|
+
code_point += (static_cast<uint8_t>(begin[1]) & 0x3fu) << 12u;
|
64
|
+
code_point += (static_cast<uint8_t>(begin[2]) & 0x3fu) << 6u;
|
65
|
+
code_point += (static_cast<uint8_t>(begin[3]) & 0x3fu);
|
66
|
+
if (code_point >= 0x10000 && check_codepoint(code_point)) {
|
67
|
+
*utf8_len = 4;
|
68
|
+
return code_point;
|
69
|
+
}
|
70
|
+
}
|
71
|
+
// Invalid utf-8
|
72
|
+
*utf8_len = 1;
|
73
|
+
return INVALID_UNICODE;
|
74
|
+
}
|
75
|
+
|
76
|
+
void utf8_to_chars(uint32_t x, std::back_insert_iterator<string> it) {
|
77
|
+
assert(check_codepoint(x));
|
78
|
+
|
79
|
+
if (x <= 0x7f) {
|
80
|
+
*(it++) = x;
|
81
|
+
return;
|
82
|
+
}
|
83
|
+
|
84
|
+
if (x <= 0x7ff) {
|
85
|
+
*(it++) = 0xc0u | (x >> 6u);
|
86
|
+
*(it++) = 0x80u | (x & 0x3fu);
|
87
|
+
return;
|
88
|
+
}
|
89
|
+
|
90
|
+
if (x <= 0xffff) {
|
91
|
+
*(it++) = 0xe0u | (x >> 12u);
|
92
|
+
*(it++) = 0x80u | ((x >> 6u) & 0x3fu);
|
93
|
+
*(it++) = 0x80u | (x & 0x3fu);
|
94
|
+
return;
|
95
|
+
}
|
96
|
+
|
97
|
+
*(it++) = 0xf0u | (x >> 18u);
|
98
|
+
*(it++) = 0x80u | ((x >> 12u) & 0x3fu);
|
99
|
+
*(it++) = 0x80u | ((x >> 6u) & 0x3fu);
|
100
|
+
*(it++) = 0x80u | (x & 0x3fu);
|
101
|
+
}
|
102
|
+
|
103
|
+
string encode_utf8(const vector<uint32_t>& text) {
|
104
|
+
string utf8_text;
|
105
|
+
for (const uint32_t c : text) {
|
106
|
+
utf8_to_chars(c, std::back_inserter(utf8_text));
|
107
|
+
}
|
108
|
+
return utf8_text;
|
109
|
+
}
|
110
|
+
|
111
|
+
vector<uint32_t> decode_utf8(const char* begin, const char* end) {
|
112
|
+
vector<uint32_t> decoded_text;
|
113
|
+
uint64_t utf8_len = 0;
|
114
|
+
bool invalid_input = false;
|
115
|
+
for (; begin < end; begin += utf8_len) {
|
116
|
+
uint32_t code_point = chars_to_utf8(begin, end - begin, &utf8_len);
|
117
|
+
if (code_point != INVALID_UNICODE) {
|
118
|
+
decoded_text.push_back(code_point);
|
119
|
+
} else {
|
120
|
+
invalid_input = true;
|
121
|
+
}
|
122
|
+
}
|
123
|
+
if (invalid_input) {
|
124
|
+
std::cerr << "WARNING Input contains invalid unicode characters."
|
125
|
+
<< std::endl;
|
126
|
+
}
|
127
|
+
return decoded_text;
|
128
|
+
}
|
129
|
+
|
130
|
+
vector<uint32_t> decode_utf8(const string& utf8_text) {
|
131
|
+
return decode_utf8(utf8_text.data(), utf8_text.data() + utf8_text.size());
|
132
|
+
}
|
133
|
+
|
134
|
+
} // namespace vkcom
|
@@ -0,0 +1,23 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include "utils.h"
|
4
|
+
|
5
|
+
namespace vkcom {
|
6
|
+
|
7
|
+
constexpr static uint32_t INVALID_UNICODE = 0x0fffffff;
|
8
|
+
|
9
|
+
uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len);
|
10
|
+
|
11
|
+
std::string encode_utf8(const std::vector<uint32_t> &utext);
|
12
|
+
|
13
|
+
std::vector<uint32_t> decode_utf8(const char *begin, const char *end);
|
14
|
+
|
15
|
+
std::vector<uint32_t> decode_utf8(const std::string &utf8_text);
|
16
|
+
|
17
|
+
|
18
|
+
|
19
|
+
|
20
|
+
} // namespace vkcom
|
21
|
+
|
22
|
+
|
23
|
+
|
@@ -0,0 +1,119 @@
|
|
1
|
+
#include "utils.h"
|
2
|
+
#include <cassert>
|
3
|
+
#include <fstream>
|
4
|
+
#include <iostream>
|
5
|
+
#include <string>
|
6
|
+
#include <vector>
|
7
|
+
|
8
|
+
namespace vkcom {
|
9
|
+
using std::string;
|
10
|
+
using std::vector;
|
11
|
+
|
12
|
+
void SpecialTokens::dump(std::ofstream &fout) {
|
13
|
+
fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id
|
14
|
+
<< std::endl;
|
15
|
+
}
|
16
|
+
|
17
|
+
void SpecialTokens::load(std::ifstream &fin) {
|
18
|
+
fin >> unk_id >> pad_id >> bos_id >> eos_id;
|
19
|
+
}
|
20
|
+
|
21
|
+
uint32_t SpecialTokens::max_id() const {
|
22
|
+
int ret = 0;
|
23
|
+
ret = std::max(ret, unk_id);
|
24
|
+
ret = std::max(ret, pad_id);
|
25
|
+
ret = std::max(ret, bos_id);
|
26
|
+
ret = std::max(ret, eos_id);
|
27
|
+
return ret;
|
28
|
+
}
|
29
|
+
|
30
|
+
bool SpecialTokens::taken_id(int id) const {
|
31
|
+
return id == unk_id || id == pad_id || id == bos_id || id == eos_id;
|
32
|
+
}
|
33
|
+
|
34
|
+
uint64_t SpecialTokens::n_special_tokens() const {
|
35
|
+
uint64_t cnt = 0;
|
36
|
+
cnt += (unk_id != -1);
|
37
|
+
cnt += (pad_id != -1);
|
38
|
+
cnt += (bos_id != -1);
|
39
|
+
cnt += (eos_id != -1);
|
40
|
+
return cnt;
|
41
|
+
}
|
42
|
+
|
43
|
+
SpecialTokens::SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id)
|
44
|
+
: pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {}
|
45
|
+
|
46
|
+
bool BPE_Rule::operator==(const BPE_Rule &other) const {
|
47
|
+
return x == other.x && y == other.y && z == other.z;
|
48
|
+
}
|
49
|
+
|
50
|
+
BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}
|
51
|
+
|
52
|
+
void BPEState::dump(const string &file_name) {
|
53
|
+
std::ofstream fout(file_name, std::ios::out);
|
54
|
+
if (fout.fail()) {
|
55
|
+
std::cerr << "Can't open file: " << file_name << std::endl;
|
56
|
+
assert(false);
|
57
|
+
}
|
58
|
+
fout << char2id.size() << " " << rules.size() << std::endl;
|
59
|
+
for (auto s : char2id) {
|
60
|
+
fout << s.first << " " << s.second << std::endl;
|
61
|
+
}
|
62
|
+
|
63
|
+
for (auto rule : rules) {
|
64
|
+
fout << rule.x << " " << rule.y << " " << rule.z << std::endl;
|
65
|
+
}
|
66
|
+
special_tokens.dump(fout);
|
67
|
+
fout.close();
|
68
|
+
}
|
69
|
+
|
70
|
+
Status BPEState::load(const string &file_name) {
|
71
|
+
char2id.clear();
|
72
|
+
rules.clear();
|
73
|
+
std::ifstream fin(file_name, std::ios::in);
|
74
|
+
if (fin.fail()) {
|
75
|
+
return Status(1, "Can not open file with model: " + file_name);
|
76
|
+
}
|
77
|
+
int n, m;
|
78
|
+
fin >> n >> m;
|
79
|
+
for (int i = 0; i < n; i++) {
|
80
|
+
uint32_t inner_id;
|
81
|
+
uint32_t utf32_id;
|
82
|
+
fin >> inner_id >> utf32_id;
|
83
|
+
char2id[inner_id] = utf32_id;
|
84
|
+
}
|
85
|
+
for (int i = 0; i < m; i++) {
|
86
|
+
uint32_t x, y, z;
|
87
|
+
fin >> x >> y >> z;
|
88
|
+
rules.emplace_back(x, y, z);
|
89
|
+
}
|
90
|
+
special_tokens.load(fin);
|
91
|
+
fin.close();
|
92
|
+
return Status();
|
93
|
+
}
|
94
|
+
|
95
|
+
BpeConfig::BpeConfig(double _character_coverage, int _n_threads,
|
96
|
+
const SpecialTokens &_special_tokens)
|
97
|
+
: character_coverage(_character_coverage),
|
98
|
+
n_threads(_n_threads),
|
99
|
+
special_tokens(_special_tokens) {}
|
100
|
+
|
101
|
+
vector<string> read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) {
|
102
|
+
vector<string> sentences;
|
103
|
+
string s;
|
104
|
+
while (*processed < batch_limit && getline(std::cin, s)) {
|
105
|
+
*processed += s.size();
|
106
|
+
sentences.push_back(std::move(s));
|
107
|
+
}
|
108
|
+
return sentences;
|
109
|
+
}
|
110
|
+
|
111
|
+
Status::Status(int code, std::string message) : code(code), message(std::move(message)) {}
|
112
|
+
|
113
|
+
const std::string &Status::error_message() const {
|
114
|
+
return message;
|
115
|
+
}
|
116
|
+
bool Status::ok() const {
|
117
|
+
return code == 0;
|
118
|
+
}
|
119
|
+
} // namespace vkcom
|
@@ -0,0 +1,105 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <iostream>
|
4
|
+
#include <string>
|
5
|
+
#include <vector>
|
6
|
+
#include "third_party/flat_hash_map.h"
|
7
|
+
|
8
|
+
namespace vkcom {
|
9
|
+
const uint32_t SPACE_TOKEN = 9601;
|
10
|
+
|
11
|
+
struct BPE_Rule {
|
12
|
+
// x + y -> z
|
13
|
+
uint32_t x{0};
|
14
|
+
uint32_t y{0};
|
15
|
+
uint32_t z{0};
|
16
|
+
|
17
|
+
BPE_Rule() = default;
|
18
|
+
|
19
|
+
BPE_Rule(uint32_t x, uint32_t y, uint32_t z);
|
20
|
+
|
21
|
+
bool operator==(const BPE_Rule &other) const;
|
22
|
+
};
|
23
|
+
|
24
|
+
struct SpecialTokens {
|
25
|
+
int pad_id = -1;
|
26
|
+
int unk_id = -1;
|
27
|
+
int bos_id = -1;
|
28
|
+
int eos_id = -1;
|
29
|
+
|
30
|
+
SpecialTokens() = default;
|
31
|
+
|
32
|
+
SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id);
|
33
|
+
|
34
|
+
void dump(std::ofstream &fout);
|
35
|
+
|
36
|
+
void load(std::ifstream &fin);
|
37
|
+
|
38
|
+
uint32_t max_id() const;
|
39
|
+
|
40
|
+
bool taken_id(int id) const;
|
41
|
+
|
42
|
+
uint64_t n_special_tokens() const;
|
43
|
+
};
|
44
|
+
|
45
|
+
struct BpeConfig {
|
46
|
+
double character_coverage = 1;
|
47
|
+
int n_threads = 0;
|
48
|
+
SpecialTokens special_tokens;
|
49
|
+
|
50
|
+
BpeConfig() = default;
|
51
|
+
|
52
|
+
BpeConfig(double character_coverage, int n_threads,
|
53
|
+
const SpecialTokens &special_tokens);
|
54
|
+
};
|
55
|
+
|
56
|
+
struct Status {
|
57
|
+
int code{0};
|
58
|
+
std::string message;
|
59
|
+
Status() = default;
|
60
|
+
Status(int code, std::string message);
|
61
|
+
|
62
|
+
const std::string &error_message() const;
|
63
|
+
bool ok() const;
|
64
|
+
};
|
65
|
+
|
66
|
+
struct BPEState {
|
67
|
+
flat_hash_map<uint32_t, uint32_t> char2id;
|
68
|
+
std::vector<BPE_Rule> rules;
|
69
|
+
SpecialTokens special_tokens;
|
70
|
+
|
71
|
+
void dump(const std::string &file_name);
|
72
|
+
|
73
|
+
Status load(const std::string &file_name);
|
74
|
+
};
|
75
|
+
|
76
|
+
struct DecodeResult {
|
77
|
+
std::vector<int> ids;
|
78
|
+
std::vector<std::string> pieces;
|
79
|
+
};
|
80
|
+
|
81
|
+
struct EncodingConfig {
|
82
|
+
bool bos;
|
83
|
+
bool eos;
|
84
|
+
bool reverse;
|
85
|
+
double dropout_prob;
|
86
|
+
};
|
87
|
+
|
88
|
+
bool is_space(uint32_t ch);
|
89
|
+
|
90
|
+
std::vector<std::string> read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed);
|
91
|
+
|
92
|
+
template<typename T>
|
93
|
+
void write_to_stdout(const std::vector<std::vector<T>> &sentences, bool flush) {
|
94
|
+
for (const auto &sentence : sentences) {
|
95
|
+
for (const auto &token : sentence) {
|
96
|
+
std::cout << token << " ";
|
97
|
+
}
|
98
|
+
std::cout << "\n";
|
99
|
+
}
|
100
|
+
if (flush) {
|
101
|
+
std::cout << std::flush;
|
102
|
+
}
|
103
|
+
}
|
104
|
+
|
105
|
+
} // namespace vkcom
|
@@ -0,0 +1,182 @@
|
|
1
|
+
from libcpp.vector cimport vector
|
2
|
+
from libcpp.unordered_set cimport unordered_set
|
3
|
+
from libcpp.string cimport string
|
4
|
+
from libcpp cimport bool
|
5
|
+
import os
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Collection
|
8
|
+
|
9
|
+
|
10
|
+
cdef extern from "bpe.h" namespace "vkcom":
|
11
|
+
|
12
|
+
cdef cppclass SpecialTokens:
|
13
|
+
int pad_id
|
14
|
+
int unk_id
|
15
|
+
int bos_id
|
16
|
+
int eos_id
|
17
|
+
|
18
|
+
cdef cppclass BpeConfig:
|
19
|
+
double character_coverage
|
20
|
+
int n_threads
|
21
|
+
SpecialTokens special_tokens
|
22
|
+
|
23
|
+
cdef cppclass Status:
|
24
|
+
int code
|
25
|
+
string message
|
26
|
+
|
27
|
+
|
28
|
+
cdef extern from "bpe.h" namespace "vkcom":
|
29
|
+
Status train_bpe(const string &source_path, const string& model_path, int vocab_size, const BpeConfig& bpe_config)
|
30
|
+
|
31
|
+
cdef extern from "bpe.h" namespace "vkcom":
|
32
|
+
cdef cppclass BaseEncoder:
|
33
|
+
BaseEncoder(const string& model_path, int n_threads, Status* status)
|
34
|
+
|
35
|
+
Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse, double dropout_prob) const
|
36
|
+
Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse, double dropout_prob) const
|
37
|
+
|
38
|
+
Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse, double dropout_prob) const
|
39
|
+
|
40
|
+
Status decode_cli(const unordered_set[int]* ignore_ids) const
|
41
|
+
|
42
|
+
void vocab_cli(bool verbose) const
|
43
|
+
|
44
|
+
Status id_to_subword(int id, string* subword) const
|
45
|
+
|
46
|
+
int subword_to_id(const string &subword) const
|
47
|
+
Status decode(const vector[vector[int]]& ids, vector[string]* output, const unordered_set[int]* ignore_ids) const
|
48
|
+
int vocab_size() const
|
49
|
+
vector[string] vocabulary() const
|
50
|
+
|
51
|
+
|
52
|
+
cdef class BPE:
|
53
|
+
cdef BaseEncoder* encoder
|
54
|
+
|
55
|
+
def __dealloc__(self):
|
56
|
+
del self.encoder
|
57
|
+
|
58
|
+
def __init__(self, model_path, n_threads=-1):
|
59
|
+
cdef Status status
|
60
|
+
self.encoder = new BaseEncoder(model_path.encode(), n_threads, &status)
|
61
|
+
if status.code != 0:
|
62
|
+
raise ValueError(status.message.decode())
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def train(data,
|
66
|
+
model,
|
67
|
+
vocab_size,
|
68
|
+
coverage=1.0,
|
69
|
+
n_threads=-1,
|
70
|
+
pad_id=0,
|
71
|
+
unk_id=1,
|
72
|
+
bos_id=2,
|
73
|
+
eos_id=3):
|
74
|
+
|
75
|
+
cdef BpeConfig bpe_config
|
76
|
+
bpe_config.character_coverage = coverage
|
77
|
+
bpe_config.n_threads = n_threads
|
78
|
+
bpe_config.special_tokens.pad_id = pad_id
|
79
|
+
bpe_config.special_tokens.unk_id = unk_id
|
80
|
+
bpe_config.special_tokens.bos_id = bos_id
|
81
|
+
bpe_config.special_tokens.eos_id = eos_id
|
82
|
+
|
83
|
+
cdef Status status = train_bpe(data.encode(), model.encode(), vocab_size, bpe_config)
|
84
|
+
if status.code != 0:
|
85
|
+
raise ValueError(status.message.decode())
|
86
|
+
|
87
|
+
def encode(self, sentences, output_type, bos, eos, reverse, dropout_prob):
|
88
|
+
cdef vector[string] s
|
89
|
+
cdef vector[vector[string]] ret_subwords
|
90
|
+
cdef vector[vector[int]] ret_ids
|
91
|
+
cdef Status status
|
92
|
+
if dropout_prob < 0 or dropout_prob > 1:
|
93
|
+
raise ValueError("dropout_prob value must be in the range [0, 1]. Current value of dropout_prob = " + str(dropout_prob))
|
94
|
+
if output_type == 'id':
|
95
|
+
if isinstance(sentences, str):
|
96
|
+
s = [sentences.encode()]
|
97
|
+
status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob)
|
98
|
+
if status.code != 0:
|
99
|
+
raise ValueError(status.message.decode())
|
100
|
+
return ret_ids[0]
|
101
|
+
|
102
|
+
assert isinstance(sentences, list) or isinstance(sentences, tuple)
|
103
|
+
s = [x.encode() for x in sentences]
|
104
|
+
status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob)
|
105
|
+
if status.code != 0:
|
106
|
+
raise ValueError(status.message.decode())
|
107
|
+
return ret_ids
|
108
|
+
elif output_type == 'subword':
|
109
|
+
if isinstance(sentences, str):
|
110
|
+
s = [sentences.encode()]
|
111
|
+
status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob)
|
112
|
+
if status.code != 0:
|
113
|
+
raise ValueError(status.message.decode())
|
114
|
+
assert len(ret_subwords) == 1
|
115
|
+
return [piece.decode() for piece in ret_subwords[0]]
|
116
|
+
|
117
|
+
assert isinstance(sentences, list) or isinstance(sentences, tuple)
|
118
|
+
s = [x.encode() for x in sentences]
|
119
|
+
status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob)
|
120
|
+
if status.code != 0:
|
121
|
+
raise ValueError(status.message.decode())
|
122
|
+
return [[piece.decode() for piece in sentence] for sentence in ret_subwords]
|
123
|
+
else:
|
124
|
+
raise ValueError('output_type must be equal to "id" or "subword"')
|
125
|
+
|
126
|
+
def subword_to_id(self, subword):
|
127
|
+
return self.encoder.subword_to_id(subword.encode())
|
128
|
+
|
129
|
+
def id_to_subword(self, id):
|
130
|
+
cdef string subword
|
131
|
+
cdef Status status = self.encoder.id_to_subword(id, &subword)
|
132
|
+
if status.code != 0:
|
133
|
+
raise ValueError(status.message.decode())
|
134
|
+
return subword.decode()
|
135
|
+
|
136
|
+
def decode(self, ids, ignore_ids):
|
137
|
+
|
138
|
+
if not isinstance(ids, list):
|
139
|
+
raise TypeError(
|
140
|
+
"{} is not a list instance".format(type(ids))
|
141
|
+
)
|
142
|
+
|
143
|
+
if not isinstance(ignore_ids, Collection) and ignore_ids is not None:
|
144
|
+
raise TypeError(
|
145
|
+
"{} is not a Collection instance".format(type(ignore_ids))
|
146
|
+
)
|
147
|
+
|
148
|
+
if len(ids) > 0 and isinstance(ids[0], int):
|
149
|
+
ids = [ids]
|
150
|
+
if ignore_ids is None:
|
151
|
+
ignore_ids = set()
|
152
|
+
|
153
|
+
cdef vector[string] sentences
|
154
|
+
cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids)
|
155
|
+
cdef Status status = self.encoder.decode(ids, &sentences, &c_ignore_ids)
|
156
|
+
if status.code != 0:
|
157
|
+
raise ValueError(status.message.decode())
|
158
|
+
return [sentence.decode() for sentence in sentences]
|
159
|
+
|
160
|
+
def vocab_size(self):
|
161
|
+
return self.encoder.vocab_size();
|
162
|
+
|
163
|
+
def vocab(self):
|
164
|
+
cdef vector[string] vocab = self.encoder.vocabulary()
|
165
|
+
return [token.decode() for token in vocab]
|
166
|
+
|
167
|
+
def encode_cli(self, output_type, stream, bos, eos, reverse, dropout_prob):
|
168
|
+
cdef Status status = self.encoder.encode_cli(output_type.encode(), stream, bos, eos, reverse, dropout_prob)
|
169
|
+
if status.code != 0:
|
170
|
+
raise ValueError(status.message.decode())
|
171
|
+
|
172
|
+
def decode_cli(self, ignore_ids):
|
173
|
+
if ignore_ids is None:
|
174
|
+
ignore_ids = set()
|
175
|
+
cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids)
|
176
|
+
cdef Status status = self.encoder.decode_cli(&c_ignore_ids)
|
177
|
+
if status.code != 0:
|
178
|
+
raise ValueError(status.message.decode())
|
179
|
+
|
180
|
+
def vocab_cli(self, verbose):
|
181
|
+
self.encoder.vocab_cli(verbose)
|
182
|
+
|