llama_cpp 0.14.7 → 0.15.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -5,11 +5,15 @@
5
5
  #include <cstddef>
6
6
  #include <cstdint>
7
7
  #include <map>
8
+ #include <regex>
8
9
  #include <stdexcept>
9
10
  #include <string>
10
11
  #include <unordered_map>
12
+ #include <unordered_set>
11
13
  #include <utility>
12
14
  #include <vector>
15
+ #include <locale>
16
+ #include <codecvt>
13
17
 
14
18
  static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
15
19
  std::string result;
@@ -53,23 +57,22 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
53
57
  offset += 4;
54
58
  return result;
55
59
  }
56
- throw std::invalid_argument("invalid string");
60
+ throw std::invalid_argument("failed to convert utf8 to codepoint");
57
61
  }
58
62
 
59
- static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
60
- std::vector<uint16_t> result;
61
- if (/* 0x0000 <= cp && */ cp <= 0xffff) {
62
- result.emplace_back(cp);
63
- }
64
- else if (0x10000 <= cp && cp <= 0x10ffff) {
65
- result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
66
- result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
67
- }
68
- else {
69
- throw std::invalid_argument("invalid cpt");
70
- }
71
- return result;
72
- }
63
+ //static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
64
+ // std::vector<uint16_t> result;
65
+ // if (/* 0x0000 <= cp && */ cp <= 0xffff) {
66
+ // result.emplace_back(cp);
67
+ // return result;
68
+ // }
69
+ // if (0x10000 <= cp && cp <= 0x10ffff) {
70
+ // result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
71
+ // result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
72
+ // return result;
73
+ // }
74
+ // throw std::invalid_argument("failed to convert codepoint to utf16");
75
+ //}
73
76
 
74
77
  //static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
75
78
  // std::vector<uint16_t> result;
@@ -80,56 +83,56 @@ static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
80
83
  // return result;
81
84
  //}
82
85
 
83
- static uint32_t cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
84
- assert(offset < utf16.size());
85
- if (((utf16[0] >> 10) << 10) != 0xd800) {
86
- auto result = utf16[offset + 0];
87
- offset += 1;
88
- return result;
89
- }
90
-
91
- if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
92
- throw std::invalid_argument("invalid character");
93
- }
94
-
95
- auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
96
- offset += 2;
97
- return result;
98
- }
86
+ //static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
87
+ // assert(offset < utf16.size());
88
+ // if (((utf16[0] >> 10) << 10) != 0xd800) {
89
+ // auto result = utf16[offset + 0];
90
+ // offset += 1;
91
+ // return result;
92
+ // }
93
+ //
94
+ // if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
95
+ // throw std::invalid_argument("invalid character");
96
+ // }
97
+ //
98
+ // auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
99
+ // offset += 2;
100
+ // return result;
101
+ //}
99
102
 
100
103
  //static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
101
104
  // std::vector<uint32_t> result;
102
105
  // size_t offset = 0;
103
106
  // while (offset < utf16.size()) {
104
- // result.push_back(cpt_from_utf16(utf16, offset));
107
+ // result.push_back(unicode_cpt_from_utf16(utf16, offset));
105
108
  // }
106
109
  // return result;
107
110
  //}
108
111
 
109
112
  static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
110
113
  std::unordered_map<uint32_t, int> cpt_types;
