youtokentome 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/ext/youtokentome/ext.cpp +135 -0
- data/ext/youtokentome/extconf.rb +12 -0
- data/lib/youtokentome.rb +10 -0
- data/lib/youtokentome/bpe.rb +54 -0
- data/lib/youtokentome/ext.bundle +0 -0
- data/lib/youtokentome/version.rb +3 -0
- data/vendor/YouTokenToMe/LICENSE +19 -0
- data/vendor/YouTokenToMe/README.md +304 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/bpe.cpp +2185 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/bpe.h +86 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/third_party/LICENSE +23 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/third_party/flat_hash_map.h +1502 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utf8.cpp +134 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utf8.h +23 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utils.cpp +119 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/utils.h +105 -0
- data/vendor/YouTokenToMe/youtokentome/cpp/yttm.pyx +182 -0
- metadata +133 -0
@@ -0,0 +1,2185 @@
|
|
1
|
+
#include <utility>
|
2
|
+
|
3
|
+
#include "bpe.h"
|
4
|
+
|
5
|
+
#include <algorithm>
|
6
|
+
#include <atomic>
|
7
|
+
#include <cassert>
|
8
|
+
#include <condition_variable>
|
9
|
+
#include <fstream>
|
10
|
+
#include <functional>
|
11
|
+
#include <iostream>
|
12
|
+
#include <mutex>
|
13
|
+
#include <queue>
|
14
|
+
#include <sstream>
|
15
|
+
#include <string>
|
16
|
+
#include <thread>
|
17
|
+
#include <vector>
|
18
|
+
#include <random>
|
19
|
+
#include <unordered_set>
|
20
|
+
#include <cstring>
|
21
|
+
|
22
|
+
#include "third_party/flat_hash_map.h"
|
23
|
+
#include "utf8.h"
|
24
|
+
#include "utils.h"
|
25
|
+
|
26
|
+
namespace vkcom {
|
27
|
+
using std::string;
|
28
|
+
using std::vector;
|
29
|
+
using std::unordered_set;
|
30
|
+
|
31
|
+
struct VectorSegment {
|
32
|
+
constexpr static uint64_t MOD = 2032191299;
|
33
|
+
constexpr static uint64_t P = 726328703;
|
34
|
+
|
35
|
+
const char* begin;
|
36
|
+
const char* end;
|
37
|
+
uint64_t hash;
|
38
|
+
|
39
|
+
VectorSegment(const char* begin, const char* end): begin(begin), end(end) {
|
40
|
+
hash = 0;
|
41
|
+
for (auto it = begin; it != end; it++) {
|
42
|
+
hash = (hash * P + (unsigned char)(*it)) % MOD;
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
bool operator==(const VectorSegment &other) const {
|
47
|
+
if (other.hash != hash || end - begin != other.end - other.begin) {
|
48
|
+
return false;
|
49
|
+
}
|
50
|
+
for (auto it = begin, other_it = other.begin; it != end; it++, other_it++) {
|
51
|
+
if (*it != *other_it) {
|
52
|
+
return false;
|
53
|
+
}
|
54
|
+
}
|
55
|
+
return true;
|
56
|
+
}
|
57
|
+
};
|
58
|
+
|
59
|
+
} // namespace vkcom
|
60
|
+
|
61
|
+
namespace std {
|
62
|
+
template<>
|
63
|
+
struct hash<vkcom::VectorSegment> {
|
64
|
+
uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash; }
|
65
|
+
};
|
66
|
+
} // namespace std
|
67
|
+
|
68
|
+
namespace vkcom {
|
69
|
+
|
70
|
+
Status fast_read_file_utf8(const string &file_name, string *file_content) {
|
71
|
+
static const int buf_size = 1000000;
|
72
|
+
*file_content = "";
|
73
|
+
auto fin = fopen(file_name.data(), "rb");
|
74
|
+
if (fin == nullptr) {
|
75
|
+
return Status(1, "Failed to open file: " + file_name);
|
76
|
+
}
|
77
|
+
while (true) {
|
78
|
+
uint64_t cur_size = file_content->size();
|
79
|
+
file_content->resize(cur_size + buf_size);
|
80
|
+
int buf_len = fread((void *) (file_content->data() + cur_size), 1, buf_size, fin);
|
81
|
+
if (buf_len < buf_size) {
|
82
|
+
file_content->resize(file_content->size() - (buf_size - buf_len));
|
83
|
+
fclose(fin);
|
84
|
+
return Status();
|
85
|
+
}
|
86
|
+
}
|
87
|
+
}
|
88
|
+
|
89
|
+
string token2word(const vector<uint32_t> &source,
|
90
|
+
const flat_hash_map<uint32_t, uint32_t> &id2char) {
|
91
|
+
vector<uint32_t> res;
|
92
|
+
for (int i : source) {
|
93
|
+
assert(id2char.count(i) == 1);
|
94
|
+
res.push_back(id2char.at(i));
|
95
|
+
}
|
96
|
+
return encode_utf8(res);
|
97
|
+
}
|
98
|
+
|
99
|
+
bool is_space(uint32_t ch) {
|
100
|
+
return (ch < 256 && isspace(ch)) || (ch == SPACE_TOKEN);
|
101
|
+
}
|
102
|
+
|
103
|
+
uint64_t int2comb(uint32_t a, uint32_t b) {
|
104
|
+
return (static_cast<uint64_t >(a) << 32u) + b;
|
105
|
+
}
|
106
|
+
|
107
|
+
struct MergeCandidate {
|
108
|
+
uint64_t count{0};
|
109
|
+
uint32_t left_token{0};
|
110
|
+
uint32_t right_token{0};
|
111
|
+
|
112
|
+
MergeCandidate() = default;
|
113
|
+
|
114
|
+
MergeCandidate(uint64_t count, uint32_t left_token, uint32_t right_token) : count(count), left_token(left_token),
|
115
|
+
right_token(right_token) {}
|
116
|
+
|
117
|
+
bool operator<(const MergeCandidate &other) const {
|
118
|
+
if (count != other.count) {
|
119
|
+
return count < other.count;
|
120
|
+
}
|
121
|
+
auto this_mn = std::min(left_token, right_token);
|
122
|
+
auto this_mx = std::max(left_token, right_token);
|
123
|
+
|
124
|
+
auto other_mn = std::min(other.left_token, other.right_token);
|
125
|
+
auto other_mx = std::max(other.left_token, other.right_token);
|
126
|
+
if (this_mx != other_mx) {
|
127
|
+
return this_mx > other_mx;
|
128
|
+
}
|
129
|
+
if (this_mn != other_mn) {
|
130
|
+
return this_mn > other_mn;
|
131
|
+
}
|
132
|
+
return left_token < other.left_token;
|
133
|
+
}
|
134
|
+
};
|
135
|
+
|
136
|
+
struct UTF8Iterator {
|
137
|
+
UTF8Iterator(char* begin, char* end): begin(begin), end(end) {}
|
138
|
+
|
139
|
+
UTF8Iterator operator++() {
|
140
|
+
if (!state) {
|
141
|
+
parse();
|
142
|
+
}
|
143
|
+
begin += utf8_len;
|
144
|
+
state = false;
|
145
|
+
return *this;
|
146
|
+
}
|
147
|
+
|
148
|
+
uint32_t operator*() {
|
149
|
+
if (!state) {
|
150
|
+
parse();
|
151
|
+
}
|
152
|
+
return code_point;
|
153
|
+
}
|
154
|
+
|
155
|
+
char* get_ptr() {
|
156
|
+
return begin;
|
157
|
+
}
|
158
|
+
uint64_t get_utf8_len() {
|
159
|
+
return utf8_len;
|
160
|
+
}
|
161
|
+
|
162
|
+
bool empty() {
|
163
|
+
assert(begin <= end);
|
164
|
+
return begin == end;
|
165
|
+
}
|
166
|
+
private:
|
167
|
+
char *begin, *end;
|
168
|
+
uint32_t code_point = 0;
|
169
|
+
uint64_t utf8_len = 0;
|
170
|
+
bool state = false;
|
171
|
+
void parse() {
|
172
|
+
if (state) {
|
173
|
+
return;
|
174
|
+
}
|
175
|
+
assert(!empty());
|
176
|
+
code_point = chars_to_utf8(begin, end - begin, &utf8_len);
|
177
|
+
state = true;
|
178
|
+
}
|
179
|
+
};
|
180
|
+
|
181
|
+
|
182
|
+
struct Position {
|
183
|
+
uint64_t word_id, pos_id;
|
184
|
+
|
185
|
+
Position(uint64_t word_id, uint64_t pos_id) : word_id(word_id), pos_id(pos_id) {}
|
186
|
+
|
187
|
+
bool operator<(const Position &other) const {
|
188
|
+
return word_id < other.word_id ||
|
189
|
+
(word_id == other.word_id && pos_id < other.pos_id);
|
190
|
+
}
|
191
|
+
};
|
192
|
+
|
193
|
+
int pairsInSeg(int x) {
|
194
|
+
assert(x >= 2);
|
195
|
+
return x / 2;
|
196
|
+
}
|
197
|
+
|
198
|
+
struct PositionsCnt {
|
199
|
+
vector<Position> positions;
|
200
|
+
uint64_t cnt;
|
201
|
+
};
|
202
|
+
bool rule_intersection(BPE_Rule rule, uint32_t new_left, uint32_t new_right) {
|
203
|
+
return rule.y == new_left || rule.x == new_right;
|
204
|
+
}
|
205
|
+
|
206
|
+
struct SmallObjectQueue {
|
207
|
+
|
208
|
+
vector<vector<MergeCandidate>> queue;
|
209
|
+
bool flag_started{false};
|
210
|
+
uint64_t _size{0};
|
211
|
+
|
212
|
+
SmallObjectQueue() = default;
|
213
|
+
|
214
|
+
void push(const MergeCandidate &event) {
|
215
|
+
if (queue.size() <= event.count) {
|
216
|
+
queue.resize(event.count + 1);
|
217
|
+
}
|
218
|
+
if (flag_started) {
|
219
|
+
assert(event.count + 1 <= queue.size());
|
220
|
+
};
|
221
|
+
queue[event.count].push_back(event);
|
222
|
+
_size++;
|
223
|
+
#ifdef DETERMINISTIC_QUEUE
|
224
|
+
if (queue.size() - 1 == event.count && flag_started) {
|
225
|
+
sort(queue.back().begin(), queue.back().end());
|
226
|
+
}
|
227
|
+
#endif
|
228
|
+
}
|
229
|
+
|
230
|
+
void process_empty_slots() {
|
231
|
+
#ifdef DETERMINISTIC_QUEUE
|
232
|
+
bool moved_down = !flag_started;
|
233
|
+
#endif
|
234
|
+
flag_started = true;
|
235
|
+
|
236
|
+
for (; !queue.empty() && queue.back().empty(); queue.pop_back()) {
|
237
|
+
#ifdef DETERMINISTIC_QUEUE
|
238
|
+
moved_down = true;
|
239
|
+
#endif
|
240
|
+
}
|
241
|
+
#ifdef DETERMINISTIC_QUEUE
|
242
|
+
if (moved_down && !queue.empty()) {
|
243
|
+
sort(queue.back().begin(), queue.back().end());
|
244
|
+
}
|
245
|
+
#endif
|
246
|
+
}
|
247
|
+
|
248
|
+
bool empty() {
|
249
|
+
process_empty_slots();
|
250
|
+
return queue.empty();
|
251
|
+
}
|
252
|
+
|
253
|
+
MergeCandidate top() {
|
254
|
+
process_empty_slots();
|
255
|
+
assert(!queue.empty());
|
256
|
+
return queue.back().back();
|
257
|
+
}
|
258
|
+
|
259
|
+
void pop() {
|
260
|
+
assert(!queue.empty());
|
261
|
+
assert(!queue.back().empty());
|
262
|
+
queue.back().pop_back();
|
263
|
+
_size--;
|
264
|
+
}
|
265
|
+
|
266
|
+
uint64_t size() const {
|
267
|
+
return _size;
|
268
|
+
}
|
269
|
+
};
|
270
|
+
|
271
|
+
struct BigObjectQueue {
|
272
|
+
vector<MergeCandidate> big_events;
|
273
|
+
uint64_t big_event_bound;
|
274
|
+
|
275
|
+
BigObjectQueue(uint64_t big_event_bound) : big_event_bound(big_event_bound) {}
|
276
|
+
|
277
|
+
void push(const MergeCandidate &event) {
|
278
|
+
big_events.push_back(event);
|
279
|
+
}
|
280
|
+
|
281
|
+
bool empty() const {
|
282
|
+
return big_events.empty();
|
283
|
+
}
|
284
|
+
|
285
|
+
bool top(std::function<uint64_t(uint64_t)> &check_cnt, MergeCandidate &ret, SmallObjectQueue *small_object_queue,
|
286
|
+
BPE_Rule last_rule) {
|
287
|
+
for (uint64_t i = 0; i < big_events.size();) {
|
288
|
+
if (!rule_intersection(last_rule, big_events[i].left_token, big_events[i].right_token)) {
|
289
|
+
uint64_t comb = int2comb(big_events[i].left_token, big_events[i].right_token);
|
290
|
+
assert(big_events[i].count >= check_cnt(comb));
|
291
|
+
big_events[i].count = check_cnt(comb);
|
292
|
+
}
|
293
|
+
|
294
|
+
if (big_events[i].count < big_event_bound) {
|
295
|
+
small_object_queue->push(big_events[i]);
|
296
|
+
big_events[i] = big_events.back();
|
297
|
+
big_events.pop_back();
|
298
|
+
} else {
|
299
|
+
i++;
|
300
|
+
}
|
301
|
+
}
|
302
|
+
#ifdef DETERMINISTIC_QUEUE
|
303
|
+
sort(big_events.begin(), big_events.end()); /// TODO remove unoptimal code
|
304
|
+
#else
|
305
|
+
for (auto &big_event : big_events) {
|
306
|
+
if (big_event.count > big_events.back().count) {
|
307
|
+
std::swap(big_event, big_events.back());
|
308
|
+
}
|
309
|
+
}
|
310
|
+
#endif
|
311
|
+
|
312
|
+
if (big_events.empty()) {
|
313
|
+
return false;
|
314
|
+
}
|
315
|
+
ret = big_events.back();
|
316
|
+
return true;
|
317
|
+
}
|
318
|
+
|
319
|
+
void pop() {
|
320
|
+
assert(!big_events.empty());
|
321
|
+
big_events.pop_back();
|
322
|
+
}
|
323
|
+
|
324
|
+
uint64_t size() const {
|
325
|
+
return big_events.size();
|
326
|
+
}
|
327
|
+
};
|
328
|
+
|
329
|
+
struct PriorityQueue {
|
330
|
+
SmallObjectQueue small_queue;
|
331
|
+
BigObjectQueue big_queue;
|
332
|
+
uint64_t big_event_bound;
|
333
|
+
|
334
|
+
explicit PriorityQueue(uint64_t dataset_size) : big_queue(static_cast<uint64_t>(sqrt(dataset_size))),
|
335
|
+
big_event_bound(static_cast<uint64_t>(sqrt(dataset_size))) {}
|
336
|
+
|
337
|
+
void push(const MergeCandidate &event) {
|
338
|
+
if (event.count == 0) {
|
339
|
+
return;
|
340
|
+
}
|
341
|
+
if (event.count < big_event_bound) {
|
342
|
+
small_queue.push(event);
|
343
|
+
} else {
|
344
|
+
big_queue.push(event);
|
345
|
+
}
|
346
|
+
}
|
347
|
+
|
348
|
+
bool empty() {
|
349
|
+
return big_queue.empty() && small_queue.empty();
|
350
|
+
}
|
351
|
+
|
352
|
+
MergeCandidate top(std::function<uint64_t(uint64_t)> &check_cnt, BPE_Rule last_rule) {
|
353
|
+
MergeCandidate res;
|
354
|
+
bool has_top = big_queue.top(check_cnt, res, &small_queue, last_rule);
|
355
|
+
if (has_top) {
|
356
|
+
return res;
|
357
|
+
}
|
358
|
+
return small_queue.top();
|
359
|
+
}
|
360
|
+
|
361
|
+
void pop() {
|
362
|
+
if (!big_queue.empty()) {
|
363
|
+
big_queue.pop();
|
364
|
+
} else {
|
365
|
+
small_queue.pop();
|
366
|
+
}
|
367
|
+
}
|
368
|
+
|
369
|
+
uint64_t size() const {
|
370
|
+
return big_queue.size() + small_queue.size();
|
371
|
+
}
|
372
|
+
};
|
373
|
+
|
374
|
+
flat_hash_map<uint32_t, uint32_t> compute_alphabet_helper(
|
375
|
+
const flat_hash_map<uint32_t, uint64_t> &char_cnt, uint64_t data_len,
|
376
|
+
flat_hash_set<uint32_t> &removed_chars, const BpeConfig &bpe_config) {
|
377
|
+
vector<std::pair<uint64_t, uint32_t>> frequencies;
|
378
|
+
|
379
|
+
for (auto x : char_cnt) {
|
380
|
+
frequencies.emplace_back(x.second, x.first);
|
381
|
+
}
|
382
|
+
sort(frequencies.begin(), frequencies.end());
|
383
|
+
|
384
|
+
uint64_t cur = 0;
|
385
|
+
uint64_t n_removed = 0;
|
386
|
+
for (; cur < frequencies.size() &&
|
387
|
+
(data_len - n_removed - frequencies[cur].first) >
|
388
|
+
data_len * bpe_config.character_coverage;
|
389
|
+
cur++) {
|
390
|
+
n_removed += frequencies[cur].first;
|
391
|
+
}
|
392
|
+
std::cerr << "number of unique characters in the training data: "
|
393
|
+
<< frequencies.size() << std::endl;
|
394
|
+
std::cerr << "number of deleted characters: " << cur << std::endl;
|
395
|
+
std::cerr << "number of unique characters left: " << frequencies.size() - cur
|
396
|
+
<< std::endl;
|
397
|
+
|
398
|
+
flat_hash_map<uint32_t, uint32_t> char2id;
|
399
|
+
uint64_t used_ids = bpe_config.special_tokens.n_special_tokens();
|
400
|
+
char2id[SPACE_TOKEN] = used_ids++;
|
401
|
+
|
402
|
+
for (uint64_t i = 0; i < cur; i++) {
|
403
|
+
removed_chars.insert(frequencies[i].second);
|
404
|
+
}
|
405
|
+
|
406
|
+
for (int i = frequencies.size() - 1; i >= static_cast<int>(cur); i--) {
|
407
|
+
if (!is_space(frequencies[i].second)) {
|
408
|
+
assert(char2id.count(frequencies[i].second) == 0);
|
409
|
+
char2id[frequencies[i].second] = used_ids++;
|
410
|
+
}
|
411
|
+
}
|
412
|
+
return char2id;
|
413
|
+
}
|
414
|
+
|
415
|
+
void remove_rare_chars(vector<uint32_t> &data,
|
416
|
+
const flat_hash_set<uint32_t> &removed_chars) {
|
417
|
+
if (removed_chars.empty()) {
|
418
|
+
return;
|
419
|
+
}
|
420
|
+
auto it_first_rare_char =
|
421
|
+
std::remove_if(data.begin(), data.end(),
|
422
|
+
[&](uint32_t c) { return removed_chars.count(c) != 0; });
|
423
|
+
data.erase(it_first_rare_char, data.end());
|
424
|
+
}
|
425
|
+
|
426
|
+
char* remove_rare_chars(char* begin, char* end, const flat_hash_set<uint32_t> &removed_chars) {
|
427
|
+
if (removed_chars.empty()) {
|
428
|
+
return end;
|
429
|
+
}
|
430
|
+
char* end_candidate = begin;
|
431
|
+
bool invalid_input = false;
|
432
|
+
UTF8Iterator utf8_iter(begin, end);
|
433
|
+
for (; !utf8_iter.empty(); ++utf8_iter) {
|
434
|
+
if (*utf8_iter != INVALID_UNICODE) {
|
435
|
+
if (removed_chars.count(*utf8_iter) == 0) {
|
436
|
+
char* token_begin = utf8_iter.get_ptr();
|
437
|
+
uint64_t token_len = utf8_iter.get_utf8_len();
|
438
|
+
memcpy(end_candidate, token_begin, token_len);
|
439
|
+
end_candidate += token_len;
|
440
|
+
}
|
441
|
+
} else {
|
442
|
+
invalid_input = true;
|
443
|
+
}
|
444
|
+
}
|
445
|
+
if (invalid_input) {
|
446
|
+
std::cerr << "WARNING Input contains invalid unicode characters." << std::endl;
|
447
|
+
}
|
448
|
+
return end_candidate;
|
449
|
+
}
|
450
|
+
|
451
|
+
struct WordCount {
|
452
|
+
vector<uint32_t> word;
|
453
|
+
uint64_t cnt;
|
454
|
+
};
|
455
|
+
|
456
|
+
|
457
|
+
flat_hash_map<VectorSegment, WordCount> compute_word_count(
|
458
|
+
char* sbegin, char* send,
|
459
|
+
const flat_hash_map<uint32_t, uint32_t> &char2id) {
|
460
|
+
flat_hash_map<VectorSegment, WordCount> hash2wordcnt;
|
461
|
+
vector<uint32_t> word;
|
462
|
+
UTF8Iterator utf8_iter(sbegin, send);
|
463
|
+
|
464
|
+
for (;!utf8_iter.empty();) {
|
465
|
+
for (; !utf8_iter.empty() && is_space(*utf8_iter); ++utf8_iter);
|
466
|
+
if (utf8_iter.empty()) {
|
467
|
+
break;
|
468
|
+
}
|
469
|
+
char* begin_of_word = utf8_iter.get_ptr();
|
470
|
+
for (; !utf8_iter.empty() && !is_space(*utf8_iter); ++utf8_iter);
|
471
|
+
char* end_of_word = utf8_iter.get_ptr();
|
472
|
+
VectorSegment word_hash(begin_of_word, end_of_word);
|
473
|
+
auto it = hash2wordcnt.find(word_hash);
|
474
|
+
if (it == hash2wordcnt.end()) {
|
475
|
+
word.clear();
|
476
|
+
word.push_back(char2id.at(SPACE_TOKEN));
|
477
|
+
UTF8Iterator word_iter(begin_of_word, end_of_word);
|
478
|
+
for (; !word_iter.empty(); ++word_iter) {
|
479
|
+
word.push_back(char2id.at(*word_iter));
|
480
|
+
}
|
481
|
+
hash2wordcnt[word_hash] = {word, 1};
|
482
|
+
} else {
|
483
|
+
it->second.cnt++;
|
484
|
+
}
|
485
|
+
}
|
486
|
+
return hash2wordcnt;
|
487
|
+
}
|
488
|
+
|
489
|
+
|
490
|
+
struct NodeEncoder {
|
491
|
+
uint32_t val;
|
492
|
+
int prev;
|
493
|
+
int next;
|
494
|
+
int seg_len;
|
495
|
+
|
496
|
+
NodeEncoder(uint32_t val, int prev, int next, int seg_len)
|
497
|
+
: val(val), prev(prev), next(next), seg_len(seg_len) {}
|
498
|
+
|
499
|
+
bool is_alive() const {
|
500
|
+
assert((val == 0) == (seg_len == 0));
|
501
|
+
return val != 0;
|
502
|
+
}
|
503
|
+
};
|
504
|
+
|
505
|
+
void build_linked_list(const vector<WordCount> &word_cnt,
|
506
|
+
vector<vector<NodeEncoder>> &list,
|
507
|
+
flat_hash_map<uint64_t, vector<Position>> &pair2pos,
|
508
|
+
flat_hash_map<uint64_t, uint64_t> &pair2cnt) {
|
509
|
+
list.resize(word_cnt.size());
|
510
|
+
for (uint64_t i = 0; i < word_cnt.size(); i++) {
|
511
|
+
for (uint32_t ch : word_cnt[i].word) {
|
512
|
+
if (!list[i].empty() && list[i].back().val == ch) {
|
513
|
+
list[i].back().seg_len++;
|
514
|
+
} else {
|
515
|
+
int list_size = list[i].size();
|
516
|
+
list[i].emplace_back(ch, list_size - 1, list_size + 1, 1);
|
517
|
+
}
|
518
|
+
}
|
519
|
+
|
520
|
+
list[i].back().next = -1;
|
521
|
+
for (uint64_t j = 0; j < list[i].size(); j++) {
|
522
|
+
if (j + 1 < list[i].size()) {
|
523
|
+
uint64_t comb = int2comb(list[i][j].val, list[i][j + 1].val);
|
524
|
+
auto it = pair2pos.find(comb);
|
525
|
+
if (it == pair2pos.end()) {
|
526
|
+
pair2pos[comb] = {{i, j}};
|
527
|
+
} else {
|
528
|
+
it->second.emplace_back(i, j);
|
529
|
+
}
|
530
|
+
pair2cnt[comb] += word_cnt[i].cnt;
|
531
|
+
}
|
532
|
+
assert(list[i][j].seg_len >= 1);
|
533
|
+
|
534
|
+
if (list[i][j].seg_len > 1) {
|
535
|
+
uint64_t comb = int2comb(list[i][j].val, list[i][j].val);
|
536
|
+
auto it = pair2pos.find(comb);
|
537
|
+
uint64_t cc = word_cnt[i].cnt * pairsInSeg(list[i][j].seg_len);
|
538
|
+
if (it == pair2pos.end()) {
|
539
|
+
pair2pos[comb] = {{i, j}};
|
540
|
+
} else {
|
541
|
+
it->second.emplace_back(i, j);
|
542
|
+
}
|
543
|
+
pair2cnt[comb] += cc;
|
544
|
+
}
|
545
|
+
}
|
546
|
+
}
|
547
|
+
}
|
548
|
+
|
549
|
+
std::chrono::steady_clock::time_point last_time_stamp;
|
550
|
+
|
551
|
+
void time_check(const string &message) {
|
552
|
+
auto cur_moment = std::chrono::steady_clock::now();
|
553
|
+
if (!message.empty()) {
|
554
|
+
std::cerr << "## time " << message << " ... "
|
555
|
+
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
556
|
+
cur_moment - last_time_stamp)
|
557
|
+
.count() *
|
558
|
+
1.0 / 1e6
|
559
|
+
<< std::endl;
|
560
|
+
}
|
561
|
+
last_time_stamp = cur_moment;
|
562
|
+
}
|
563
|
+
|
564
|
+
double time_check_silent() {
|
565
|
+
auto cur_moment = std::chrono::steady_clock::now();
|
566
|
+
double ret = std::chrono::duration_cast<std::chrono::microseconds>(
|
567
|
+
cur_moment - last_time_stamp)
|
568
|
+
.count() *
|
569
|
+
1.0 / 1e6;
|
570
|
+
last_time_stamp = cur_moment;
|
571
|
+
return ret;
|
572
|
+
}
|
573
|
+
|
574
|
+
int alive_tokens;
|
575
|
+
|
576
|
+
void init_recipe(const flat_hash_map<uint32_t, uint32_t> &char2id,
|
577
|
+
flat_hash_map<uint32_t, vector<uint32_t>> &recipe,
|
578
|
+
flat_hash_map<uint32_t, string> &recipe_s) {
|
579
|
+
for (auto token_id : char2id) {
|
580
|
+
uint32_t ch = token_id.first;
|
581
|
+
uint32_t id = token_id.second;
|
582
|
+
recipe[id] = {id};
|
583
|
+
recipe_s[id] = encode_utf8({ch});
|
584
|
+
}
|
585
|
+
}
|
586
|
+
|
587
|
+
void worker_doing_merge(
|
588
|
+
uint64_t thread_id, vector<vector<NodeEncoder>> &lists_of_tokens,
|
589
|
+
vector<flat_hash_map<uint64_t, uint64_t>> &pair2cnt_g,
|
590
|
+
flat_hash_map<uint64_t, vector<Position>> &pair2pos,
|
591
|
+
vector<uint64_t> &word_freq, vector<std::mutex> &mt,
|
592
|
+
vector<std::condition_variable> &cv, vector<BPE_Rule> &task_order,
|
593
|
+
vector<std::atomic_bool> &thread_use_hs,
|
594
|
+
flat_hash_map<uint32_t, uint32_t> &char2id,
|
595
|
+
vector<vector<flat_hash_map<uint32_t, uint64_t>>> &left_tokens_submit,
|
596
|
+
vector<vector<flat_hash_map<uint32_t, uint64_t>>> &right_tokens_submit,
|
597
|
+
std::atomic<uint32_t> &real_n_tokens,
|
598
|
+
vector<std::atomic<uint32_t>> &results_ready, const BpeConfig &bpe_config,
|
599
|
+
std::mutex &main_loop_mt, std::condition_variable &main_loop_cv) {
|
600
|
+
auto &pair2cnt = pair2cnt_g[thread_id];
|
601
|
+
flat_hash_set<uint32_t> left_tokens;
|
602
|
+
flat_hash_set<uint32_t> right_tokens;
|
603
|
+
|
604
|
+
uint32_t cur_token_rule =
|
605
|
+
char2id.size() + bpe_config.special_tokens.n_special_tokens();
|
606
|
+
auto get_pair_code = [&](uint64_t word_id, uint64_t p1) {
|
607
|
+
int p2 = lists_of_tokens[word_id][p1].next;
|
608
|
+
return int2comb(lists_of_tokens[word_id][p1].val,
|
609
|
+
lists_of_tokens[word_id][p2].val);
|
610
|
+
};
|
611
|
+
|
612
|
+
auto get_self_code = [&](uint64_t word_id, uint64_t p1) {
|
613
|
+
return int2comb(lists_of_tokens[word_id][p1].val,
|
614
|
+
lists_of_tokens[word_id][p1].val);
|
615
|
+
};
|
616
|
+
|
617
|
+
auto remove_pair = [&](int word_id, int pos_id) {
|
618
|
+
pair2cnt[get_pair_code(word_id, pos_id)] -= word_freq[word_id];
|
619
|
+
};
|
620
|
+
|
621
|
+
auto add_pair = [&](uint64_t word_id, uint64_t pos_id) {
|
622
|
+
uint64_t comb = get_pair_code(word_id, pos_id);
|
623
|
+
auto it = pair2pos.find(comb);
|
624
|
+
if (it == pair2pos.end()) {
|
625
|
+
pair2pos[comb] = {{word_id, pos_id}};
|
626
|
+
} else {
|
627
|
+
it->second.emplace_back(word_id, pos_id);
|
628
|
+
}
|
629
|
+
pair2cnt[comb] += word_freq[word_id];
|
630
|
+
};
|
631
|
+
|
632
|
+
auto add_empty_pair = [&](uint64_t word_id, uint64_t pos_id) {
|
633
|
+
auto it = pair2pos.find(get_pair_code(word_id, pos_id));
|
634
|
+
assert(it != pair2pos.end());
|
635
|
+
it->second.emplace_back(word_id, pos_id);
|
636
|
+
};
|
637
|
+
|
638
|
+
auto add_self_pair = [&](uint64_t word_id, uint64_t pos_id) {
|
639
|
+
int seg_len = lists_of_tokens[word_id][pos_id].seg_len;
|
640
|
+
assert(seg_len >= 2);
|
641
|
+
uint64_t comb = get_self_code(word_id, pos_id);
|
642
|
+
auto it = pair2pos.find(comb);
|
643
|
+
uint64_t real_cnt = word_freq[word_id] * pairsInSeg(seg_len);
|
644
|
+
if (it == pair2pos.end()) {
|
645
|
+
pair2pos[comb] = {{word_id, pos_id}};
|
646
|
+
assert(pair2pos[comb].size() == 1);
|
647
|
+
} else {
|
648
|
+
it->second.emplace_back(word_id, pos_id);
|
649
|
+
}
|
650
|
+
pair2cnt[comb] += real_cnt;
|
651
|
+
};
|
652
|
+
|
653
|
+
auto add_merge_compensation = [&](uint64_t word_id, uint64_t pos_id,
|
654
|
+
int score_diff) {
|
655
|
+
assert(score_diff > 0);
|
656
|
+
uint64_t comb = get_self_code(word_id, pos_id);
|
657
|
+
pair2cnt[comb] -= score_diff * word_freq[word_id];
|
658
|
+
};
|
659
|
+
|
660
|
+
auto seg_len_decrement = [&](uint64_t word_id, uint64_t pos_id) {
|
661
|
+
int seg_len = lists_of_tokens[word_id][pos_id].seg_len;
|
662
|
+
assert(seg_len >= 2);
|
663
|
+
lists_of_tokens[word_id][pos_id].seg_len--;
|
664
|
+
if (seg_len % 2 == 1) {
|
665
|
+
return;
|
666
|
+
}
|
667
|
+
uint64_t comb = get_self_code(word_id, pos_id);
|
668
|
+
pair2cnt[comb] -= word_freq[word_id];
|
669
|
+
};
|
670
|
+
|
671
|
+
auto self_full_remove = [&](uint64_t word_id, uint64_t pos_id) {
|
672
|
+
uint64_t comb = get_self_code(word_id, pos_id);
|
673
|
+
uint32_t real_cnt = word_freq[word_id] *
|
674
|
+
pairsInSeg(lists_of_tokens[word_id][pos_id].seg_len);
|
675
|
+
pair2cnt[comb] -= real_cnt;
|
676
|
+
};
|
677
|
+
|
678
|
+
auto try_merge = [&](uint64_t word_id, uint64_t pos1, uint64_t pos2) {
|
679
|
+
vector<NodeEncoder> &cur_list = lists_of_tokens[word_id];
|
680
|
+
if (cur_list[pos1].val == cur_list[pos2].val) {
|
681
|
+
int score_before =
|
682
|
+
(cur_list[pos1].seg_len / 2) + (cur_list[pos2].seg_len / 2) + 1;
|
683
|
+
cur_list[pos1].seg_len += cur_list[pos2].seg_len;
|
684
|
+
int score_after = cur_list[pos1].seg_len / 2;
|
685
|
+
if (score_before != score_after) {
|
686
|
+
add_merge_compensation(word_id, pos1, score_before - score_after);
|
687
|
+
}
|
688
|
+
|
689
|
+
cur_list[pos1].next = cur_list[pos2].next;
|
690
|
+
cur_list[pos2] = {0, -1, -1, 0};
|
691
|
+
if (cur_list[pos1].next != -1) {
|
692
|
+
cur_list[cur_list[pos1].next].prev = pos1;
|
693
|
+
add_empty_pair(word_id, pos1);
|
694
|
+
}
|
695
|
+
}
|
696
|
+
};
|
697
|
+
while (true) {
|
698
|
+
{
|
699
|
+
std::unique_lock<std::mutex> ul(mt[thread_id]);
|
700
|
+
cv[thread_id].wait(ul, [&] {
|
701
|
+
return task_order[cur_token_rule % 2].z == cur_token_rule ||
|
702
|
+
cur_token_rule >= real_n_tokens;
|
703
|
+
});
|
704
|
+
assert(cur_token_rule <= real_n_tokens);
|
705
|
+
if (cur_token_rule == real_n_tokens) {
|
706
|
+
break;
|
707
|
+
}
|
708
|
+
}
|
709
|
+
|
710
|
+
uint32_t x = task_order[cur_token_rule % 2].x;
|
711
|
+
uint32_t y = task_order[cur_token_rule % 2].y;
|
712
|
+
uint32_t z = task_order[cur_token_rule % 2].z;
|
713
|
+
|
714
|
+
left_tokens.clear();
|
715
|
+
right_tokens.clear();
|
716
|
+
|
717
|
+
left_tokens.insert(z);
|
718
|
+
int real_merge = 0;
|
719
|
+
int not_real_merge = 0;
|
720
|
+
|
721
|
+
if (x == y) {
|
722
|
+
const vector<Position> &merge_candidates = pair2pos[int2comb(x, y)];
|
723
|
+
|
724
|
+
std::unique_lock<std::mutex> lk(mt[thread_id]);
|
725
|
+
for (auto word_pos : merge_candidates) {
|
726
|
+
cv[thread_id].wait(lk, [&] { return thread_use_hs[thread_id].load(); });
|
727
|
+
not_real_merge++;
|
728
|
+
|
729
|
+
// p0 <-> p1 <-> p3 -- ids of nodes in linked list.
|
730
|
+
// merge will happen inside p1
|
731
|
+
|
732
|
+
int word_id = word_pos.word_id;
|
733
|
+
vector<NodeEncoder> &cur_list = lists_of_tokens[word_id];
|
734
|
+
int p1 = word_pos.pos_id;
|
735
|
+
if (cur_list[p1].val != x || cur_list[p1].seg_len < 2) {
|
736
|
+
continue;
|
737
|
+
}
|
738
|
+
real_merge++;
|
739
|
+
|
740
|
+
int p0 = cur_list[p1].prev;
|
741
|
+
int p3 = cur_list[p1].next;
|
742
|
+
|
743
|
+
self_full_remove(word_id, p1);
|
744
|
+
if (p0 != -1) {
|
745
|
+
remove_pair(word_id, p0);
|
746
|
+
}
|
747
|
+
if (p3 != -1) {
|
748
|
+
remove_pair(word_id, p1);
|
749
|
+
}
|
750
|
+
int seg_len = cur_list[p1].seg_len;
|
751
|
+
if (seg_len % 2 == 0) {
|
752
|
+
cur_list[p1] = {z, p0, p3, seg_len / 2};
|
753
|
+
|
754
|
+
if (p0 != -1) {
|
755
|
+
add_pair(word_id, p0);
|
756
|
+
left_tokens.insert(cur_list[p0].val);
|
757
|
+
}
|
758
|
+
|
759
|
+
if (p3 != -1) {
|
760
|
+
add_pair(word_id, p1);
|
761
|
+
right_tokens.insert(cur_list[p3].val);
|
762
|
+
}
|
763
|
+
if (seg_len / 2 >= 2) {
|
764
|
+
add_self_pair(word_id, p1);
|
765
|
+
}
|
766
|
+
} else {
|
767
|
+
cur_list.emplace_back(x, p1, p3, 1);
|
768
|
+
int p2 = static_cast<int>(cur_list.size() - 1);
|
769
|
+
cur_list[p1] = {z, p0, p2, seg_len / 2};
|
770
|
+
if (p0 != -1) {
|
771
|
+
add_pair(word_id, p0);
|
772
|
+
left_tokens.insert(cur_list[p0].val);
|
773
|
+
}
|
774
|
+
|
775
|
+
add_pair(word_id, p1);
|
776
|
+
right_tokens.insert(cur_list[p2].val);
|
777
|
+
|
778
|
+
if (p3 != -1) {
|
779
|
+
cur_list[p3].prev = p2;
|
780
|
+
add_pair(word_id, p2);
|
781
|
+
}
|
782
|
+
|
783
|
+
if (seg_len / 2 >= 2) {
|
784
|
+
add_self_pair(word_id, p1);
|
785
|
+
}
|
786
|
+
}
|
787
|
+
}
|
788
|
+
} else {
|
789
|
+
std::unique_lock<std::mutex> lk(mt[thread_id]);
|
790
|
+
for (auto word_pos : pair2pos[int2comb(x, y)]) {
|
791
|
+
not_real_merge++;
|
792
|
+
cv[thread_id].wait(lk, [&] { return thread_use_hs[thread_id].load(); });
|
793
|
+
// p0 <-> p1 <-> p2 <-> p3 -- ids of nodes in linked list.
|
794
|
+
// merge will happen between p1 p2
|
795
|
+
int word_id = word_pos.word_id;
|
796
|
+
|
797
|
+
int p1 = word_pos.pos_id;
|
798
|
+
vector<NodeEncoder> &cur_list = lists_of_tokens[word_id];
|
799
|
+
int p2 = cur_list[p1].next;
|
800
|
+
if (cur_list[p1].val != x || p2 == -1 || cur_list[p2].val != y) {
|
801
|
+
continue;
|
802
|
+
}
|
803
|
+
real_merge++;
|
804
|
+
|
805
|
+
int p0 = cur_list[p1].prev;
|
806
|
+
int p3 = cur_list[p2].next;
|
807
|
+
remove_pair(word_id, p1);
|
808
|
+
if (p0 != -1 && cur_list[p1].seg_len == 1) {
|
809
|
+
remove_pair(word_id, p0);
|
810
|
+
}
|
811
|
+
if (p3 != -1 && cur_list[p2].seg_len == 1) {
|
812
|
+
remove_pair(word_id, p2);
|
813
|
+
}
|
814
|
+
|
815
|
+
if (cur_list[p1].seg_len > 1 && cur_list[p2].seg_len > 1) {
|
816
|
+
cur_list.emplace_back(z, p1, p2, 1);
|
817
|
+
int p12 = static_cast<int>(cur_list.size() - 1);
|
818
|
+
|
819
|
+
seg_len_decrement(word_id, p1);
|
820
|
+
seg_len_decrement(word_id, p2);
|
821
|
+
|
822
|
+
cur_list[p1].next = p12;
|
823
|
+
cur_list[p2].prev = p12;
|
824
|
+
|
825
|
+
add_pair(word_id, p1);
|
826
|
+
left_tokens.insert(cur_list[p1].val);
|
827
|
+
|
828
|
+
add_pair(word_id, p12);
|
829
|
+
right_tokens.insert(cur_list[p2].val);
|
830
|
+
} else if (cur_list[p1].seg_len > 1 && cur_list[p2].seg_len == 1) {
|
831
|
+
cur_list[p2] = {z, p1, p3, 1};
|
832
|
+
|
833
|
+
seg_len_decrement(word_id, p1);
|
834
|
+
|
835
|
+
add_pair(word_id, p1);
|
836
|
+
left_tokens.insert(cur_list[p1].val);
|
837
|
+
|
838
|
+
if (p3 != -1) {
|
839
|
+
add_pair(word_id, p2);
|
840
|
+
right_tokens.insert(cur_list[p3].val);
|
841
|
+
try_merge(word_id, p2, p3);
|
842
|
+
}
|
843
|
+
} else if (cur_list[p1].seg_len == 1 && cur_list[p2].seg_len > 1) {
|
844
|
+
cur_list[p1] = {z, p0, p2, 1};
|
845
|
+
|
846
|
+
seg_len_decrement(word_id, p2);
|
847
|
+
|
848
|
+
if (p0 != -1) {
|
849
|
+
add_pair(word_id, p0);
|
850
|
+
left_tokens.insert(cur_list[p0].val);
|
851
|
+
}
|
852
|
+
|
853
|
+
add_pair(word_id, p1);
|
854
|
+
right_tokens.insert(cur_list[p2].val);
|
855
|
+
if (p0 != -1) {
|
856
|
+
try_merge(word_id, p0, p1);
|
857
|
+
}
|
858
|
+
} else {
|
859
|
+
assert(cur_list[p1].seg_len == 1 && cur_list[p2].seg_len == 1);
|
860
|
+
|
861
|
+
cur_list[p1] = {z, p0, p3, 1};
|
862
|
+
cur_list[p2] = {0, -1, -1, 0};
|
863
|
+
if (p3 != -1) {
|
864
|
+
cur_list[p3].prev = p1;
|
865
|
+
}
|
866
|
+
|
867
|
+
if (p0 != -1) {
|
868
|
+
add_pair(word_id, p0);
|
869
|
+
left_tokens.insert(cur_list[p0].val);
|
870
|
+
}
|
871
|
+
if (p3 != -1) {
|
872
|
+
add_pair(word_id, p1);
|
873
|
+
right_tokens.insert(cur_list[p3].val);
|
874
|
+
}
|
875
|
+
if (p0 != -1) {
|
876
|
+
try_merge(word_id, p0, p1);
|
877
|
+
}
|
878
|
+
if (p3 != -1) {
|
879
|
+
try_merge(word_id, p1, p3);
|
880
|
+
}
|
881
|
+
}
|
882
|
+
}
|
883
|
+
}
|
884
|
+
pair2pos.erase(int2comb(x, y));
|
885
|
+
{
|
886
|
+
std::unique_lock<std::mutex> lk(mt[thread_id]);
|
887
|
+
|
888
|
+
left_tokens_submit[cur_token_rule % 2][thread_id].clear();
|
889
|
+
right_tokens_submit[cur_token_rule % 2][thread_id].clear();
|
890
|
+
|
891
|
+
for (auto token : left_tokens) {
|
892
|
+
left_tokens_submit[cur_token_rule % 2][thread_id][token] =
|
893
|
+
pair2cnt[int2comb(token, z)];
|
894
|
+
}
|
895
|
+
|
896
|
+
for (auto token : right_tokens) {
|
897
|
+
right_tokens_submit[cur_token_rule % 2][thread_id][token] =
|
898
|
+
pair2cnt[int2comb(z, token)];
|
899
|
+
}
|
900
|
+
}
|
901
|
+
{
|
902
|
+
std::lock_guard<std::mutex> lg(main_loop_mt);
|
903
|
+
results_ready[thread_id] = cur_token_rule;
|
904
|
+
}
|
905
|
+
main_loop_cv.notify_one();
|
906
|
+
cur_token_rule++;
|
907
|
+
}
|
908
|
+
}
|
909
|
+
|
910
|
+
void rename_tokens(flat_hash_map<uint32_t, uint32_t> &char2id,
|
911
|
+
vector<BPE_Rule> &rules, const SpecialTokens &special_tokens,
|
912
|
+
uint32_t n_tokens) {
|
913
|
+
flat_hash_map<uint32_t, uint32_t> renaming;
|
914
|
+
uint32_t cur = special_tokens.n_special_tokens();
|
915
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
916
|
+
if (!special_tokens.taken_id(i)) {
|
917
|
+
renaming[cur++] = i;
|
918
|
+
}
|
919
|
+
}
|
920
|
+
for (auto &node : char2id) {
|
921
|
+
assert(renaming.count(node.second));
|
922
|
+
node.second = renaming[node.second];
|
923
|
+
}
|
924
|
+
|
925
|
+
for (auto &rule : rules) {
|
926
|
+
assert(renaming.count(rule.x));
|
927
|
+
assert(renaming.count(rule.y));
|
928
|
+
assert(renaming.count(rule.z));
|
929
|
+
rule.x = renaming[rule.x];
|
930
|
+
rule.y = renaming[rule.y];
|
931
|
+
rule.z = renaming[rule.z];
|
932
|
+
}
|
933
|
+
}
|
934
|
+
|
935
|
+
uint64_t compute_char_count(flat_hash_map<uint32_t, uint64_t>& char_cnt, char* begin, char* end) {
|
936
|
+
bool invalid_input = false;
|
937
|
+
UTF8Iterator utf8_iter(begin, end);
|
938
|
+
uint64_t char_count = 0;
|
939
|
+
for (; !utf8_iter.empty(); char_count++, ++utf8_iter) {
|
940
|
+
if (*utf8_iter != INVALID_UNICODE) {
|
941
|
+
if (!is_space(*utf8_iter)) {
|
942
|
+
char_cnt[*utf8_iter]++;
|
943
|
+
}
|
944
|
+
} else {
|
945
|
+
invalid_input = true;
|
946
|
+
}
|
947
|
+
}
|
948
|
+
if (invalid_input) {
|
949
|
+
std::cerr << "WARNING Input contains invalid unicode characters."
|
950
|
+
<< std::endl;
|
951
|
+
}
|
952
|
+
return char_count;
|
953
|
+
}
|
954
|
+
|
955
|
+
Status learn_bpe_from_string(string &text_utf8, int n_tokens,
|
956
|
+
const string &output_file,
|
957
|
+
BpeConfig bpe_config, BPEState *bpe_state) {
|
958
|
+
assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1);
|
959
|
+
uint64_t n_threads = bpe_config.n_threads;
|
960
|
+
vector<uint64_t> split_pos;
|
961
|
+
split_pos.push_back(0);
|
962
|
+
for (uint64_t i = 1; i <= n_threads; i++) {
|
963
|
+
uint64_t candidate = text_utf8.size() * i / n_threads;
|
964
|
+
for (; candidate < text_utf8.size() && !is_space(text_utf8[candidate]);
|
965
|
+
candidate++) {
|
966
|
+
}
|
967
|
+
|
968
|
+
split_pos.push_back(candidate);
|
969
|
+
}
|
970
|
+
|
971
|
+
vector<flat_hash_map<uint32_t, uint64_t>> shared_char_cnt(n_threads);
|
972
|
+
|
973
|
+
vector<std::mutex> mt(n_threads);
|
974
|
+
vector<std::condition_variable> cv(n_threads);
|
975
|
+
vector<char> thread_finished(n_threads, 0);
|
976
|
+
vector<char> main_finished(n_threads, 0);
|
977
|
+
vector<uint64_t> text_len(n_threads);
|
978
|
+
|
979
|
+
flat_hash_set<uint32_t> removed_chars;
|
980
|
+
flat_hash_map<uint32_t, uint32_t> char2id;
|
981
|
+
|
982
|
+
vector<flat_hash_map<VectorSegment, WordCount>> hash2wordcnt(n_threads);
|
983
|
+
int error_flag = 0;
|
984
|
+
|
985
|
+
flat_hash_map<uint32_t, vector<uint32_t>> recipe;
|
986
|
+
flat_hash_map<uint32_t, string> recipe_s;
|
987
|
+
vector<flat_hash_map<uint64_t, uint64_t>> pair2cnt_g(n_threads);
|
988
|
+
PriorityQueue merge_order(1);
|
989
|
+
vector<uint64_t> split_word_cnt;
|
990
|
+
vector<WordCount> word_cnt_global;
|
991
|
+
|
992
|
+
auto comb2int = [](uint64_t a, uint32_t &b, uint32_t &c) {
|
993
|
+
b = static_cast<uint32_t>(a >> 32u);
|
994
|
+
c = static_cast<uint32_t>(a & UINT32_MAX);
|
995
|
+
};
|
996
|
+
|
997
|
+
vector<vector<flat_hash_map<uint32_t, uint64_t>>> left_tokens_submit(
|
998
|
+
2, vector<flat_hash_map<uint32_t, uint64_t>>(n_threads));
|
999
|
+
vector<vector<flat_hash_map<uint32_t, uint64_t>>> right_tokens_submit(
|
1000
|
+
2, vector<flat_hash_map<uint32_t, uint64_t>>(n_threads));
|
1001
|
+
vector<std::atomic<uint32_t>> results_ready(n_threads);
|
1002
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1003
|
+
results_ready[i] = 0;
|
1004
|
+
}
|
1005
|
+
|
1006
|
+
vector<char> thread_stopped(n_threads);
|
1007
|
+
vector<char> thread_ready_to_run(n_threads);
|
1008
|
+
vector<std::atomic_bool> thread_use_hs(n_threads);
|
1009
|
+
vector<BPE_Rule> task_order(2);
|
1010
|
+
|
1011
|
+
std::atomic<uint32_t> real_n_tokens(n_tokens);
|
1012
|
+
|
1013
|
+
std::mutex main_loop_mt;
|
1014
|
+
std::condition_variable main_loop_cv;
|
1015
|
+
|
1016
|
+
vector<std::thread> threads;
|
1017
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1018
|
+
threads.emplace_back(
|
1019
|
+
[&](uint64_t thread_id) {
|
1020
|
+
// threads are working 1
|
1021
|
+
|
1022
|
+
|
1023
|
+
auto thread_awake_main = [&]() {
|
1024
|
+
{
|
1025
|
+
std::lock_guard<std::mutex> lk(mt[thread_id]);
|
1026
|
+
thread_finished[thread_id] = 1;
|
1027
|
+
}
|
1028
|
+
cv[thread_id].notify_one();
|
1029
|
+
};
|
1030
|
+
|
1031
|
+
auto thread_wait_main = [&]() {
|
1032
|
+
std::unique_lock<std::mutex> lk(mt[thread_id]);
|
1033
|
+
cv[thread_id].wait(lk, [&] { return main_finished[thread_id]; });
|
1034
|
+
main_finished[thread_id] = 0;
|
1035
|
+
};
|
1036
|
+
|
1037
|
+
flat_hash_map<uint32_t, uint64_t> char_cnt;
|
1038
|
+
uint64_t char_count = compute_char_count(char_cnt, &text_utf8[0] + split_pos[thread_id], &text_utf8[0] + split_pos[thread_id + 1]);
|
1039
|
+
text_len[thread_id] = char_count;
|
1040
|
+
shared_char_cnt[thread_id] = char_cnt;
|
1041
|
+
|
1042
|
+
thread_awake_main();
|
1043
|
+
// main is working 1
|
1044
|
+
thread_wait_main();
|
1045
|
+
// threads are working 2
|
1046
|
+
char* seg_begin = &text_utf8[0] + split_pos[thread_id];
|
1047
|
+
char* seg_end = &text_utf8[0] + split_pos[thread_id + 1];
|
1048
|
+
char* new_seg_end = remove_rare_chars(seg_begin, seg_end, removed_chars);
|
1049
|
+
seg_end = new_seg_end;
|
1050
|
+
|
1051
|
+
hash2wordcnt[thread_id] = compute_word_count(seg_begin, seg_end, char2id);
|
1052
|
+
|
1053
|
+
thread_awake_main();
|
1054
|
+
// main is working 2
|
1055
|
+
thread_wait_main();
|
1056
|
+
// threads are working 3
|
1057
|
+
if (error_flag != 0) {
|
1058
|
+
return;
|
1059
|
+
}
|
1060
|
+
|
1061
|
+
flat_hash_map<uint64_t, vector<Position>> pair2pos;
|
1062
|
+
vector<vector<NodeEncoder>> lists_of_tokens;
|
1063
|
+
vector<uint64_t> word_freq;
|
1064
|
+
build_linked_list(
|
1065
|
+
{word_cnt_global.begin() + split_word_cnt[thread_id],
|
1066
|
+
word_cnt_global.begin() + split_word_cnt[thread_id + 1]},
|
1067
|
+
lists_of_tokens, pair2pos, pair2cnt_g[thread_id]);
|
1068
|
+
|
1069
|
+
std::transform(
|
1070
|
+
word_cnt_global.begin() + split_word_cnt[thread_id],
|
1071
|
+
word_cnt_global.begin() + split_word_cnt[thread_id + 1],
|
1072
|
+
std::back_inserter(word_freq),
|
1073
|
+
[](const WordCount &x) { return x.cnt; });
|
1074
|
+
|
1075
|
+
thread_awake_main();
|
1076
|
+
// main is working 3
|
1077
|
+
// threads are working 4
|
1078
|
+
|
1079
|
+
worker_doing_merge(thread_id, lists_of_tokens, pair2cnt_g, pair2pos,
|
1080
|
+
word_freq, mt, cv, task_order, thread_use_hs,
|
1081
|
+
char2id, left_tokens_submit, right_tokens_submit,
|
1082
|
+
real_n_tokens, results_ready, bpe_config,
|
1083
|
+
main_loop_mt, main_loop_cv);
|
1084
|
+
},
|
1085
|
+
i);
|
1086
|
+
}
|
1087
|
+
|
1088
|
+
auto main_wait_threads = [&]() {
|
1089
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1090
|
+
std::unique_lock<std::mutex> lk(mt[i]);
|
1091
|
+
cv[i].wait(lk, [&] { return thread_finished[i]; });
|
1092
|
+
thread_finished[i] = 0;
|
1093
|
+
}
|
1094
|
+
};
|
1095
|
+
|
1096
|
+
auto main_awake_threads = [&]() {
|
1097
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1098
|
+
std::lock_guard<std::mutex> lk(mt[i]);
|
1099
|
+
main_finished[i] = 1;
|
1100
|
+
}
|
1101
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1102
|
+
cv[i].notify_one();
|
1103
|
+
}
|
1104
|
+
};
|
1105
|
+
|
1106
|
+
main_wait_threads();
|
1107
|
+
|
1108
|
+
// main is working 1
|
1109
|
+
for (uint64_t i = 1; i < n_threads; i++) {
|
1110
|
+
for (auto x : shared_char_cnt[i]) {
|
1111
|
+
shared_char_cnt[0][x.first] += x.second;
|
1112
|
+
}
|
1113
|
+
text_len[0] += text_len[i];
|
1114
|
+
}
|
1115
|
+
|
1116
|
+
char2id = compute_alphabet_helper(shared_char_cnt[0], text_len[0],
|
1117
|
+
removed_chars, bpe_config);
|
1118
|
+
|
1119
|
+
main_awake_threads();
|
1120
|
+
// threads are working 2
|
1121
|
+
|
1122
|
+
main_wait_threads();
|
1123
|
+
// main is working 2
|
1124
|
+
|
1125
|
+
for (uint64_t i = 1; i < n_threads; i++) {
|
1126
|
+
for (const auto &x : hash2wordcnt[i]) {
|
1127
|
+
auto it = hash2wordcnt[0].find(x.first);
|
1128
|
+
if (it == hash2wordcnt[0].end()) {
|
1129
|
+
hash2wordcnt[0][x.first] = x.second;
|
1130
|
+
} else {
|
1131
|
+
it->second.cnt += x.second.cnt;
|
1132
|
+
}
|
1133
|
+
}
|
1134
|
+
hash2wordcnt[i].clear();
|
1135
|
+
}
|
1136
|
+
|
1137
|
+
word_cnt_global.resize(hash2wordcnt[0].size());
|
1138
|
+
std::transform(
|
1139
|
+
hash2wordcnt[0].begin(), hash2wordcnt[0].end(), word_cnt_global.begin(),
|
1140
|
+
[](const std::pair<VectorSegment, WordCount> &x) { return x.second; });
|
1141
|
+
|
1142
|
+
hash2wordcnt.shrink_to_fit();
|
1143
|
+
text_utf8.shrink_to_fit();
|
1144
|
+
|
1145
|
+
merge_order = PriorityQueue(text_len[0]);
|
1146
|
+
|
1147
|
+
uint64_t used_ids =
|
1148
|
+
char2id.size() + bpe_config.special_tokens.n_special_tokens();
|
1149
|
+
if (used_ids > (uint64_t) n_tokens) {
|
1150
|
+
string error_message = "Incorrect arguments. Vocabulary size too small. Set vocab_size>=";
|
1151
|
+
error_message += std::to_string(used_ids) + ". Current value for vocab_size=" + std::to_string(n_tokens);
|
1152
|
+
error_flag = 1;
|
1153
|
+
main_awake_threads();
|
1154
|
+
for (auto &t : threads) {
|
1155
|
+
t.join();
|
1156
|
+
}
|
1157
|
+
return Status(1, error_message);
|
1158
|
+
}
|
1159
|
+
|
1160
|
+
init_recipe(char2id, recipe, recipe_s);
|
1161
|
+
|
1162
|
+
split_word_cnt.push_back(0);
|
1163
|
+
for (uint64_t i = 1; i <= n_threads; i++) {
|
1164
|
+
split_word_cnt.push_back(word_cnt_global.size() * i / n_threads);
|
1165
|
+
}
|
1166
|
+
|
1167
|
+
main_awake_threads();
|
1168
|
+
// threads are working 3
|
1169
|
+
|
1170
|
+
main_wait_threads();
|
1171
|
+
// main is working 3
|
1172
|
+
flat_hash_map<uint64_t, uint64_t> real_pair_cnt;
|
1173
|
+
|
1174
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1175
|
+
for (const auto &x : pair2cnt_g[i]) {
|
1176
|
+
real_pair_cnt[x.first] += x.second;
|
1177
|
+
}
|
1178
|
+
}
|
1179
|
+
|
1180
|
+
for (const auto &x : real_pair_cnt) {
|
1181
|
+
uint32_t ka, kb;
|
1182
|
+
comb2int(x.first, ka, kb);
|
1183
|
+
merge_order.push({x.second, ka, kb});
|
1184
|
+
}
|
1185
|
+
vector<BPE_Rule> rules;
|
1186
|
+
|
1187
|
+
auto get_recipe = [&](uint32_t x, uint32_t y) {
|
1188
|
+
assert(recipe.count(x));
|
1189
|
+
assert(recipe.count(y));
|
1190
|
+
vector<uint32_t> new_recipe = recipe[x];
|
1191
|
+
new_recipe.insert(new_recipe.end(), recipe[y].begin(), recipe[y].end());
|
1192
|
+
return new_recipe;
|
1193
|
+
};
|
1194
|
+
|
1195
|
+
std::function<uint64_t(uint64_t)> check_cnt = [&](uint64_t mask) {
|
1196
|
+
uint64_t ret = 0;
|
1197
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1198
|
+
auto it = pair2cnt_g[i].find(mask);
|
1199
|
+
if (it != pair2cnt_g[i].end()) {
|
1200
|
+
ret += it->second;
|
1201
|
+
}
|
1202
|
+
}
|
1203
|
+
return ret;
|
1204
|
+
};
|
1205
|
+
|
1206
|
+
uint64_t finished_cur = used_ids;
|
1207
|
+
uint64_t last_failed_try = 0;
|
1208
|
+
|
1209
|
+
flat_hash_map<uint32_t, uint64_t> all_res;
|
1210
|
+
vector<char> local_check_list(n_threads);
|
1211
|
+
flat_hash_map<uint32_t, uint64_t> global_ht_update_left;
|
1212
|
+
flat_hash_map<uint32_t, uint64_t> global_ht_update_right;
|
1213
|
+
|
1214
|
+
int inter_fail = 0;
|
1215
|
+
int equal_fail = 0;
|
1216
|
+
vector<std::pair<int, double>> progress_debug;
|
1217
|
+
while (used_ids < (uint64_t) n_tokens) {
|
1218
|
+
uint32_t x, y, z;
|
1219
|
+
assert(finished_cur <= used_ids && used_ids <= finished_cur + 2);
|
1220
|
+
bool progress = false;
|
1221
|
+
|
1222
|
+
if (used_ids < (uint64_t) n_tokens && used_ids - finished_cur < 2 &&
|
1223
|
+
last_failed_try < finished_cur) {
|
1224
|
+
progress = true;
|
1225
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1226
|
+
thread_use_hs[i] = false;
|
1227
|
+
}
|
1228
|
+
{
|
1229
|
+
vector<std::lock_guard<std::mutex>> lg(mt.begin(), mt.end());
|
1230
|
+
|
1231
|
+
uint64_t real_cnt = 0;
|
1232
|
+
while (true) {
|
1233
|
+
if (merge_order.empty()) {
|
1234
|
+
if (finished_cur == used_ids) {
|
1235
|
+
std::cerr << "WARNING merged only: " << used_ids
|
1236
|
+
<< " pairs of tokens" << std::endl;
|
1237
|
+
x = UINT32_MAX;
|
1238
|
+
y = UINT32_MAX;
|
1239
|
+
z = UINT32_MAX;
|
1240
|
+
real_n_tokens = used_ids;
|
1241
|
+
break;
|
1242
|
+
} else {
|
1243
|
+
x = y = z = 0;
|
1244
|
+
last_failed_try = finished_cur;
|
1245
|
+
break;
|
1246
|
+
}
|
1247
|
+
}
|
1248
|
+
BPE_Rule last_rule = (used_ids - finished_cur == 1)
|
1249
|
+
? rules.back()
|
1250
|
+
: BPE_Rule({0, 0, 0});
|
1251
|
+
|
1252
|
+
auto merge_event = merge_order.top(check_cnt, last_rule);
|
1253
|
+
if ((used_ids - finished_cur == 1) &&
|
1254
|
+
(merge_event.left_token == rules.back().y ||
|
1255
|
+
merge_event.right_token == rules.back().x ||
|
1256
|
+
(!rules.empty() && rules.back().x == rules.back().y))) {
|
1257
|
+
inter_fail += merge_event.left_token == rules.back().y ||
|
1258
|
+
merge_event.right_token == rules.back().x;
|
1259
|
+
equal_fail += !rules.empty() && rules.back().x == rules.back().y &&
|
1260
|
+
used_ids - finished_cur == 1;
|
1261
|
+
|
1262
|
+
last_failed_try = finished_cur;
|
1263
|
+
x = y = z = 0;
|
1264
|
+
break;
|
1265
|
+
}
|
1266
|
+
|
1267
|
+
merge_order.pop();
|
1268
|
+
real_cnt = check_cnt(
|
1269
|
+
int2comb(merge_event.left_token, merge_event.right_token));
|
1270
|
+
assert(real_cnt <= merge_event.count);
|
1271
|
+
|
1272
|
+
if (real_cnt != merge_event.count) {
|
1273
|
+
if (real_cnt > 0) {
|
1274
|
+
merge_event.count = real_cnt;
|
1275
|
+
merge_order.push(merge_event);
|
1276
|
+
}
|
1277
|
+
continue;
|
1278
|
+
}
|
1279
|
+
|
1280
|
+
if (real_cnt == 0) {
|
1281
|
+
continue;
|
1282
|
+
}
|
1283
|
+
|
1284
|
+
x = merge_event.left_token;
|
1285
|
+
y = merge_event.right_token;
|
1286
|
+
z = used_ids;
|
1287
|
+
break;
|
1288
|
+
}
|
1289
|
+
if (last_failed_try != finished_cur && x != UINT32_MAX) {
|
1290
|
+
task_order[used_ids % 2] = {x, y, z};
|
1291
|
+
recipe[z] = get_recipe(x, y);
|
1292
|
+
recipe_s[z] = recipe_s[x] + recipe_s[y];
|
1293
|
+
|
1294
|
+
if (used_ids % 1000 == 0) {
|
1295
|
+
int used_symbols = 0;
|
1296
|
+
std::cerr << "id: " << z << "=" << x << "+" << y;
|
1297
|
+
used_symbols += std::to_string(z).size();
|
1298
|
+
used_symbols += 1;
|
1299
|
+
used_symbols += std::to_string(x).size();
|
1300
|
+
used_symbols += 1;
|
1301
|
+
used_symbols += std::to_string(y).size();
|
1302
|
+
for (int j = used_symbols; j < 22 + 4; j++) {
|
1303
|
+
std::cerr << " ";
|
1304
|
+
}
|
1305
|
+
used_symbols = 0;
|
1306
|
+
std::cerr << "freq: " << real_cnt;
|
1307
|
+
used_symbols += 5;
|
1308
|
+
used_symbols += std::to_string(real_cnt).size();
|
1309
|
+
|
1310
|
+
for (int j = used_symbols; j < 15; j++) {
|
1311
|
+
std::cerr << " ";
|
1312
|
+
}
|
1313
|
+
std::cerr << " subword: " << recipe_s[z] << "="
|
1314
|
+
<< recipe_s[x] + "+" + recipe_s[y] << std::endl;
|
1315
|
+
}
|
1316
|
+
used_ids++;
|
1317
|
+
rules.emplace_back(x, y, z);
|
1318
|
+
}
|
1319
|
+
|
1320
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1321
|
+
thread_use_hs[i] = true;
|
1322
|
+
}
|
1323
|
+
}
|
1324
|
+
for (auto &cond_value : cv) {
|
1325
|
+
cond_value.notify_one();
|
1326
|
+
}
|
1327
|
+
if (x == UINT32_MAX) {
|
1328
|
+
break;
|
1329
|
+
}
|
1330
|
+
}
|
1331
|
+
|
1332
|
+
// collect results
|
1333
|
+
|
1334
|
+
bool full_epoch = true;
|
1335
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1336
|
+
if (!local_check_list[i]) {
|
1337
|
+
if (results_ready[i] >= finished_cur) {
|
1338
|
+
progress = true;
|
1339
|
+
local_check_list[i] = 1;
|
1340
|
+
|
1341
|
+
for (auto token_cnt : left_tokens_submit[finished_cur % 2][i]) {
|
1342
|
+
global_ht_update_left[token_cnt.first] += token_cnt.second;
|
1343
|
+
}
|
1344
|
+
|
1345
|
+
for (auto token_cnt : right_tokens_submit[finished_cur % 2][i]) {
|
1346
|
+
global_ht_update_right[token_cnt.first] += token_cnt.second;
|
1347
|
+
}
|
1348
|
+
} else {
|
1349
|
+
full_epoch = false;
|
1350
|
+
}
|
1351
|
+
}
|
1352
|
+
}
|
1353
|
+
|
1354
|
+
if (full_epoch) {
|
1355
|
+
for (auto left_token : global_ht_update_left) {
|
1356
|
+
merge_order.push({left_token.second, left_token.first,
|
1357
|
+
task_order[finished_cur % 2].z});
|
1358
|
+
}
|
1359
|
+
for (auto right_token : global_ht_update_right) {
|
1360
|
+
merge_order.push({right_token.second, task_order[finished_cur % 2].z,
|
1361
|
+
right_token.first});
|
1362
|
+
}
|
1363
|
+
local_check_list.assign(n_threads, 0);
|
1364
|
+
global_ht_update_left.clear();
|
1365
|
+
global_ht_update_right.clear();
|
1366
|
+
finished_cur++;
|
1367
|
+
}
|
1368
|
+
if (!progress) {
|
1369
|
+
std::unique_lock<std::mutex> ul(main_loop_mt);
|
1370
|
+
main_loop_cv.wait(ul, [&] {
|
1371
|
+
for (uint64_t i = 0; i < n_threads; i++) {
|
1372
|
+
if (!local_check_list[i] && results_ready[i] >= finished_cur)
|
1373
|
+
return true;
|
1374
|
+
}
|
1375
|
+
return false;
|
1376
|
+
});
|
1377
|
+
}
|
1378
|
+
}
|
1379
|
+
for (auto &t : threads) {
|
1380
|
+
t.join();
|
1381
|
+
}
|
1382
|
+
|
1383
|
+
rename_tokens(char2id, rules, bpe_config.special_tokens, n_tokens);
|
1384
|
+
|
1385
|
+
*bpe_state = {char2id, rules, bpe_config.special_tokens};
|
1386
|
+
bpe_state->dump(output_file);
|
1387
|
+
std::cerr << "model saved to: " << output_file << std::endl;
|
1388
|
+
return Status();
|
1389
|
+
}
|
1390
|
+
|
1391
|
+
void build_linked_list(
|
1392
|
+
const vector<WordCount> &word_cnt, vector<vector<NodeEncoder>> &list,
|
1393
|
+
flat_hash_map<uint64_t, PositionsCnt> &pair2poscnt) {
|
1394
|
+
list.resize(word_cnt.size());
|
1395
|
+
for (uint64_t i = 0; i < word_cnt.size(); i++) {
|
1396
|
+
for (uint32_t ch : word_cnt[i].word) {
|
1397
|
+
if (!list[i].empty() && list[i].back().val == ch) {
|
1398
|
+
list[i].back().seg_len++;
|
1399
|
+
} else {
|
1400
|
+
int list_size = list[i].size();
|
1401
|
+
list[i].emplace_back(ch, list_size - 1, list_size + 1, 1);
|
1402
|
+
}
|
1403
|
+
}
|
1404
|
+
|
1405
|
+
list[i].back().next = -1;
|
1406
|
+
for (uint64_t j = 0; j < list[i].size(); j++) {
|
1407
|
+
if (j + 1 < list[i].size()) {
|
1408
|
+
uint64_t comb = int2comb(list[i][j].val, list[i][j + 1].val);
|
1409
|
+
auto it = pair2poscnt.find(comb);
|
1410
|
+
if (it == pair2poscnt.end()) {
|
1411
|
+
pair2poscnt[comb] = {{{i, j}}, word_cnt[i].cnt};
|
1412
|
+
} else {
|
1413
|
+
it->second.positions.emplace_back(i, j);
|
1414
|
+
it->second.cnt += word_cnt[i].cnt;
|
1415
|
+
}
|
1416
|
+
}
|
1417
|
+
assert(list[i][j].seg_len >= 1);
|
1418
|
+
|
1419
|
+
if (list[i][j].seg_len > 1) {
|
1420
|
+
uint64_t comb = int2comb(list[i][j].val, list[i][j].val);
|
1421
|
+
auto it = pair2poscnt.find(comb);
|
1422
|
+
uint64_t cc = word_cnt[i].cnt * pairsInSeg(list[i][j].seg_len);
|
1423
|
+
if (it == pair2poscnt.end()) {
|
1424
|
+
pair2poscnt[comb] = {{{i, j}}, cc};
|
1425
|
+
} else {
|
1426
|
+
it->second.positions.emplace_back(i, j);
|
1427
|
+
it->second.cnt += cc;
|
1428
|
+
}
|
1429
|
+
}
|
1430
|
+
}
|
1431
|
+
}
|
1432
|
+
}
|
1433
|
+
|
1434
|
+
flat_hash_map<uint32_t, uint32_t> compute_alphabet(
|
1435
|
+
const vector<uint32_t> &data, flat_hash_set<uint32_t> &removed_chars,
|
1436
|
+
const BpeConfig &bpe_config) {
|
1437
|
+
flat_hash_map<uint32_t, uint64_t> char_cnt;
|
1438
|
+
for (auto ch : data) {
|
1439
|
+
if (!is_space(ch)) {
|
1440
|
+
char_cnt[ch]++;
|
1441
|
+
}
|
1442
|
+
}
|
1443
|
+
assert(!char_cnt.empty());
|
1444
|
+
return compute_alphabet_helper(char_cnt, data.size(), removed_chars,
|
1445
|
+
bpe_config);
|
1446
|
+
}
|
1447
|
+
|
1448
|
+
Status check_config(BpeConfig &bpe_config, int vocab_size) {
|
1449
|
+
if (bpe_config.character_coverage <= 0 || bpe_config.character_coverage > 1) {
|
1450
|
+
return Status(1, "coverage value must be in the range (0, 1]. Current value of coverage = " +
|
1451
|
+
std::to_string(bpe_config.character_coverage));
|
1452
|
+
}
|
1453
|
+
if (bpe_config.special_tokens.unk_id < 0 ||
|
1454
|
+
bpe_config.special_tokens.unk_id >= vocab_size) {
|
1455
|
+
return Status(1,
|
1456
|
+
"unk_id: must be in the range [0, vocab_size - 1]. Current value of vocab_size = "
|
1457
|
+
+ std::to_string(vocab_size) + "; unk_id = " + std::to_string(bpe_config.special_tokens.unk_id));
|
1458
|
+
}
|
1459
|
+
|
1460
|
+
if (bpe_config.special_tokens.pad_id < -1 ||
|
1461
|
+
bpe_config.special_tokens.pad_id >= vocab_size) {
|
1462
|
+
return Status(1, "pad_id must be in the range [-1, vocab_size - 1]. Current value of vocab_size = " +
|
1463
|
+
std::to_string(vocab_size) + "; pad_id = " + std::to_string(bpe_config.special_tokens.pad_id));
|
1464
|
+
}
|
1465
|
+
|
1466
|
+
if (bpe_config.special_tokens.bos_id < -1 ||
|
1467
|
+
bpe_config.special_tokens.bos_id >= vocab_size) {
|
1468
|
+
return Status(1, "bos_id must be in the range [-1, vocab_size - 1]. Current value of vocab_size = " +
|
1469
|
+
std::to_string(vocab_size) + "; bos_id = " + std::to_string(bpe_config.special_tokens.bos_id));
|
1470
|
+
}
|
1471
|
+
|
1472
|
+
if (bpe_config.special_tokens.eos_id < -1 ||
|
1473
|
+
bpe_config.special_tokens.eos_id >= vocab_size) {
|
1474
|
+
return Status(1, "eos_id must be in the range [-1, vocab_size - 1]. Current value of vocab_size = " +
|
1475
|
+
std::to_string(vocab_size) + " eos_id = " + std::to_string(bpe_config.special_tokens.eos_id));
|
1476
|
+
}
|
1477
|
+
|
1478
|
+
flat_hash_set<int> ids;
|
1479
|
+
uint64_t cnt_add = 0;
|
1480
|
+
if (bpe_config.special_tokens.pad_id != -1) {
|
1481
|
+
ids.insert(bpe_config.special_tokens.pad_id);
|
1482
|
+
cnt_add++;
|
1483
|
+
}
|
1484
|
+
if (bpe_config.special_tokens.bos_id != -1) {
|
1485
|
+
ids.insert(bpe_config.special_tokens.bos_id);
|
1486
|
+
cnt_add++;
|
1487
|
+
}
|
1488
|
+
if (bpe_config.special_tokens.eos_id != -1) {
|
1489
|
+
ids.insert(bpe_config.special_tokens.eos_id);
|
1490
|
+
cnt_add++;
|
1491
|
+
}
|
1492
|
+
ids.insert(bpe_config.special_tokens.unk_id);
|
1493
|
+
cnt_add++;
|
1494
|
+
if (ids.size() != cnt_add) {
|
1495
|
+
return Status(1, "All ids of special tokens must be different.");
|
1496
|
+
}
|
1497
|
+
|
1498
|
+
if (bpe_config.n_threads == -1) {
|
1499
|
+
bpe_config.n_threads = std::thread::hardware_concurrency();
|
1500
|
+
}
|
1501
|
+
bpe_config.n_threads = std::min(8, std::max(1, bpe_config.n_threads));
|
1502
|
+
return Status();
|
1503
|
+
}
|
1504
|
+
|
1505
|
+
void print_config(const string &input_path, const string &model_path,
|
1506
|
+
int vocab_size, BpeConfig bpe_config) {
|
1507
|
+
std::cerr << "Training parameters" << std::endl;
|
1508
|
+
std::cerr << " input: " << input_path << std::endl;
|
1509
|
+
std::cerr << " model: " << model_path << std::endl;
|
1510
|
+
std::cerr << " vocab_size: " << vocab_size << std::endl;
|
1511
|
+
std::cerr << " n_threads: " << bpe_config.n_threads << std::endl;
|
1512
|
+
std::cerr << " character_coverage: " << bpe_config.character_coverage
|
1513
|
+
<< std::endl;
|
1514
|
+
std::cerr << " pad: " << bpe_config.special_tokens.pad_id << std::endl;
|
1515
|
+
std::cerr << " unk: " << bpe_config.special_tokens.unk_id << std::endl;
|
1516
|
+
std::cerr << " bos: " << bpe_config.special_tokens.bos_id << std::endl;
|
1517
|
+
std::cerr << " eos: " << bpe_config.special_tokens.eos_id << std::endl;
|
1518
|
+
std::cerr << std::endl;
|
1519
|
+
}
|
1520
|
+
|
1521
|
+
Status train_bpe(const string &input_path, const string &model_path,
|
1522
|
+
int vocab_size, BpeConfig bpe_config) {
|
1523
|
+
Status status = check_config(bpe_config, vocab_size);
|
1524
|
+
if (!status.ok()) {
|
1525
|
+
return status;
|
1526
|
+
}
|
1527
|
+
print_config(input_path, model_path, vocab_size, bpe_config);
|
1528
|
+
std::cerr << "reading file..." << std::endl;
|
1529
|
+
string data;
|
1530
|
+
status = fast_read_file_utf8(input_path, &data);
|
1531
|
+
if (!status.ok()) {
|
1532
|
+
return status;
|
1533
|
+
}
|
1534
|
+
std::cerr << "learning bpe..." << std::endl;
|
1535
|
+
BPEState bpe_state;
|
1536
|
+
status = learn_bpe_from_string(data, vocab_size, model_path, bpe_config, &bpe_state);
|
1537
|
+
if (!status.ok()) {
|
1538
|
+
return status;
|
1539
|
+
}
|
1540
|
+
return Status();
|
1541
|
+
}
|
1542
|
+
|
1543
|
+
|
1544
|
+
template<typename T>
|
1545
|
+
class BasePriorityQueue {
|
1546
|
+
public:
|
1547
|
+
virtual void push(T x) = 0;
|
1548
|
+
virtual bool pop(T& x) = 0;
|
1549
|
+
virtual ~BasePriorityQueue() {}
|
1550
|
+
};
|
1551
|
+
|
1552
|
+
template<typename T>
|
1553
|
+
class STLQueue : public BasePriorityQueue<T> {
|
1554
|
+
std::priority_queue<T> q;
|
1555
|
+
void push(T x) override {
|
1556
|
+
q.push(x);
|
1557
|
+
}
|
1558
|
+
bool pop(T& x) override {
|
1559
|
+
if (q.empty()) {
|
1560
|
+
return false;
|
1561
|
+
}
|
1562
|
+
x = q.top();
|
1563
|
+
q.pop();
|
1564
|
+
return true;
|
1565
|
+
}
|
1566
|
+
};
|
1567
|
+
|
1568
|
+
std::mt19937 rnd;
|
1569
|
+
|
1570
|
+
template<typename T>
|
1571
|
+
class DropoutQueue : public BasePriorityQueue<T> {
|
1572
|
+
double skip_prob;
|
1573
|
+
std::uniform_real_distribution<> dist;
|
1574
|
+
std::priority_queue<T> q;
|
1575
|
+
vector<T> skipped_elements;
|
1576
|
+
public:
|
1577
|
+
explicit DropoutQueue(double _skip_prob):skip_prob(_skip_prob), dist(std::uniform_real_distribution<>(0, 1)) {}
|
1578
|
+
void push(T x) override {
|
1579
|
+
q.push(x);
|
1580
|
+
}
|
1581
|
+
bool pop(T& x) override {
|
1582
|
+
assert(skipped_elements.empty());
|
1583
|
+
while (true) {
|
1584
|
+
if (q.empty()) {
|
1585
|
+
for (auto y: skipped_elements) {
|
1586
|
+
q.push(y);
|
1587
|
+
}
|
1588
|
+
skipped_elements.clear();
|
1589
|
+
return false;
|
1590
|
+
}
|
1591
|
+
T temp = q.top();
|
1592
|
+
q.pop();
|
1593
|
+
if (dist(rnd) < skip_prob) {
|
1594
|
+
skipped_elements.push_back(temp);
|
1595
|
+
}
|
1596
|
+
else {
|
1597
|
+
for (auto y: skipped_elements) {
|
1598
|
+
q.push(y);
|
1599
|
+
}
|
1600
|
+
skipped_elements.clear();
|
1601
|
+
x = temp;
|
1602
|
+
return true;
|
1603
|
+
}
|
1604
|
+
}
|
1605
|
+
}
|
1606
|
+
};
|
1607
|
+
|
1608
|
+
DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,
|
1609
|
+
const EncodingConfig &encoding_config,
|
1610
|
+
OutputType output_type) const {
|
1611
|
+
struct NodeDecoder {
|
1612
|
+
uint32_t token_id;
|
1613
|
+
int prev, next;
|
1614
|
+
|
1615
|
+
NodeDecoder(uint32_t _val, uint64_t cur_pos)
|
1616
|
+
: token_id(_val),
|
1617
|
+
prev(static_cast<int>(cur_pos) - 1),
|
1618
|
+
next(static_cast<int>(cur_pos) + 1) {}
|
1619
|
+
|
1620
|
+
NodeDecoder(uint32_t _val, int _prev, int _next)
|
1621
|
+
: token_id(_val), prev(_prev), next(_next) {}
|
1622
|
+
};
|
1623
|
+
|
1624
|
+
struct MergeEvent2 {
|
1625
|
+
int priority;
|
1626
|
+
int pos;
|
1627
|
+
|
1628
|
+
bool operator<(const MergeEvent2 &other) const {
|
1629
|
+
return priority > other.priority ||
|
1630
|
+
(priority == other.priority && pos > other.pos);
|
1631
|
+
}
|
1632
|
+
};
|
1633
|
+
|
1634
|
+
vector<int> output_ids;
|
1635
|
+
vector<string> output_pieces;
|
1636
|
+
|
1637
|
+
if (encoding_config.bos) {
|
1638
|
+
if (output_type == ID) {
|
1639
|
+
output_ids.push_back(bpe_state.special_tokens.bos_id);
|
1640
|
+
} else {
|
1641
|
+
output_pieces.push_back(BOS_TOKEN);
|
1642
|
+
}
|
1643
|
+
}
|
1644
|
+
|
1645
|
+
vector<NodeDecoder> list;
|
1646
|
+
flat_hash_map<uint32_t, string> unrecognized_tokens;
|
1647
|
+
|
1648
|
+
auto text = decode_utf8(sentence_utf8.data(),
|
1649
|
+
sentence_utf8.data() + sentence_utf8.size());
|
1650
|
+
|
1651
|
+
assert(bpe_state.char2id.count(SPACE_TOKEN));
|
1652
|
+
|
1653
|
+
for (; !text.empty() && is_space(text.back()); text.pop_back()) {
|
1654
|
+
}
|
1655
|
+
|
1656
|
+
const int new_tokens_start = static_cast<int>(
|
1657
|
+
1e9); // just some number that bigger than any subword id
|
1658
|
+
for (auto it_text = text.begin(); it_text != text.end();) {
|
1659
|
+
list.clear();
|
1660
|
+
unrecognized_tokens.clear();
|
1661
|
+
|
1662
|
+
auto begin_of_word = std::find_if_not(it_text, text.end(), is_space);
|
1663
|
+
auto end_of_word = std::find_if(begin_of_word, text.end(), is_space);
|
1664
|
+
it_text = end_of_word;
|
1665
|
+
|
1666
|
+
uint32_t new_token_cur = new_tokens_start;
|
1667
|
+
list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0);
|
1668
|
+
|
1669
|
+
for (auto it_char_in_word = begin_of_word; it_char_in_word < end_of_word;) {
|
1670
|
+
if (bpe_state.char2id.count(*it_char_in_word) == 0) {
|
1671
|
+
auto it_unrecognized_word = std::find_if(
|
1672
|
+
it_char_in_word, end_of_word,
|
1673
|
+
[&](uint32_t ch) { return bpe_state.char2id.count(ch); });
|
1674
|
+
|
1675
|
+
unrecognized_tokens[new_token_cur] =
|
1676
|
+
encode_utf8({it_char_in_word, it_unrecognized_word});
|
1677
|
+
it_char_in_word = it_unrecognized_word;
|
1678
|
+
|
1679
|
+
list.emplace_back(new_token_cur, list.size());
|
1680
|
+
new_token_cur++;
|
1681
|
+
} else {
|
1682
|
+
list.emplace_back(bpe_state.char2id.at(*it_char_in_word), list.size());
|
1683
|
+
++it_char_in_word;
|
1684
|
+
}
|
1685
|
+
}
|
1686
|
+
list.back().next = -1;
|
1687
|
+
|
1688
|
+
|
1689
|
+
auto pair_code = [&](uint64_t first_pos) {
|
1690
|
+
auto second_pos = list[first_pos].next;
|
1691
|
+
return int2comb(list[first_pos].token_id, list[second_pos].token_id);
|
1692
|
+
};
|
1693
|
+
|
1694
|
+
std::unique_ptr<BasePriorityQueue<MergeEvent2>> queue(nullptr);
|
1695
|
+
if (encoding_config.dropout_prob == 0) {
|
1696
|
+
queue.reset(new STLQueue<MergeEvent2>());
|
1697
|
+
}
|
1698
|
+
else {
|
1699
|
+
queue.reset(new DropoutQueue<MergeEvent2>(encoding_config.dropout_prob));
|
1700
|
+
}
|
1701
|
+
|
1702
|
+
auto push_in_queue_if_rule_exist = [&](uint64_t pos) {
|
1703
|
+
auto it = rule2id.find(pair_code(pos));
|
1704
|
+
if (it != rule2id.end()) {
|
1705
|
+
queue->push({it->second, static_cast<int>(pos)});
|
1706
|
+
}
|
1707
|
+
};
|
1708
|
+
|
1709
|
+
for (uint64_t j = 0; j + 1 < list.size(); j++) {
|
1710
|
+
push_in_queue_if_rule_exist(j);
|
1711
|
+
}
|
1712
|
+
|
1713
|
+
while (true) {
|
1714
|
+
MergeEvent2 event;
|
1715
|
+
if (!queue->pop(event)) {
|
1716
|
+
break;
|
1717
|
+
}
|
1718
|
+
int rule_id = event.priority;
|
1719
|
+
int pos_1 = event.pos;
|
1720
|
+
int pos_2 = list[pos_1].next;
|
1721
|
+
assert(pos_1 != pos_2);
|
1722
|
+
if (list[pos_1].token_id != bpe_state.rules[rule_id].x || pos_2 == -1 ||
|
1723
|
+
list[pos_2].token_id != bpe_state.rules[rule_id].y) {
|
1724
|
+
continue;
|
1725
|
+
}
|
1726
|
+
|
1727
|
+
int pos_0 = list[pos_1].prev;
|
1728
|
+
int pos_3 = list[pos_2].next;
|
1729
|
+
|
1730
|
+
list[pos_2] = {0, -1, -1};
|
1731
|
+
list[pos_1] = {bpe_state.rules[rule_id].z, pos_0, pos_3};
|
1732
|
+
if (pos_3 != -1) {
|
1733
|
+
list[pos_3].prev = pos_1;
|
1734
|
+
}
|
1735
|
+
|
1736
|
+
if (pos_0 != -1) {
|
1737
|
+
push_in_queue_if_rule_exist(pos_0);
|
1738
|
+
}
|
1739
|
+
if (pos_3 != -1) {
|
1740
|
+
push_in_queue_if_rule_exist(pos_1);
|
1741
|
+
}
|
1742
|
+
}
|
1743
|
+
|
1744
|
+
auto it_alive_token = std::find_if(
|
1745
|
+
list.begin(), list.end(),
|
1746
|
+
[](const NodeDecoder &node) { return node.token_id != 0; });
|
1747
|
+
|
1748
|
+
assert(it_alive_token != list.end());
|
1749
|
+
int alive_token = std::distance(list.begin(), it_alive_token);
|
1750
|
+
for (; alive_token != -1; alive_token = list[alive_token].next) {
|
1751
|
+
int token_id = list[alive_token].token_id;
|
1752
|
+
if (token_id >= new_tokens_start) {
|
1753
|
+
if (output_type == ID) {
|
1754
|
+
output_ids.push_back(bpe_state.special_tokens.unk_id);
|
1755
|
+
} else {
|
1756
|
+
assert(unrecognized_tokens.count(token_id));
|
1757
|
+
output_pieces.push_back(unrecognized_tokens[token_id]);
|
1758
|
+
}
|
1759
|
+
} else {
|
1760
|
+
if (output_type == ID) {
|
1761
|
+
output_ids.push_back(token_id);
|
1762
|
+
} else {
|
1763
|
+
assert(recipe.count(token_id));
|
1764
|
+
output_pieces.push_back(token2word(recipe.at(token_id), id2char));
|
1765
|
+
}
|
1766
|
+
}
|
1767
|
+
}
|
1768
|
+
}
|
1769
|
+
if (encoding_config.eos) {
|
1770
|
+
if (output_type == ID) {
|
1771
|
+
output_ids.push_back(bpe_state.special_tokens.eos_id);
|
1772
|
+
} else {
|
1773
|
+
output_pieces.push_back(EOS_TOKEN);
|
1774
|
+
}
|
1775
|
+
}
|
1776
|
+
|
1777
|
+
if (encoding_config.reverse) {
|
1778
|
+
if (output_type == ID) {
|
1779
|
+
std::reverse(output_ids.begin(), output_ids.end());
|
1780
|
+
} else {
|
1781
|
+
std::reverse(output_pieces.begin(), output_pieces.end());
|
1782
|
+
}
|
1783
|
+
}
|
1784
|
+
return {output_ids, output_pieces};
|
1785
|
+
}
|
1786
|
+
|
1787
|
+
BaseEncoder::BaseEncoder(BPEState _bpe_state, int _n_threads)
|
1788
|
+
: bpe_state(std::move(_bpe_state)), n_threads(_n_threads) {
|
1789
|
+
fill_from_state();
|
1790
|
+
assert(n_threads >= 1 || n_threads == -1);
|
1791
|
+
if (n_threads == -1) {
|
1792
|
+
n_threads = std::max(1, int(std::thread::hardware_concurrency()));
|
1793
|
+
}
|
1794
|
+
}
|
1795
|
+
|
1796
|
+
BaseEncoder::BaseEncoder(const string &model_path, int _n_threads, Status *ret_status)
|
1797
|
+
: n_threads(_n_threads) {
|
1798
|
+
Status status = bpe_state.load(model_path);
|
1799
|
+
if (!status.ok()) {
|
1800
|
+
*ret_status = status;
|
1801
|
+
return;
|
1802
|
+
}
|
1803
|
+
fill_from_state();
|
1804
|
+
assert(n_threads >= 1 || n_threads == -1);
|
1805
|
+
if (n_threads == -1) {
|
1806
|
+
n_threads = std::max(1, int(std::thread::hardware_concurrency()));
|
1807
|
+
}
|
1808
|
+
*ret_status = Status();
|
1809
|
+
}
|
1810
|
+
|
1811
|
+
template<typename T>
|
1812
|
+
vector<T> concat_vectors(const vector<T> &a, const vector<T> &b) {
|
1813
|
+
vector<T> c;
|
1814
|
+
c.reserve(a.size() + b.size());
|
1815
|
+
c.insert(c.end(), a.begin(), a.end());
|
1816
|
+
c.insert(c.end(), b.begin(), b.end());
|
1817
|
+
return c;
|
1818
|
+
}
|
1819
|
+
|
1820
|
+
void BaseEncoder::fill_from_state() {
|
1821
|
+
for (auto x : bpe_state.char2id) {
|
1822
|
+
id2char[x.second] = x.first;
|
1823
|
+
}
|
1824
|
+
|
1825
|
+
for (int i = 0; i < (int) bpe_state.rules.size(); i++) {
|
1826
|
+
rule2id[int2comb(bpe_state.rules[i].x, bpe_state.rules[i].y)] = i;
|
1827
|
+
}
|
1828
|
+
|
1829
|
+
for (auto x : id2char) {
|
1830
|
+
recipe[x.first] = {x.first};
|
1831
|
+
}
|
1832
|
+
|
1833
|
+
for (auto rule : bpe_state.rules) {
|
1834
|
+
recipe[rule.z] = concat_vectors(recipe[rule.x], recipe[rule.y]);
|
1835
|
+
}
|
1836
|
+
|
1837
|
+
for (const auto &id_to_recipe : recipe) {
|
1838
|
+
reversed_recipe[token2word(id_to_recipe.second, id2char)] =
|
1839
|
+
id_to_recipe.first;
|
1840
|
+
}
|
1841
|
+
reversed_recipe[BOS_TOKEN] = bpe_state.special_tokens.bos_id;
|
1842
|
+
reversed_recipe[EOS_TOKEN] = bpe_state.special_tokens.eos_id;
|
1843
|
+
}
|
1844
|
+
|
1845
|
+
int BaseEncoder::vocab_size() const {
|
1846
|
+
return bpe_state.rules.size() + bpe_state.char2id.size() +
|
1847
|
+
bpe_state.special_tokens.n_special_tokens();
|
1848
|
+
}
|
1849
|
+
|
1850
|
+
Status BaseEncoder::encode_parallel(
|
1851
|
+
const std::vector<std::string> &sentences,
|
1852
|
+
const EncodingConfig &encoding_config, OutputType output_type,
|
1853
|
+
std::vector<DecodeResult> *decoder_results
|
1854
|
+
) const {
|
1855
|
+
if (encoding_config.bos && bpe_state.special_tokens.bos_id == -1) {
|
1856
|
+
return Status(1, "Can't add <BOS> token. Model was trained without it.");
|
1857
|
+
}
|
1858
|
+
if (encoding_config.eos && bpe_state.special_tokens.eos_id == -1) {
|
1859
|
+
return Status(1, "Can't add <EOS> token. Model was trained without it.");
|
1860
|
+
}
|
1861
|
+
|
1862
|
+
decoder_results->assign(sentences.size(), DecodeResult());
|
1863
|
+
if (sentences.size() <= static_cast<uint64_t>(n_threads) * 3 ||
|
1864
|
+
n_threads == 1) { // Not too many sentences. It's better to solve it
|
1865
|
+
// without threads.
|
1866
|
+
for (uint64_t i = 0; i < sentences.size(); i++) {
|
1867
|
+
decoder_results->at(i) = encode_sentence(sentences[i], encoding_config, output_type);
|
1868
|
+
}
|
1869
|
+
return Status();
|
1870
|
+
}
|
1871
|
+
vector<std::thread> threads;
|
1872
|
+
for (int i = 0; i < n_threads; i++) {
|
1873
|
+
threads.emplace_back(
|
1874
|
+
[&](uint64_t this_thread) {
|
1875
|
+
uint64_t tasks_for_thread =
|
1876
|
+
(sentences.size() + n_threads - 1) / n_threads;
|
1877
|
+
uint64_t first_task = tasks_for_thread * this_thread;
|
1878
|
+
uint64_t last_task =
|
1879
|
+
std::min(tasks_for_thread * (this_thread + 1), static_cast<uint64_t>(sentences.size()));
|
1880
|
+
for (uint64_t j = first_task; j < last_task; j++) {
|
1881
|
+
decoder_results->at(j) =
|
1882
|
+
encode_sentence(sentences[j], encoding_config, output_type);
|
1883
|
+
}
|
1884
|
+
},
|
1885
|
+
i);
|
1886
|
+
}
|
1887
|
+
for (auto &thread : threads) {
|
1888
|
+
thread.join();
|
1889
|
+
}
|
1890
|
+
return Status();
|
1891
|
+
}
|
1892
|
+
|
1893
|
+
Status BaseEncoder::encode_as_ids(const vector<string> &sentences, vector<vector<int>> *ids,
|
1894
|
+
bool bos, bool eos,
|
1895
|
+
bool reverse, double dropout_prob) const {
|
1896
|
+
EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob};
|
1897
|
+
|
1898
|
+
std::vector<DecodeResult> decode_results;
|
1899
|
+
Status status = encode_parallel(sentences, encoding_config, ID, &decode_results);
|
1900
|
+
if (!status.ok()) {
|
1901
|
+
return status;
|
1902
|
+
}
|
1903
|
+
ids->assign(decode_results.size(), vector<int>());
|
1904
|
+
for (uint64_t i = 0; i < decode_results.size(); i++) {
|
1905
|
+
ids->at(i) = move(decode_results[i].ids);
|
1906
|
+
}
|
1907
|
+
return Status();
|
1908
|
+
}
|
1909
|
+
|
1910
|
+
Status BaseEncoder::encode_as_subwords(
|
1911
|
+
const vector<string> &sentences,
|
1912
|
+
vector<vector<string>> *subwords,
|
1913
|
+
bool bos, bool eos, bool reverse, double dropout_prob) const {
|
1914
|
+
time_check("");
|
1915
|
+
EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob};
|
1916
|
+
std::vector<DecodeResult> decode_results;
|
1917
|
+
Status status = encode_parallel(sentences, encoding_config, SUBWORD, &decode_results);
|
1918
|
+
if (!status.ok()) {
|
1919
|
+
return status;
|
1920
|
+
}
|
1921
|
+
subwords->assign(decode_results.size(), vector<string>());
|
1922
|
+
for (uint64_t i = 0; i < decode_results.size(); i++) {
|
1923
|
+
subwords->at(i) = move(decode_results[i].pieces);
|
1924
|
+
}
|
1925
|
+
return Status();
|
1926
|
+
}
|
1927
|
+
|
1928
|
+
Status BaseEncoder::id_to_subword(int id, string *subword, bool replace_space) const {
|
1929
|
+
if (id < 0 || vocab_size() <= id) {
|
1930
|
+
return Status(1, "id must be in the range [0, vocab_size - 1]. Current value: vocab_size = " +
|
1931
|
+
std::to_string(vocab_size()) +
|
1932
|
+
"; id=" + std::to_string(id) + ";");
|
1933
|
+
}
|
1934
|
+
if (bpe_state.special_tokens.unk_id == id) {
|
1935
|
+
*subword = UNK_TOKEN;
|
1936
|
+
return Status();
|
1937
|
+
}
|
1938
|
+
if (bpe_state.special_tokens.pad_id == id) {
|
1939
|
+
*subword = PAD_TOKEN;
|
1940
|
+
return Status();
|
1941
|
+
}
|
1942
|
+
if (bpe_state.special_tokens.bos_id == id) {
|
1943
|
+
*subword = BOS_TOKEN;
|
1944
|
+
return Status();
|
1945
|
+
}
|
1946
|
+
if (bpe_state.special_tokens.eos_id == id) {
|
1947
|
+
*subword = EOS_TOKEN;
|
1948
|
+
return Status();
|
1949
|
+
}
|
1950
|
+
|
1951
|
+
assert(recipe.count(id));
|
1952
|
+
if (replace_space) {
|
1953
|
+
auto symbols = recipe.at(id);
|
1954
|
+
if (id2char.at(symbols[0]) == SPACE_TOKEN) {
|
1955
|
+
*subword = " " + token2word({symbols.begin() + 1, symbols.end()}, id2char);
|
1956
|
+
return Status();
|
1957
|
+
}
|
1958
|
+
}
|
1959
|
+
*subword = token2word(recipe.at(id), id2char);
|
1960
|
+
return Status();
|
1961
|
+
}
|
1962
|
+
|
1963
|
+
int BaseEncoder::subword_to_id(const string &token) const {
|
1964
|
+
if (UNK_TOKEN == token) {
|
1965
|
+
return bpe_state.special_tokens.unk_id;
|
1966
|
+
}
|
1967
|
+
if (PAD_TOKEN == token) {
|
1968
|
+
return bpe_state.special_tokens.pad_id;
|
1969
|
+
}
|
1970
|
+
if (BOS_TOKEN == token) {
|
1971
|
+
return bpe_state.special_tokens.bos_id;
|
1972
|
+
}
|
1973
|
+
if (EOS_TOKEN == token) {
|
1974
|
+
return bpe_state.special_tokens.eos_id;
|
1975
|
+
}
|
1976
|
+
if (reversed_recipe.count(token)) {
|
1977
|
+
return reversed_recipe.at(token);
|
1978
|
+
}
|
1979
|
+
return bpe_state.special_tokens.unk_id;
|
1980
|
+
}
|
1981
|
+
|
1982
|
+
Status BaseEncoder::decode(const vector<vector<int>> &ids,
|
1983
|
+
vector<string> *sentences,
|
1984
|
+
const unordered_set<int> *ignore_ids) const {
|
1985
|
+
vector<string> ret;
|
1986
|
+
for (const auto &sentence : ids) {
|
1987
|
+
string decode_output;
|
1988
|
+
Status status = decode(sentence, &decode_output, ignore_ids);
|
1989
|
+
if (!status.ok()) {
|
1990
|
+
return status;
|
1991
|
+
}
|
1992
|
+
sentences->push_back(move(decode_output));
|
1993
|
+
}
|
1994
|
+
return Status();
|
1995
|
+
}
|
1996
|
+
|
1997
|
+
Status BaseEncoder::decode(const vector<int> &ids, string *sentence, const unordered_set<int> *ignore_ids) const {
|
1998
|
+
bool first_iter = true;
|
1999
|
+
for (auto id : ids) {
|
2000
|
+
string subword;
|
2001
|
+
|
2002
|
+
if (!ignore_ids || ignore_ids->count(id) == 0) {
|
2003
|
+
Status status = id_to_subword(id, &subword, true);
|
2004
|
+
if (!status.ok()) {
|
2005
|
+
return status;
|
2006
|
+
}
|
2007
|
+
*sentence += subword;
|
2008
|
+
if (first_iter && sentence->at(0) == ' ') {
|
2009
|
+
*sentence = sentence->substr(1);
|
2010
|
+
}
|
2011
|
+
first_iter = false;
|
2012
|
+
}
|
2013
|
+
}
|
2014
|
+
return Status();
|
2015
|
+
}
|
2016
|
+
|
2017
|
+
Status BaseEncoder::decode(const vector<string> &data,
|
2018
|
+
vector<string> *sentences,
|
2019
|
+
const std::unordered_set<int> *ignore_ids) const {
|
2020
|
+
for (const auto &s : data) {
|
2021
|
+
std::stringstream stream;
|
2022
|
+
stream << s;
|
2023
|
+
vector<int> ids;
|
2024
|
+
int x;
|
2025
|
+
while (stream >> x) {
|
2026
|
+
ids.push_back(x);
|
2027
|
+
}
|
2028
|
+
string sentence;
|
2029
|
+
Status status = decode(ids, &sentence, ignore_ids);
|
2030
|
+
if (!status.ok()) {
|
2031
|
+
return status;
|
2032
|
+
}
|
2033
|
+
sentences->push_back(sentence);
|
2034
|
+
}
|
2035
|
+
return Status();
|
2036
|
+
}
|
2037
|
+
|
2038
|
+
vector<string> BaseEncoder::vocabulary() const {
|
2039
|
+
int n = vocab_size();
|
2040
|
+
vector<string> vocab(n);
|
2041
|
+
for (int i = 0; i < n; i++) {
|
2042
|
+
string subword;
|
2043
|
+
Status status = id_to_subword(i, &subword);
|
2044
|
+
assert(status.ok());
|
2045
|
+
vocab[i] = subword;
|
2046
|
+
}
|
2047
|
+
return vocab;
|
2048
|
+
}
|
2049
|
+
|
2050
|
+
void BaseEncoder::vocab_cli(bool verbose) const {
|
2051
|
+
uint32_t n_tokens = 0;
|
2052
|
+
for (const auto &entry : recipe) {
|
2053
|
+
n_tokens = std::max(entry.first, n_tokens);
|
2054
|
+
}
|
2055
|
+
n_tokens = std::max(n_tokens, bpe_state.special_tokens.max_id());
|
2056
|
+
n_tokens++;
|
2057
|
+
|
2058
|
+
flat_hash_map<uint32_t, std::pair<uint32_t, uint32_t>> reversed_rules;
|
2059
|
+
if (verbose) {
|
2060
|
+
for (auto rule : bpe_state.rules) {
|
2061
|
+
reversed_rules[rule.z] = {rule.x, rule.y};
|
2062
|
+
}
|
2063
|
+
}
|
2064
|
+
|
2065
|
+
for (uint64_t i = 0; i < n_tokens; i++) {
|
2066
|
+
string token_z;
|
2067
|
+
Status status = id_to_subword(i, &token_z);
|
2068
|
+
assert(status.ok());
|
2069
|
+
std::cout << i << "\t" << token_z;
|
2070
|
+
if (verbose) {
|
2071
|
+
if (reversed_rules.count(i)) {
|
2072
|
+
int used_symbols = 0;
|
2073
|
+
auto comb = reversed_rules[i];
|
2074
|
+
string token_x;
|
2075
|
+
string token_y;
|
2076
|
+
status = id_to_subword(comb.first, &token_x);
|
2077
|
+
assert(status.ok());
|
2078
|
+
status = id_to_subword(comb.second, &token_y);
|
2079
|
+
assert(status.ok());
|
2080
|
+
|
2081
|
+
used_symbols += decode_utf8(token_z).size() + 1;
|
2082
|
+
used_symbols +=
|
2083
|
+
decode_utf8(token_x).size() + 1 + decode_utf8(token_y).size();
|
2084
|
+
|
2085
|
+
std::cout << "=" << token_x << "+" << token_y;
|
2086
|
+
for (int t = 0; t < std::max(2, 50 - used_symbols); t++) {
|
2087
|
+
std::cout << " ";
|
2088
|
+
}
|
2089
|
+
std::cout << comb.first << "+" << comb.second;
|
2090
|
+
}
|
2091
|
+
}
|
2092
|
+
std::cout << std::endl;
|
2093
|
+
}
|
2094
|
+
}
|
2095
|
+
|
2096
|
+
Status BaseEncoder::encode_cli(const string &output_type_str, bool stream,
|
2097
|
+
bool bos, bool eos, bool reverse, double dropout_prob) const {
|
2098
|
+
std::ios_base::sync_with_stdio(false);
|
2099
|
+
OutputType output_type;
|
2100
|
+
if (output_type_str == "id") {
|
2101
|
+
output_type = ID;
|
2102
|
+
} else {
|
2103
|
+
assert(output_type_str == "subword");
|
2104
|
+
output_type = SUBWORD;
|
2105
|
+
}
|
2106
|
+
if (stream) {
|
2107
|
+
if (output_type == SUBWORD) {
|
2108
|
+
string sentence;
|
2109
|
+
while (getline(std::cin, sentence)) {
|
2110
|
+
vector<vector<string>> subwords;
|
2111
|
+
Status status = encode_as_subwords({sentence}, &subwords, bos, eos, reverse, dropout_prob);
|
2112
|
+
if (!status.ok()) {
|
2113
|
+
return status;
|
2114
|
+
}
|
2115
|
+
write_to_stdout(subwords, true);
|
2116
|
+
}
|
2117
|
+
} else {
|
2118
|
+
assert(output_type == ID);
|
2119
|
+
string sentence;
|
2120
|
+
while (getline(std::cin, sentence)) {
|
2121
|
+
vector<vector<int>> ids;
|
2122
|
+
Status status = encode_as_ids({sentence}, &ids, bos, eos, reverse, dropout_prob);
|
2123
|
+
if (!status.ok()) {
|
2124
|
+
return status;
|
2125
|
+
}
|
2126
|
+
write_to_stdout(ids, true);
|
2127
|
+
}
|
2128
|
+
}
|
2129
|
+
} else {
|
2130
|
+
time_check("");
|
2131
|
+
const uint64_t batch_limit = 10 * 1024 * 1024;
|
2132
|
+
uint64_t total_progress = 0;
|
2133
|
+
uint64_t processed;
|
2134
|
+
std::cerr << "n_threads: " << n_threads << std::endl;
|
2135
|
+
int chars_remove = 0;
|
2136
|
+
do {
|
2137
|
+
processed = 0;
|
2138
|
+
auto sentences = read_lines_from_stdin(batch_limit, &processed);
|
2139
|
+
if (output_type == SUBWORD) {
|
2140
|
+
vector<vector<string>> subwords;
|
2141
|
+
Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse, dropout_prob);
|
2142
|
+
if (!status.ok()) {
|
2143
|
+
return status;
|
2144
|
+
}
|
2145
|
+
write_to_stdout(subwords, false);
|
2146
|
+
} else {
|
2147
|
+
assert(output_type == ID);
|
2148
|
+
vector<vector<int>> ids;
|
2149
|
+
Status status = encode_as_ids(sentences, &ids, bos, eos, reverse, dropout_prob);
|
2150
|
+
if (!status.ok()) {
|
2151
|
+
return status;
|
2152
|
+
}
|
2153
|
+
write_to_stdout(ids, false);
|
2154
|
+
}
|
2155
|
+
total_progress += processed;
|
2156
|
+
|
2157
|
+
for (int i = 0; i < chars_remove; i++) {
|
2158
|
+
std::cerr << '\b';
|
2159
|
+
}
|
2160
|
+
chars_remove = 0;
|
2161
|
+
string message = "bytes processed: ";
|
2162
|
+
chars_remove += message.size();
|
2163
|
+
chars_remove += std::to_string(total_progress).length();
|
2164
|
+
std::cerr << message << total_progress;
|
2165
|
+
} while (processed >= batch_limit);
|
2166
|
+
std::cerr << std::endl;
|
2167
|
+
}
|
2168
|
+
return Status();
|
2169
|
+
}
|
2170
|
+
|
2171
|
+
Status BaseEncoder::decode_cli(const std::unordered_set<int> *ignore_ids) const {
|
2172
|
+
std::ios_base::sync_with_stdio(false);
|
2173
|
+
string sentence;
|
2174
|
+
while (getline(std::cin, sentence)) {
|
2175
|
+
vector<string> output;
|
2176
|
+
Status status = decode({sentence}, &output, ignore_ids);
|
2177
|
+
if (!status.ok()) {
|
2178
|
+
return status;
|
2179
|
+
}
|
2180
|
+
std::cout << output[0] << "\n";
|
2181
|
+
}
|
2182
|
+
return Status();
|
2183
|
+
}
|
2184
|
+
|
2185
|
+
} // namespace vkcom
|