youtokentome 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,134 @@
1
+ #include "utf8.h"
2
+ #include <cassert>
3
+ #include <iostream>
4
+ #include <string>
5
+ #include <vector>
6
+ #include "utils.h"
7
+
8
+ namespace vkcom {
9
+
10
+ using std::string;
11
+ using std::vector;
12
+
13
+
14
+ bool check_byte(char x) { return (static_cast<uint8_t>(x) & 0xc0u) == 0x80u; }
15
+
16
+ bool check_codepoint(uint32_t x) {
17
+ return (x < 0xd800) || (0xdfff < x && x < 0x110000);
18
+ }
19
+
20
+ uint64_t utf_length(char ch) {
21
+ if ((static_cast<uint8_t>(ch) & 0x80u) == 0) {
22
+ return 1;
23
+ }
24
+ if ((static_cast<uint8_t>(ch) & 0xe0u) == 0xc0) {
25
+ return 2;
26
+ }
27
+ if ((static_cast<uint8_t>(ch) & 0xf0u) == 0xe0) {
28
+ return 3;
29
+ }
30
+ if ((static_cast<uint8_t>(ch) & 0xf8u) == 0xf0) {
31
+ return 4;
32
+ }
33
+ // Invalid utf-8
34
+ return 0;
35
+ }
36
+
37
+ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) {
38
+ uint64_t length = utf_length(begin[0]);
39
+ if (length == 1) {
40
+ *utf8_len = 1;
41
+ return static_cast<uint8_t>(begin[0]);
42
+ }
43
+ uint32_t code_point = 0;
44
+ if (size >= 2 && length == 2 && check_byte(begin[1])) {
45
+ code_point += (static_cast<uint8_t>(begin[0]) & 0x1fu) << 6u;
46
+ code_point += (static_cast<uint8_t>(begin[1]) & 0x3fu);
47
+ if (code_point >= 0x0080 && check_codepoint(code_point)) {
48
+ *utf8_len = 2;
49
+ return code_point;
50
+ }
51
+ } else if (size >= 3 && length == 3 && check_byte(begin[1]) &&
52
+ check_byte(begin[2])) {
53
+ code_point += (static_cast<uint8_t>(begin[0]) & 0x0fu) << 12u;
54
+ code_point += (static_cast<uint8_t>(begin[1]) & 0x3fu) << 6u;
55
+ code_point += (static_cast<uint8_t>(begin[2]) & 0x3fu);
56
+ if (code_point >= 0x0800 && check_codepoint(code_point)) {
57
+ *utf8_len = 3;
58
+ return code_point;
59
+ }
60
+ } else if (size >= 4 && length == 4 && check_byte(begin[1]) &&
61
+ check_byte(begin[2]) && check_byte(begin[3])) {
62
+ code_point += (static_cast<uint8_t>(begin[0]) & 0x07u) << 18u;
63
+ code_point += (static_cast<uint8_t>(begin[1]) & 0x3fu) << 12u;
64
+ code_point += (static_cast<uint8_t>(begin[2]) & 0x3fu) << 6u;
65
+ code_point += (static_cast<uint8_t>(begin[3]) & 0x3fu);
66
+ if (code_point >= 0x10000 && check_codepoint(code_point)) {
67
+ *utf8_len = 4;
68
+ return code_point;
69
+ }
70
+ }
71
+ // Invalid utf-8
72
+ *utf8_len = 1;
73
+ return INVALID_UNICODE;
74
+ }
75
+
76
+ void utf8_to_chars(uint32_t x, std::back_insert_iterator<string> it) {
77
+ assert(check_codepoint(x));
78
+
79
+ if (x <= 0x7f) {
80
+ *(it++) = x;
81
+ return;
82
+ }
83
+
84
+ if (x <= 0x7ff) {
85
+ *(it++) = 0xc0u | (x >> 6u);
86
+ *(it++) = 0x80u | (x & 0x3fu);
87
+ return;
88
+ }
89
+
90
+ if (x <= 0xffff) {
91
+ *(it++) = 0xe0u | (x >> 12u);
92
+ *(it++) = 0x80u | ((x >> 6u) & 0x3fu);
93
+ *(it++) = 0x80u | (x & 0x3fu);
94
+ return;
95
+ }
96
+
97
+ *(it++) = 0xf0u | (x >> 18u);
98
+ *(it++) = 0x80u | ((x >> 12u) & 0x3fu);
99
+ *(it++) = 0x80u | ((x >> 6u) & 0x3fu);
100
+ *(it++) = 0x80u | (x & 0x3fu);
101
+ }
102
+
103
+ string encode_utf8(const vector<uint32_t>& text) {
104
+ string utf8_text;
105
+ for (const uint32_t c : text) {
106
+ utf8_to_chars(c, std::back_inserter(utf8_text));
107
+ }
108
+ return utf8_text;
109
+ }
110
+
111
+ vector<uint32_t> decode_utf8(const char* begin, const char* end) {
112
+ vector<uint32_t> decoded_text;
113
+ uint64_t utf8_len = 0;
114
+ bool invalid_input = false;
115
+ for (; begin < end; begin += utf8_len) {
116
+ uint32_t code_point = chars_to_utf8(begin, end - begin, &utf8_len);
117
+ if (code_point != INVALID_UNICODE) {
118
+ decoded_text.push_back(code_point);
119
+ } else {
120
+ invalid_input = true;
121
+ }
122
+ }
123
+ if (invalid_input) {
124
+ std::cerr << "WARNING Input contains invalid unicode characters."
125
+ << std::endl;
126
+ }
127
+ return decoded_text;
128
+ }
129
+
130
+ vector<uint32_t> decode_utf8(const string& utf8_text) {
131
+ return decode_utf8(utf8_text.data(), utf8_text.data() + utf8_text.size());
132
+ }
133
+
134
+ } // namespace vkcom
@@ -0,0 +1,23 @@
1
+ #pragma once
2
+
3
+ #include "utils.h"
4
+
5
+ namespace vkcom {
6
+
7
+ constexpr static uint32_t INVALID_UNICODE = 0x0fffffff;
8
+
9
+ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len);
10
+
11
+ std::string encode_utf8(const std::vector<uint32_t> &utext);
12
+
13
+ std::vector<uint32_t> decode_utf8(const char *begin, const char *end);
14
+
15
+ std::vector<uint32_t> decode_utf8(const std::string &utf8_text);
16
+
17
+
18
+
19
+
20
+ } // namespace vkcom
21
+
22
+
23
+
@@ -0,0 +1,119 @@
1
+ #include "utils.h"
2
+ #include <cassert>
3
+ #include <fstream>
4
+ #include <iostream>
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ namespace vkcom {
9
+ using std::string;
10
+ using std::vector;
11
+
12
+ void SpecialTokens::dump(std::ofstream &fout) {
13
+ fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id
14
+ << std::endl;
15
+ }
16
+
17
+ void SpecialTokens::load(std::ifstream &fin) {
18
+ fin >> unk_id >> pad_id >> bos_id >> eos_id;
19
+ }
20
+
21
+ uint32_t SpecialTokens::max_id() const {
22
+ int ret = 0;
23
+ ret = std::max(ret, unk_id);
24
+ ret = std::max(ret, pad_id);
25
+ ret = std::max(ret, bos_id);
26
+ ret = std::max(ret, eos_id);
27
+ return ret;
28
+ }
29
+
30
+ bool SpecialTokens::taken_id(int id) const {
31
+ return id == unk_id || id == pad_id || id == bos_id || id == eos_id;
32
+ }
33
+
34
+ uint64_t SpecialTokens::n_special_tokens() const {
35
+ uint64_t cnt = 0;
36
+ cnt += (unk_id != -1);
37
+ cnt += (pad_id != -1);
38
+ cnt += (bos_id != -1);
39
+ cnt += (eos_id != -1);
40
+ return cnt;
41
+ }
42
+
43
+ SpecialTokens::SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id)
44
+ : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {}
45
+
46
+ bool BPE_Rule::operator==(const BPE_Rule &other) const {
47
+ return x == other.x && y == other.y && z == other.z;
48
+ }
49
+
50
+ BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}
51
+
52
+ void BPEState::dump(const string &file_name) {
53
+ std::ofstream fout(file_name, std::ios::out);
54
+ if (fout.fail()) {
55
+ std::cerr << "Can't open file: " << file_name << std::endl;
56
+ assert(false);
57
+ }
58
+ fout << char2id.size() << " " << rules.size() << std::endl;
59
+ for (auto s : char2id) {
60
+ fout << s.first << " " << s.second << std::endl;
61
+ }
62
+
63
+ for (auto rule : rules) {
64
+ fout << rule.x << " " << rule.y << " " << rule.z << std::endl;
65
+ }
66
+ special_tokens.dump(fout);
67
+ fout.close();
68
+ }
69
+
70
+ Status BPEState::load(const string &file_name) {
71
+ char2id.clear();
72
+ rules.clear();
73
+ std::ifstream fin(file_name, std::ios::in);
74
+ if (fin.fail()) {
75
+ return Status(1, "Can not open file with model: " + file_name);
76
+ }
77
+ int n, m;
78
+ fin >> n >> m;
79
+ for (int i = 0; i < n; i++) {
80
+ uint32_t inner_id;
81
+ uint32_t utf32_id;
82
+ fin >> inner_id >> utf32_id;
83
+ char2id[inner_id] = utf32_id;
84
+ }
85
+ for (int i = 0; i < m; i++) {
86
+ uint32_t x, y, z;
87
+ fin >> x >> y >> z;
88
+ rules.emplace_back(x, y, z);
89
+ }
90
+ special_tokens.load(fin);
91
+ fin.close();
92
+ return Status();
93
+ }
94
+
95
+ BpeConfig::BpeConfig(double _character_coverage, int _n_threads,
96
+ const SpecialTokens &_special_tokens)
97
+ : character_coverage(_character_coverage),
98
+ n_threads(_n_threads),
99
+ special_tokens(_special_tokens) {}
100
+
101
+ vector<string> read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) {
102
+ vector<string> sentences;
103
+ string s;
104
+ while (*processed < batch_limit && getline(std::cin, s)) {
105
+ *processed += s.size();
106
+ sentences.push_back(std::move(s));
107
+ }
108
+ return sentences;
109
+ }
110
+
111
+ Status::Status(int code, std::string message) : code(code), message(std::move(message)) {}
112
+
113
+ const std::string &Status::error_message() const {
114
+ return message;
115
+ }
116
+ bool Status::ok() const {
117
+ return code == 0;
118
+ }
119
+ } // namespace vkcom
@@ -0,0 +1,105 @@
1
+ #pragma once
2
+
3
+ #include <iostream>
4
+ #include <string>
5
+ #include <vector>
6
+ #include "third_party/flat_hash_map.h"
7
+
8
+ namespace vkcom {
9
+ const uint32_t SPACE_TOKEN = 9601;
10
+
11
+ struct BPE_Rule {
12
+ // x + y -> z
13
+ uint32_t x{0};
14
+ uint32_t y{0};
15
+ uint32_t z{0};
16
+
17
+ BPE_Rule() = default;
18
+
19
+ BPE_Rule(uint32_t x, uint32_t y, uint32_t z);
20
+
21
+ bool operator==(const BPE_Rule &other) const;
22
+ };
23
+
24
+ struct SpecialTokens {
25
+ int pad_id = -1;
26
+ int unk_id = -1;
27
+ int bos_id = -1;
28
+ int eos_id = -1;
29
+
30
+ SpecialTokens() = default;
31
+
32
+ SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id);
33
+
34
+ void dump(std::ofstream &fout);
35
+
36
+ void load(std::ifstream &fin);
37
+
38
+ uint32_t max_id() const;
39
+
40
+ bool taken_id(int id) const;
41
+
42
+ uint64_t n_special_tokens() const;
43
+ };
44
+
45
+ struct BpeConfig {
46
+ double character_coverage = 1;
47
+ int n_threads = 0;
48
+ SpecialTokens special_tokens;
49
+
50
+ BpeConfig() = default;
51
+
52
+ BpeConfig(double character_coverage, int n_threads,
53
+ const SpecialTokens &special_tokens);
54
+ };
55
+
56
+ struct Status {
57
+ int code{0};
58
+ std::string message;
59
+ Status() = default;
60
+ Status(int code, std::string message);
61
+
62
+ const std::string &error_message() const;
63
+ bool ok() const;
64
+ };
65
+
66
+ struct BPEState {
67
+ flat_hash_map<uint32_t, uint32_t> char2id;
68
+ std::vector<BPE_Rule> rules;
69
+ SpecialTokens special_tokens;
70
+
71
+ void dump(const std::string &file_name);
72
+
73
+ Status load(const std::string &file_name);
74
+ };
75
+
76
+ struct DecodeResult {
77
+ std::vector<int> ids;
78
+ std::vector<std::string> pieces;
79
+ };
80
+
81
+ struct EncodingConfig {
82
+ bool bos;
83
+ bool eos;
84
+ bool reverse;
85
+ double dropout_prob;
86
+ };
87
+
88
+ bool is_space(uint32_t ch);
89
+
90
+ std::vector<std::string> read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed);
91
+
92
+ template<typename T>
93
+ void write_to_stdout(const std::vector<std::vector<T>> &sentences, bool flush) {
94
+ for (const auto &sentence : sentences) {
95
+ for (const auto &token : sentence) {
96
+ std::cout << token << " ";
97
+ }
98
+ std::cout << "\n";
99
+ }
100
+ if (flush) {
101
+ std::cout << std::flush;
102
+ }
103
+ }
104
+
105
+ } // namespace vkcom
@@ -0,0 +1,182 @@
1
+ from libcpp.vector cimport vector
2
+ from libcpp.unordered_set cimport unordered_set
3
+ from libcpp.string cimport string
4
+ from libcpp cimport bool
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Collection
8
+
9
+
10
+ cdef extern from "bpe.h" namespace "vkcom":
11
+
12
+ cdef cppclass SpecialTokens:
13
+ int pad_id
14
+ int unk_id
15
+ int bos_id
16
+ int eos_id
17
+
18
+ cdef cppclass BpeConfig:
19
+ double character_coverage
20
+ int n_threads
21
+ SpecialTokens special_tokens
22
+
23
+ cdef cppclass Status:
24
+ int code
25
+ string message
26
+
27
+
28
+ cdef extern from "bpe.h" namespace "vkcom":
29
+ Status train_bpe(const string &source_path, const string& model_path, int vocab_size, const BpeConfig& bpe_config)
30
+
31
+ cdef extern from "bpe.h" namespace "vkcom":
32
+ cdef cppclass BaseEncoder:
33
+ BaseEncoder(const string& model_path, int n_threads, Status* status)
34
+
35
+ Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse, double dropout_prob) const
36
+ Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse, double dropout_prob) const
37
+
38
+ Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse, double dropout_prob) const
39
+
40
+ Status decode_cli(const unordered_set[int]* ignore_ids) const
41
+
42
+ void vocab_cli(bool verbose) const
43
+
44
+ Status id_to_subword(int id, string* subword) const
45
+
46
+ int subword_to_id(const string &subword) const
47
+ Status decode(const vector[vector[int]]& ids, vector[string]* output, const unordered_set[int]* ignore_ids) const
48
+ int vocab_size() const
49
+ vector[string] vocabulary() const
50
+
51
+
52
+ cdef class BPE:
53
+ cdef BaseEncoder* encoder
54
+
55
+ def __dealloc__(self):
56
+ del self.encoder
57
+
58
+ def __init__(self, model_path, n_threads=-1):
59
+ cdef Status status
60
+ self.encoder = new BaseEncoder(model_path.encode(), n_threads, &status)
61
+ if status.code != 0:
62
+ raise ValueError(status.message.decode())
63
+
64
+ @staticmethod
65
+ def train(data,
66
+ model,
67
+ vocab_size,
68
+ coverage=1.0,
69
+ n_threads=-1,
70
+ pad_id=0,
71
+ unk_id=1,
72
+ bos_id=2,
73
+ eos_id=3):
74
+
75
+ cdef BpeConfig bpe_config
76
+ bpe_config.character_coverage = coverage
77
+ bpe_config.n_threads = n_threads
78
+ bpe_config.special_tokens.pad_id = pad_id
79
+ bpe_config.special_tokens.unk_id = unk_id
80
+ bpe_config.special_tokens.bos_id = bos_id
81
+ bpe_config.special_tokens.eos_id = eos_id
82
+
83
+ cdef Status status = train_bpe(data.encode(), model.encode(), vocab_size, bpe_config)
84
+ if status.code != 0:
85
+ raise ValueError(status.message.decode())
86
+
87
+ def encode(self, sentences, output_type, bos, eos, reverse, dropout_prob):
88
+ cdef vector[string] s
89
+ cdef vector[vector[string]] ret_subwords
90
+ cdef vector[vector[int]] ret_ids
91
+ cdef Status status
92
+ if dropout_prob < 0 or dropout_prob > 1:
93
+ raise ValueError("dropout_prob value must be in the range [0, 1]. Current value of dropout_prob = " + str(dropout_prob))
94
+ if output_type == 'id':
95
+ if isinstance(sentences, str):
96
+ s = [sentences.encode()]
97
+ status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob)
98
+ if status.code != 0:
99
+ raise ValueError(status.message.decode())
100
+ return ret_ids[0]
101
+
102
+ assert isinstance(sentences, list) or isinstance(sentences, tuple)
103
+ s = [x.encode() for x in sentences]
104
+ status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob)
105
+ if status.code != 0:
106
+ raise ValueError(status.message.decode())
107
+ return ret_ids
108
+ elif output_type == 'subword':
109
+ if isinstance(sentences, str):
110
+ s = [sentences.encode()]
111
+ status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob)
112
+ if status.code != 0:
113
+ raise ValueError(status.message.decode())
114
+ assert len(ret_subwords) == 1
115
+ return [piece.decode() for piece in ret_subwords[0]]
116
+
117
+ assert isinstance(sentences, list) or isinstance(sentences, tuple)
118
+ s = [x.encode() for x in sentences]
119
+ status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob)
120
+ if status.code != 0:
121
+ raise ValueError(status.message.decode())
122
+ return [[piece.decode() for piece in sentence] for sentence in ret_subwords]
123
+ else:
124
+ raise ValueError('output_type must be equal to "id" or "subword"')
125
+
126
+ def subword_to_id(self, subword):
127
+ return self.encoder.subword_to_id(subword.encode())
128
+
129
+ def id_to_subword(self, id):
130
+ cdef string subword
131
+ cdef Status status = self.encoder.id_to_subword(id, &subword)
132
+ if status.code != 0:
133
+ raise ValueError(status.message.decode())
134
+ return subword.decode()
135
+
136
+ def decode(self, ids, ignore_ids):
137
+
138
+ if not isinstance(ids, list):
139
+ raise TypeError(
140
+ "{} is not a list instance".format(type(ids))
141
+ )
142
+
143
+ if not isinstance(ignore_ids, Collection) and ignore_ids is not None:
144
+ raise TypeError(
145
+ "{} is not a Collection instance".format(type(ignore_ids))
146
+ )
147
+
148
+ if len(ids) > 0 and isinstance(ids[0], int):
149
+ ids = [ids]
150
+ if ignore_ids is None:
151
+ ignore_ids = set()
152
+
153
+ cdef vector[string] sentences
154
+ cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids)
155
+ cdef Status status = self.encoder.decode(ids, &sentences, &c_ignore_ids)
156
+ if status.code != 0:
157
+ raise ValueError(status.message.decode())
158
+ return [sentence.decode() for sentence in sentences]
159
+
160
+ def vocab_size(self):
161
+ return self.encoder.vocab_size();
162
+
163
+ def vocab(self):
164
+ cdef vector[string] vocab = self.encoder.vocabulary()
165
+ return [token.decode() for token in vocab]
166
+
167
+ def encode_cli(self, output_type, stream, bos, eos, reverse, dropout_prob):
168
+ cdef Status status = self.encoder.encode_cli(output_type.encode(), stream, bos, eos, reverse, dropout_prob)
169
+ if status.code != 0:
170
+ raise ValueError(status.message.decode())
171
+
172
+ def decode_cli(self, ignore_ids):
173
+ if ignore_ids is None:
174
+ ignore_ids = set()
175
+ cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids)
176
+ cdef Status status = self.encoder.decode_cli(&c_ignore_ids)
177
+ if status.code != 0:
178
+ raise ValueError(status.message.decode())
179
+
180
+ def vocab_cli(self, verbose):
181
+ self.encoder.vocab_cli(verbose)
182
+