llama_cpp 0.15.0 → 0.15.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,16 +1,20 @@
1
1
  #pragma once
2
2
 
3
3
  #include <cstdint>
4
- #include <map>
5
- #include <utility>
6
4
  #include <vector>
5
+ #include <unordered_map>
6
+ #include <unordered_set>
7
7
 
8
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_digit;
9
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
10
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
11
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
12
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
13
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
14
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
15
- extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
16
- extern const std::map<char32_t, char32_t> unicode_map_lowercase;
8
+ struct range_nfd {
9
+ uint32_t first;
10
+ uint32_t last;
11
+ uint32_t nfd;
12
+ };
13
+
14
+ static const uint32_t MAX_CODEPOINTS = 0x110000;
15
+
16
+ extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
17
+ extern const std::unordered_set<uint32_t> unicode_set_whitespace;
18
+ extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
19
+ extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
20
+ extern const std::vector<range_nfd> unicode_ranges_nfd;
@@ -1,4 +1,4 @@
1
- #include "unicode.h"
1
+ #include "unicode.h"
2
2
  #include "unicode-data.h"
3
3
 
4
4
  #include <cassert>
@@ -9,6 +9,7 @@
9
9
  #include <stdexcept>
10
10
  #include <string>
11
11
  #include <unordered_map>
12
+ #include <unordered_set>
12
13
  #include <utility>
13
14
  #include <vector>
14
15
  #include <locale>
@@ -108,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
108
109
  // return result;
109
110
  //}
110
111
 