111
- for (auto p : unicode_ranges_digit) {
112
- for (auto i = p.first; i <= p.second; ++ i) {
113
- cpt_types[i] = CODEPOINT_TYPE_DIGIT;
114
+ for (auto p : unicode_ranges_number) {
115
+ for (auto i = p.first; i <= p.second; ++i) {
116
+ cpt_types[i] = CODEPOINT_TYPE_NUMBER;
114
117
  }
115
118
  }
116
119
  for (auto p : unicode_ranges_letter) {
117
- for (auto i = p.first; i <= p.second; ++ i) {
120
+ for (auto i = p.first; i <= p.second; ++i) {
118
121
  cpt_types[i] = CODEPOINT_TYPE_LETTER;
119
122
  }
120
123
  }
121
- for (auto p : unicode_ranges_whitespace) {
122
- for (auto i = p.first; i <= p.second; ++ i) {
123
- cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
124
+ for (auto p : unicode_ranges_separator) {
125
+ for (auto i = p.first; i <= p.second; ++i) {
126
+ cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
124
127
  }
125
128
  }
126
129
  for (auto p : unicode_ranges_accent_mark) {
127
- for (auto i = p.first; i <= p.second; ++ i) {
130
+ for (auto i = p.first; i <= p.second; ++i) {
128
131
  cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
129
132
  }
130
133
  }
131
134
  for (auto p : unicode_ranges_punctuation) {
132
- for (auto i = p.first; i <= p.second; ++ i) {
135
+ for (auto i = p.first; i <= p.second; ++i) {
133
136
  cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
134
137
  }
135
138
  }
@@ -139,7 +142,7 @@ static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
139
142
  }
140
143
  }
141
144
  for (auto p : unicode_ranges_control) {
142
- for (auto i = p.first; i <= p.second; ++ i) {
145
+ for (auto i = p.first; i <= p.second; ++i) {
143
146
  cpt_types[i] = CODEPOINT_TYPE_CONTROL;
144
147
  }
145
148
  }
@@ -194,34 +197,395 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
194
197
  return map;
195
198
  }
196
199
 
200
+ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
201
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
202
+ return conv.from_bytes(s);
203
+ }
204
+
205
+ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
206
+ std::vector<std::string> bpe_encoded_words;
207
+ for (const auto & word : bpe_words) {
208
+ std::string text_utf;
209
+ auto utf_word = unicode_cpts_from_utf8(word);
210
+ for (size_t i = 0; i < utf_word.size(); ++i) {
211
+ text_utf += unicode_cpt_to_utf8(utf_word[i]);
212
+ }
213
+
214
+ std::string encoded_token;
215
+ for (char & c : text_utf) {
216
+ encoded_token += unicode_byte_to_utf8(c);
217
+ }
218
+ bpe_encoded_words.emplace_back(encoded_token);
219
+ }
220
+ return bpe_encoded_words;
221
+ }
222
+
223
+ // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
224
+ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
225
+ std::vector<size_t> bpe_offsets; // store the offset of each word
226
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
227
+
228
+ const auto cpts = unicode_cpts_from_utf8(text);
229
+
230
+ size_t start = 0;
231
+ for (auto offset : offsets) {
232
+ const size_t offset_ini = start;
233
+ const size_t offset_end = start + offset;
234
+ assert(offset_end <= cpts.size());
235
+ start = offset_end;
236
+
237
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
238
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
239
+ };
240
+
241
+ auto _get_cpt_type = [&] (const size_t pos) -> int {
242
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
243
+ };
244
+
245
+ size_t _prev_end = offset_ini;
246
+ auto _add_token = [&] (const size_t end) -> size_t {
247
+ assert(_prev_end <= end && end <= offset_end);
248
+ size_t len = end - _prev_end;
249
+ if (len > 0) {
250
+ bpe_offsets.push_back(len);
251
+ }
252
+ _prev_end = end;
253
+ //if (len > 0) {
254
+ // std::string s = "";
255
+ // for(size_t p = end-len; p < end; p++)
256
+ // s += unicode_cpt_to_utf8(cpts[p]);
257
+ // printf(">>> '%s'\n", s.c_str());
258
+ //}
259
+ return len;
260
+ };
261
+
262
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
263
+ const char32_t cpt = _get_cpt(pos);
264
+ const int cpt_type = _get_cpt_type(pos);
265
+
266
+ // regex: 's|'t|'re|'ve|'m|'ll|'d
267
+ if (cpt == '\'' && pos+1 < offset_end) {
268
+ char32_t cpt_next = _get_cpt(pos+1);
269
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
270
+ pos += _add_token(pos+2);
271
+ continue;
272
+ }
273
+ if (pos+2 < offset_end) {
274
+ char32_t cpt_next_next = _get_cpt(pos+2);
275
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
276
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
277
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
278
+ pos += _add_token(pos+3);
279
+ continue;
280
+ }
281
+ }
282
+ }
283
+
284
+ char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
285
+ int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
286
+ // regex: <space>?\p{L}+
287
+ if (cpt2_type == CODEPOINT_TYPE_LETTER) {
288
+ pos += (cpt == ' ');
289
+ while (cpt2_type == CODEPOINT_TYPE_LETTER) {
290
+ cpt2_type = _get_cpt_type(++pos);
291
+ }
292
+ _add_token(pos);
293
+ continue;
294
+ }
295
+ // regex: <space>?\p{N}+
296
+ if (cpt2_type == CODEPOINT_TYPE_NUMBER) {
297
+ pos += (cpt == ' ');
298
+ while (cpt2_type == CODEPOINT_TYPE_NUMBER) {
299
+ cpt2_type = _get_cpt_type(++pos);
300
+ }
301
+ _add_token(pos);
302
+ continue;
303
+ }
304
+ // regex: <space>?[^\s\p{L}\p{N}]+
305
+ if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
306
+ pos += (cpt == ' ');
307
+ while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
308
+ cpt2_type = _get_cpt_type(++pos);
309
+ cpt2 = _get_cpt(pos);
310
+ }
311
+ _add_token(pos);
312
+ continue;
313
+ }
314
+
315
+ size_t num_whitespaces = 0;
316
+ while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
317
+ num_whitespaces++;
318
+ }
319
+
320
+ // regex: \s+(?!\S)
321
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
322
+ pos += num_whitespaces - 1;
323
+ _add_token(pos);
324
+ continue;
325
+ }
326
+
327
+ // regex: \s+
328
+ if (num_whitespaces > 0) {
329
+ pos += num_whitespaces;
330
+ _add_token(pos);
331
+ continue;
332
+ }
333
+
334
+ // no matches
335
+ _add_token(++pos);
336
+ }
337
+ }
338
+
339
+ return bpe_offsets;
340
+ }
341
+
342
+ // LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
343
+ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
344
+ std::vector<size_t> bpe_offsets; // store the offset of each word
345
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
346
+
347
+ const auto cpts = unicode_cpts_from_utf8(text);
348
+
349
+ size_t start = 0;
350
+ for (auto offset : offsets) {
351
+ const size_t offset_ini = start;
352
+ const size_t offset_end = start + offset;
353
+ assert(offset_end <= cpts.size());
354
+ start = offset_end;
355
+
356
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
357
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
358
+ };
359
+
360
+ auto _get_cpt_type = [&] (const size_t pos) -> int {
361
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
362
+ };
363
+
364
+ size_t _prev_end = offset_ini;
365
+ auto _add_token = [&] (const size_t end) -> size_t {
366
+ assert(_prev_end <= end && end <= offset_end);
367
+ size_t len = end - _prev_end;
368
+ if (len > 0) {
369
+ bpe_offsets.push_back(len);
370
+ }
371
+ _prev_end = end;
372
+ //if (len > 0) {
373
+ // std::string s = "";
374
+ // for(size_t p = end-len; p < end; p++)
375
+ // s += unicode_cpt_to_utf8(cpts[p]);
376
+ // printf(">>> '%s'\n", s.c_str());
377
+ //}
378
+ return len;
379
+ };
380
+
381
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
382
+ const char32_t cpt = _get_cpt(pos);
383
+ const int cpt_type = _get_cpt_type(pos);
384
+
385
+ // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
386
+ if (cpt == '\'' && pos+1 < offset_end) {
387
+ char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
388
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
389
+ pos += _add_token(pos+2);
390
+ continue;
391
+ }
392
+ if (pos+2 < offset_end) {
393
+ char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
394
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
395
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
396
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
397
+ pos += _add_token(pos+3);
398
+ continue;
399
+ }
400
+ }
401
+ }
402
+
403
+ // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
404
+ if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) {
405
+ if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters
406
+ pos++;
407
+ while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) {
408
+ pos++;
409
+ }
410
+ _add_token(pos);
411
+ continue;
412
+ }
413
+ }
414
+
415
+ // regex: \p{N}{1,3}
416
+ if (cpt_type == CODEPOINT_TYPE_NUMBER) {
417
+ size_t ini = pos;
418
+ while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) {
419
+ if (++pos - ini >= 3 ) {
420
+ _add_token(pos);
421
+ ini = pos;
422
+ }
423
+ }
424
+ _add_token(pos);
425
+ continue;
426
+ }
427
+
428
+ // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
429
+ char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
430
+ int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
431
+ if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
432
+ pos += (cpt == ' ');
433
+ while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
434
+ cpt2_type = _get_cpt_type(++pos);
435
+ cpt2 = _get_cpt(pos);
436
+ }
437
+ while (cpt2 == '\r' || cpt2 == '\n') {
438
+ cpt2 = _get_cpt(++pos);
439
+ }
440
+ _add_token(pos);
441
+ continue;
442
+ }
443
+
444
+ size_t num_whitespaces = 0;
445
+ size_t last_end_r_or_n = 0;
446
+ while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
447
+ char32_t cpt2 = _get_cpt(pos+num_whitespaces);
448
+ if (cpt2 == '\r' || cpt2 == '\n') {
449
+ last_end_r_or_n = pos + num_whitespaces + 1;
450
+ }
451
+ num_whitespaces++;
452
+ }
453
+
454
+ // regex: \s*[\r\n]+
455
+ if (last_end_r_or_n > 0) {
456
+ pos = last_end_r_or_n;
457
+ _add_token(pos);
458
+ continue;
459
+ }
460
+
461
+ // regex: \s+(?!\S)
462
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
463
+ pos += num_whitespaces - 1;
464
+ _add_token(pos);
465
+ continue;
466
+ }
467
+
468
+ // regex: \s+
469
+ if (num_whitespaces > 0) {
470
+ pos += num_whitespaces;
471
+ _add_token(pos);
472
+ continue;
473
+ }
474
+
475
+ // no matches
476
+ _add_token(++pos);
477
+ }
478
+ }
479
+
480
+ return bpe_offsets;
481
+ }
482
+
483
+ // use std::wregex to split the text
484
+ static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
485
+ std::wregex expr(regex_expr);
486
+ std::vector<size_t> bpe_offsets; // store the offset of each word
487
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
488
+ size_t start = 0;
489
+ for (auto offset : offsets) {
490
+ std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
491
+ std::wcregex_iterator end;
492
+
493
+ int64_t start_idx = 0;
494
+ while (it != end) {
495
+ std::wcmatch match = *it;
496
+ if (match.position() > start_idx) {
497
+ bpe_offsets.emplace_back(match.position() - start_idx);
498
+ }
499
+ bpe_offsets.emplace_back(match.length());
500
+ start_idx = match.position() + match.length();
501
+ ++it;
502
+ }
503
+
504
+ if (start_idx < (int64_t) offset) {
505
+ bpe_offsets.emplace_back(offset - start_idx);
506
+ }
507
+ start += offset;
508
+ }
509
+
510
+ return bpe_offsets;
511
+ }
512
+
513
+ // use std::regex to split the text
514
+ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
515
+ std::regex expr(regex_expr);
516
+ std::vector<size_t> bpe_offsets; // store the offset of each word
517
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
518
+ size_t start = 0;
519
+ for (auto offset : offsets) {
520
+ std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
521
+ std::cregex_iterator end;
522
+
523
+ int64_t start_idx = 0;
524
+ while (it != end) {
525
+ std::cmatch match = *it;
526
+ if (match.position() > start_idx) {
527
+ bpe_offsets.emplace_back(match.position() - start_idx);
528
+ }
529
+ bpe_offsets.emplace_back(match.length());
530
+ start_idx = match.position() + match.length();
531
+ ++it;
532
+ }
533
+
534
+ if (start_idx < (int64_t) offset) {
535
+ bpe_offsets.emplace_back(offset - start_idx);
536
+ }
537
+ start += offset;
538
+ }
539
+
540
+ return bpe_offsets;
541
+ }
542
+
543
+ static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
544
+ std::vector<size_t> bpe_offsets;
545
+
546
+ if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
547
+ bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
548
+ } else if (
549
+ regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
550
+ regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
551
+
552
+ bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
553
+ }
554
+
555
+ return bpe_offsets;
556
+ }
557
+
197
558
  //
