llama_cpp 0.15.0 → 0.15.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -7
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +303 -23
- data/vendor/tmp/llama.cpp/ggml-impl.h +84 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +137 -133
- data/vendor/tmp/llama.cpp/ggml-metal.metal +87 -110
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +2220 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +35 -152
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +953 -268
- data/vendor/tmp/llama.cpp/ggml.c +1762 -681
- data/vendor/tmp/llama.cpp/ggml.h +43 -24
- data/vendor/tmp/llama.cpp/llama.cpp +533 -296
- data/vendor/tmp/llama.cpp/llama.h +10 -1
- data/vendor/tmp/llama.cpp/sgemm.cpp +56 -21
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -1637
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -11
- data/vendor/tmp/llama.cpp/unicode.cpp +286 -176
- data/vendor/tmp/llama.cpp/unicode.h +44 -10
- metadata +4 -2
@@ -1,16 +1,20 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <cstdint>
|
4
|
-
#include <map>
|
5
|
-
#include <utility>
|
6
4
|
#include <vector>
|
5
|
+
#include <unordered_map>
|
6
|
+
#include <unordered_set>
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
extern const std::
|
8
|
+
struct range_nfd {
|
9
|
+
uint32_t first;
|
10
|
+
uint32_t last;
|
11
|
+
uint32_t nfd;
|
12
|
+
};
|
13
|
+
|
14
|
+
static const uint32_t MAX_CODEPOINTS = 0x110000;
|
15
|
+
|
16
|
+
extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
|
17
|
+
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
|
18
|
+
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
|
19
|
+
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
|
20
|
+
extern const std::vector<range_nfd> unicode_ranges_nfd;
|
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
#include "unicode.h"
|
2
2
|
#include "unicode-data.h"
|
3
3
|
|
4
4
|
#include <cassert>
|
@@ -9,6 +9,7 @@
|
|
9
9
|
#include <stdexcept>
|
10
10
|
#include <string>
|
11
11
|
#include <unordered_map>
|
12
|
+
#include <unordered_set>
|
12
13
|
#include <utility>
|
13
14
|
#include <vector>
|
14
15
|
#include <locale>
|
@@ -108,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
|
|
108
109
|
// return result;
|
109
110
|
//}
|
110
111
|
|
111
|
-
static std::
|
112
|
-
std::
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
for (
|
120
|
-
|
121
|
-
}
|
122
|
-
}
|
123
|
-
for (auto p : unicode_ranges_whitespace) {
|
124
|
-
for (auto i = p.first; i <= p.second; ++ i) {
|
125
|
-
cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
|
112
|
+
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
113
|
+
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
|
114
|
+
|
115
|
+
assert (unicode_ranges_flags.front().first == 0);
|
116
|
+
assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
|
117
|
+
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
118
|
+
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
|
119
|
+
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
|
120
|
+
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
121
|
+
cpt_flags[cpt] = range_ini.second;
|
126
122
|
}
|
127
123
|
}
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
}
|
124
|
+
|
125
|
+
for (auto cpt : unicode_set_whitespace) {
|
126
|
+
cpt_flags[cpt].is_whitespace = true;
|
132
127
|
}
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
}
|
128
|
+
|
129
|
+
for (auto p : unicode_map_lowercase) {
|
130
|
+
cpt_flags[p.second].is_lowercase = true;
|
137
131
|
}
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
}
|
132
|
+
|
133
|
+
for (auto p : unicode_map_uppercase) {
|
134
|
+
cpt_flags[p.second].is_uppercase = true;
|
142
135
|
}
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
}
|
136
|
+
|
137
|
+
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
138
|
+
cpt_flags[range.nfd].is_nfd = true;
|
147
139
|
}
|
148
|
-
|
140
|
+
|
141
|
+
return cpt_flags;
|
149
142
|
}
|
150
143
|
|
151
144
|
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
152
145
|
std::unordered_map<uint8_t, std::string> map;
|
153
|
-
for (int ch =
|
146
|
+
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
154
147
|
assert(0 <= ch && ch < 256);
|
155
148
|
map[ch] = unicode_cpt_to_utf8(ch);
|
156
149
|
}
|
157
|
-
for (int ch =
|
150
|
+
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
158
151
|
assert(0 <= ch && ch < 256);
|
159
152
|
map[ch] = unicode_cpt_to_utf8(ch);
|
160
153
|
}
|
161
|
-
for (int ch =
|
154
|
+
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
162
155
|
assert(0 <= ch && ch < 256);
|
163
156
|
map[ch] = unicode_cpt_to_utf8(ch);
|
164
157
|
}
|
@@ -174,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
|
174
167
|
|
175
168
|
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
176
169
|
std::unordered_map<std::string, uint8_t> map;
|
177
|
-
for (int ch =
|
170
|
+
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
178
171
|
assert(0 <= ch && ch < 256);
|
179
172
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
180
173
|
}
|
181
|
-
for (int ch =
|
174
|
+
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
182
175
|
assert(0 <= ch && ch < 256);
|
183
176
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
184
177
|
}
|
185
|
-
for (int ch =
|
178
|
+
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
186
179
|
assert(0 <= ch && ch < 256);
|
187
180
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
188
181
|
}
|
@@ -224,138 +217,255 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|
224
217
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
225
218
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
226
219
|
|
227
|
-
size_t start = 0;
|
228
|
-
|
229
220
|
const auto cpts = unicode_cpts_from_utf8(text);
|
230
221
|
|
222
|
+
size_t start = 0;
|
231
223
|
for (auto offset : offsets) {
|
232
|
-
|
224
|
+
const size_t offset_ini = start;
|
225
|
+
const size_t offset_end = start + offset;
|
226
|
+
assert(offset_end <= cpts.size());
|
227
|
+
start = offset_end;
|
228
|
+
|
229
|
+
auto _get_cpt = [&] (const size_t pos) -> char32_t {
|
230
|
+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
231
|
+
};
|
232
|
+
|
233
|
+
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
234
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
235
|
+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
236
|
+
};
|
237
|
+
|
238
|
+
size_t _prev_end = offset_ini;
|
239
|
+
auto _add_token = [&] (const size_t end) -> size_t {
|
240
|
+
assert(_prev_end <= end && end <= offset_end);
|
241
|
+
size_t len = end - _prev_end;
|
242
|
+
if (len > 0) {
|
243
|
+
bpe_offsets.push_back(len);
|
244
|
+
}
|
245
|
+
_prev_end = end;
|
246
|
+
//if (len > 0) {
|
247
|
+
// std::string s = "";
|
248
|
+
// for(size_t p = end-len; p < end; p++)
|
249
|
+
// s += unicode_cpt_to_utf8(cpts[p]);
|
250
|
+
// printf(">>> '%s'\n", s.c_str());
|
251
|
+
//}
|
252
|
+
return len;
|
253
|
+
};
|
254
|
+
|
255
|
+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
256
|
+
const char32_t cpt = _get_cpt(pos);
|
257
|
+
const auto flags = _get_flags(pos);
|
258
|
+
|
259
|
+
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
260
|
+
if (cpt == '\'' && pos+1 < offset_end) {
|
261
|
+
char32_t cpt_next = _get_cpt(pos+1);
|
262
|
+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
263
|
+
pos += _add_token(pos+2);
|
264
|
+
continue;
|
265
|
+
}
|
266
|
+
if (pos+2 < offset_end) {
|
267
|
+
char32_t cpt_next_next = _get_cpt(pos+2);
|
268
|
+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
269
|
+
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
270
|
+
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
271
|
+
pos += _add_token(pos+3);
|
272
|
+
continue;
|
273
|
+
}
|
274
|
+
}
|
275
|
+
}
|
276
|
+
|
277
|
+
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
278
|
+
// regex: <space>?\p{L}+
|
279
|
+
if (flags2.is_letter) {
|
280
|
+
pos += (cpt == ' ');
|
281
|
+
while (flags2.is_letter) {
|
282
|
+
flags2 = _get_flags(++pos);
|
283
|
+
}
|
284
|
+
_add_token(pos);
|
285
|
+
continue;
|
286
|
+
}
|
287
|
+
// regex: <space>?\p{N}+
|
288
|
+
if (flags2.is_number) {
|
289
|
+
pos += (cpt == ' ');
|
290
|
+
while (flags2.is_number) {
|
291
|
+
flags2 = _get_flags(++pos);
|
292
|
+
}
|
293
|
+
_add_token(pos);
|
294
|
+
continue;
|
295
|
+
}
|
296
|
+
// regex: <space>?[^\s\p{L}\p{N}]+
|
297
|
+
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
298
|
+
pos += (cpt == ' ');
|
299
|
+
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
300
|
+
flags2 = _get_flags(++pos);
|
301
|
+
}
|
302
|
+
_add_token(pos);
|
303
|
+
continue;
|
304
|
+
}
|
305
|
+
|
306
|
+
size_t num_whitespaces = 0;
|
307
|
+
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
308
|
+
num_whitespaces++;
|
309
|
+
}
|
233
310
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
311
|
+
// regex: \s+(?!\S)
|
312
|
+
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
313
|
+
pos += num_whitespaces - 1;
|
314
|
+
_add_token(pos);
|
315
|
+
continue;
|
316
|
+
}
|
239
317
|
|
240
|
-
|
241
|
-
|
318
|
+
// regex: \s+
|
319
|
+
if (num_whitespaces > 0) {
|
320
|
+
pos += num_whitespaces;
|
321
|
+
_add_token(pos);
|
322
|
+
continue;
|
323
|
+
}
|
242
324
|
|
243
|
-
|
244
|
-
|
325
|
+
// no matches
|
326
|
+
_add_token(++pos);
|
245
327
|
}
|
328
|
+
}
|
246
329
|
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
330
|
+
return bpe_offsets;
|
331
|
+
}
|
332
|
+
|
333
|
+
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
|
334
|
+
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
|
335
|
+
std::vector<size_t> bpe_offsets; // store the offset of each word
|
336
|
+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
251
337
|
|
252
|
-
|
253
|
-
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
|
254
|
-
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
|
338
|
+
const auto cpts = unicode_cpts_from_utf8(text);
|
255
339
|
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
340
|
+
size_t start = 0;
|
341
|
+
for (auto offset : offsets) {
|
342
|
+
const size_t offset_ini = start;
|
343
|
+
const size_t offset_end = start + offset;
|
344
|
+
assert(offset_end <= cpts.size());
|
345
|
+
start = offset_end;
|
346
|
+
|
347
|
+
auto _get_cpt = [&] (const size_t pos) -> char32_t {
|
348
|
+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
349
|
+
};
|
350
|
+
|
351
|
+
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
352
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
353
|
+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
354
|
+
};
|
355
|
+
|
356
|
+
size_t _prev_end = offset_ini;
|
357
|
+
auto _add_token = [&] (const size_t end) -> size_t {
|
358
|
+
assert(_prev_end <= end && end <= offset_end);
|
359
|
+
size_t len = end - _prev_end;
|
360
|
+
if (len > 0) {
|
361
|
+
bpe_offsets.push_back(len);
|
362
|
+
}
|
363
|
+
_prev_end = end;
|
364
|
+
//if (len > 0) {
|
365
|
+
// std::string s = "";
|
366
|
+
// for(size_t p = end-len; p < end; p++)
|
367
|
+
// s += unicode_cpt_to_utf8(cpts[p]);
|
368
|
+
// printf(">>> '%s'\n", s.c_str());
|
369
|
+
//}
|
370
|
+
return len;
|
371
|
+
};
|
372
|
+
|
373
|
+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
374
|
+
const char32_t cpt = _get_cpt(pos);
|
375
|
+
const auto flags = _get_flags(pos);
|
376
|
+
|
377
|
+
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
378
|
+
if (cpt == '\'' && pos+1 < offset_end) {
|
379
|
+
char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
|
380
|
+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
381
|
+
pos += _add_token(pos+2);
|
382
|
+
continue;
|
261
383
|
}
|
262
|
-
if (
|
263
|
-
|
264
|
-
|
384
|
+
if (pos+2 < offset_end) {
|
385
|
+
char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
|
386
|
+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
387
|
+
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
388
|
+
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
389
|
+
pos += _add_token(pos+3);
|
390
|
+
continue;
|
265
391
|
}
|
266
|
-
token = utf_char + utf_char_next;
|
267
|
-
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
268
|
-
token = "";
|
269
|
-
i++;
|
270
|
-
continue;
|
271
392
|
}
|
272
393
|
}
|
273
|
-
if (!split_condition && bytes_remain >= 3) {
|
274
|
-
// 're|'ve|'ll
|
275
|
-
if (utf_char == "\'" && (
|
276
|
-
(utf_char_next == "r" && utf_char_next_next == "e") ||
|
277
|
-
(utf_char_next == "v" && utf_char_next_next == "e") ||
|
278
|
-
(utf_char_next == "l" && utf_char_next_next == "l"))
|
279
|
-
) {
|
280
|
-
split_condition = true;
|
281
|
-
}
|
282
|
-
if (split_condition) {
|
283
|
-
// current token + next token can be defined
|
284
|
-
if (token.size()) {
|
285
|
-
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
286
|
-
}
|
287
|
-
token = utf_char;
|
288
|
-
token += utf_char_next;
|
289
|
-
token += utf_char_next_next;
|
290
394
|
|
291
|
-
|
292
|
-
|
293
|
-
|
395
|
+
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
|
396
|
+
if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
|
397
|
+
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
398
|
+
pos++;
|
399
|
+
while (_get_flags(pos).is_letter) {
|
400
|
+
pos++;
|
401
|
+
}
|
402
|
+
_add_token(pos);
|
294
403
|
continue;
|
295
404
|
}
|
296
405
|
}
|
297
406
|
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
}
|
307
|
-
else if (
|
308
|
-
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
309
|
-
(token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
310
|
-
) {
|
311
|
-
collecting_special = true;
|
312
|
-
collecting = true;
|
313
|
-
}
|
314
|
-
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
315
|
-
collecting_whitespace_lookahead = true;
|
316
|
-
collecting = true;
|
317
|
-
}
|
318
|
-
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
319
|
-
split_condition = true;
|
407
|
+
// regex: \p{N}{1,3}
|
408
|
+
if (flags.is_number) {
|
409
|
+
size_t ini = pos;
|
410
|
+
while (_get_flags(pos).is_number) {
|
411
|
+
if (++pos - ini >= 3 ) {
|
412
|
+
_add_token(pos);
|
413
|
+
ini = pos;
|
414
|
+
}
|
320
415
|
}
|
416
|
+
_add_token(pos);
|
417
|
+
continue;
|
321
418
|
}
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
419
|
+
|
420
|
+
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
421
|
+
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
422
|
+
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
423
|
+
pos += (cpt == ' ');
|
424
|
+
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
425
|
+
flags2 = _get_flags(++pos);
|
328
426
|
}
|
329
|
-
|
330
|
-
|
427
|
+
char32_t cpt2 = _get_cpt(pos);
|
428
|
+
while (cpt2 == '\r' || cpt2 == '\n') {
|
429
|
+
cpt2 = _get_cpt(++pos);
|
331
430
|
}
|
332
|
-
|
333
|
-
|
431
|
+
_add_token(pos);
|
432
|
+
continue;
|
433
|
+
}
|
434
|
+
|
435
|
+
size_t num_whitespaces = 0;
|
436
|
+
size_t last_end_r_or_n = 0;
|
437
|
+
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
438
|
+
char32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
439
|
+
if (cpt2 == '\r' || cpt2 == '\n') {
|
440
|
+
last_end_r_or_n = pos + num_whitespaces + 1;
|
334
441
|
}
|
442
|
+
num_whitespaces++;
|
335
443
|
}
|
336
444
|
|
337
|
-
|
338
|
-
|
339
|
-
|
445
|
+
// regex: \s*[\r\n]+
|
446
|
+
if (last_end_r_or_n > 0) {
|
447
|
+
pos = last_end_r_or_n;
|
448
|
+
_add_token(pos);
|
449
|
+
continue;
|
340
450
|
}
|
341
451
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
collecting = false;
|
348
|
-
collecting_letter = false;
|
349
|
-
collecting_numeric = false;
|
350
|
-
collecting_special = false;
|
351
|
-
collecting_whitespace_lookahead = false;
|
452
|
+
// regex: \s+(?!\S)
|
453
|
+
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
454
|
+
pos += num_whitespaces - 1;
|
455
|
+
_add_token(pos);
|
456
|
+
continue;
|
352
457
|
}
|
353
|
-
|
354
|
-
|
458
|
+
|
459
|
+
// regex: \s+
|
460
|
+
if (num_whitespaces > 0) {
|
461
|
+
pos += num_whitespaces;
|
462
|
+
_add_token(pos);
|
463
|
+
continue;
|
355
464
|
}
|
356
|
-
}
|
357
465
|
|
358
|
-
|
466
|
+
// no matches
|
467
|
+
_add_token(++pos);
|
468
|
+
}
|
359
469
|
}
|
360
470
|
|
361
471
|
return bpe_offsets;
|
@@ -424,14 +534,14 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
|
|
424
534
|
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
425
535
|
std::vector<size_t> bpe_offsets;
|
426
536
|
|
427
|
-
(
|
428
|
-
|
429
|
-
(
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
537
|
+
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
538
|
+
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
|
539
|
+
} else if (
|
540
|
+
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
|
541
|
+
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
542
|
+
|
543
|
+
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
544
|
+
}
|
435
545
|
|
436
546
|
return bpe_offsets;
|
437
547
|
}
|
@@ -470,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
|
|
470
580
|
}
|
471
581
|
|
472
582
|
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
473
|
-
|
474
|
-
|
583
|
+
auto comp = [] (const uint32_t cpt, const range_nfd & range) {
|
584
|
+
return cpt < range.first;
|
585
|
+
};
|
586
|
+
std::vector<uint32_t> result(cpts.size());
|
475
587
|
for (size_t i = 0; i < cpts.size(); ++i) {
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
} else {
|
480
|
-
result.push_back(it->second);
|
481
|
-
}
|
588
|
+
const uint32_t cpt = cpts[i];
|
589
|
+
auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
|
590
|
+
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
|
482
591
|
}
|
483
592
|
return result;
|
484
593
|
}
|
@@ -492,18 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
|
492
601
|
return result;
|
493
602
|
}
|
494
603
|
|
495
|
-
|
496
|
-
static
|
497
|
-
const auto
|
498
|
-
return
|
604
|
+
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
|
605
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
606
|
+
static const auto cpt_flags = unicode_cpt_flags_array();
|
607
|
+
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
|
499
608
|
}
|
500
609
|
|
501
|
-
|
502
|
-
|
503
|
-
|
610
|
+
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
|
611
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
612
|
+
if (utf8.empty()) {
|
613
|
+
return undef; // undefined
|
504
614
|
}
|
505
615
|
size_t offset = 0;
|
506
|
-
return
|
616
|
+
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
|
507
617
|
}
|
508
618
|
|
509
619
|
std::string unicode_byte_to_utf8(uint8_t byte) {
|
@@ -524,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
|
|
524
634
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
525
635
|
// unicode categories
|
526
636
|
static const std::map<std::string, int> k_ucat_enum = {
|
527
|
-
{ "\\p{N}",
|
528
|
-
{ "\\p{L}",
|
529
|
-
{ "\\p{P}",
|
637
|
+
{ "\\p{N}", codepoint_flags::NUMBER },
|
638
|
+
{ "\\p{L}", codepoint_flags::LETTER },
|
639
|
+
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
530
640
|
};
|
531
641
|
|
532
642
|
static const std::map<int, int> k_ucat_cpt = {
|
533
|
-
{
|
534
|
-
{
|
535
|
-
{
|
643
|
+
{ codepoint_flags::NUMBER, 0xD1 },
|
644
|
+
{ codepoint_flags::LETTER, 0xD2 },
|
645
|
+
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
536
646
|
};
|
537
647
|
|
538
648
|
static const std::map<int, std::string> k_ucat_map = {
|
539
|
-
{
|
540
|
-
{
|
541
|
-
{
|
649
|
+
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
650
|
+
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
651
|
+
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
542
652
|
};
|
543
653
|
|
544
654
|
// compute collapsed codepoints only if needed by at least one regex
|
@@ -569,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|
569
679
|
continue;
|
570
680
|
}
|
571
681
|
|
572
|
-
const int
|
682
|
+
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
|
573
683
|
|
574
|
-
if (k_ucat_cpt.find(
|
575
|
-
text_collapsed[i] = k_ucat_cpt.at(
|
684
|
+
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
|
685
|
+
text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
|
576
686
|
} else {
|
577
687
|
text_collapsed[i] = (char) 0xD0; // fallback
|
578
688
|
}
|