llama_cpp 0.15.0 → 0.15.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,20 @@
1
1
  #pragma once
2
2
 
3
3
  #include <cstdint>
4
- #include <map>
5
- #include <utility>
6
4
  #include <vector>
5
+ #include <unordered_map>
6
+ #include <unordered_set>
7
7
 
8
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_digit;
9
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
10
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
11
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
12
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
13
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
14
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
15
- extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
16
- extern const std::map<char32_t, char32_t> unicode_map_lowercase;
8
+ struct range_nfd {
9
+ uint32_t first;
10
+ uint32_t last;
11
+ uint32_t nfd;
12
+ };
13
+
14
+ static const uint32_t MAX_CODEPOINTS = 0x110000;
15
+
16
+ extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
17
+ extern const std::unordered_set<uint32_t> unicode_set_whitespace;
18
+ extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
19
+ extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
20
+ extern const std::vector<range_nfd> unicode_ranges_nfd;
@@ -1,4 +1,4 @@
1
- #include "unicode.h"
1
+ #include "unicode.h"
2
2
  #include "unicode-data.h"
3
3
 
4
4
  #include <cassert>
@@ -9,6 +9,7 @@
9
9
  #include <stdexcept>
10
10
  #include <string>
11
11
  #include <unordered_map>
12
+ #include <unordered_set>
12
13
  #include <utility>
13
14
  #include <vector>
14
15
  #include <locale>
@@ -108,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
108
109
  // return result;
109
110
  //}
110
111
 