198
559
  // interface
199
560
  //
200
561
 
201
562
  std::string unicode_cpt_to_utf8(uint32_t cp) {
202
563
  std::string result;
564
+
203
565
  if (/* 0x00 <= cp && */ cp <= 0x7f) {
204
566
  result.push_back(cp);
567
+ return result;
205
568
  }
206
- else if (0x80 <= cp && cp <= 0x7ff) {
569
+ if (0x80 <= cp && cp <= 0x7ff) {
207
570
  result.push_back(0xc0 | ((cp >> 6) & 0x1f));
208
571
  result.push_back(0x80 | (cp & 0x3f));
572
+ return result;
209
573
  }
210
- else if (0x800 <= cp && cp <= 0xffff) {
574
+ if (0x800 <= cp && cp <= 0xffff) {
211
575
  result.push_back(0xe0 | ((cp >> 12) & 0x0f));
212
576
  result.push_back(0x80 | ((cp >> 6) & 0x3f));
213
577
  result.push_back(0x80 | (cp & 0x3f));
578
+ return result;
214
579
  }
215
- else if (0x10000 <= cp && cp <= 0x10ffff) {
580
+ if (0x10000 <= cp && cp <= 0x10ffff) {
216
581
  result.push_back(0xf0 | ((cp >> 18) & 0x07));
217
582
  result.push_back(0x80 | ((cp >> 12) & 0x3f));
218
583
  result.push_back(0x80 | ((cp >> 6) & 0x3f));
219
584
  result.push_back(0x80 | (cp & 0x3f));
585
+ return result;
220
586
  }
221
- else {
222
- throw std::invalid_argument("invalid codepoint");
223
- }
224
- return result;
587
+
588
+ throw std::invalid_argument("invalid codepoint");
225
589
  }
226
590
 
227
591
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
@@ -261,6 +625,19 @@ int unicode_cpt_type(const std::string & utf8) {
261
625
  return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
262
626
  }
263
627
 
628
+ bool unicode_cpt_is_whitespace(uint32_t cp) {
629
+ static const std::unordered_set<uint32_t> is_whitespace = [] {
630
+ std::unordered_set<uint32_t> is_whitespace;
631
+ for (auto p : unicode_ranges_whitespace) {
632
+ for (auto i = p.first; i <= p.second; ++i) {
633
+ is_whitespace.insert(i);
634
+ }
635
+ }
636
+ return is_whitespace;
637
+ }();
638
+ return (bool)is_whitespace.count(cp);
639
+ }
640
+
264
641
  std::string unicode_byte_to_utf8(uint8_t byte) {
265
642
  static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
266
643
  return map.at(byte);
@@ -275,3 +652,167 @@ char32_t unicode_tolower(char32_t cp) {
275
652
  auto it = unicode_map_lowercase.find(cp);
276
653
  return it == unicode_map_lowercase.end() ? cp : it->second;
277
654
  }
655
+
656
+ std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
657
+ // unicode categories
658
+ static const std::map<std::string, int> k_ucat_enum = {
659
+ { "\\p{N}", CODEPOINT_TYPE_NUMBER },
660
+ { "\\p{L}", CODEPOINT_TYPE_LETTER },
661
+ { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
662
+ };
663
+
664
+ static const std::map<int, int> k_ucat_cpt = {
665
+ { CODEPOINT_TYPE_NUMBER, 0xD1 },
666
+ { CODEPOINT_TYPE_LETTER, 0xD2 },
667
+ { CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
668
+ };
669
+
670
+ static const std::map<int, std::string> k_ucat_map = {
671
+ { CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9
672
+ { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
673
+ { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
674
+ };
675
+
676
+ // compute collapsed codepoints only if needed by at least one regex
677
+ bool need_collapse = false;
678
+ for (auto & regex_expr : regex_exprs) {
679
+ // search for unicode categories
680
+ for (const auto & ucat : k_ucat_enum) {
681
+ if (std::string::npos != regex_expr.find(ucat.first)) {
682
+ need_collapse = true;
683
+ break;
684
+ }
685
+ }
686
+ }
687
+
688
+ const auto cpts = unicode_cpts_from_utf8(text);
689
+
690
+ // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
691
+ // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
692
+ std::string text_collapsed;
693
+ if (need_collapse) {
694
+ // collapse all unicode categories
695
+ text_collapsed.resize(cpts.size());
696
+
697
+ for (size_t i = 0; i < cpts.size(); ++i) {
698
+ // keep single-byte codepoints as is
699
+ if (cpts[i] < 128) {
700
+ text_collapsed[i] = cpts[i];
701
+ continue;
702
+ }
703
+
704
+ const int cpt_type = unicode_cpt_type(cpts[i]);
705
+
706
+ if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
707
+ text_collapsed[i] = k_ucat_cpt.at(cpt_type);
708
+ } else {
709
+ text_collapsed[i] = (char) 0xD0; // fallback
710
+ }
711
+ }
712
+ }
713
+
714
+ std::vector<size_t> bpe_offsets = { cpts.size() };
715
+
716
+ for (auto & regex_expr : regex_exprs) {
717
+ // first, see if we have an efficient custom regex implementation
718
+ auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
719
+
720
+ if (!tmp.empty()) {
721
+ bpe_offsets = std::move(tmp);
722
+ continue;
723
+ }
724
+
725
+ // fallback to general-purpose std::regex / std::wregex
726
+ try {
727
+ // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
728
+ // with the corresponding collapsed representation
729
+ bool use_collapsed = false;
730
+ for (auto & ucat : k_ucat_enum) {
731
+ if (std::string::npos != regex_expr.find(ucat.first)) {
732
+ use_collapsed = true;
733
+ break;
734
+ }
735
+ }
736
+
737
+ if (use_collapsed) {
738
+ // sanity-check that the original regex does not contain any non-ASCII characters
739
+ const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
740
+ for (size_t i = 0; i < cpts_regex.size(); ++i) {
741
+ if (cpts_regex[i] >= 128) {
742
+ throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
743
+ }
744
+ }
745
+
746
+ // generate a collapsed representation of the regex
747
+ std::string regex_expr_collapsed;
748
+
749
+ // track if we are inside [], because nested [] are not allowed
750
+ bool inside = false;
751
+ for (size_t i = 0; i < regex_expr.size(); ++i) {
752
+ if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
753
+ regex_expr_collapsed += '[';
754
+ inside = true;
755
+ continue;
756
+ }
757
+
758
+ if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
759
+ regex_expr_collapsed += ']';
760
+ inside = false;
761
+ continue;
762
+ }
763
+
764
+ if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
765
+ regex_expr[i + 1] == 'p' &&
766
+ regex_expr[i + 2] == '{' &&
767
+ regex_expr[i + 4] == '}') {
768
+ const std::string pat = regex_expr.substr(i, 5);
769
+ if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
770
+ if (!inside) {
771
+ regex_expr_collapsed += '[';
772
+ }
773
+ regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
774
+ regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
775
+ if (!inside) {
776
+ regex_expr_collapsed += ']';
777
+ }
778
+ i += 4;
779
+ continue;
780
+ }
781
+ }
782
+
783
+ regex_expr_collapsed += regex_expr[i];
784
+ }
785
+
786
+ //printf("text_collapsed: %s\n", text_collapsed.c_str());
787
+ //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
788
+ bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
789
+ } else {
790
+ // no unicode category used, we can use std::wregex directly
791
+ const std::wstring wtext = unicode_wstring_from_utf8(text);
792
+ const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
793
+
794
+ //printf("text: %s\n", text.c_str());
795
+ //printf("regex_expr: %s\n", regex_expr.c_str());
796
+ bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
797
+ }
798
+ } catch (std::regex_error & e) {
799
+ fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
800
+ fprintf(stderr, "Regex error: %s\n", e.what());
801
+ throw std::runtime_error("Failed to process regex");
802
+ }
803
+ }
804
+
805
+ std::vector<std::string> bpe_words;
806
+ bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
807
+
808
+ size_t start = 0;
809
+ for (size_t & offset : bpe_offsets) {
810
+ bpe_words.emplace_back();
811
+ for (size_t i = start; i < start + offset; ++i) {
812
+ bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
813
+ }
814
+ start += offset;
815
+ }
816
+
817
+ return unicode_byte_encoding_process(bpe_words);
818
+ }