111
- static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
112
- std::unordered_map<uint32_t, int> cpt_types;
113
- for (auto p : unicode_ranges_digit) {
114
- for (auto i = p.first; i <= p.second; ++ i) {
115
- cpt_types[i] = CODEPOINT_TYPE_DIGIT;
116
- }
117
- }
118
- for (auto p : unicode_ranges_letter) {
119
- for (auto i = p.first; i <= p.second; ++ i) {
120
- cpt_types[i] = CODEPOINT_TYPE_LETTER;
121
- }
122
- }
123
- for (auto p : unicode_ranges_whitespace) {
124
- for (auto i = p.first; i <= p.second; ++ i) {
125
- cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
112
+ static std::vector<codepoint_flags> unicode_cpt_flags_array() {
113
+ std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
114
+
115
+ assert (unicode_ranges_flags.front().first == 0);
116
+ assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
117
+ for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
118
+ const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
119
+ const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
120
+ for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
121
+ cpt_flags[cpt] = range_ini.second;
126
122
  }
127
123
  }
128
- for (auto p : unicode_ranges_accent_mark) {
129
- for (auto i = p.first; i <= p.second; ++ i) {
130
- cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
131
- }
124
+
125
+ for (auto cpt : unicode_set_whitespace) {
126
+ cpt_flags[cpt].is_whitespace = true;
132
127
  }
133
- for (auto p : unicode_ranges_punctuation) {
134
- for (auto i = p.first; i <= p.second; ++ i) {
135
- cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
136
- }
128
+
129
+ for (auto p : unicode_map_lowercase) {
130
+ cpt_flags[p.second].is_lowercase = true;
137
131
  }
138
- for (auto p : unicode_ranges_symbol) {
139
- for (auto i = p.first; i <= p.second; ++i) {
140
- cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
141
- }
132
+
133
+ for (auto p : unicode_map_uppercase) {
134
+ cpt_flags[p.second].is_uppercase = true;
142
135
  }
143
- for (auto p : unicode_ranges_control) {
144
- for (auto i = p.first; i <= p.second; ++ i) {
145
- cpt_types[i] = CODEPOINT_TYPE_CONTROL;
146
- }
136
+
137
+ for (auto &range : unicode_ranges_nfd) { // start, last, nfd
138
+ cpt_flags[range.nfd].is_nfd = true;
147
139
  }
148
- return cpt_types;
140
+
141
+ return cpt_flags;
149
142
  }
150
143
 
151
144
  static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
152
145
  std::unordered_map<uint8_t, std::string> map;
153
- for (int ch = u'!'; ch <= u'~'; ++ch) {
146
+ for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
154
147
  assert(0 <= ch && ch < 256);
155
148
  map[ch] = unicode_cpt_to_utf8(ch);
156
149
  }
157
- for (int ch = u'¡'; ch <= u'¬'; ++ch) {
150
+ for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
158
151
  assert(0 <= ch && ch < 256);
159
152
  map[ch] = unicode_cpt_to_utf8(ch);
160
153
  }
161
- for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
154
+ for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
162
155
  assert(0 <= ch && ch < 256);
163
156
  map[ch] = unicode_cpt_to_utf8(ch);
164
157
  }
@@ -174,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
174
167
 
175
168
  static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
176
169
  std::unordered_map<std::string, uint8_t> map;
177
- for (int ch = u'!'; ch <= u'~'; ++ch) {
170
+ for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
178
171
  assert(0 <= ch && ch < 256);
179
172
  map[unicode_cpt_to_utf8(ch)] = ch;
180
173
  }
181
- for (int ch = u'¡'; ch <= u'¬'; ++ch) {
174
+ for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
182
175
  assert(0 <= ch && ch < 256);
183
176
  map[unicode_cpt_to_utf8(ch)] = ch;
184
177
  }
185
- for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
178
+ for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
186
179
  assert(0 <= ch && ch < 256);
187
180
  map[unicode_cpt_to_utf8(ch)] = ch;
188
181
  }
@@ -224,138 +217,255 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
224
217
  std::vector<size_t> bpe_offsets; // store the offset of each word
225
218
  bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
226
219
 
227
- size_t start = 0;
228
-
229
220
  const auto cpts = unicode_cpts_from_utf8(text);
230
221
 
222
+ size_t start = 0;
231
223
  for (auto offset : offsets) {
232
- std::string token;
224
+ const size_t offset_ini = start;
225
+ const size_t offset_end = start + offset;
226
+ assert(offset_end <= cpts.size());
227
+ start = offset_end;
228
+
229
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
230
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
231
+ };
232
+
233
+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
234
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
235
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
236
+ };
237
+
238
+ size_t _prev_end = offset_ini;
239
+ auto _add_token = [&] (const size_t end) -> size_t {
240
+ assert(_prev_end <= end && end <= offset_end);
241
+ size_t len = end - _prev_end;
242
+ if (len > 0) {
243
+ bpe_offsets.push_back(len);
244
+ }
245
+ _prev_end = end;
246
+ //if (len > 0) {
247
+ // std::string s = "";
248
+ // for(size_t p = end-len; p < end; p++)
249
+ // s += unicode_cpt_to_utf8(cpts[p]);
250
+ // printf(">>> '%s'\n", s.c_str());
251
+ //}
252
+ return len;
253
+ };
254
+
255
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
256
+ const char32_t cpt = _get_cpt(pos);
257
+ const auto flags = _get_flags(pos);
258
+
259
+ // regex: 's|'t|'re|'ve|'m|'ll|'d
260
+ if (cpt == '\'' && pos+1 < offset_end) {
261
+ char32_t cpt_next = _get_cpt(pos+1);
262
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
263
+ pos += _add_token(pos+2);
264
+ continue;
265
+ }
266
+ if (pos+2 < offset_end) {
267
+ char32_t cpt_next_next = _get_cpt(pos+2);
268
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
269
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
270
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
271
+ pos += _add_token(pos+3);
272
+ continue;
273
+ }
274
+ }
275
+ }
276
+
277
+ auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
278
+ // regex: <space>?\p{L}+
279
+ if (flags2.is_letter) {
280
+ pos += (cpt == ' ');
281
+ while (flags2.is_letter) {
282
+ flags2 = _get_flags(++pos);
283
+ }
284
+ _add_token(pos);
285
+ continue;
286
+ }
287
+ // regex: <space>?\p{N}+
288
+ if (flags2.is_number) {
289
+ pos += (cpt == ' ');
290
+ while (flags2.is_number) {
291
+ flags2 = _get_flags(++pos);
292
+ }
293
+ _add_token(pos);
294
+ continue;
295
+ }
296
+ // regex: <space>?[^\s\p{L}\p{N}]+
297
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
298
+ pos += (cpt == ' ');
299
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
300
+ flags2 = _get_flags(++pos);
301
+ }
302
+ _add_token(pos);
303
+ continue;
304
+ }
305
+
306
+ size_t num_whitespaces = 0;
307
+ while (_get_flags(pos+num_whitespaces).is_whitespace) {
308
+ num_whitespaces++;
309
+ }
233
310
 