111
- static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
112
- std::unordered_map<uint32_t, int> cpt_types;
113
- for (auto p : unicode_ranges_digit) {
114
- for (auto i = p.first; i <= p.second; ++ i) {
115
- cpt_types[i] = CODEPOINT_TYPE_DIGIT;
116
- }
117
- }
118
- for (auto p : unicode_ranges_letter) {
119
- for (auto i = p.first; i <= p.second; ++ i) {
120
- cpt_types[i] = CODEPOINT_TYPE_LETTER;
121
- }
122
- }
123
- for (auto p : unicode_ranges_whitespace) {
124
- for (auto i = p.first; i <= p.second; ++ i) {
125
- cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
112
+ static std::vector<codepoint_flags> unicode_cpt_flags_array() {
113
+ std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
114
+
115
+ assert (unicode_ranges_flags.front().first == 0);
116
+ assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
117
+ for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
118
+ const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
119
+ const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
120
+ for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
121
+ cpt_flags[cpt] = range_ini.second;
126
122
  }
127
123
  }
128
- for (auto p : unicode_ranges_accent_mark) {
129
- for (auto i = p.first; i <= p.second; ++ i) {
130
- cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
131
- }
124
+
125
+ for (auto cpt : unicode_set_whitespace) {
126
+ cpt_flags[cpt].is_whitespace = true;
132
127
  }
133
- for (auto p : unicode_ranges_punctuation) {
134
- for (auto i = p.first; i <= p.second; ++ i) {
135
- cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
136
- }
128
+
129
+ for (auto p : unicode_map_lowercase) {
130
+ cpt_flags[p.second].is_lowercase = true;
137
131
  }
138
- for (auto p : unicode_ranges_symbol) {
139
- for (auto i = p.first; i <= p.second; ++i) {
140
- cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
141
- }
132
+
133
+ for (auto p : unicode_map_uppercase) {
134
+ cpt_flags[p.second].is_uppercase = true;
142
135
  }
143
- for (auto p : unicode_ranges_control) {
144
- for (auto i = p.first; i <= p.second; ++ i) {
145
- cpt_types[i] = CODEPOINT_TYPE_CONTROL;
146
- }
136
+
137
+ for (auto &range : unicode_ranges_nfd) { // start, last, nfd
138
+ cpt_flags[range.nfd].is_nfd = true;
147
139
  }
148
- return cpt_types;
140
+
141
+ return cpt_flags;
149
142
  }
150
143
 
151
144
  static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
152
145
  std::unordered_map<uint8_t, std::string> map;
153
- for (int ch = u'!'; ch <= u'~'; ++ch) {
146
+ for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
154
147
  assert(0 <= ch && ch < 256);
155
148
  map[ch] = unicode_cpt_to_utf8(ch);
156
149
  }
157
- for (int ch = u'¡'; ch <= u'¬'; ++ch) {
150
+ for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
158
151
  assert(0 <= ch && ch < 256);
159
152
  map[ch] = unicode_cpt_to_utf8(ch);
160
153
  }
161
- for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
154
+ for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
162
155
  assert(0 <= ch && ch < 256);
163
156
  map[ch] = unicode_cpt_to_utf8(ch);
164
157
  }
@@ -174,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
174
167
 
175
168
  static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
176
169
  std::unordered_map<std::string, uint8_t> map;
177
- for (int ch = u'!'; ch <= u'~'; ++ch) {
170
+ for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
178
171
  assert(0 <= ch && ch < 256);
179
172
  map[unicode_cpt_to_utf8(ch)] = ch;
180
173
  }
181
- for (int ch = u'¡'; ch <= u'¬'; ++ch) {
174
+ for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
182
175
  assert(0 <= ch && ch < 256);
183
176
  map[unicode_cpt_to_utf8(ch)] = ch;
184
177
  }
185
- for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
178
+ for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
186
179
  assert(0 <= ch && ch < 256);
187
180
  map[unicode_cpt_to_utf8(ch)] = ch;
188
181
  }
@@ -224,138 +217,255 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
224
217
  std::vector<size_t> bpe_offsets; // store the offset of each word
225
218
  bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
226
219
 
227
- size_t start = 0;
228
-
229
220
  const auto cpts = unicode_cpts_from_utf8(text);
230
221
 
222
+ size_t start = 0;
231
223
  for (auto offset : offsets) {
232
- std::string token;
224
+ const size_t offset_ini = start;
225
+ const size_t offset_end = start + offset;
226
+ assert(offset_end <= cpts.size());
227
+ start = offset_end;
228
+
229
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
230
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
231
+ };
232
+
233
+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
234
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
235
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
236
+ };
237
+
238
+ size_t _prev_end = offset_ini;
239
+ auto _add_token = [&] (const size_t end) -> size_t {
240
+ assert(_prev_end <= end && end <= offset_end);
241
+ size_t len = end - _prev_end;
242
+ if (len > 0) {
243
+ bpe_offsets.push_back(len);
244
+ }
245
+ _prev_end = end;
246
+ //if (len > 0) {
247
+ // std::string s = "";
248
+ // for(size_t p = end-len; p < end; p++)
249
+ // s += unicode_cpt_to_utf8(cpts[p]);
250
+ // printf(">>> '%s'\n", s.c_str());
251
+ //}
252
+ return len;
253
+ };
254
+
255
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
256
+ const char32_t cpt = _get_cpt(pos);
257
+ const auto flags = _get_flags(pos);
258
+
259
+ // regex: 's|'t|'re|'ve|'m|'ll|'d
260
+ if (cpt == '\'' && pos+1 < offset_end) {
261
+ char32_t cpt_next = _get_cpt(pos+1);
262
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
263
+ pos += _add_token(pos+2);
264
+ continue;
265
+ }
266
+ if (pos+2 < offset_end) {
267
+ char32_t cpt_next_next = _get_cpt(pos+2);
268
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
269
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
270
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
271
+ pos += _add_token(pos+3);
272
+ continue;
273
+ }
274
+ }
275
+ }
276
+
277
+ auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
278
+ // regex: <space>?\p{L}+
279
+ if (flags2.is_letter) {
280
+ pos += (cpt == ' ');
281
+ while (flags2.is_letter) {
282
+ flags2 = _get_flags(++pos);
283
+ }
284
+ _add_token(pos);
285
+ continue;
286
+ }
287
+ // regex: <space>?\p{N}+
288
+ if (flags2.is_number) {
289
+ pos += (cpt == ' ');
290
+ while (flags2.is_number) {
291
+ flags2 = _get_flags(++pos);
292
+ }
293
+ _add_token(pos);
294
+ continue;
295
+ }
296
+ // regex: <space>?[^\s\p{L}\p{N}]+
297
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
298
+ pos += (cpt == ' ');
299
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
300
+ flags2 = _get_flags(++pos);
301
+ }
302
+ _add_token(pos);
303
+ continue;
304
+ }
305
+
306
+ size_t num_whitespaces = 0;
307
+ while (_get_flags(pos+num_whitespaces).is_whitespace) {
308
+ num_whitespaces++;
309
+ }
233
310
 
