youtokentome 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,2185 @@
1
+ #include <utility>
2
+
3
+ #include "bpe.h"
4
+
5
+ #include <algorithm>
6
+ #include <atomic>
7
+ #include <cassert>
8
+ #include <condition_variable>
9
+ #include <fstream>
10
+ #include <functional>
11
+ #include <iostream>
12
+ #include <mutex>
13
+ #include <queue>
14
+ #include <sstream>
15
+ #include <string>
16
+ #include <thread>
17
+ #include <vector>
18
+ #include <random>
19
+ #include <unordered_set>
20
+ #include <cstring>
21
+
22
+ #include "third_party/flat_hash_map.h"
23
+ #include "utf8.h"
24
+ #include "utils.h"
25
+
26
+ namespace vkcom {
27
+ using std::string;
28
+ using std::vector;
29
+ using std::unordered_set;
30
+
31
+ struct VectorSegment {
32
+ constexpr static uint64_t MOD = 2032191299;
33
+ constexpr static uint64_t P = 726328703;
34
+
35
+ const char* begin;
36
+ const char* end;
37
+ uint64_t hash;
38
+
39
+ VectorSegment(const char* begin, const char* end): begin(begin), end(end) {
40
+ hash = 0;
41
+ for (auto it = begin; it != end; it++) {
42
+ hash = (hash * P + (unsigned char)(*it)) % MOD;
43
+ }
44
+ }
45
+
46
+ bool operator==(const VectorSegment &other) const {
47
+ if (other.hash != hash || end - begin != other.end - other.begin) {
48
+ return false;
49
+ }
50
+ for (auto it = begin, other_it = other.begin; it != end; it++, other_it++) {
51
+ if (*it != *other_it) {
52
+ return false;
53
+ }
54
+ }
55
+ return true;
56
+ }
57
+ };
58
+
59
+ } // namespace vkcom
60
+
61
+ namespace std {
62
+ template<>
63
+ struct hash<vkcom::VectorSegment> {
64
+ uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash; }
65
+ };
66
+ } // namespace std
67
+
68
+ namespace vkcom {
69
+
70
+ Status fast_read_file_utf8(const string &file_name, string *file_content) {
71
+ static const int buf_size = 1000000;
72
+ *file_content = "";
73
+ auto fin = fopen(file_name.data(), "rb");
74
+ if (fin == nullptr) {
75
+ return Status(1, "Failed to open file: " + file_name);
76
+ }
77
+ while (true) {
78
+ uint64_t cur_size = file_content->size();
79
+ file_content->resize(cur_size + buf_size);
80
+ int buf_len = fread((void *) (file_content->data() + cur_size), 1, buf_size, fin);
81
+ if (buf_len < buf_size) {
82
+ file_content->resize(file_content->size() - (buf_size - buf_len));
83
+ fclose(fin);
84
+ return Status();
85
+ }
86
+ }
87
+ }
88
+
89
+ string token2word(const vector<uint32_t> &source,
90
+ const flat_hash_map<uint32_t, uint32_t> &id2char) {
91
+ vector<uint32_t> res;
92
+ for (int i : source) {
93
+ assert(id2char.count(i) == 1);
94
+ res.push_back(id2char.at(i));
95
+ }
96
+ return encode_utf8(res);
97
+ }
98
+
99
+ bool is_space(uint32_t ch) {
100
+ return (ch < 256 && isspace(ch)) || (ch == SPACE_TOKEN);
101
+ }
102
+
103
+ uint64_t int2comb(uint32_t a, uint32_t b) {
104
+ return (static_cast<uint64_t >(a) << 32u) + b;
105
+ }
106
+
107
+ struct MergeCandidate {
108
+ uint64_t count{0};
109
+ uint32_t left_token{0};
110
+ uint32_t right_token{0};
111
+
112
+ MergeCandidate() = default;
113
+
114
+ MergeCandidate(uint64_t count, uint32_t left_token, uint32_t right_token) : count(count), left_token(left_token),
115
+ right_token(right_token) {}
116
+
117
+ bool operator<(const MergeCandidate &other) const {
118
+ if (count != other.count) {
119
+ return count < other.count;
120
+ }
121
+ auto this_mn = std::min(left_token, right_token);
122
+ auto this_mx = std::max(left_token, right_token);
123
+
124
+ auto other_mn = std::min(other.left_token, other.right_token);
125
+ auto other_mx = std::max(other.left_token, other.right_token);
126
+ if (this_mx != other_mx) {
127
+ return this_mx > other_mx;
128
+ }
129
+ if (this_mn != other_mn) {
130
+ return this_mn > other_mn;
131
+ }
132
+ return left_token < other.left_token;
133
+ }
134
+ };
135
+
136
+ struct UTF8Iterator {
137
+ UTF8Iterator(char* begin, char* end): begin(begin), end(end) {}
138
+
139
+ UTF8Iterator operator++() {
140
+ if (!state) {
141
+ parse();
142
+ }
143
+ begin += utf8_len;
144
+ state = false;
145
+ return *this;
146
+ }
147
+
148
+ uint32_t operator*() {
149
+ if (!state) {
150
+ parse();
151
+ }
152
+ return code_point;
153
+ }
154
+
155
+ char* get_ptr() {
156
+ return begin;
157
+ }
158
+ uint64_t get_utf8_len() {
159
+ return utf8_len;
160
+ }
161
+
162
+ bool empty() {
163
+ assert(begin <= end);
164
+ return begin == end;
165
+ }
166
+ private:
167
+ char *begin, *end;
168
+ uint32_t code_point = 0;
169
+ uint64_t utf8_len = 0;
170
+ bool state = false;
171
+ void parse() {
172
+ if (state) {
173
+ return;
174
+ }
175
+ assert(!empty());
176
+ code_point = chars_to_utf8(begin, end - begin, &utf8_len);
177
+ state = true;
178
+ }
179
+ };
180
+
181
+
182
+ struct Position {
183
+ uint64_t word_id, pos_id;
184
+
185
+ Position(uint64_t word_id, uint64_t pos_id) : word_id(word_id), pos_id(pos_id) {}
186
+
187
+ bool operator<(const Position &other) const {
188
+ return word_id < other.word_id ||
189
+ (word_id == other.word_id && pos_id < other.pos_id);
190
+ }
191
+ };
192
+
193
+ int pairsInSeg(int x) {
194
+ assert(x >= 2);
195
+ return x / 2;
196
+ }
197
+
198
+ struct PositionsCnt {
199
+ vector<Position> positions;
200
+ uint64_t cnt;
201
+ };
202
+ bool rule_intersection(BPE_Rule rule, uint32_t new_left, uint32_t new_right) {
203
+ return rule.y == new_left || rule.x == new_right;
204
+ }
205
+
206
+ struct SmallObjectQueue {
207
+
208
+ vector<vector<MergeCandidate>> queue;
209
+ bool flag_started{false};
210
+ uint64_t _size{0};
211
+
212
+ SmallObjectQueue() = default;
213
+
214
+ void push(const MergeCandidate &event) {
215
+ if (queue.size() <= event.count) {
216
+ queue.resize(event.count + 1);
217
+ }
218
+ if (flag_started) {
219
+ assert(event.count + 1 <= queue.size());
220
+ };
221
+ queue[event.count].push_back(event);
222
+ _size++;
223
+ #ifdef DETERMINISTIC_QUEUE
224
+ if (queue.size() - 1 == event.count && flag_started) {
225
+ sort(queue.back().begin(), queue.back().end());
226
+ }
227
+ #endif
228
+ }
229
+
230
+ void process_empty_slots() {
231
+ #ifdef DETERMINISTIC_QUEUE
232
+ bool moved_down = !flag_started;
233
+ #endif
234
+ flag_started = true;
235
+
236
+ for (; !queue.empty() && queue.back().empty(); queue.pop_back()) {
237
+ #ifdef DETERMINISTIC_QUEUE
238
+ moved_down = true;
239
+ #endif
240
+ }
241
+ #ifdef DETERMINISTIC_QUEUE
242
+ if (moved_down && !queue.empty()) {
243
+ sort(queue.back().begin(), queue.back().end());
244
+ }
245
+ #endif
246
+ }
247
+
248
+ bool empty() {
249
+ process_empty_slots();
250
+ return queue.empty();
251
+ }
252
+
253
+ MergeCandidate top() {
254
+ process_empty_slots();
255
+ assert(!queue.empty());
256
+ return queue.back().back();
257
+ }
258
+
259
+ void pop() {
260
+ assert(!queue.empty());
261
+ assert(!queue.back().empty());
262
+ queue.back().pop_back();
263
+ _size--;
264
+ }
265
+
266
+ uint64_t size() const {
267
+ return _size;
268
+ }
269
+ };
270
+
271
+ struct BigObjectQueue {
272
+ vector<MergeCandidate> big_events;
273
+ uint64_t big_event_bound;
274
+
275
+ BigObjectQueue(uint64_t big_event_bound) : big_event_bound(big_event_bound) {}
276
+
277
+ void push(const MergeCandidate &event) {
278
+ big_events.push_back(event);
279
+ }
280
+
281
+ bool empty() const {
282
+ return big_events.empty();
283
+ }
284
+
285
+ bool top(std::function<uint64_t(uint64_t)> &check_cnt, MergeCandidate &ret, SmallObjectQueue *small_object_queue,
286
+ BPE_Rule last_rule) {
287
+ for (uint64_t i = 0; i < big_events.size();) {
288
+ if (!rule_intersection(last_rule, big_events[i].left_token, big_events[i].right_token)) {
289
+ uint64_t comb = int2comb(big_events[i].left_token, big_events[i].right_token);
290
+ assert(big_events[i].count >= check_cnt(comb));
291
+ big_events[i].count = check_cnt(comb);
292
+ }
293
+
294
+ if (big_events[i].count < big_event_bound) {
295
+ small_object_queue->push(big_events[i]);
296
+ big_events[i] = big_events.back();
297
+ big_events.pop_back();
298
+ } else {
299
+ i++;
300
+ }
301
+ }
302
+ #ifdef DETERMINISTIC_QUEUE
303
+ sort(big_events.begin(), big_events.end()); /// TODO remove unoptimal code
304
+ #else
305
+ for (auto &big_event : big_events) {
306
+ if (big_event.count > big_events.back().count) {
307
+ std::swap(big_event, big_events.back());
308
+ }
309
+ }
310
+ #endif
311
+
312
+ if (big_events.empty()) {
313
+ return false;
314
+ }
315
+ ret = big_events.back();
316
+ return true;
317
+ }
318
+
319
+ void pop() {
320
+ assert(!big_events.empty());
321
+ big_events.pop_back();
322
+ }
323
+
324
+ uint64_t size() const {
325
+ return big_events.size();
326
+ }
327
+ };
328
+
329
+ struct PriorityQueue {
330
+ SmallObjectQueue small_queue;
331
+ BigObjectQueue big_queue;
332
+ uint64_t big_event_bound;
333
+
334
+ explicit PriorityQueue(uint64_t dataset_size) : big_queue(static_cast<uint64_t>(sqrt(dataset_size))),
335
+ big_event_bound(static_cast<uint64_t>(sqrt(dataset_size))) {}
336
+
337
+ void push(const MergeCandidate &event) {
338
+ if (event.count == 0) {
339
+ return;
340
+ }
341
+ if (event.count < big_event_bound) {
342
+ small_queue.push(event);
343
+ } else {
344
+ big_queue.push(event);
345
+ }
346
+ }
347
+
348
+ bool empty() {
349
+ return big_queue.empty() && small_queue.empty();
350
+ }
351
+
352
+ MergeCandidate top(std::function<uint64_t(uint64_t)> &check_cnt, BPE_Rule last_rule) {
353
+ MergeCandidate res;
354
+ bool has_top = big_queue.top(check_cnt, res, &small_queue, last_rule);
355
+ if (has_top) {
356
+ return res;
357
+ }
358
+ return small_queue.top();
359
+ }
360
+
361
+ void pop() {
362
+ if (!big_queue.empty()) {
363
+ big_queue.pop();
364
+ } else {
365
+ small_queue.pop();
366
+ }
367
+ }
368
+
369
+ uint64_t size() const {
370
+ return big_queue.size() + small_queue.size();
371
+ }
372
+ };
373
+
374
+ flat_hash_map<uint32_t, uint32_t> compute_alphabet_helper(
375
+ const flat_hash_map<uint32_t, uint64_t> &char_cnt, uint64_t data_len,
376
+ flat_hash_set<uint32_t> &removed_chars, const BpeConfig &bpe_config) {
377
+ vector<std::pair<uint64_t, uint32_t>> frequencies;
378
+
379
+ for (auto x : char_cnt) {
380
+ frequencies.emplace_back(x.second, x.first);
381
+ }
382
+ sort(frequencies.begin(), frequencies.end());
383
+
384
+ uint64_t cur = 0;
385
+ uint64_t n_removed = 0;
386
+ for (; cur < frequencies.size() &&
387
+ (data_len - n_removed - frequencies[cur].first) >
388
+ data_len * bpe_config.character_coverage;
389
+ cur++) {
390
+ n_removed += frequencies[cur].first;
391
+ }
392
+ std::cerr << "number of unique characters in the training data: "
393
+ << frequencies.size() << std::endl;
394
+ std::cerr << "number of deleted characters: " << cur << std::endl;
395
+ std::cerr << "number of unique characters left: " << frequencies.size() - cur
396
+ << std::endl;
397
+
398
+ flat_hash_map<uint32_t, uint32_t> char2id;
399
+ uint64_t used_ids = bpe_config.special_tokens.n_special_tokens();
400
+ char2id[SPACE_TOKEN] = used_ids++;
401
+
402
+ for (uint64_t i = 0; i < cur; i++) {
403
+ removed_chars.insert(frequencies[i].second);
404
+ }
405
+
406
+ for (int i = frequencies.size() - 1; i >= static_cast<int>(cur); i--) {
407
+ if (!is_space(frequencies[i].second)) {
408
+ assert(char2id.count(frequencies[i].second) == 0);
409
+ char2id[frequencies[i].second] = used_ids++;
410
+ }
411
+ }
412
+ return char2id;
413
+ }
414
+
415
+ void remove_rare_chars(vector<uint32_t> &data,
416
+ const flat_hash_set<uint32_t> &removed_chars) {
417
+ if (removed_chars.empty()) {
418
+ return;
419
+ }
420
+ auto it_first_rare_char =
421
+ std::remove_if(data.begin(), data.end(),
422
+ [&](uint32_t c) { return removed_chars.count(c) != 0; });
423
+ data.erase(it_first_rare_char, data.end());
424
+ }
425
+
426
+ char* remove_rare_chars(char* begin, char* end, const flat_hash_set<uint32_t> &removed_chars) {
427
+ if (removed_chars.empty()) {
428
+ return end;
429
+ }
430
+ char* end_candidate = begin;
431
+ bool invalid_input = false;
432
+ UTF8Iterator utf8_iter(begin, end);
433
+ for (; !utf8_iter.empty(); ++utf8_iter) {
434
+ if (*utf8_iter != INVALID_UNICODE) {
435
+ if (removed_chars.count(*utf8_iter) == 0) {
436
+ char* token_begin = utf8_iter.get_ptr();
437
+ uint64_t token_len = utf8_iter.get_utf8_len();
438
+ memcpy(end_candidate, token_begin, token_len);
439
+ end_candidate += token_len;
440
+ }
441
+ } else {
442
+ invalid_input = true;
443
+ }
444
+ }
445
+ if (invalid_input) {
446
+ std::cerr << "WARNING Input contains invalid unicode characters." << std::endl;
447
+ }
448
+ return end_candidate;
449
+ }
450
+
451
+ struct WordCount {
452
+ vector<uint32_t> word;
453
+ uint64_t cnt;
454
+ };
455
+
456
+
457
+ flat_hash_map<VectorSegment, WordCount> compute_word_count(
458
+ char* sbegin, char* send,
459
+ const flat_hash_map<uint32_t, uint32_t> &char2id) {
460
+ flat_hash_map<VectorSegment, WordCount> hash2wordcnt;
461
+ vector<uint32_t> word;
462
+ UTF8Iterator utf8_iter(sbegin, send);
463
+
464
+ for (;!utf8_iter.empty();) {
465
+ for (; !utf8_iter.empty() && is_space(*utf8_iter); ++utf8_iter);
466
+ if (utf8_iter.empty()) {
467
+ break;
468
+ }
469
+ char* begin_of_word = utf8_iter.get_ptr();
470
+ for (; !utf8_iter.empty() && !is_space(*utf8_iter); ++utf8_iter);
471
+ char* end_of_word = utf8_iter.get_ptr();
472
+ VectorSegment word_hash(begin_of_word, end_of_word);
473
+ auto it = hash2wordcnt.find(word_hash);
474
+ if (it == hash2wordcnt.end()) {
475
+ word.clear();
476
+ word.push_back(char2id.at(SPACE_TOKEN));
477
+ UTF8Iterator word_iter(begin_of_word, end_of_word);
478
+ for (; !word_iter.empty(); ++word_iter) {
479
+ word.push_back(char2id.at(*word_iter));
480
+ }
481
+ hash2wordcnt[word_hash] = {word, 1};
482
+ } else {
483
+ it->second.cnt++;
484
+ }
485
+ }
486
+ return hash2wordcnt;
487
+ }
488
+
489
+
490
+ struct NodeEncoder {
491
+ uint32_t val;
492
+ int prev;
493
+ int next;
494
+ int seg_len;
495
+
496
+ NodeEncoder(uint32_t val, int prev, int next, int seg_len)
497
+ : val(val), prev(prev), next(next), seg_len(seg_len) {}
498
+
499
+ bool is_alive() const {
500
+ assert((val == 0) == (seg_len == 0));
501
+ return val != 0;
502
+ }
503
+ };
504
+
505
+ void build_linked_list(const vector<WordCount> &word_cnt,
506
+ vector<vector<NodeEncoder>> &list,
507
+ flat_hash_map<uint64_t, vector<Position>> &pair2pos,
508
+ flat_hash_map<uint64_t, uint64_t> &pair2cnt) {
509
+ list.resize(word_cnt.size());
510
+ for (uint64_t i = 0; i < word_cnt.size(); i++) {
511
+ for (uint32_t ch : word_cnt[i].word) {
512
+ if (!list[i].empty() && list[i].back().val == ch) {
513
+ list[i].back().seg_len++;
514
+ } else {
515
+ int list_size = list[i].size();
516
+ list[i].emplace_back(ch, list_size - 1, list_size + 1, 1);
517
+ }
518
+ }
519
+
520
+ list[i].back().next = -1;
521
+ for (uint64_t j = 0; j < list[i].size(); j++) {
522
+ if (j + 1 < list[i].size()) {
523
+ uint64_t comb = int2comb(list[i][j].val, list[i][j + 1].val);
524
+ auto it = pair2pos.find(comb);
525
+ if (it == pair2pos.end()) {
526
+ pair2pos[comb] = {{i, j}};
527
+ } else {
528
+ it->second.emplace_back(i, j);
529
+ }
530
+ pair2cnt[comb] += word_cnt[i].cnt;
531
+ }
532
+ assert(list[i][j].seg_len >= 1);
533
+
534
+ if (list[i][j].seg_len > 1) {
535
+ uint64_t comb = int2comb(list[i][j].val, list[i][j].val);
536
+ auto it = pair2pos.find(comb);
537
+ uint64_t cc = word_cnt[i].cnt * pairsInSeg(list[i][j].seg_len);
538
+ if (it == pair2pos.end()) {
539
+ pair2pos[comb] = {{i, j}};
540
+ } else {
541
+ it->second.emplace_back(i, j);
542
+ }
543
+ pair2cnt[comb] += cc;
544
+ }
545
+ }
546
+ }
547
+ }
548
+
549
+ std::chrono::steady_clock::time_point last_time_stamp;
550
+
551
+ void time_check(const string &message) {
552
+ auto cur_moment = std::chrono::steady_clock::now();
553
+ if (!message.empty()) {
554
+ std::cerr << "## time " << message << " ... "
555
+ << std::chrono::duration_cast<std::chrono::microseconds>(
556
+ cur_moment - last_time_stamp)
557
+ .count() *
558
+ 1.0 / 1e6
559
+ << std::endl;
560
+ }
561
+ last_time_stamp = cur_moment;
562
+ }
563
+
564
+ double time_check_silent() {
565
+ auto cur_moment = std::chrono::steady_clock::now();
566
+ double ret = std::chrono::duration_cast<std::chrono::microseconds>(
567
+ cur_moment - last_time_stamp)
568
+ .count() *
569
+ 1.0 / 1e6;
570
+ last_time_stamp = cur_moment;
571
+ return ret;
572
+ }
573
+
574
+ int alive_tokens;
575
+
576
+ void init_recipe(const flat_hash_map<uint32_t, uint32_t> &char2id,
577
+ flat_hash_map<uint32_t, vector<uint32_t>> &recipe,
578
+ flat_hash_map<uint32_t, string> &recipe_s) {
579
+ for (auto token_id : char2id) {
580
+ uint32_t ch = token_id.first;
581
+ uint32_t id = token_id.second;
582
+ recipe[id] = {id};
583
+ recipe_s[id] = encode_utf8({ch});
584
+ }
585
+ }
586
+
587
+ void worker_doing_merge(
588
+ uint64_t thread_id, vector<vector<NodeEncoder>> &lists_of_tokens,
589
+ vector<flat_hash_map<uint64_t, uint64_t>> &pair2cnt_g,
590
+ flat_hash_map<uint64_t, vector<Position>> &pair2pos,
591
+ vector<uint64_t> &word_freq, vector<std::mutex> &mt,
592
+ vector<std::condition_variable> &cv, vector<BPE_Rule> &task_order,
593
+ vector<std::atomic_bool> &thread_use_hs,
594
+ flat_hash_map<uint32_t, uint32_t> &char2id,
595
+ vector<vector<flat_hash_map<uint32_t, uint64_t>>> &left_tokens_submit,
596
+ vector<vector<flat_hash_map<uint32_t, uint64_t>>> &right_tokens_submit,
597
+ std::atomic<uint32_t> &real_n_tokens,
598
+ vector<std::atomic<uint32_t>> &results_ready, const BpeConfig &bpe_config,
599
+ std::mutex &main_loop_mt, std::condition_variable &main_loop_cv) {
600
+ auto &pair2cnt = pair2cnt_g[thread_id];
601
+ flat_hash_set<uint32_t> left_tokens;
602
+ flat_hash_set<uint32_t> right_tokens;
603
+
604
+ uint32_t cur_token_rule =
605
+ char2id.size() + bpe_config.special_tokens.n_special_tokens();
606
+ auto get_pair_code = [&](uint64_t word_id, uint64_t p1) {
607
+ int p2 = lists_of_tokens[word_id][p1].next;
608
+ return int2comb(lists_of_tokens[word_id][p1].val,
609
+ lists_of_tokens[word_id][p2].val);
610
+ };
611
+
612
+ auto get_self_code = [&](uint64_t word_id, uint64_t p1) {
613
+ return int2comb(lists_of_tokens[word_id][p1].val,
614
+ lists_of_tokens[word_id][p1].val);
615
+ };
616
+
617
+ auto remove_pair = [&](int word_id, int pos_id) {
618
+ pair2cnt[get_pair_code(word_id, pos_id)] -= word_freq[word_id];
619
+ };
620
+
621
+ auto add_pair = [&](uint64_t word_id, uint64_t pos_id) {
622
+ uint64_t comb = get_pair_code(word_id, pos_id);
623
+ auto it = pair2pos.find(comb);
624
+ if (it == pair2pos.end()) {
625
+ pair2pos[comb] = {{word_id, pos_id}};
626
+ } else {
627
+ it->second.emplace_back(word_id, pos_id);
628
+ }
629
+ pair2cnt[comb] += word_freq[word_id];
630
+ };
631
+
632
+ auto add_empty_pair = [&](uint64_t word_id, uint64_t pos_id) {
633
+ auto it = pair2pos.find(get_pair_code(word_id, pos_id));
634
+ assert(it != pair2pos.end());
635
+ it->second.emplace_back(word_id, pos_id);
636
+ };
637
+
638
+ auto add_self_pair = [&](uint64_t word_id, uint64_t pos_id) {
639
+ int seg_len = lists_of_tokens[word_id][pos_id].seg_len;
640
+ assert(seg_len >= 2);
641
+ uint64_t comb = get_self_code(word_id, pos_id);
642
+ auto it = pair2pos.find(comb);
643
+ uint64_t real_cnt = word_freq[word_id] * pairsInSeg(seg_len);
644
+ if (it == pair2pos.end()) {
645
+ pair2pos[comb] = {{word_id, pos_id}};
646
+ assert(pair2pos[comb].size() == 1);
647
+ } else {
648
+ it->second.emplace_back(word_id, pos_id);
649
+ }
650
+ pair2cnt[comb] += real_cnt;
651
+ };
652
+
653
+ auto add_merge_compensation = [&](uint64_t word_id, uint64_t pos_id,
654
+ int score_diff) {
655
+ assert(score_diff > 0);
656
+ uint64_t comb = get_self_code(word_id, pos_id);
657
+ pair2cnt[comb] -= score_diff * word_freq[word_id];
658
+ };
659
+
660
+ auto seg_len_decrement = [&](uint64_t word_id, uint64_t pos_id) {
661
+ int seg_len = lists_of_tokens[word_id][pos_id].seg_len;
662
+ assert(seg_len >= 2);
663
+ lists_of_tokens[word_id][pos_id].seg_len--;
664
+ if (seg_len % 2 == 1) {
665
+ return;
666
+ }
667
+ uint64_t comb = get_self_code(word_id, pos_id);
668
+ pair2cnt[comb] -= word_freq[word_id];
669
+ };
670
+
671
+ auto self_full_remove = [&](uint64_t word_id, uint64_t pos_id) {
672
+ uint64_t comb = get_self_code(word_id, pos_id);
673
+ uint32_t real_cnt = word_freq[word_id] *
674
+ pairsInSeg(lists_of_tokens[word_id][pos_id].seg_len);
675
+ pair2cnt[comb] -= real_cnt;
676
+ };
677
+
678
+ auto try_merge = [&](uint64_t word_id, uint64_t pos1, uint64_t pos2) {
679
+ vector<NodeEncoder> &cur_list = lists_of_tokens[word_id];
680
+ if (cur_list[pos1].val == cur_list[pos2].val) {
681
+ int score_before =
682
+ (cur_list[pos1].seg_len / 2) + (cur_list[pos2].seg_len / 2) + 1;
683
+ cur_list[pos1].seg_len += cur_list[pos2].seg_len;
684
+ int score_after = cur_list[pos1].seg_len / 2;
685
+ if (score_before != score_after) {
686
+ add_merge_compensation(word_id, pos1, score_before - score_after);
687
+ }
688
+
689
+ cur_list[pos1].next = cur_list[pos2].next;
690
+ cur_list[pos2] = {0, -1, -1, 0};
691
+ if (cur_list[pos1].next != -1) {
692
+ cur_list[cur_list[pos1].next].prev = pos1;
693
+ add_empty_pair(word_id, pos1);
694
+ }
695
+ }
696
+ };
697
+ while (true) {
698
+ {
699
+ std::unique_lock<std::mutex> ul(mt[thread_id]);
700
+ cv[thread_id].wait(ul, [&] {
701
+ return task_order[cur_token_rule % 2].z == cur_token_rule ||
702
+ cur_token_rule >= real_n_tokens;
703
+ });
704
+ assert(cur_token_rule <= real_n_tokens);
705
+ if (cur_token_rule == real_n_tokens) {
706
+ break;
707
+ }
708
+ }
709
+
710
+ uint32_t x = task_order[cur_token_rule % 2].x;
711
+ uint32_t y = task_order[cur_token_rule % 2].y;
712
+ uint32_t z = task_order[cur_token_rule % 2].z;
713
+
714
+ left_tokens.clear();
715
+ right_tokens.clear();
716
+
717
+ left_tokens.insert(z);
718
+ int real_merge = 0;
719
+ int not_real_merge = 0;
720
+
721
+ if (x == y) {
722
+ const vector<Position> &merge_candidates = pair2pos[int2comb(x, y)];
723
+
724
+ std::unique_lock<std::mutex> lk(mt[thread_id]);
725
+ for (auto word_pos : merge_candidates) {
726
+ cv[thread_id].wait(lk, [&] { return thread_use_hs[thread_id].load(); });
727
+ not_real_merge++;
728
+
729
+ // p0 <-> p1 <-> p3 -- ids of nodes in linked list.
730
+ // merge will happen inside p1
731
+
732
+ int word_id = word_pos.word_id;
733
+ vector<NodeEncoder> &cur_list = lists_of_tokens[word_id];
734
+ int p1 = word_pos.pos_id;
735
+ if (cur_list[p1].val != x || cur_list[p1].seg_len < 2) {
736
+ continue;
737
+ }
738
+ real_merge++;
739
+
740
+ int p0 = cur_list[p1].prev;
741
+ int p3 = cur_list[p1].next;
742
+
743
+ self_full_remove(word_id, p1);
744
+ if (p0 != -1) {
745
+ remove_pair(word_id, p0);
746
+ }
747
+ if (p3 != -1) {
748
+ remove_pair(word_id, p1);
749
+ }
750
+ int seg_len = cur_list[p1].seg_len;
751
+ if (seg_len % 2 == 0) {
752
+ cur_list[p1] = {z, p0, p3, seg_len / 2};
753
+
754
+ if (p0 != -1) {
755
+ add_pair(word_id, p0);
756
+ left_tokens.insert(cur_list[p0].val);
757
+ }
758
+
759
+ if (p3 != -1) {
760
+ add_pair(word_id, p1);
761
+ right_tokens.insert(cur_list[p3].val);
762
+ }
763
+ if (seg_len / 2 >= 2) {
764
+ add_self_pair(word_id, p1);
765
+ }
766
+ } else {
767
+ cur_list.emplace_back(x, p1, p3, 1);
768
+ int p2 = static_cast<int>(cur_list.size() - 1);
769
+ cur_list[p1] = {z, p0, p2, seg_len / 2};
770
+ if (p0 != -1) {
771
+ add_pair(word_id, p0);
772
+ left_tokens.insert(cur_list[p0].val);
773
+ }
774
+
775
+ add_pair(word_id, p1);
776
+ right_tokens.insert(cur_list[p2].val);
777
+
778
+ if (p3 != -1) {
779
+ cur_list[p3].prev = p2;
780
+ add_pair(word_id, p2);
781
+ }
782
+
783
+ if (seg_len / 2 >= 2) {
784
+ add_self_pair(word_id, p1);
785
+ }
786
+ }
787
+ }
788
+ } else {
789
+ std::unique_lock<std::mutex> lk(mt[thread_id]);
790
+ for (auto word_pos : pair2pos[int2comb(x, y)]) {
791
+ not_real_merge++;
792
+ cv[thread_id].wait(lk, [&] { return thread_use_hs[thread_id].load(); });
793
+ // p0 <-> p1 <-> p2 <-> p3 -- ids of nodes in linked list.
794
+ // merge will happen between p1 p2
795
+ int word_id = word_pos.word_id;
796
+
797
+ int p1 = word_pos.pos_id;
798
+ vector<NodeEncoder> &cur_list = lists_of_tokens[word_id];
799
+ int p2 = cur_list[p1].next;
800
+ if (cur_list[p1].val != x || p2 == -1 || cur_list[p2].val != y) {
801
+ continue;
802
+ }
803
+ real_merge++;
804
+
805
+ int p0 = cur_list[p1].prev;
806
+ int p3 = cur_list[p2].next;
807
+ remove_pair(word_id, p1);
808
+ if (p0 != -1 && cur_list[p1].seg_len == 1) {
809
+ remove_pair(word_id, p0);
810
+ }
811
+ if (p3 != -1 && cur_list[p2].seg_len == 1) {
812
+ remove_pair(word_id, p2);
813
+ }
814
+
815
+ if (cur_list[p1].seg_len > 1 && cur_list[p2].seg_len > 1) {
816
+ cur_list.emplace_back(z, p1, p2, 1);
817
+ int p12 = static_cast<int>(cur_list.size() - 1);
818
+
819
+ seg_len_decrement(word_id, p1);
820
+ seg_len_decrement(word_id, p2);
821
+
822
+ cur_list[p1].next = p12;
823
+ cur_list[p2].prev = p12;
824
+
825
+ add_pair(word_id, p1);
826
+ left_tokens.insert(cur_list[p1].val);
827
+
828
+ add_pair(word_id, p12);
829
+ right_tokens.insert(cur_list[p2].val);
830
+ } else if (cur_list[p1].seg_len > 1 && cur_list[p2].seg_len == 1) {
831
+ cur_list[p2] = {z, p1, p3, 1};
832
+
833
+ seg_len_decrement(word_id, p1);
834
+
835
+ add_pair(word_id, p1);
836
+ left_tokens.insert(cur_list[p1].val);
837
+
838
+ if (p3 != -1) {
839
+ add_pair(word_id, p2);
840
+ right_tokens.insert(cur_list[p3].val);
841
+ try_merge(word_id, p2, p3);
842
+ }
843
+ } else if (cur_list[p1].seg_len == 1 && cur_list[p2].seg_len > 1) {
844
+ cur_list[p1] = {z, p0, p2, 1};
845
+
846
+ seg_len_decrement(word_id, p2);
847
+
848
+ if (p0 != -1) {
849
+ add_pair(word_id, p0);
850
+ left_tokens.insert(cur_list[p0].val);
851
+ }
852
+
853
+ add_pair(word_id, p1);
854
+ right_tokens.insert(cur_list[p2].val);
855
+ if (p0 != -1) {
856
+ try_merge(word_id, p0, p1);
857
+ }
858
+ } else {
859
+ assert(cur_list[p1].seg_len == 1 && cur_list[p2].seg_len == 1);
860
+
861
+ cur_list[p1] = {z, p0, p3, 1};
862
+ cur_list[p2] = {0, -1, -1, 0};
863
+ if (p3 != -1) {
864
+ cur_list[p3].prev = p1;
865
+ }
866
+
867
+ if (p0 != -1) {
868
+ add_pair(word_id, p0);
869
+ left_tokens.insert(cur_list[p0].val);
870
+ }
871
+ if (p3 != -1) {
872
+ add_pair(word_id, p1);
873
+ right_tokens.insert(cur_list[p3].val);
874
+ }
875
+ if (p0 != -1) {
876
+ try_merge(word_id, p0, p1);
877
+ }
878
+ if (p3 != -1) {
879
+ try_merge(word_id, p1, p3);
880
+ }
881
+ }
882
+ }
883
+ }
884
+ pair2pos.erase(int2comb(x, y));
885
+ {
886
+ std::unique_lock<std::mutex> lk(mt[thread_id]);
887
+
888
+ left_tokens_submit[cur_token_rule % 2][thread_id].clear();
889
+ right_tokens_submit[cur_token_rule % 2][thread_id].clear();
890
+
891
+ for (auto token : left_tokens) {
892
+ left_tokens_submit[cur_token_rule % 2][thread_id][token] =
893
+ pair2cnt[int2comb(token, z)];
894
+ }
895
+
896
+ for (auto token : right_tokens) {
897
+ right_tokens_submit[cur_token_rule % 2][thread_id][token] =
898
+ pair2cnt[int2comb(z, token)];
899
+ }
900
+ }
901
+ {
902
+ std::lock_guard<std::mutex> lg(main_loop_mt);
903
+ results_ready[thread_id] = cur_token_rule;
904
+ }
905
+ main_loop_cv.notify_one();
906
+ cur_token_rule++;
907
+ }
908
+ }
909
+
910
+ void rename_tokens(flat_hash_map<uint32_t, uint32_t> &char2id,
911
+ vector<BPE_Rule> &rules, const SpecialTokens &special_tokens,
912
+ uint32_t n_tokens) {
913
+ flat_hash_map<uint32_t, uint32_t> renaming;
914
+ uint32_t cur = special_tokens.n_special_tokens();
915
+ for (uint32_t i = 0; i < n_tokens; i++) {
916
+ if (!special_tokens.taken_id(i)) {
917
+ renaming[cur++] = i;
918
+ }
919
+ }
920
+ for (auto &node : char2id) {
921
+ assert(renaming.count(node.second));
922
+ node.second = renaming[node.second];
923
+ }
924
+
925
+ for (auto &rule : rules) {
926
+ assert(renaming.count(rule.x));
927
+ assert(renaming.count(rule.y));
928
+ assert(renaming.count(rule.z));
929
+ rule.x = renaming[rule.x];
930
+ rule.y = renaming[rule.y];
931
+ rule.z = renaming[rule.z];
932
+ }
933
+ }
934
+
935
+ uint64_t compute_char_count(flat_hash_map<uint32_t, uint64_t>& char_cnt, char* begin, char* end) {
936
+ bool invalid_input = false;
937
+ UTF8Iterator utf8_iter(begin, end);
938
+ uint64_t char_count = 0;
939
+ for (; !utf8_iter.empty(); char_count++, ++utf8_iter) {
940
+ if (*utf8_iter != INVALID_UNICODE) {
941
+ if (!is_space(*utf8_iter)) {
942
+ char_cnt[*utf8_iter]++;
943
+ }
944
+ } else {
945
+ invalid_input = true;
946
+ }
947
+ }
948
+ if (invalid_input) {
949
+ std::cerr << "WARNING Input contains invalid unicode characters."
950
+ << std::endl;
951
+ }
952
+ return char_count;
953
+ }
954
+
955
+ Status learn_bpe_from_string(string &text_utf8, int n_tokens,
956
+ const string &output_file,
957
+ BpeConfig bpe_config, BPEState *bpe_state) {
958
+ assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1);
959
+ uint64_t n_threads = bpe_config.n_threads;
960
+ vector<uint64_t> split_pos;
961
+ split_pos.push_back(0);
962
+ for (uint64_t i = 1; i <= n_threads; i++) {
963
+ uint64_t candidate = text_utf8.size() * i / n_threads;
964
+ for (; candidate < text_utf8.size() && !is_space(text_utf8[candidate]);
965
+ candidate++) {
966
+ }
967
+
968
+ split_pos.push_back(candidate);
969
+ }
970
+
971
+ vector<flat_hash_map<uint32_t, uint64_t>> shared_char_cnt(n_threads);
972
+
973
+ vector<std::mutex> mt(n_threads);
974
+ vector<std::condition_variable> cv(n_threads);
975
+ vector<char> thread_finished(n_threads, 0);
976
+ vector<char> main_finished(n_threads, 0);
977
+ vector<uint64_t> text_len(n_threads);
978
+
979
+ flat_hash_set<uint32_t> removed_chars;
980
+ flat_hash_map<uint32_t, uint32_t> char2id;
981
+
982
+ vector<flat_hash_map<VectorSegment, WordCount>> hash2wordcnt(n_threads);
983
+ int error_flag = 0;
984
+
985
+ flat_hash_map<uint32_t, vector<uint32_t>> recipe;
986
+ flat_hash_map<uint32_t, string> recipe_s;
987
+ vector<flat_hash_map<uint64_t, uint64_t>> pair2cnt_g(n_threads);
988
+ PriorityQueue merge_order(1);
989
+ vector<uint64_t> split_word_cnt;
990
+ vector<WordCount> word_cnt_global;
991
+
992
+ auto comb2int = [](uint64_t a, uint32_t &b, uint32_t &c) {
993
+ b = static_cast<uint32_t>(a >> 32u);
994
+ c = static_cast<uint32_t>(a & UINT32_MAX);
995
+ };
996
+
997
+ vector<vector<flat_hash_map<uint32_t, uint64_t>>> left_tokens_submit(
998
+ 2, vector<flat_hash_map<uint32_t, uint64_t>>(n_threads));
999
+ vector<vector<flat_hash_map<uint32_t, uint64_t>>> right_tokens_submit(
1000
+ 2, vector<flat_hash_map<uint32_t, uint64_t>>(n_threads));
1001
+ vector<std::atomic<uint32_t>> results_ready(n_threads);
1002
+ for (uint64_t i = 0; i < n_threads; i++) {
1003
+ results_ready[i] = 0;
1004
+ }
1005
+
1006
+ vector<char> thread_stopped(n_threads);
1007
+ vector<char> thread_ready_to_run(n_threads);
1008
+ vector<std::atomic_bool> thread_use_hs(n_threads);
1009
+ vector<BPE_Rule> task_order(2);
1010
+
1011
+ std::atomic<uint32_t> real_n_tokens(n_tokens);
1012
+
1013
+ std::mutex main_loop_mt;
1014
+ std::condition_variable main_loop_cv;
1015
+
1016
+ vector<std::thread> threads;
1017
+ for (uint64_t i = 0; i < n_threads; i++) {
1018
+ threads.emplace_back(
1019
+ [&](uint64_t thread_id) {
1020
+ // threads are working 1
1021
+
1022
+
1023
+ auto thread_awake_main = [&]() {
1024
+ {
1025
+ std::lock_guard<std::mutex> lk(mt[thread_id]);
1026
+ thread_finished[thread_id] = 1;
1027
+ }
1028
+ cv[thread_id].notify_one();
1029
+ };
1030
+
1031
+ auto thread_wait_main = [&]() {
1032
+ std::unique_lock<std::mutex> lk(mt[thread_id]);
1033
+ cv[thread_id].wait(lk, [&] { return main_finished[thread_id]; });
1034
+ main_finished[thread_id] = 0;
1035
+ };
1036
+
1037
+ flat_hash_map<uint32_t, uint64_t> char_cnt;
1038
+ uint64_t char_count = compute_char_count(char_cnt, &text_utf8[0] + split_pos[thread_id], &text_utf8[0] + split_pos[thread_id + 1]);
1039
+ text_len[thread_id] = char_count;
1040
+ shared_char_cnt[thread_id] = char_cnt;
1041
+
1042
+ thread_awake_main();
1043
+ // main is working 1
1044
+ thread_wait_main();
1045
+ // threads are working 2
1046
+ char* seg_begin = &text_utf8[0] + split_pos[thread_id];
1047
+ char* seg_end = &text_utf8[0] + split_pos[thread_id + 1];
1048
+ char* new_seg_end = remove_rare_chars(seg_begin, seg_end, removed_chars);
1049
+ seg_end = new_seg_end;
1050
+
1051
+ hash2wordcnt[thread_id] = compute_word_count(seg_begin, seg_end, char2id);
1052
+
1053
+ thread_awake_main();
1054
+ // main is working 2
1055
+ thread_wait_main();
1056
+ // threads are working 3
1057
+ if (error_flag != 0) {
1058
+ return;
1059
+ }
1060
+
1061
+ flat_hash_map<uint64_t, vector<Position>> pair2pos;
1062
+ vector<vector<NodeEncoder>> lists_of_tokens;
1063
+ vector<uint64_t> word_freq;
1064
+ build_linked_list(
1065
+ {word_cnt_global.begin() + split_word_cnt[thread_id],
1066
+ word_cnt_global.begin() + split_word_cnt[thread_id + 1]},
1067
+ lists_of_tokens, pair2pos, pair2cnt_g[thread_id]);
1068
+
1069
+ std::transform(
1070
+ word_cnt_global.begin() + split_word_cnt[thread_id],
1071
+ word_cnt_global.begin() + split_word_cnt[thread_id + 1],
1072
+ std::back_inserter(word_freq),
1073
+ [](const WordCount &x) { return x.cnt; });
1074
+
1075
+ thread_awake_main();
1076
+ // main is working 3
1077
+ // threads are working 4
1078
+
1079
+ worker_doing_merge(thread_id, lists_of_tokens, pair2cnt_g, pair2pos,
1080
+ word_freq, mt, cv, task_order, thread_use_hs,
1081
+ char2id, left_tokens_submit, right_tokens_submit,
1082
+ real_n_tokens, results_ready, bpe_config,
1083
+ main_loop_mt, main_loop_cv);
1084
+ },
1085
+ i);
1086
+ }
1087
+
1088
+ auto main_wait_threads = [&]() {
1089
+ for (uint64_t i = 0; i < n_threads; i++) {
1090
+ std::unique_lock<std::mutex> lk(mt[i]);
1091
+ cv[i].wait(lk, [&] { return thread_finished[i]; });
1092
+ thread_finished[i] = 0;
1093
+ }
1094
+ };
1095
+
1096
+ auto main_awake_threads = [&]() {
1097
+ for (uint64_t i = 0; i < n_threads; i++) {
1098
+ std::lock_guard<std::mutex> lk(mt[i]);
1099
+ main_finished[i] = 1;
1100
+ }
1101
+ for (uint64_t i = 0; i < n_threads; i++) {
1102
+ cv[i].notify_one();
1103
+ }
1104
+ };
1105
+
1106
+ main_wait_threads();
1107
+
1108
+ // main is working 1
1109
+ for (uint64_t i = 1; i < n_threads; i++) {
1110
+ for (auto x : shared_char_cnt[i]) {
1111
+ shared_char_cnt[0][x.first] += x.second;
1112
+ }
1113
+ text_len[0] += text_len[i];
1114
+ }
1115
+
1116
+ char2id = compute_alphabet_helper(shared_char_cnt[0], text_len[0],
1117
+ removed_chars, bpe_config);
1118
+
1119
+ main_awake_threads();
1120
+ // threads are working 2
1121
+
1122
+ main_wait_threads();
1123
+ // main is working 2
1124
+
1125
+ for (uint64_t i = 1; i < n_threads; i++) {
1126
+ for (const auto &x : hash2wordcnt[i]) {
1127
+ auto it = hash2wordcnt[0].find(x.first);
1128
+ if (it == hash2wordcnt[0].end()) {
1129
+ hash2wordcnt[0][x.first] = x.second;
1130
+ } else {
1131
+ it->second.cnt += x.second.cnt;
1132
+ }
1133
+ }
1134
+ hash2wordcnt[i].clear();
1135
+ }
1136
+
1137
+ word_cnt_global.resize(hash2wordcnt[0].size());
1138
+ std::transform(
1139
+ hash2wordcnt[0].begin(), hash2wordcnt[0].end(), word_cnt_global.begin(),
1140
+ [](const std::pair<VectorSegment, WordCount> &x) { return x.second; });
1141
+
1142
+ hash2wordcnt.shrink_to_fit();
1143
+ text_utf8.shrink_to_fit();
1144
+
1145
+ merge_order = PriorityQueue(text_len[0]);
1146
+
1147
+ uint64_t used_ids =
1148
+ char2id.size() + bpe_config.special_tokens.n_special_tokens();
1149
+ if (used_ids > (uint64_t) n_tokens) {
1150
+ string error_message = "Incorrect arguments. Vocabulary size too small. Set vocab_size>=";
1151
+ error_message += std::to_string(used_ids) + ". Current value for vocab_size=" + std::to_string(n_tokens);
1152
+ error_flag = 1;
1153
+ main_awake_threads();
1154
+ for (auto &t : threads) {
1155
+ t.join();
1156
+ }
1157
+ return Status(1, error_message);
1158
+ }
1159
+
1160
+ init_recipe(char2id, recipe, recipe_s);
1161
+
1162
+ split_word_cnt.push_back(0);
1163
+ for (uint64_t i = 1; i <= n_threads; i++) {
1164
+ split_word_cnt.push_back(word_cnt_global.size() * i / n_threads);
1165
+ }
1166
+
1167
+ main_awake_threads();
1168
+ // threads are working 3
1169
+
1170
+ main_wait_threads();
1171
+ // main is working 3
1172
+ flat_hash_map<uint64_t, uint64_t> real_pair_cnt;
1173
+
1174
+ for (uint64_t i = 0; i < n_threads; i++) {
1175
+ for (const auto &x : pair2cnt_g[i]) {
1176
+ real_pair_cnt[x.first] += x.second;
1177
+ }
1178
+ }
1179
+
1180
+ for (const auto &x : real_pair_cnt) {
1181
+ uint32_t ka, kb;
1182
+ comb2int(x.first, ka, kb);
1183
+ merge_order.push({x.second, ka, kb});
1184
+ }
1185
+ vector<BPE_Rule> rules;
1186
+
1187
+ auto get_recipe = [&](uint32_t x, uint32_t y) {
1188
+ assert(recipe.count(x));
1189
+ assert(recipe.count(y));
1190
+ vector<uint32_t> new_recipe = recipe[x];
1191
+ new_recipe.insert(new_recipe.end(), recipe[y].begin(), recipe[y].end());
1192
+ return new_recipe;
1193
+ };
1194
+
1195
+ std::function<uint64_t(uint64_t)> check_cnt = [&](uint64_t mask) {
1196
+ uint64_t ret = 0;
1197
+ for (uint64_t i = 0; i < n_threads; i++) {
1198
+ auto it = pair2cnt_g[i].find(mask);
1199
+ if (it != pair2cnt_g[i].end()) {
1200
+ ret += it->second;
1201
+ }
1202
+ }
1203
+ return ret;
1204
+ };
1205
+
1206
+ uint64_t finished_cur = used_ids;
1207
+ uint64_t last_failed_try = 0;
1208
+
1209
+ flat_hash_map<uint32_t, uint64_t> all_res;
1210
+ vector<char> local_check_list(n_threads);
1211
+ flat_hash_map<uint32_t, uint64_t> global_ht_update_left;
1212
+ flat_hash_map<uint32_t, uint64_t> global_ht_update_right;
1213
+
1214
+ int inter_fail = 0;
1215
+ int equal_fail = 0;
1216
+ vector<std::pair<int, double>> progress_debug;
1217
+ while (used_ids < (uint64_t) n_tokens) {
1218
+ uint32_t x, y, z;
1219
+ assert(finished_cur <= used_ids && used_ids <= finished_cur + 2);
1220
+ bool progress = false;
1221
+
1222
+ if (used_ids < (uint64_t) n_tokens && used_ids - finished_cur < 2 &&
1223
+ last_failed_try < finished_cur) {
1224
+ progress = true;
1225
+ for (uint64_t i = 0; i < n_threads; i++) {
1226
+ thread_use_hs[i] = false;
1227
+ }
1228
+ {
1229
+ vector<std::lock_guard<std::mutex>> lg(mt.begin(), mt.end());
1230
+
1231
+ uint64_t real_cnt = 0;
1232
+ while (true) {
1233
+ if (merge_order.empty()) {
1234
+ if (finished_cur == used_ids) {
1235
+ std::cerr << "WARNING merged only: " << used_ids
1236
+ << " pairs of tokens" << std::endl;
1237
+ x = UINT32_MAX;
1238
+ y = UINT32_MAX;
1239
+ z = UINT32_MAX;
1240
+ real_n_tokens = used_ids;
1241
+ break;
1242
+ } else {
1243
+ x = y = z = 0;
1244
+ last_failed_try = finished_cur;
1245
+ break;
1246
+ }
1247
+ }
1248
+ BPE_Rule last_rule = (used_ids - finished_cur == 1)
1249
+ ? rules.back()
1250
+ : BPE_Rule({0, 0, 0});
1251
+
1252
+ auto merge_event = merge_order.top(check_cnt, last_rule);
1253
+ if ((used_ids - finished_cur == 1) &&
1254
+ (merge_event.left_token == rules.back().y ||
1255
+ merge_event.right_token == rules.back().x ||
1256
+ (!rules.empty() && rules.back().x == rules.back().y))) {
1257
+ inter_fail += merge_event.left_token == rules.back().y ||
1258
+ merge_event.right_token == rules.back().x;
1259
+ equal_fail += !rules.empty() && rules.back().x == rules.back().y &&
1260
+ used_ids - finished_cur == 1;
1261
+
1262
+ last_failed_try = finished_cur;
1263
+ x = y = z = 0;
1264
+ break;
1265
+ }
1266
+
1267
+ merge_order.pop();
1268
+ real_cnt = check_cnt(
1269
+ int2comb(merge_event.left_token, merge_event.right_token));
1270
+ assert(real_cnt <= merge_event.count);
1271
+
1272
+ if (real_cnt != merge_event.count) {
1273
+ if (real_cnt > 0) {
1274
+ merge_event.count = real_cnt;
1275
+ merge_order.push(merge_event);
1276
+ }
1277
+ continue;
1278
+ }
1279
+
1280
+ if (real_cnt == 0) {
1281
+ continue;
1282
+ }
1283
+
1284
+ x = merge_event.left_token;
1285
+ y = merge_event.right_token;
1286
+ z = used_ids;
1287
+ break;
1288
+ }
1289
+ if (last_failed_try != finished_cur && x != UINT32_MAX) {
1290
+ task_order[used_ids % 2] = {x, y, z};
1291
+ recipe[z] = get_recipe(x, y);
1292
+ recipe_s[z] = recipe_s[x] + recipe_s[y];
1293
+
1294
+ if (used_ids % 1000 == 0) {
1295
+ int used_symbols = 0;
1296
+ std::cerr << "id: " << z << "=" << x << "+" << y;
1297
+ used_symbols += std::to_string(z).size();
1298
+ used_symbols += 1;
1299
+ used_symbols += std::to_string(x).size();
1300
+ used_symbols += 1;
1301
+ used_symbols += std::to_string(y).size();
1302
+ for (int j = used_symbols; j < 22 + 4; j++) {
1303
+ std::cerr << " ";
1304
+ }
1305
+ used_symbols = 0;
1306
+ std::cerr << "freq: " << real_cnt;
1307
+ used_symbols += 5;
1308
+ used_symbols += std::to_string(real_cnt).size();
1309
+
1310
+ for (int j = used_symbols; j < 15; j++) {
1311
+ std::cerr << " ";
1312
+ }
1313
+ std::cerr << " subword: " << recipe_s[z] << "="
1314
+ << recipe_s[x] + "+" + recipe_s[y] << std::endl;
1315
+ }
1316
+ used_ids++;
1317
+ rules.emplace_back(x, y, z);
1318
+ }
1319
+
1320
+ for (uint64_t i = 0; i < n_threads; i++) {
1321
+ thread_use_hs[i] = true;
1322
+ }
1323
+ }
1324
+ for (auto &cond_value : cv) {
1325
+ cond_value.notify_one();
1326
+ }
1327
+ if (x == UINT32_MAX) {
1328
+ break;
1329
+ }
1330
+ }
1331
+
1332
+ // collect results
1333
+
1334
+ bool full_epoch = true;
1335
+ for (uint64_t i = 0; i < n_threads; i++) {
1336
+ if (!local_check_list[i]) {
1337
+ if (results_ready[i] >= finished_cur) {
1338
+ progress = true;
1339
+ local_check_list[i] = 1;
1340
+
1341
+ for (auto token_cnt : left_tokens_submit[finished_cur % 2][i]) {
1342
+ global_ht_update_left[token_cnt.first] += token_cnt.second;
1343
+ }
1344
+
1345
+ for (auto token_cnt : right_tokens_submit[finished_cur % 2][i]) {
1346
+ global_ht_update_right[token_cnt.first] += token_cnt.second;
1347
+ }
1348
+ } else {
1349
+ full_epoch = false;
1350
+ }
1351
+ }
1352
+ }
1353
+
1354
+ if (full_epoch) {
1355
+ for (auto left_token : global_ht_update_left) {
1356
+ merge_order.push({left_token.second, left_token.first,
1357
+ task_order[finished_cur % 2].z});
1358
+ }
1359
+ for (auto right_token : global_ht_update_right) {
1360
+ merge_order.push({right_token.second, task_order[finished_cur % 2].z,
1361
+ right_token.first});
1362
+ }
1363
+ local_check_list.assign(n_threads, 0);
1364
+ global_ht_update_left.clear();
1365
+ global_ht_update_right.clear();
1366
+ finished_cur++;
1367
+ }
1368
+ if (!progress) {
1369
+ std::unique_lock<std::mutex> ul(main_loop_mt);
1370
+ main_loop_cv.wait(ul, [&] {
1371
+ for (uint64_t i = 0; i < n_threads; i++) {
1372
+ if (!local_check_list[i] && results_ready[i] >= finished_cur)
1373
+ return true;
1374
+ }
1375
+ return false;
1376
+ });
1377
+ }
1378
+ }
1379
+ for (auto &t : threads) {
1380
+ t.join();
1381
+ }
1382
+
1383
+ rename_tokens(char2id, rules, bpe_config.special_tokens, n_tokens);
1384
+
1385
+ *bpe_state = {char2id, rules, bpe_config.special_tokens};
1386
+ bpe_state->dump(output_file);
1387
+ std::cerr << "model saved to: " << output_file << std::endl;
1388
+ return Status();
1389
+ }
1390
+
1391
+ void build_linked_list(
1392
+ const vector<WordCount> &word_cnt, vector<vector<NodeEncoder>> &list,
1393
+ flat_hash_map<uint64_t, PositionsCnt> &pair2poscnt) {
1394
+ list.resize(word_cnt.size());
1395
+ for (uint64_t i = 0; i < word_cnt.size(); i++) {
1396
+ for (uint32_t ch : word_cnt[i].word) {
1397
+ if (!list[i].empty() && list[i].back().val == ch) {
1398
+ list[i].back().seg_len++;
1399
+ } else {
1400
+ int list_size = list[i].size();
1401
+ list[i].emplace_back(ch, list_size - 1, list_size + 1, 1);
1402
+ }
1403
+ }
1404
+
1405
+ list[i].back().next = -1;
1406
+ for (uint64_t j = 0; j < list[i].size(); j++) {
1407
+ if (j + 1 < list[i].size()) {
1408
+ uint64_t comb = int2comb(list[i][j].val, list[i][j + 1].val);
1409
+ auto it = pair2poscnt.find(comb);
1410
+ if (it == pair2poscnt.end()) {
1411
+ pair2poscnt[comb] = {{{i, j}}, word_cnt[i].cnt};
1412
+ } else {
1413
+ it->second.positions.emplace_back(i, j);
1414
+ it->second.cnt += word_cnt[i].cnt;
1415
+ }
1416
+ }
1417
+ assert(list[i][j].seg_len >= 1);
1418
+
1419
+ if (list[i][j].seg_len > 1) {
1420
+ uint64_t comb = int2comb(list[i][j].val, list[i][j].val);
1421
+ auto it = pair2poscnt.find(comb);
1422
+ uint64_t cc = word_cnt[i].cnt * pairsInSeg(list[i][j].seg_len);
1423
+ if (it == pair2poscnt.end()) {
1424
+ pair2poscnt[comb] = {{{i, j}}, cc};
1425
+ } else {
1426
+ it->second.positions.emplace_back(i, j);
1427
+ it->second.cnt += cc;
1428
+ }
1429
+ }
1430
+ }
1431
+ }
1432
+ }
1433
+
1434
+ flat_hash_map<uint32_t, uint32_t> compute_alphabet(
1435
+ const vector<uint32_t> &data, flat_hash_set<uint32_t> &removed_chars,
1436
+ const BpeConfig &bpe_config) {
1437
+ flat_hash_map<uint32_t, uint64_t> char_cnt;
1438
+ for (auto ch : data) {
1439
+ if (!is_space(ch)) {
1440
+ char_cnt[ch]++;
1441
+ }
1442
+ }
1443
+ assert(!char_cnt.empty());
1444
+ return compute_alphabet_helper(char_cnt, data.size(), removed_chars,
1445
+ bpe_config);
1446
+ }
1447
+
1448
+ Status check_config(BpeConfig &bpe_config, int vocab_size) {
1449
+ if (bpe_config.character_coverage <= 0 || bpe_config.character_coverage > 1) {
1450
+ return Status(1, "coverage value must be in the range (0, 1]. Current value of coverage = " +
1451
+ std::to_string(bpe_config.character_coverage));
1452
+ }
1453
+ if (bpe_config.special_tokens.unk_id < 0 ||
1454
+ bpe_config.special_tokens.unk_id >= vocab_size) {
1455
+ return Status(1,
1456
+ "unk_id: must be in the range [0, vocab_size - 1]. Current value of vocab_size = "
1457
+ + std::to_string(vocab_size) + "; unk_id = " + std::to_string(bpe_config.special_tokens.unk_id));
1458
+ }
1459
+
1460
+ if (bpe_config.special_tokens.pad_id < -1 ||
1461
+ bpe_config.special_tokens.pad_id >= vocab_size) {
1462
+ return Status(1, "pad_id must be in the range [-1, vocab_size - 1]. Current value of vocab_size = " +
1463
+ std::to_string(vocab_size) + "; pad_id = " + std::to_string(bpe_config.special_tokens.pad_id));
1464
+ }
1465
+
1466
+ if (bpe_config.special_tokens.bos_id < -1 ||
1467
+ bpe_config.special_tokens.bos_id >= vocab_size) {
1468
+ return Status(1, "bos_id must be in the range [-1, vocab_size - 1]. Current value of vocab_size = " +
1469
+ std::to_string(vocab_size) + "; bos_id = " + std::to_string(bpe_config.special_tokens.bos_id));
1470
+ }
1471
+
1472
+ if (bpe_config.special_tokens.eos_id < -1 ||
1473
+ bpe_config.special_tokens.eos_id >= vocab_size) {
1474
+ return Status(1, "eos_id must be in the range [-1, vocab_size - 1]. Current value of vocab_size = " +
1475
+ std::to_string(vocab_size) + " eos_id = " + std::to_string(bpe_config.special_tokens.eos_id));
1476
+ }
1477
+
1478
+ flat_hash_set<int> ids;
1479
+ uint64_t cnt_add = 0;
1480
+ if (bpe_config.special_tokens.pad_id != -1) {
1481
+ ids.insert(bpe_config.special_tokens.pad_id);
1482
+ cnt_add++;
1483
+ }
1484
+ if (bpe_config.special_tokens.bos_id != -1) {
1485
+ ids.insert(bpe_config.special_tokens.bos_id);
1486
+ cnt_add++;
1487
+ }
1488
+ if (bpe_config.special_tokens.eos_id != -1) {
1489
+ ids.insert(bpe_config.special_tokens.eos_id);
1490
+ cnt_add++;
1491
+ }
1492
+ ids.insert(bpe_config.special_tokens.unk_id);
1493
+ cnt_add++;
1494
+ if (ids.size() != cnt_add) {
1495
+ return Status(1, "All ids of special tokens must be different.");
1496
+ }
1497
+
1498
+ if (bpe_config.n_threads == -1) {
1499
+ bpe_config.n_threads = std::thread::hardware_concurrency();
1500
+ }
1501
+ bpe_config.n_threads = std::min(8, std::max(1, bpe_config.n_threads));
1502
+ return Status();
1503
+ }
1504
+
1505
+ void print_config(const string &input_path, const string &model_path,
1506
+ int vocab_size, BpeConfig bpe_config) {
1507
+ std::cerr << "Training parameters" << std::endl;
1508
+ std::cerr << " input: " << input_path << std::endl;
1509
+ std::cerr << " model: " << model_path << std::endl;
1510
+ std::cerr << " vocab_size: " << vocab_size << std::endl;
1511
+ std::cerr << " n_threads: " << bpe_config.n_threads << std::endl;
1512
+ std::cerr << " character_coverage: " << bpe_config.character_coverage
1513
+ << std::endl;
1514
+ std::cerr << " pad: " << bpe_config.special_tokens.pad_id << std::endl;
1515
+ std::cerr << " unk: " << bpe_config.special_tokens.unk_id << std::endl;
1516
+ std::cerr << " bos: " << bpe_config.special_tokens.bos_id << std::endl;
1517
+ std::cerr << " eos: " << bpe_config.special_tokens.eos_id << std::endl;
1518
+ std::cerr << std::endl;
1519
+ }
1520
+
1521
+ Status train_bpe(const string &input_path, const string &model_path,
1522
+ int vocab_size, BpeConfig bpe_config) {
1523
+ Status status = check_config(bpe_config, vocab_size);
1524
+ if (!status.ok()) {
1525
+ return status;
1526
+ }
1527
+ print_config(input_path, model_path, vocab_size, bpe_config);
1528
+ std::cerr << "reading file..." << std::endl;
1529
+ string data;
1530
+ status = fast_read_file_utf8(input_path, &data);
1531
+ if (!status.ok()) {
1532
+ return status;
1533
+ }
1534
+ std::cerr << "learning bpe..." << std::endl;
1535
+ BPEState bpe_state;
1536
+ status = learn_bpe_from_string(data, vocab_size, model_path, bpe_config, &bpe_state);
1537
+ if (!status.ok()) {
1538
+ return status;
1539
+ }
1540
+ return Status();
1541
+ }
1542
+
1543
+
1544
+ template<typename T>
1545
+ class BasePriorityQueue {
1546
+ public:
1547
+ virtual void push(T x) = 0;
1548
+ virtual bool pop(T& x) = 0;
1549
+ virtual ~BasePriorityQueue() {}
1550
+ };
1551
+
1552
+ template<typename T>
1553
+ class STLQueue : public BasePriorityQueue<T> {
1554
+ std::priority_queue<T> q;
1555
+ void push(T x) override {
1556
+ q.push(x);
1557
+ }
1558
+ bool pop(T& x) override {
1559
+ if (q.empty()) {
1560
+ return false;
1561
+ }
1562
+ x = q.top();
1563
+ q.pop();
1564
+ return true;
1565
+ }
1566
+ };
1567
+
1568
+ std::mt19937 rnd;
1569
+
1570
+ template<typename T>
1571
+ class DropoutQueue : public BasePriorityQueue<T> {
1572
+ double skip_prob;
1573
+ std::uniform_real_distribution<> dist;
1574
+ std::priority_queue<T> q;
1575
+ vector<T> skipped_elements;
1576
+ public:
1577
+ explicit DropoutQueue(double _skip_prob):skip_prob(_skip_prob), dist(std::uniform_real_distribution<>(0, 1)) {}
1578
+ void push(T x) override {
1579
+ q.push(x);
1580
+ }
1581
+ bool pop(T& x) override {
1582
+ assert(skipped_elements.empty());
1583
+ while (true) {
1584
+ if (q.empty()) {
1585
+ for (auto y: skipped_elements) {
1586
+ q.push(y);
1587
+ }
1588
+ skipped_elements.clear();
1589
+ return false;
1590
+ }
1591
+ T temp = q.top();
1592
+ q.pop();
1593
+ if (dist(rnd) < skip_prob) {
1594
+ skipped_elements.push_back(temp);
1595
+ }
1596
+ else {
1597
+ for (auto y: skipped_elements) {
1598
+ q.push(y);
1599
+ }
1600
+ skipped_elements.clear();
1601
+ x = temp;
1602
+ return true;
1603
+ }
1604
+ }
1605
+ }
1606
+ };
1607
+
1608
+ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,
1609
+ const EncodingConfig &encoding_config,
1610
+ OutputType output_type) const {
1611
+ struct NodeDecoder {
1612
+ uint32_t token_id;
1613
+ int prev, next;
1614
+
1615
+ NodeDecoder(uint32_t _val, uint64_t cur_pos)
1616
+ : token_id(_val),
1617
+ prev(static_cast<int>(cur_pos) - 1),
1618
+ next(static_cast<int>(cur_pos) + 1) {}
1619
+
1620
+ NodeDecoder(uint32_t _val, int _prev, int _next)
1621
+ : token_id(_val), prev(_prev), next(_next) {}
1622
+ };
1623
+
1624
+ struct MergeEvent2 {
1625
+ int priority;
1626
+ int pos;
1627
+
1628
+ bool operator<(const MergeEvent2 &other) const {
1629
+ return priority > other.priority ||
1630
+ (priority == other.priority && pos > other.pos);
1631
+ }
1632
+ };
1633
+
1634
+ vector<int> output_ids;
1635
+ vector<string> output_pieces;
1636
+
1637
+ if (encoding_config.bos) {
1638
+ if (output_type == ID) {
1639
+ output_ids.push_back(bpe_state.special_tokens.bos_id);
1640
+ } else {
1641
+ output_pieces.push_back(BOS_TOKEN);
1642
+ }
1643
+ }
1644
+
1645
+ vector<NodeDecoder> list;
1646
+ flat_hash_map<uint32_t, string> unrecognized_tokens;
1647
+
1648
+ auto text = decode_utf8(sentence_utf8.data(),
1649
+ sentence_utf8.data() + sentence_utf8.size());
1650
+
1651
+ assert(bpe_state.char2id.count(SPACE_TOKEN));
1652
+
1653
+ for (; !text.empty() && is_space(text.back()); text.pop_back()) {
1654
+ }
1655
+
1656
+ const int new_tokens_start = static_cast<int>(
1657
+ 1e9); // just some number that bigger than any subword id
1658
+ for (auto it_text = text.begin(); it_text != text.end();) {
1659
+ list.clear();
1660
+ unrecognized_tokens.clear();
1661
+
1662
+ auto begin_of_word = std::find_if_not(it_text, text.end(), is_space);
1663
+ auto end_of_word = std::find_if(begin_of_word, text.end(), is_space);
1664
+ it_text = end_of_word;
1665
+
1666
+ uint32_t new_token_cur = new_tokens_start;
1667
+ list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0);
1668
+
1669
+ for (auto it_char_in_word = begin_of_word; it_char_in_word < end_of_word;) {
1670
+ if (bpe_state.char2id.count(*it_char_in_word) == 0) {
1671
+ auto it_unrecognized_word = std::find_if(
1672
+ it_char_in_word, end_of_word,
1673
+ [&](uint32_t ch) { return bpe_state.char2id.count(ch); });
1674
+
1675
+ unrecognized_tokens[new_token_cur] =
1676
+ encode_utf8({it_char_in_word, it_unrecognized_word});
1677
+ it_char_in_word = it_unrecognized_word;
1678
+
1679
+ list.emplace_back(new_token_cur, list.size());
1680
+ new_token_cur++;
1681
+ } else {
1682
+ list.emplace_back(bpe_state.char2id.at(*it_char_in_word), list.size());
1683
+ ++it_char_in_word;
1684
+ }
1685
+ }
1686
+ list.back().next = -1;
1687
+
1688
+
1689
+ auto pair_code = [&](uint64_t first_pos) {
1690
+ auto second_pos = list[first_pos].next;
1691
+ return int2comb(list[first_pos].token_id, list[second_pos].token_id);
1692
+ };
1693
+
1694
+ std::unique_ptr<BasePriorityQueue<MergeEvent2>> queue(nullptr);
1695
+ if (encoding_config.dropout_prob == 0) {
1696
+ queue.reset(new STLQueue<MergeEvent2>());
1697
+ }
1698
+ else {
1699
+ queue.reset(new DropoutQueue<MergeEvent2>(encoding_config.dropout_prob));
1700
+ }
1701
+
1702
+ auto push_in_queue_if_rule_exist = [&](uint64_t pos) {
1703
+ auto it = rule2id.find(pair_code(pos));
1704
+ if (it != rule2id.end()) {
1705
+ queue->push({it->second, static_cast<int>(pos)});
1706
+ }
1707
+ };
1708
+
1709
+ for (uint64_t j = 0; j + 1 < list.size(); j++) {
1710
+ push_in_queue_if_rule_exist(j);
1711
+ }
1712
+
1713
+ while (true) {
1714
+ MergeEvent2 event;
1715
+ if (!queue->pop(event)) {
1716
+ break;
1717
+ }
1718
+ int rule_id = event.priority;
1719
+ int pos_1 = event.pos;
1720
+ int pos_2 = list[pos_1].next;
1721
+ assert(pos_1 != pos_2);
1722
+ if (list[pos_1].token_id != bpe_state.rules[rule_id].x || pos_2 == -1 ||
1723
+ list[pos_2].token_id != bpe_state.rules[rule_id].y) {
1724
+ continue;
1725
+ }
1726
+
1727
+ int pos_0 = list[pos_1].prev;
1728
+ int pos_3 = list[pos_2].next;
1729
+
1730
+ list[pos_2] = {0, -1, -1};
1731
+ list[pos_1] = {bpe_state.rules[rule_id].z, pos_0, pos_3};
1732
+ if (pos_3 != -1) {
1733
+ list[pos_3].prev = pos_1;
1734
+ }
1735
+
1736
+ if (pos_0 != -1) {
1737
+ push_in_queue_if_rule_exist(pos_0);
1738
+ }
1739
+ if (pos_3 != -1) {
1740
+ push_in_queue_if_rule_exist(pos_1);
1741
+ }
1742
+ }
1743
+
1744
+ auto it_alive_token = std::find_if(
1745
+ list.begin(), list.end(),
1746
+ [](const NodeDecoder &node) { return node.token_id != 0; });
1747
+
1748
+ assert(it_alive_token != list.end());
1749
+ int alive_token = std::distance(list.begin(), it_alive_token);
1750
+ for (; alive_token != -1; alive_token = list[alive_token].next) {
1751
+ int token_id = list[alive_token].token_id;
1752
+ if (token_id >= new_tokens_start) {
1753
+ if (output_type == ID) {
1754
+ output_ids.push_back(bpe_state.special_tokens.unk_id);
1755
+ } else {
1756
+ assert(unrecognized_tokens.count(token_id));
1757
+ output_pieces.push_back(unrecognized_tokens[token_id]);
1758
+ }
1759
+ } else {
1760
+ if (output_type == ID) {
1761
+ output_ids.push_back(token_id);
1762
+ } else {
1763
+ assert(recipe.count(token_id));
1764
+ output_pieces.push_back(token2word(recipe.at(token_id), id2char));
1765
+ }
1766
+ }
1767
+ }
1768
+ }
1769
+ if (encoding_config.eos) {
1770
+ if (output_type == ID) {
1771
+ output_ids.push_back(bpe_state.special_tokens.eos_id);
1772
+ } else {
1773
+ output_pieces.push_back(EOS_TOKEN);
1774
+ }
1775
+ }
1776
+
1777
+ if (encoding_config.reverse) {
1778
+ if (output_type == ID) {
1779
+ std::reverse(output_ids.begin(), output_ids.end());
1780
+ } else {
1781
+ std::reverse(output_pieces.begin(), output_pieces.end());
1782
+ }
1783
+ }
1784
+ return {output_ids, output_pieces};
1785
+ }
1786
+
1787
+ BaseEncoder::BaseEncoder(BPEState _bpe_state, int _n_threads)
1788
+ : bpe_state(std::move(_bpe_state)), n_threads(_n_threads) {
1789
+ fill_from_state();
1790
+ assert(n_threads >= 1 || n_threads == -1);
1791
+ if (n_threads == -1) {
1792
+ n_threads = std::max(1, int(std::thread::hardware_concurrency()));
1793
+ }
1794
+ }
1795
+
1796
+ BaseEncoder::BaseEncoder(const string &model_path, int _n_threads, Status *ret_status)
1797
+ : n_threads(_n_threads) {
1798
+ Status status = bpe_state.load(model_path);
1799
+ if (!status.ok()) {
1800
+ *ret_status = status;
1801
+ return;
1802
+ }
1803
+ fill_from_state();
1804
+ assert(n_threads >= 1 || n_threads == -1);
1805
+ if (n_threads == -1) {
1806
+ n_threads = std::max(1, int(std::thread::hardware_concurrency()));
1807
+ }
1808
+ *ret_status = Status();
1809
+ }
1810
+
1811
+ template<typename T>
1812
+ vector<T> concat_vectors(const vector<T> &a, const vector<T> &b) {
1813
+ vector<T> c;
1814
+ c.reserve(a.size() + b.size());
1815
+ c.insert(c.end(), a.begin(), a.end());
1816
+ c.insert(c.end(), b.begin(), b.end());
1817
+ return c;
1818
+ }
1819
+
1820
+ void BaseEncoder::fill_from_state() {
1821
+ for (auto x : bpe_state.char2id) {
1822
+ id2char[x.second] = x.first;
1823
+ }
1824
+
1825
+ for (int i = 0; i < (int) bpe_state.rules.size(); i++) {
1826
+ rule2id[int2comb(bpe_state.rules[i].x, bpe_state.rules[i].y)] = i;
1827
+ }
1828
+
1829
+ for (auto x : id2char) {
1830
+ recipe[x.first] = {x.first};
1831
+ }
1832
+
1833
+ for (auto rule : bpe_state.rules) {
1834
+ recipe[rule.z] = concat_vectors(recipe[rule.x], recipe[rule.y]);
1835
+ }
1836
+
1837
+ for (const auto &id_to_recipe : recipe) {
1838
+ reversed_recipe[token2word(id_to_recipe.second, id2char)] =
1839
+ id_to_recipe.first;
1840
+ }
1841
+ reversed_recipe[BOS_TOKEN] = bpe_state.special_tokens.bos_id;
1842
+ reversed_recipe[EOS_TOKEN] = bpe_state.special_tokens.eos_id;
1843
+ }
1844
+
1845
+ int BaseEncoder::vocab_size() const {
1846
+ return bpe_state.rules.size() + bpe_state.char2id.size() +
1847
+ bpe_state.special_tokens.n_special_tokens();
1848
+ }
1849
+
1850
+ Status BaseEncoder::encode_parallel(
1851
+ const std::vector<std::string> &sentences,
1852
+ const EncodingConfig &encoding_config, OutputType output_type,
1853
+ std::vector<DecodeResult> *decoder_results
1854
+ ) const {
1855
+ if (encoding_config.bos && bpe_state.special_tokens.bos_id == -1) {
1856
+ return Status(1, "Can't add <BOS> token. Model was trained without it.");
1857
+ }
1858
+ if (encoding_config.eos && bpe_state.special_tokens.eos_id == -1) {
1859
+ return Status(1, "Can't add <EOS> token. Model was trained without it.");
1860
+ }
1861
+
1862
+ decoder_results->assign(sentences.size(), DecodeResult());
1863
+ if (sentences.size() <= static_cast<uint64_t>(n_threads) * 3 ||
1864
+ n_threads == 1) { // Not too many sentences. It's better to solve it
1865
+ // without threads.
1866
+ for (uint64_t i = 0; i < sentences.size(); i++) {
1867
+ decoder_results->at(i) = encode_sentence(sentences[i], encoding_config, output_type);
1868
+ }
1869
+ return Status();
1870
+ }
1871
+ vector<std::thread> threads;
1872
+ for (int i = 0; i < n_threads; i++) {
1873
+ threads.emplace_back(
1874
+ [&](uint64_t this_thread) {
1875
+ uint64_t tasks_for_thread =
1876
+ (sentences.size() + n_threads - 1) / n_threads;
1877
+ uint64_t first_task = tasks_for_thread * this_thread;
1878
+ uint64_t last_task =
1879
+ std::min(tasks_for_thread * (this_thread + 1), static_cast<uint64_t>(sentences.size()));
1880
+ for (uint64_t j = first_task; j < last_task; j++) {
1881
+ decoder_results->at(j) =
1882
+ encode_sentence(sentences[j], encoding_config, output_type);
1883
+ }
1884
+ },
1885
+ i);
1886
+ }
1887
+ for (auto &thread : threads) {
1888
+ thread.join();
1889
+ }
1890
+ return Status();
1891
+ }
1892
+
1893
+ Status BaseEncoder::encode_as_ids(const vector<string> &sentences, vector<vector<int>> *ids,
1894
+ bool bos, bool eos,
1895
+ bool reverse, double dropout_prob) const {
1896
+ EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob};
1897
+
1898
+ std::vector<DecodeResult> decode_results;
1899
+ Status status = encode_parallel(sentences, encoding_config, ID, &decode_results);
1900
+ if (!status.ok()) {
1901
+ return status;
1902
+ }
1903
+ ids->assign(decode_results.size(), vector<int>());
1904
+ for (uint64_t i = 0; i < decode_results.size(); i++) {
1905
+ ids->at(i) = move(decode_results[i].ids);
1906
+ }
1907
+ return Status();
1908
+ }
1909
+
1910
+ Status BaseEncoder::encode_as_subwords(
1911
+ const vector<string> &sentences,
1912
+ vector<vector<string>> *subwords,
1913
+ bool bos, bool eos, bool reverse, double dropout_prob) const {
1914
+ time_check("");
1915
+ EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob};
1916
+ std::vector<DecodeResult> decode_results;
1917
+ Status status = encode_parallel(sentences, encoding_config, SUBWORD, &decode_results);
1918
+ if (!status.ok()) {
1919
+ return status;
1920
+ }
1921
+ subwords->assign(decode_results.size(), vector<string>());
1922
+ for (uint64_t i = 0; i < decode_results.size(); i++) {
1923
+ subwords->at(i) = move(decode_results[i].pieces);
1924
+ }
1925
+ return Status();
1926
+ }
1927
+
1928
+ Status BaseEncoder::id_to_subword(int id, string *subword, bool replace_space) const {
1929
+ if (id < 0 || vocab_size() <= id) {
1930
+ return Status(1, "id must be in the range [0, vocab_size - 1]. Current value: vocab_size = " +
1931
+ std::to_string(vocab_size()) +
1932
+ "; id=" + std::to_string(id) + ";");
1933
+ }
1934
+ if (bpe_state.special_tokens.unk_id == id) {
1935
+ *subword = UNK_TOKEN;
1936
+ return Status();
1937
+ }
1938
+ if (bpe_state.special_tokens.pad_id == id) {
1939
+ *subword = PAD_TOKEN;
1940
+ return Status();
1941
+ }
1942
+ if (bpe_state.special_tokens.bos_id == id) {
1943
+ *subword = BOS_TOKEN;
1944
+ return Status();
1945
+ }
1946
+ if (bpe_state.special_tokens.eos_id == id) {
1947
+ *subword = EOS_TOKEN;
1948
+ return Status();
1949
+ }
1950
+
1951
+ assert(recipe.count(id));
1952
+ if (replace_space) {
1953
+ auto symbols = recipe.at(id);
1954
+ if (id2char.at(symbols[0]) == SPACE_TOKEN) {
1955
+ *subword = " " + token2word({symbols.begin() + 1, symbols.end()}, id2char);
1956
+ return Status();
1957
+ }
1958
+ }
1959
+ *subword = token2word(recipe.at(id), id2char);
1960
+ return Status();
1961
+ }
1962
+
1963
+ int BaseEncoder::subword_to_id(const string &token) const {
1964
+ if (UNK_TOKEN == token) {
1965
+ return bpe_state.special_tokens.unk_id;
1966
+ }
1967
+ if (PAD_TOKEN == token) {
1968
+ return bpe_state.special_tokens.pad_id;
1969
+ }
1970
+ if (BOS_TOKEN == token) {
1971
+ return bpe_state.special_tokens.bos_id;
1972
+ }
1973
+ if (EOS_TOKEN == token) {
1974
+ return bpe_state.special_tokens.eos_id;
1975
+ }
1976
+ if (reversed_recipe.count(token)) {
1977
+ return reversed_recipe.at(token);
1978
+ }
1979
+ return bpe_state.special_tokens.unk_id;
1980
+ }
1981
+
1982
+ Status BaseEncoder::decode(const vector<vector<int>> &ids,
1983
+ vector<string> *sentences,
1984
+ const unordered_set<int> *ignore_ids) const {
1985
+ vector<string> ret;
1986
+ for (const auto &sentence : ids) {
1987
+ string decode_output;
1988
+ Status status = decode(sentence, &decode_output, ignore_ids);
1989
+ if (!status.ok()) {
1990
+ return status;
1991
+ }
1992
+ sentences->push_back(move(decode_output));
1993
+ }
1994
+ return Status();
1995
+ }
1996
+
1997
+ Status BaseEncoder::decode(const vector<int> &ids, string *sentence, const unordered_set<int> *ignore_ids) const {
1998
+ bool first_iter = true;
1999
+ for (auto id : ids) {
2000
+ string subword;
2001
+
2002
+ if (!ignore_ids || ignore_ids->count(id) == 0) {
2003
+ Status status = id_to_subword(id, &subword, true);
2004
+ if (!status.ok()) {
2005
+ return status;
2006
+ }
2007
+ *sentence += subword;
2008
+ if (first_iter && sentence->at(0) == ' ') {
2009
+ *sentence = sentence->substr(1);
2010
+ }
2011
+ first_iter = false;
2012
+ }
2013
+ }
2014
+ return Status();
2015
+ }
2016
+
2017
+ Status BaseEncoder::decode(const vector<string> &data,
2018
+ vector<string> *sentences,
2019
+ const std::unordered_set<int> *ignore_ids) const {
2020
+ for (const auto &s : data) {
2021
+ std::stringstream stream;
2022
+ stream << s;
2023
+ vector<int> ids;
2024
+ int x;
2025
+ while (stream >> x) {
2026
+ ids.push_back(x);
2027
+ }
2028
+ string sentence;
2029
+ Status status = decode(ids, &sentence, ignore_ids);
2030
+ if (!status.ok()) {
2031
+ return status;
2032
+ }
2033
+ sentences->push_back(sentence);
2034
+ }
2035
+ return Status();
2036
+ }
2037
+
2038
+ vector<string> BaseEncoder::vocabulary() const {
2039
+ int n = vocab_size();
2040
+ vector<string> vocab(n);
2041
+ for (int i = 0; i < n; i++) {
2042
+ string subword;
2043
+ Status status = id_to_subword(i, &subword);
2044
+ assert(status.ok());
2045
+ vocab[i] = subword;
2046
+ }
2047
+ return vocab;
2048
+ }
2049
+
2050
+ void BaseEncoder::vocab_cli(bool verbose) const {
2051
+ uint32_t n_tokens = 0;
2052
+ for (const auto &entry : recipe) {
2053
+ n_tokens = std::max(entry.first, n_tokens);
2054
+ }
2055
+ n_tokens = std::max(n_tokens, bpe_state.special_tokens.max_id());
2056
+ n_tokens++;
2057
+
2058
+ flat_hash_map<uint32_t, std::pair<uint32_t, uint32_t>> reversed_rules;
2059
+ if (verbose) {
2060
+ for (auto rule : bpe_state.rules) {
2061
+ reversed_rules[rule.z] = {rule.x, rule.y};
2062
+ }
2063
+ }
2064
+
2065
+ for (uint64_t i = 0; i < n_tokens; i++) {
2066
+ string token_z;
2067
+ Status status = id_to_subword(i, &token_z);
2068
+ assert(status.ok());
2069
+ std::cout << i << "\t" << token_z;
2070
+ if (verbose) {
2071
+ if (reversed_rules.count(i)) {
2072
+ int used_symbols = 0;
2073
+ auto comb = reversed_rules[i];
2074
+ string token_x;
2075
+ string token_y;
2076
+ status = id_to_subword(comb.first, &token_x);
2077
+ assert(status.ok());
2078
+ status = id_to_subword(comb.second, &token_y);
2079
+ assert(status.ok());
2080
+
2081
+ used_symbols += decode_utf8(token_z).size() + 1;
2082
+ used_symbols +=
2083
+ decode_utf8(token_x).size() + 1 + decode_utf8(token_y).size();
2084
+
2085
+ std::cout << "=" << token_x << "+" << token_y;
2086
+ for (int t = 0; t < std::max(2, 50 - used_symbols); t++) {
2087
+ std::cout << " ";
2088
+ }
2089
+ std::cout << comb.first << "+" << comb.second;
2090
+ }
2091
+ }
2092
+ std::cout << std::endl;
2093
+ }
2094
+ }
2095
+
2096
+ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream,
2097
+ bool bos, bool eos, bool reverse, double dropout_prob) const {
2098
+ std::ios_base::sync_with_stdio(false);
2099
+ OutputType output_type;
2100
+ if (output_type_str == "id") {
2101
+ output_type = ID;
2102
+ } else {
2103
+ assert(output_type_str == "subword");
2104
+ output_type = SUBWORD;
2105
+ }
2106
+ if (stream) {
2107
+ if (output_type == SUBWORD) {
2108
+ string sentence;
2109
+ while (getline(std::cin, sentence)) {
2110
+ vector<vector<string>> subwords;
2111
+ Status status = encode_as_subwords({sentence}, &subwords, bos, eos, reverse, dropout_prob);
2112
+ if (!status.ok()) {
2113
+ return status;
2114
+ }
2115
+ write_to_stdout(subwords, true);
2116
+ }
2117
+ } else {
2118
+ assert(output_type == ID);
2119
+ string sentence;
2120
+ while (getline(std::cin, sentence)) {
2121
+ vector<vector<int>> ids;
2122
+ Status status = encode_as_ids({sentence}, &ids, bos, eos, reverse, dropout_prob);
2123
+ if (!status.ok()) {
2124
+ return status;
2125
+ }
2126
+ write_to_stdout(ids, true);
2127
+ }
2128
+ }
2129
+ } else {
2130
+ time_check("");
2131
+ const uint64_t batch_limit = 10 * 1024 * 1024;
2132
+ uint64_t total_progress = 0;
2133
+ uint64_t processed;
2134
+ std::cerr << "n_threads: " << n_threads << std::endl;
2135
+ int chars_remove = 0;
2136
+ do {
2137
+ processed = 0;
2138
+ auto sentences = read_lines_from_stdin(batch_limit, &processed);
2139
+ if (output_type == SUBWORD) {
2140
+ vector<vector<string>> subwords;
2141
+ Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse, dropout_prob);
2142
+ if (!status.ok()) {
2143
+ return status;
2144
+ }
2145
+ write_to_stdout(subwords, false);
2146
+ } else {
2147
+ assert(output_type == ID);
2148
+ vector<vector<int>> ids;
2149
+ Status status = encode_as_ids(sentences, &ids, bos, eos, reverse, dropout_prob);
2150
+ if (!status.ok()) {
2151
+ return status;
2152
+ }
2153
+ write_to_stdout(ids, false);
2154
+ }
2155
+ total_progress += processed;
2156
+
2157
+ for (int i = 0; i < chars_remove; i++) {
2158
+ std::cerr << '\b';
2159
+ }
2160
+ chars_remove = 0;
2161
+ string message = "bytes processed: ";
2162
+ chars_remove += message.size();
2163
+ chars_remove += std::to_string(total_progress).length();
2164
+ std::cerr << message << total_progress;
2165
+ } while (processed >= batch_limit);
2166
+ std::cerr << std::endl;
2167
+ }
2168
+ return Status();
2169
+ }
2170
+
2171
+ Status BaseEncoder::decode_cli(const std::unordered_set<int> *ignore_ids) const {
2172
+ std::ios_base::sync_with_stdio(false);
2173
+ string sentence;
2174
+ while (getline(std::cin, sentence)) {
2175
+ vector<string> output;
2176
+ Status status = decode({sentence}, &output, ignore_ids);
2177
+ if (!status.ok()) {
2178
+ return status;
2179
+ }
2180
+ std::cout << output[0] << "\n";
2181
+ }
2182
+ return Status();
2183
+ }
2184
+
2185
+ } // namespace vkcom