234
- bool collecting_numeric = false;
235
- bool collecting_letter = false;
236
- bool collecting_special = false;
237
- bool collecting_whitespace_lookahead = false;
238
- bool collecting = false;
311
+ // regex: \s+(?!\S)
312
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
313
+ pos += num_whitespaces - 1;
314
+ _add_token(pos);
315
+ continue;
316
+ }
239
317
 
240
- std::vector<std::string> text_utf;
241
- text_utf.reserve(offset);
318
+ // regex: \s+
319
+ if (num_whitespaces > 0) {
320
+ pos += num_whitespaces;
321
+ _add_token(pos);
322
+ continue;
323
+ }
242
324
 
243
- for (size_t i = start; i < start + offset; ++i) {
244
- text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
325
+ // no matches
326
+ _add_token(++pos);
245
327
  }
328
+ }
246
329
 
247
- for (int i = 0; i < (int)text_utf.size(); i++) {
248
- const std::string & utf_char = text_utf[i];
249
- bool split_condition = false;
250
- int bytes_remain = text_utf.size() - i;
330
+ return bpe_offsets;
331
+ }
332
+
333
+ // LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
334
+ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
335
+ std::vector<size_t> bpe_offsets; // store the offset of each word
336
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
251
337
 
252
- // forward backward lookups
253
- const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
254
- const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
338
+ const auto cpts = unicode_cpts_from_utf8(text);
255
339
 
256
- // handling contractions
257
- if (!split_condition && bytes_remain >= 2) {
258
- // 's|'t|'m|'d
259
- if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
260
- split_condition = true;
340
+ size_t start = 0;
341
+ for (auto offset : offsets) {
342
+ const size_t offset_ini = start;
343
+ const size_t offset_end = start + offset;
344
+ assert(offset_end <= cpts.size());
345
+ start = offset_end;
346
+
347
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
348
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
349
+ };
350
+
351
+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
352
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
353
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
354
+ };
355
+
356
+ size_t _prev_end = offset_ini;
357
+ auto _add_token = [&] (const size_t end) -> size_t {
358
+ assert(_prev_end <= end && end <= offset_end);
359
+ size_t len = end - _prev_end;
360
+ if (len > 0) {
361
+ bpe_offsets.push_back(len);
362
+ }
363
+ _prev_end = end;
364
+ //if (len > 0) {
365
+ // std::string s = "";
366
+ // for(size_t p = end-len; p < end; p++)
367
+ // s += unicode_cpt_to_utf8(cpts[p]);
368
+ // printf(">>> '%s'\n", s.c_str());
369
+ //}
370
+ return len;
371
+ };
372
+
373
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
374
+ const char32_t cpt = _get_cpt(pos);
375
+ const auto flags = _get_flags(pos);
376
+
377
+ // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
378
+ if (cpt == '\'' && pos+1 < offset_end) {
379
+ char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
380
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
381
+ pos += _add_token(pos+2);
382
+ continue;
261
383
  }
262
- if (split_condition) {
263
- if (token.size()) {
264
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
384
+ if (pos+2 < offset_end) {
385
+ char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
386
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
387
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
388
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
389
+ pos += _add_token(pos+3);
390
+ continue;
265
391
  }
266
- token = utf_char + utf_char_next;
267
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
268
- token = "";
269
- i++;
270
- continue;
271
392
  }
272
393
  }