234
- bool collecting_numeric = false;
235
- bool collecting_letter = false;
236
- bool collecting_special = false;
237
- bool collecting_whitespace_lookahead = false;
238
- bool collecting = false;
311
+ // regex: \s+(?!\S)
312
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
313
+ pos += num_whitespaces - 1;
314
+ _add_token(pos);
315
+ continue;
316
+ }
239
317
 
240
- std::vector<std::string> text_utf;
241
- text_utf.reserve(offset);
318
+ // regex: \s+
319
+ if (num_whitespaces > 0) {
320
+ pos += num_whitespaces;
321
+ _add_token(pos);
322
+ continue;
323
+ }
242
324
 
243
- for (size_t i = start; i < start + offset; ++i) {
244
- text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
325
+ // no matches
326
+ _add_token(++pos);
245
327
  }
328
+ }
246
329
 
247
- for (int i = 0; i < (int)text_utf.size(); i++) {
248
- const std::string & utf_char = text_utf[i];
249
- bool split_condition = false;
250
- int bytes_remain = text_utf.size() - i;
330
+ return bpe_offsets;
331
+ }
332
+
333
+ // LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
334
+ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
335
+ std::vector<size_t> bpe_offsets; // store the offset of each word
336
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
251
337
 
252
- // forward backward lookups
253
- const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
254
- const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
338
+ const auto cpts = unicode_cpts_from_utf8(text);
255
339
 
256
- // handling contractions
257
- if (!split_condition && bytes_remain >= 2) {
258
- // 's|'t|'m|'d
259
- if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
260
- split_condition = true;
340
+ size_t start = 0;
341
+ for (auto offset : offsets) {
342
+ const size_t offset_ini = start;
343
+ const size_t offset_end = start + offset;
344
+ assert(offset_end <= cpts.size());
345
+ start = offset_end;
346
+
347
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
348
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
349
+ };
350
+
351
+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
352
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
353
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
354
+ };
355
+
356
+ size_t _prev_end = offset_ini;
357
+ auto _add_token = [&] (const size_t end) -> size_t {
358
+ assert(_prev_end <= end && end <= offset_end);
359
+ size_t len = end - _prev_end;
360
+ if (len > 0) {
361
+ bpe_offsets.push_back(len);
362
+ }
363
+ _prev_end = end;
364
+ //if (len > 0) {
365
+ // std::string s = "";
366
+ // for(size_t p = end-len; p < end; p++)
367
+ // s += unicode_cpt_to_utf8(cpts[p]);
368
+ // printf(">>> '%s'\n", s.c_str());
369
+ //}
370
+ return len;
371
+ };
372
+
373
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
374
+ const char32_t cpt = _get_cpt(pos);
375
+ const auto flags = _get_flags(pos);
376
+
377
+ // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
378
+ if (cpt == '\'' && pos+1 < offset_end) {
379
+ char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
380
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
381
+ pos += _add_token(pos+2);
382
+ continue;
261
383
  }
262
- if (split_condition) {
263
- if (token.size()) {
264
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
384
+ if (pos+2 < offset_end) {
385
+ char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
386
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
387
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
388
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
389
+ pos += _add_token(pos+3);
390
+ continue;
265
391
  }
266
- token = utf_char + utf_char_next;
267
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
268
- token = "";
269
- i++;
270
- continue;
271
392
  }
272
393
  }
