llama_cpp 0.15.0 → 0.15.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -7
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +303 -23
- data/vendor/tmp/llama.cpp/ggml-impl.h +84 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +137 -133
- data/vendor/tmp/llama.cpp/ggml-metal.metal +87 -110
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +2220 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +35 -152
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +953 -268
- data/vendor/tmp/llama.cpp/ggml.c +1762 -681
- data/vendor/tmp/llama.cpp/ggml.h +43 -24
- data/vendor/tmp/llama.cpp/llama.cpp +533 -296
- data/vendor/tmp/llama.cpp/llama.h +10 -1
- data/vendor/tmp/llama.cpp/sgemm.cpp +56 -21
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -1637
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -11
- data/vendor/tmp/llama.cpp/unicode.cpp +286 -176
- data/vendor/tmp/llama.cpp/unicode.h +44 -10
- metadata +4 -2
@@ -1,16 +1,20 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <cstdint>
|
4
|
-
#include <map>
|
5
|
-
#include <utility>
|
6
4
|
#include <vector>
|
5
|
+
#include <unordered_map>
|
6
|
+
#include <unordered_set>
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
extern const std::
|
8
|
+
struct range_nfd {
|
9
|
+
uint32_t first;
|
10
|
+
uint32_t last;
|
11
|
+
uint32_t nfd;
|
12
|
+
};
|
13
|
+
|
14
|
+
static const uint32_t MAX_CODEPOINTS = 0x110000;
|
15
|
+
|
16
|
+
extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
|
17
|
+
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
|
18
|
+
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
|
19
|
+
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
|
20
|
+
extern const std::vector<range_nfd> unicode_ranges_nfd;
|
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
#include "unicode.h"
|
2
2
|
#include "unicode-data.h"
|
3
3
|
|
4
4
|
#include <cassert>
|
@@ -9,6 +9,7 @@
|
|
9
9
|
#include <stdexcept>
|
10
10
|
#include <string>
|
11
11
|
#include <unordered_map>
|
12
|
+
#include <unordered_set>
|
12
13
|
#include <utility>
|
13
14
|
#include <vector>
|
14
15
|
#include <locale>
|
@@ -108,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
|
|
108
109
|
// return result;
|
109
110
|
//}
|
110
111
|
|
111
|
-
static std::
|
112
|
-
std::
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
for (
|
120
|
-
|
121
|
-
}
|
122
|
-
}
|
123
|
-
for (auto p : unicode_ranges_whitespace) {
|
124
|
-
for (auto i = p.first; i <= p.second; ++ i) {
|
125
|
-
cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
|
112
|
+
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
113
|
+
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
|
114
|
+
|
115
|
+
assert (unicode_ranges_flags.front().first == 0);
|
116
|
+
assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
|
117
|
+
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
118
|
+
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
|
119
|
+
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
|
120
|
+
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
121
|
+
cpt_flags[cpt] = range_ini.second;
|
126
122
|
}
|
127
123
|
}
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
}
|
124
|
+
|
125
|
+
for (auto cpt : unicode_set_whitespace) {
|
126
|
+
cpt_flags[cpt].is_whitespace = true;
|
132
127
|
}
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
}
|
128
|
+
|
129
|
+
for (auto p : unicode_map_lowercase) {
|
130
|
+
cpt_flags[p.second].is_lowercase = true;
|
137
131
|
}
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
}
|
132
|
+
|
133
|
+
for (auto p : unicode_map_uppercase) {
|
134
|
+
cpt_flags[p.second].is_uppercase = true;
|
142
135
|
}
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
}
|
136
|
+
|
137
|
+
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
138
|
+
cpt_flags[range.nfd].is_nfd = true;
|
147
139
|
}
|
148
|
-
|
140
|
+
|
141
|
+
return cpt_flags;
|
149
142
|
}
|
150
143
|
|
151
144
|
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
152
145
|
std::unordered_map<uint8_t, std::string> map;
|
153
|
-
for (int ch =
|
146
|
+
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
154
147
|
assert(0 <= ch && ch < 256);
|
155
148
|
map[ch] = unicode_cpt_to_utf8(ch);
|
156
149
|
}
|
157
|
-
for (int ch =
|
150
|
+
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
158
151
|
assert(0 <= ch && ch < 256);
|
159
152
|
map[ch] = unicode_cpt_to_utf8(ch);
|
160
153
|
}
|
161
|
-
for (int ch =
|
154
|
+
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
162
155
|
assert(0 <= ch && ch < 256);
|
163
156
|
map[ch] = unicode_cpt_to_utf8(ch);
|
164
157
|
}
|
@@ -174,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
|
174
167
|
|
175
168
|
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
176
169
|
std::unordered_map<std::string, uint8_t> map;
|
177
|
-
for (int ch =
|
170
|
+
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
178
171
|
assert(0 <= ch && ch < 256);
|
179
172
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
180
173
|
}
|
181
|
-
for (int ch =
|
174
|
+
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
182
175
|
assert(0 <= ch && ch < 256);
|
183
176
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
184
177
|
}
|
185
|
-
for (int ch =
|
178
|
+
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
186
179
|
assert(0 <= ch && ch < 256);
|
187
180
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
188
181
|
}
|
@@ -224,138 +217,255 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|
224
217
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
225
218
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
226
219
|
|
227
|
-
size_t start = 0;
|
228
|
-
|
229
220
|
const auto cpts = unicode_cpts_from_utf8(text);
|
230
221
|
|
222
|
+
size_t start = 0;
|
231
223
|
for (auto offset : offsets) {
|
232
|
-
|
224
|
+
const size_t offset_ini = start;
|
225
|
+
const size_t offset_end = start + offset;
|
226
|
+
assert(offset_end <= cpts.size());
|
227
|
+
start = offset_end;
|
228
|
+
|
229
|
+
auto _get_cpt = [&] (const size_t pos) -> char32_t {
|
230
|
+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
231
|
+
};
|
232
|
+
|
233
|
+
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
234
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
235
|
+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
236
|
+
};
|
237
|
+
|
238
|
+
size_t _prev_end = offset_ini;
|
239
|
+
auto _add_token = [&] (const size_t end) -> size_t {
|
240
|
+
assert(_prev_end <= end && end <= offset_end);
|
241
|
+
size_t len = end - _prev_end;
|
242
|
+
if (len > 0) {
|
243
|
+
bpe_offsets.push_back(len);
|
244
|
+
}
|
245
|
+
_prev_end = end;
|
246
|
+
//if (len > 0) {
|
247
|
+
// std::string s = "";
|
248
|
+
// for(size_t p = end-len; p < end; p++)
|
249
|
+
// s += unicode_cpt_to_utf8(cpts[p]);
|
250
|
+
// printf(">>> '%s'\n", s.c_str());
|
251
|
+
//}
|
252
|
+
return len;
|
253
|
+
};
|
254
|
+
|
255
|
+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
256
|
+
const char32_t cpt = _get_cpt(pos);
|
257
|
+
const auto flags = _get_flags(pos);
|
258
|
+
|
259
|
+
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
260
|
+
if (cpt == '\'' && pos+1 < offset_end) {
|
261
|
+
char32_t cpt_next = _get_cpt(pos+1);
|
262
|
+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
263
|
+
pos += _add_token(pos+2);
|
264
|
+
continue;
|
265
|
+
}
|
266
|
+
if (pos+2 < offset_end) {
|
267
|
+
char32_t cpt_next_next = _get_cpt(pos+2);
|
268
|
+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
269
|
+
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
270
|
+
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
271
|
+
pos += _add_token(pos+3);
|
272
|
+
continue;
|
273
|
+
}
|
274
|
+
}
|
275
|
+
}
|
276
|
+
|
277
|
+
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
278
|
+
// regex: <space>?\p{L}+
|
279
|
+
if (flags2.is_letter) {
|
280
|
+
pos += (cpt == ' ');
|
281
|
+
while (flags2.is_letter) {
|
282
|
+
flags2 = _get_flags(++pos);
|
283
|
+
}
|
284
|
+
_add_token(pos);
|
285
|
+
continue;
|
286
|
+
}
|
287
|
+
// regex: <space>?\p{N}+
|
288
|
+
if (flags2.is_number) {
|
289
|
+
pos += (cpt == ' ');
|
290
|
+
while (flags2.is_number) {
|
291
|
+
flags2 = _get_flags(++pos);
|
292
|
+
}
|
293
|
+
_add_token(pos);
|
294
|
+
continue;
|
295
|
+
}
|
296
|
+
// regex: <space>?[^\s\p{L}\p{N}]+
|
297
|
+
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
298
|
+
pos += (cpt == ' ');
|
299
|
+
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
300
|
+
flags2 = _get_flags(++pos);
|
301
|
+
}
|
302
|
+
_add_token(pos);
|
303
|
+
continue;
|
304
|
+
}
|
305
|
+
|
306
|
+
size_t num_whitespaces = 0;
|
307
|
+
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
308
|
+
num_whitespaces++;
|
309
|
+
}
|
233
310
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
311
|
+
// regex: \s+(?!\S)
|
312
|
+
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
313
|
+
pos += num_whitespaces - 1;
|
314
|
+
_add_token(pos);
|
315
|
+
continue;
|
316
|
+
}
|
239
317
|
|
240
|
-
|
241
|
-
|
318
|
+
// regex: \s+
|
319
|
+
if (num_whitespaces > 0) {
|
320
|
+
pos += num_whitespaces;
|
321
|
+
_add_token(pos);
|
322
|
+
continue;
|
323
|
+
}
|
242
324
|
|
243
|
-
|
244
|
-
|
325
|
+
// no matches
|
326
|
+
_add_token(++pos);
|
245
327
|
}
|
328
|
+
}
|
246
329
|
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
330
|
+
return bpe_offsets;
|
331
|
+
}
|
332
|
+
|
333
|
+
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
|
334
|
+
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
|
335
|
+
std::vector<size_t> bpe_offsets; // store the offset of each word
|
336
|
+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
251
337
|
|
252
|
-
|
253
|
-
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
|
254
|
-
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
|
338
|
+
const auto cpts = unicode_cpts_from_utf8(text);
|
255
339
|
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
340
|
+
size_t start = 0;
|
341
|
+
for (auto offset : offsets) {
|
342
|
+
const size_t offset_ini = start;
|
343
|
+
const size_t offset_end = start + offset;
|
344
|
+
assert(offset_end <= cpts.size());
|
345
|
+
start = offset_end;
|
346
|
+
|
347
|
+
auto _get_cpt = [&] (const size_t pos) -> char32_t {
|
348
|
+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
349
|
+
};
|
350
|
+
|
351
|
+
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
352
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
353
|
+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
354
|
+
};
|
355
|
+
|
356
|
+
size_t _prev_end = offset_ini;
|
357
|
+
auto _add_token = [&] (const size_t end) -> size_t {
|
358
|
+
assert(_prev_end <= end && end <= offset_end);
|
359
|
+
size_t len = end - _prev_end;
|
360
|
+
if (len > 0) {
|
361
|
+
bpe_offsets.push_back(len);
|
362
|
+
}
|
363
|
+
_prev_end = end;
|
364
|
+
//if (len > 0) {
|
365
|
+
// std::string s = "";
|
366
|
+
// for(size_t p = end-len; p < end; p++)
|
367
|
+
// s += unicode_cpt_to_utf8(cpts[p]);
|
368
|
+
// printf(">>> '%s'\n", s.c_str());
|
369
|
+
//}
|
370
|
+
return len;
|
371
|
+
};
|
372
|
+
|
373
|
+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
374
|
+
const char32_t cpt = _get_cpt(pos);
|
375
|
+
const auto flags = _get_flags(pos);
|
376
|
+
|
377
|
+
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
378
|
+
if (cpt == '\'' && pos+1 < offset_end) {
|
379
|
+
char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
|
380
|
+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
381
|
+
pos += _add_token(pos+2);
|
382
|
+
continue;
|
261
383
|
}
|
262
|
-
if (
|
263
|
-
|
264
|
-
|
384
|
+
if (pos+2 < offset_end) {
|
385
|
+
char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
|
386
|
+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
387
|
+
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
388
|
+
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
389
|
+
pos += _add_token(pos+3);
|
390
|
+
continue;
|
265
391
|
}
|
266
|
-
token = utf_char + utf_char_next;
|
267
|
-
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
268
|
-
token = "";
|
269
|
-
i++;
|
270
|
-
continue;
|
271
392
|
}
|
272
393
|
}
|
273
|
-
if (!split_condition && bytes_remain >= 3) {
|
274
|
-
// 're|'ve|'ll
|
275
|
-
if (utf_char == "\'" && (
|
276
|
-
(utf_char_next == "r" && utf_char_next_next == "e") ||
|
277
|
-
(utf_char_next == "v" && utf_char_next_next == "e") ||
|
278
|
-
(utf_char_next == "l" && utf_char_next_next == "l"))
|
279
|
-
) {
|
280
|
-
split_condition = true;
|
281
|
-
}
|
282
|
-
if (split_condition) {
|
283
|
-
// current token + next token can be defined
|
284
|
-
if (token.size()) {
|
285
|
-
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
286
|
-
}
|
287
|
-
token = utf_char;
|
288
|
-
token += utf_char_next;
|
289
|
-
token += utf_char_next_next;
|
290
394
|
|
291
|
-
|
292
|
-
|
293
|
-
|
395
|
+
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
|
396
|
+
if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
|
397
|
+
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
398
|
+
pos++;
|
399
|
+
while (_get_flags(pos).is_letter) {
|
400
|
+
pos++;
|
401
|
+
}
|
402
|
+
_add_token(pos);
|
294
403
|
continue;
|
295
404
|
}
|
296
405
|
}
|
297
406
|
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
}
|
307
|
-
else if (
|
308
|
-
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
309
|
-
(token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
310
|
-
) {
|
311
|
-
collecting_special = true;
|
312
|
-
collecting = true;
|
313
|
-
}
|
314
|
-
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
315
|
-
collecting_whitespace_lookahead = true;
|
316
|
-
collecting = true;
|
317
|
-
}
|
318
|
-
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
319
|
-
split_condition = true;
|
407
|
+
// regex: \p{N}{1,3}
|
408
|
+
if (flags.is_number) {
|
409
|
+
size_t ini = pos;
|
410
|
+
while (_get_flags(pos).is_number) {
|
411
|
+
if (++pos - ini >= 3 ) {
|
412
|
+
_add_token(pos);
|
413
|
+
ini = pos;
|
414
|
+
}
|
320
415
|
}
|
416
|
+
_add_token(pos);
|
417
|
+
continue;
|
321
418
|
}
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
419
|
+
|
420
|
+
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
421
|
+
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
422
|
+
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
423
|
+
pos += (cpt == ' ');
|
424
|
+
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
425
|
+
flags2 = _get_flags(++pos);
|
328
426
|
}
|
329
|
-
|
330
|
-
|
427
|
+
char32_t cpt2 = _get_cpt(pos);
|
428
|
+
while (cpt2 == '\r' || cpt2 == '\n') {
|
429
|
+
cpt2 = _get_cpt(++pos);
|
331
430
|
}
|
332
|
-
|
333
|
-
|
431
|
+
_add_token(pos);
|
432
|
+
continue;
|
433
|
+
}
|
434
|
+
|
435
|
+
size_t num_whitespaces = 0;
|
436
|
+
size_t last_end_r_or_n = 0;
|
437
|
+
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
438
|
+
char32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
439
|
+
if (cpt2 == '\r' || cpt2 == '\n') {
|
440
|
+
last_end_r_or_n = pos + num_whitespaces + 1;
|
334
441
|
}
|
442
|
+
num_whitespaces++;
|
335
443
|
}
|
336
444
|
|
337
|
-
|
338
|
-
|
339
|
-
|
445
|
+
// regex: \s*[\r\n]+
|
446
|
+
if (last_end_r_or_n > 0) {
|
447
|
+
pos = last_end_r_or_n;
|
448
|
+
_add_token(pos);
|
449
|
+
continue;
|
340
450
|
}
|
341
451
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
collecting = false;
|
348
|
-
collecting_letter = false;
|
349
|
-
collecting_numeric = false;
|
350
|
-
collecting_special = false;
|
351
|
-
collecting_whitespace_lookahead = false;
|
452
|
+
// regex: \s+(?!\S)
|
453
|
+
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
454
|
+
pos += num_whitespaces - 1;
|
455
|
+
_add_token(pos);
|
456
|
+
continue;
|
352
457
|
}
|
353
|
-
|
354
|
-
|
458
|
+
|
459
|
+
// regex: \s+
|
460
|
+
if (num_whitespaces > 0) {
|
461
|
+
pos += num_whitespaces;
|
462
|
+
_add_token(pos);
|
463
|
+
continue;
|
355
464
|
}
|
356
|
-
}
|
357
465
|
|
358
|
-
|
466
|
+
// no matches
|
467
|
+
_add_token(++pos);
|
468
|
+
}
|
359
469
|
}
|
360
470
|
|
361
471
|
return bpe_offsets;
|
@@ -424,14 +534,14 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
|
|
424
534
|
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
425
535
|
std::vector<size_t> bpe_offsets;
|
426
536
|
|
427
|
-
(
|
428
|
-
|
429
|
-
(
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
537
|
+
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
538
|
+
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
|
539
|
+
} else if (
|
540
|
+
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
|
541
|
+
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
542
|
+
|
543
|
+
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
544
|
+
}
|
435
545
|
|
436
546
|
return bpe_offsets;
|
437
547
|
}
|
@@ -470,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
|
|
470
580
|
}
|
471
581
|
|
472
582
|
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
473
|
-
|
474
|
-
|
583
|
+
auto comp = [] (const uint32_t cpt, const range_nfd & range) {
|
584
|
+
return cpt < range.first;
|
585
|
+
};
|
586
|
+
std::vector<uint32_t> result(cpts.size());
|
475
587
|
for (size_t i = 0; i < cpts.size(); ++i) {
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
} else {
|
480
|
-
result.push_back(it->second);
|
481
|
-
}
|
588
|
+
const uint32_t cpt = cpts[i];
|
589
|
+
auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
|
590
|
+
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
|
482
591
|
}
|
483
592
|
return result;
|
484
593
|
}
|
@@ -492,18 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
|
492
601
|
return result;
|
493
602
|
}
|
494
603
|
|
495
|
-
|
496
|
-
static
|
497
|
-
const auto
|
498
|
-
return
|
604
|
+
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
|
605
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
606
|
+
static const auto cpt_flags = unicode_cpt_flags_array();
|
607
|
+
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
|
499
608
|
}
|
500
609
|
|
501
|
-
|
502
|
-
|
503
|
-
|
610
|
+
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
|
611
|
+
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
612
|
+
if (utf8.empty()) {
|
613
|
+
return undef; // undefined
|
504
614
|
}
|
505
615
|
size_t offset = 0;
|
506
|
-
return
|
616
|
+
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
|
507
617
|
}
|
508
618
|
|
509
619
|
std::string unicode_byte_to_utf8(uint8_t byte) {
|
@@ -524,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
|
|
524
634
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
525
635
|
// unicode categories
|
526
636
|
static const std::map<std::string, int> k_ucat_enum = {
|
527
|
-
{ "\\p{N}",
|
528
|
-
{ "\\p{L}",
|
529
|
-
{ "\\p{P}",
|
637
|
+
{ "\\p{N}", codepoint_flags::NUMBER },
|
638
|
+
{ "\\p{L}", codepoint_flags::LETTER },
|
639
|
+
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
530
640
|
};
|
531
641
|
|
532
642
|
static const std::map<int, int> k_ucat_cpt = {
|
533
|
-
{
|
534
|
-
{
|
535
|
-
{
|
643
|
+
{ codepoint_flags::NUMBER, 0xD1 },
|
644
|
+
{ codepoint_flags::LETTER, 0xD2 },
|
645
|
+
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
536
646
|
};
|
537
647
|
|
538
648
|
static const std::map<int, std::string> k_ucat_map = {
|
539
|
-
{
|
540
|
-
{
|
541
|
-
{
|
649
|
+
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
650
|
+
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
651
|
+
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
542
652
|
};
|
543
653
|
|
544
654
|
// compute collapsed codepoints only if needed by at least one regex
|
@@ -569,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|
569
679
|
continue;
|
570
680
|
}
|
571
681
|
|
572
|
-
const int
|
682
|
+
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
|
573
683
|
|
574
|
-
if (k_ucat_cpt.find(
|
575
|
-
text_collapsed[i] = k_ucat_cpt.at(
|
684
|
+
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
|
685
|
+
text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
|
576
686
|
} else {
|
577
687
|
text_collapsed[i] = (char) 0xD0; // fallback
|
578
688
|
}
|