273
- if (!split_condition && bytes_remain >= 3) {
274
- // 're|'ve|'ll
275
- if (utf_char == "\'" && (
276
- (utf_char_next == "r" && utf_char_next_next == "e") ||
277
- (utf_char_next == "v" && utf_char_next_next == "e") ||
278
- (utf_char_next == "l" && utf_char_next_next == "l"))
279
- ) {
280
- split_condition = true;
281
- }
282
- if (split_condition) {
283
- // current token + next token can be defined
284
- if (token.size()) {
285
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
286
- }
287
- token = utf_char;
288
- token += utf_char_next;
289
- token += utf_char_next_next;
290
394
 
291
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
292
- token = "";
293
- i += 2;
395
+ // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
396
+ if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
397
+ if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
398
+ pos++;
399
+ while (_get_flags(pos).is_letter) {
400
+ pos++;
401
+ }
402
+ _add_token(pos);
294
403
  continue;
295
404
  }
296
405
  }
297
406
 
298
- if (!split_condition && !collecting) {
299
- if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
300
- collecting_letter = true;
301
- collecting = true;
302
- }
303
- else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
304
- collecting_numeric = true;
305
- collecting = true;
306
- }
307
- else if (
308
- ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
309
- (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
310
- ) {
311
- collecting_special = true;
312
- collecting = true;
313
- }
314
- else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
315
- collecting_whitespace_lookahead = true;
316
- collecting = true;
317
- }
318
- else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
319
- split_condition = true;
407
+ // regex: \p{N}{1,3}
408
+ if (flags.is_number) {
409
+ size_t ini = pos;
410
+ while (_get_flags(pos).is_number) {
411
+ if (++pos - ini >= 3 ) {
412
+ _add_token(pos);
413
+ ini = pos;
414
+ }
320
415
  }
416
+ _add_token(pos);
417
+ continue;
321
418
  }
322
- else if (!split_condition && collecting) {
323
- if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
324
- split_condition = true;
325
- }
326
- else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
327
- split_condition = true;
419
+
420
+ // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
421
+ auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
422
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
423
+ pos += (cpt == ' ');
424
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
425
+ flags2 = _get_flags(++pos);
328
426
  }
329
- else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
330
- split_condition = true;
427
+ char32_t cpt2 = _get_cpt(pos);
428
+ while (cpt2 == '\r' || cpt2 == '\n') {
429
+ cpt2 = _get_cpt(++pos);
331
430
  }
332
- else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
333
- split_condition = true;
431
+ _add_token(pos);
432
+ continue;
433
+ }
434
+
435
+ size_t num_whitespaces = 0;
436
+ size_t last_end_r_or_n = 0;
437
+ while (_get_flags(pos+num_whitespaces).is_whitespace) {
438
+ char32_t cpt2 = _get_cpt(pos+num_whitespaces);
439
+ if (cpt2 == '\r' || cpt2 == '\n') {
440
+ last_end_r_or_n = pos + num_whitespaces + 1;
334
441
  }
442
+ num_whitespaces++;
335
443
  }
336
444
 
337
- if (utf_char_next == "") {
338
- split_condition = true; // final
339
- token += utf_char;
445
+ // regex: \s*[\r\n]+
446
+ if (last_end_r_or_n > 0) {
447
+ pos = last_end_r_or_n;
448
+ _add_token(pos);
449
+ continue;
340
450
  }
341
451
 
342
- if (split_condition) {
343
- if (token.size()) {
344
- bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
345
- }
346
- token = utf_char;
347
- collecting = false;
348
- collecting_letter = false;
349
- collecting_numeric = false;
350
- collecting_special = false;
351
- collecting_whitespace_lookahead = false;
452
+ // regex: \s+(?!\S)
453
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
454
+ pos += num_whitespaces - 1;
455
+ _add_token(pos);
456
+ continue;
352
457
  }
353
- else {
354
- token += utf_char;
458
+
459
+ // regex: \s+
460
+ if (num_whitespaces > 0) {
461
+ pos += num_whitespaces;
462
+ _add_token(pos);
463
+ continue;
355
464
  }
356
- }
357
465
 
358
- start += offset;
466
+ // no matches
467
+ _add_token(++pos);
468
+ }
359
469
  }
360
470
 
361
471
  return bpe_offsets;
@@ -424,14 +534,14 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
424
534
  static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
425
535
  std::vector<size_t> bpe_offsets;