273
- if (!split_condition && bytes_remain >= 3) {
274
- // 're|'ve|'ll
275
- if (utf_char == "\'" && (
276
- (utf_char_next == "r" && utf_char_next_next == "e") ||
277
- (utf_char_next == "v" && utf_char_next_next == "e") ||
278
- (utf_char_next == "l" && utf_char_next_next == "l"))
279
- ) {
280
- split_condition = true;
281
- }
282
- if (split_condition) {
283
- // current token + next token can be defined
284
- if (token.size()) {
285
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
286
- }
287
- token = utf_char;
288
- token += utf_char_next;
289
- token += utf_char_next_next;
290
394
 
291
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
292
- token = "";
293
- i += 2;
395
+ // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
396
+ if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
397
+ if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
398
+ pos++;
399
+ while (_get_flags(pos).is_letter) {
400
+ pos++;
401
+ }
402
+ _add_token(pos);
294
403
  continue;
295
404
  }
296
405
  }
297
406
 
298
- if (!split_condition && !collecting) {
299
- if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
300
- collecting_letter = true;
301
- collecting = true;
302
- }
303
- else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
304
- collecting_numeric = true;
305
- collecting = true;
306
- }
307
- else if (
308
- ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
309
- (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
310
- ) {
311
- collecting_special = true;
312
- collecting = true;
313
- }
314
- else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
315
- collecting_whitespace_lookahead = true;
316
- collecting = true;
317
- }
318
- else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
319
- split_condition = true;
407
+ // regex: \p{N}{1,3}
408
+ if (flags.is_number) {
409
+ size_t ini = pos;
410
+ while (_get_flags(pos).is_number) {
411
+ if (++pos - ini >= 3 ) {
412
+ _add_token(pos);
413
+ ini = pos;
414
+ }
320
415
  }
416
+ _add_token(pos);
417
+ continue;
321
418
  }
322
- else if (!split_condition && collecting) {
323
- if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
324
- split_condition = true;
325
- }
326
- else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
327
- split_condition = true;
419
+
420
+ // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
421
+ auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
422
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
423
+ pos += (cpt == ' ');
424
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
425
+ flags2 = _get_flags(++pos);
328
426
  }
329
- else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
330
- split_condition = true;
427
+ char32_t cpt2 = _get_cpt(pos);
428
+ while (cpt2 == '\r' || cpt2 == '\n') {
429
+ cpt2 = _get_cpt(++pos);
331
430
  }
332
- else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
333
- split_condition = true;
431
+ _add_token(pos);
432
+ continue;
433
+ }
434
+
435
+ size_t num_whitespaces = 0;
436
+ size_t last_end_r_or_n = 0;
437
+ while (_get_flags(pos+num_whitespaces).is_whitespace) {
438
+ char32_t cpt2 = _get_cpt(pos+num_whitespaces);
439
+ if (cpt2 == '\r' || cpt2 == '\n') {
440
+ last_end_r_or_n = pos + num_whitespaces + 1;
334
441
  }
442
+ num_whitespaces++;
335
443
  }
336
444
 
337
- if (utf_char_next == "") {
338
- split_condition = true; // final
339
- token += utf_char;
445
+ // regex: \s*[\r\n]+
446
+ if (last_end_r_or_n > 0) {
447
+ pos = last_end_r_or_n;
448
+ _add_token(pos);
449
+ continue;
340
450
  }
341
451
 
342
- if (split_condition) {
343
- if (token.size()) {
344
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
345
- }
346
- token = utf_char;
347
- collecting = false;
348
- collecting_letter = false;
349
- collecting_numeric = false;
350
- collecting_special = false;
351
- collecting_whitespace_lookahead = false;
452
+ // regex: \s+(?!\S)
453
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
454
+ pos += num_whitespaces - 1;
455
+ _add_token(pos);
456
+ continue;
352
457
  }
353
- else {
354
- token += utf_char;
458
+
459
+ // regex: \s+
460
+ if (num_whitespaces > 0) {
461
+ pos += num_whitespaces;
462
+ _add_token(pos);
463
+ continue;
355
464
  }
356
- }
357
465
 
358
- start += offset;
466
+ // no matches
467
+ _add_token(++pos);
468
+ }
359
469
  }
360
470
 
361
471
  return bpe_offsets;
@@ -424,14 +534,14 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
424
534
  static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
425
535
  std::vector<size_t> bpe_offsets;