426
536
 
427
- (void)(text);
428
- (void)(regex_expr);
429
- (void)(offsets);
430
- // TODO: this implementation is actually wrong, uncomment and run:
431
- // make -j && ./bin/test-tokenizer-0 ../models/ggml-vocab-gpt-2.gguf
432
- //if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
433
- // bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
434
- //}
537
+ if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
538
+ bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
539
+ } else if (
540
+ regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
541
+ regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
542
+
543
+ bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
544
+ }
435
545
 
436
546
  return bpe_offsets;
437
547
  }
@@ -470,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
470
580
  }
471
581
 
472
582
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
473
- std::vector<uint32_t> result;
474
- result.reserve(cpts.size());
583
+ auto comp = [] (const uint32_t cpt, const range_nfd & range) {
584
+ return cpt < range.first;
585
+ };
586
+ std::vector<uint32_t> result(cpts.size());
475
587
  for (size_t i = 0; i < cpts.size(); ++i) {
476
- auto it = unicode_map_nfd.find(cpts[i]);
477
- if (it == unicode_map_nfd.end()) {
478
- result.push_back(cpts[i]);
479
- } else {
480
- result.push_back(it->second);
481
- }
588
+ const uint32_t cpt = cpts[i];
589
+ auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
590
+ result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
482
591
  }
483
592
  return result;
484
593
  }
@@ -492,18 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
492
601
  return result;
493
602
  }
494
603
 
495
- int unicode_cpt_type(uint32_t cp) {
496
- static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map();
497
- const auto it = cpt_types.find(cp);
498
- return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second;
604
+ codepoint_flags unicode_cpt_flags(const uint32_t cp) {
605
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
606
+ static const auto cpt_flags = unicode_cpt_flags_array();
607
+ return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
499
608
  }
500
609
 
501
- int unicode_cpt_type(const std::string & utf8) {
502
- if (utf8.length() == 0) {
503
- return CODEPOINT_TYPE_UNIDENTIFIED;
610
+ codepoint_flags unicode_cpt_flags(const std::string & utf8) {
611
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
612
+ if (utf8.empty()) {
613
+ return undef; // undefined
504
614
  }
505
615
  size_t offset = 0;
506
- return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
616
+ return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
507
617
  }
508
618
 
509
619
  std::string unicode_byte_to_utf8(uint8_t byte) {
@@ -524,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
524
634
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
525
635
  // unicode categories
526
636
  static const std::map<std::string, int> k_ucat_enum = {
527
- { "\\p{N}", CODEPOINT_TYPE_DIGIT },
528
- { "\\p{L}", CODEPOINT_TYPE_LETTER },
529
- { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
637
+ { "\\p{N}", codepoint_flags::NUMBER },
638
+ { "\\p{L}", codepoint_flags::LETTER },
639
+ { "\\p{P}", codepoint_flags::PUNCTUATION },
530
640
  };
531
641
 
532
642
  static const std::map<int, int> k_ucat_cpt = {
533
- { CODEPOINT_TYPE_DIGIT, 0xD1 },
534
- { CODEPOINT_TYPE_LETTER, 0xD2 },
535
- { CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
643
+ { codepoint_flags::NUMBER, 0xD1 },
644
+ { codepoint_flags::LETTER, 0xD2 },
645
+ { codepoint_flags::PUNCTUATION, 0xD3 },
536
646
  };
537
647
 
538
648
  static const std::map<int, std::string> k_ucat_map = {
539
- { CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, // 0-9
540
- { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
541
- { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
649
+ { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
650
+ { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
651
+ { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
542
652
  };
543
653
 
544
654
  // compute collapsed codepoints only if needed by at least one regex
@@ -569,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
569
679
  continue;
570
680
  }
571
681
 
572
- const int cpt_type = unicode_cpt_type(cpts[i]);
682
+ const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
573
683
 
574
- if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
575
- text_collapsed[i] = k_ucat_cpt.at(cpt_type);
684
+ if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
685
+ text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
576
686
  } else {
577
687
  text_collapsed[i] = (char) 0xD0; // fallback
578
688
  }