426
536
 
427
- (void)(text);
428
- (void)(regex_expr);
429
- (void)(offsets);
430
- // TODO: this implementation is actually wrong, uncomment and run:
431
- // make -j && ./bin/test-tokenizer-0 ../models/ggml-vocab-gpt-2.gguf
432
- //if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
433
- // bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
434
- //}
537
+ if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
538
+ bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
539
+ } else if (
540
+ regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
541
+ regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
542
+
543
+ bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
544
+ }
435
545
 
436
546
  return bpe_offsets;
437
547
  }
@@ -470,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
470
580
  }
471
581
 
472
582
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
473
- std::vector<uint32_t> result;
474
- result.reserve(cpts.size());
583
+ auto comp = [] (const uint32_t cpt, const range_nfd & range) {
584
+ return cpt < range.first;
585
+ };
586
+ std::vector<uint32_t> result(cpts.size());
475
587
  for (size_t i = 0; i < cpts.size(); ++i) {
476
- auto it = unicode_map_nfd.find(cpts[i]);
477
- if (it == unicode_map_nfd.end()) {
478
- result.push_back(cpts[i]);
479
- } else {
480
- result.push_back(it->second);
481
- }
588
+ const uint32_t cpt = cpts[i];
589
+ auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
590
+ result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
482
591
  }
483
592
  return result;
484
593
  }
@@ -492,18 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
492
601
  return result;
493
602
  }
494
603
 
495
- int unicode_cpt_type(uint32_t cp) {
496
- static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map();
497
- const auto it = cpt_types.find(cp);
498
- return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second;
604
+ codepoint_flags unicode_cpt_flags(const uint32_t cp) {
605
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
606
+ static const auto cpt_flags = unicode_cpt_flags_array();
607
+ return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
499
608
  }
500
609
 
501
- int unicode_cpt_type(const std::string & utf8) {
502
- if (utf8.length() == 0) {
503
- return CODEPOINT_TYPE_UNIDENTIFIED;
610
+ codepoint_flags unicode_cpt_flags(const std::string & utf8) {
611
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
612
+ if (utf8.empty()) {
613
+ return undef; // undefined
504
614
  }
505
615
  size_t offset = 0;
506
- return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
616
+ return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
507
617
  }
508
618
 
509
619
  std::string unicode_byte_to_utf8(uint8_t byte) {
@@ -524,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
524
634
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
525
635
  // unicode categories
526
636
  static const std::map<std::string, int> k_ucat_enum = {
527
- { "\\p{N}", CODEPOINT_TYPE_DIGIT },
528
- { "\\p{L}", CODEPOINT_TYPE_LETTER },
529
- { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
637
+ { "\\p{N}", codepoint_flags::NUMBER },
638
+ { "\\p{L}", codepoint_flags::LETTER },
639
+ { "\\p{P}", codepoint_flags::PUNCTUATION },
530
640
  };
531
641
 
532
642
  static const std::map<int, int> k_ucat_cpt = {
533
- { CODEPOINT_TYPE_DIGIT, 0xD1 },
534
- { CODEPOINT_TYPE_LETTER, 0xD2 },
535
- { CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
643
+ { codepoint_flags::NUMBER, 0xD1 },
644
+ { codepoint_flags::LETTER, 0xD2 },
645
+ { codepoint_flags::PUNCTUATION, 0xD3 },
536
646
  };
537
647
 
538
648
  static const std::map<int, std::string> k_ucat_map = {
539
- { CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, // 0-9
540
- { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
541
- { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
649
+ { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
650
+ { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
651
+ { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
542
652
  };
543
653
 
544
654
  // compute collapsed codepoints only if needed by at least one regex
@@ -569,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
569
679
  continue;
570
680
  }
571
681
 
572
- const int cpt_type = unicode_cpt_type(cpts[i]);
682
+ const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
573
683
 
574
- if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
575
- text_collapsed[i] = k_ucat_cpt.at(cpt_type);
684
+ if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
685
+ text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
576
686
  } else {
577
687
  text_collapsed[i] = (char) 0xD0; // fallback
578
688
  }