cui-llama.rn 1.1.2 → 1.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +1 -2
- package/android/src/main/jni.cpp +26 -21
- package/cpp/common.cpp +181 -1584
- package/cpp/common.h +131 -52
- package/cpp/ggml-aarch64.c +612 -0
- package/cpp/ggml-alloc.h +2 -2
- package/cpp/ggml-backend.c +33 -6
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-common.h +20 -0
- package/cpp/ggml-impl.h +36 -7
- package/cpp/ggml-metal.m +68 -8
- package/cpp/ggml-quants.c +932 -50
- package/cpp/ggml-quants.h +15 -0
- package/cpp/ggml.c +1712 -325
- package/cpp/ggml.h +169 -100
- package/cpp/llama-grammar.cpp +721 -122
- package/cpp/llama-grammar.h +120 -15
- package/cpp/llama-impl.h +132 -1
- package/cpp/llama-sampling.cpp +1483 -354
- package/cpp/llama-sampling.h +20 -48
- package/cpp/llama-vocab.cpp +140 -7
- package/cpp/llama-vocab.h +3 -2
- package/cpp/llama.cpp +824 -327
- package/cpp/llama.h +235 -256
- package/cpp/rn-llama.hpp +18 -14
- package/cpp/sampling.cpp +353 -354
- package/cpp/sampling.h +62 -143
- package/cpp/sgemm.cpp +153 -0
- package/package.json +1 -1
- package/cpp/grammar-parser.cpp +0 -539
- package/cpp/grammar-parser.h +0 -29
package/cpp/llama-grammar.cpp
CHANGED
@@ -3,11 +3,31 @@
|
|
3
3
|
#include "llama-vocab.h"
|
4
4
|
#include "llama-sampling.h"
|
5
5
|
|
6
|
+
#include <cmath>
|
6
7
|
#include <algorithm>
|
8
|
+
#include <stdexcept>
|
7
9
|
|
8
|
-
//
|
9
|
-
//
|
10
|
-
|
10
|
+
//
|
11
|
+
// helpers
|
12
|
+
//
|
13
|
+
|
14
|
+
// NOTE: assumes valid utf8 (but checks for overrun)
|
15
|
+
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
16
|
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
17
|
+
uint8_t first_byte = static_cast<uint8_t>(*src);
|
18
|
+
uint8_t highbits = first_byte >> 4;
|
19
|
+
int len = lookup[highbits];
|
20
|
+
uint8_t mask = (1 << (8 - len)) - 1;
|
21
|
+
uint32_t value = first_byte & mask;
|
22
|
+
const char * end = src + len; // may overrun!
|
23
|
+
const char * pos = src + 1;
|
24
|
+
for ( ; pos < end && *pos; pos++) {
|
25
|
+
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
26
|
+
}
|
27
|
+
return std::make_pair(value, pos);
|
28
|
+
}
|
29
|
+
|
30
|
+
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
11
31
|
const std::string & src,
|
12
32
|
llama_partial_utf8 partial_start) {
|
13
33
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
@@ -40,7 +60,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
40
60
|
while (*pos != 0) {
|
41
61
|
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
42
62
|
uint8_t highbits = first_byte >> 4;
|
43
|
-
|
63
|
+
n_remain = lookup[highbits] - 1;
|
44
64
|
|
45
65
|
if (n_remain < 0) {
|
46
66
|
// invalid sequence, abort
|
@@ -50,7 +70,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
50
70
|
}
|
51
71
|
|
52
72
|
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
53
|
-
|
73
|
+
value = first_byte & mask;
|
54
74
|
|
55
75
|
++pos;
|
56
76
|
while (*pos != 0 && n_remain > 0) {
|
@@ -67,12 +87,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
67
87
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
68
88
|
}
|
69
89
|
|
70
|
-
|
71
|
-
return
|
90
|
+
static bool is_digit_char(char c) {
|
91
|
+
return '0' <= c && c <= '9';
|
72
92
|
}
|
73
93
|
|
74
|
-
|
75
|
-
return
|
94
|
+
static bool is_word_char(char c) {
|
95
|
+
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
96
|
+
}
|
97
|
+
|
98
|
+
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
99
|
+
const char * pos = src;
|
100
|
+
const char * end = src + size;
|
101
|
+
uint32_t value = 0;
|
102
|
+
for ( ; pos < end && *pos; pos++) {
|
103
|
+
value <<= 4;
|
104
|
+
char c = *pos;
|
105
|
+
if ('a' <= c && c <= 'f') {
|
106
|
+
value += c - 'a' + 10;
|
107
|
+
} else if ('A' <= c && c <= 'F') {
|
108
|
+
value += c - 'A' + 10;
|
109
|
+
} else if ('0' <= c && c <= '9') {
|
110
|
+
value += c - '0';
|
111
|
+
} else {
|
112
|
+
break;
|
113
|
+
}
|
114
|
+
}
|
115
|
+
if (pos != end) {
|
116
|
+
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
117
|
+
}
|
118
|
+
return std::make_pair(value, pos);
|
119
|
+
}
|
120
|
+
|
121
|
+
static const char * parse_space(const char * src, bool newline_ok) {
|
122
|
+
const char * pos = src;
|
123
|
+
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
124
|
+
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
125
|
+
if (*pos == '#') {
|
126
|
+
while (*pos && *pos != '\r' && *pos != '\n') {
|
127
|
+
pos++;
|
128
|
+
}
|
129
|
+
} else {
|
130
|
+
pos++;
|
131
|
+
}
|
132
|
+
}
|
133
|
+
return pos;
|
134
|
+
}
|
135
|
+
|
136
|
+
static const char * parse_name(const char * src) {
|
137
|
+
const char * pos = src;
|
138
|
+
while (is_word_char(*pos)) {
|
139
|
+
pos++;
|
140
|
+
}
|
141
|
+
if (pos == src) {
|
142
|
+
throw std::runtime_error(std::string("expecting name at ") + src);
|
143
|
+
}
|
144
|
+
return pos;
|
145
|
+
}
|
146
|
+
|
147
|
+
static const char * parse_int(const char * src) {
|
148
|
+
const char * pos = src;
|
149
|
+
while (is_digit_char(*pos)) {
|
150
|
+
pos++;
|
151
|
+
}
|
152
|
+
if (pos == src) {
|
153
|
+
throw std::runtime_error(std::string("expecting integer at ") + src);
|
154
|
+
}
|
155
|
+
return pos;
|
156
|
+
}
|
157
|
+
|
158
|
+
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
159
|
+
if (*src == '\\') {
|
160
|
+
switch (src[1]) {
|
161
|
+
case 'x': return parse_hex(src + 2, 2);
|
162
|
+
case 'u': return parse_hex(src + 2, 4);
|
163
|
+
case 'U': return parse_hex(src + 2, 8);
|
164
|
+
case 't': return std::make_pair('\t', src + 2);
|
165
|
+
case 'r': return std::make_pair('\r', src + 2);
|
166
|
+
case 'n': return std::make_pair('\n', src + 2);
|
167
|
+
case '\\':
|
168
|
+
case '"':
|
169
|
+
case '[':
|
170
|
+
case ']':
|
171
|
+
return std::make_pair(src[1], src + 2);
|
172
|
+
default:
|
173
|
+
throw std::runtime_error(std::string("unknown escape at ") + src);
|
174
|
+
}
|
175
|
+
} else if (*src) {
|
176
|
+
return decode_utf8(src);
|
177
|
+
}
|
178
|
+
throw std::runtime_error("unexpected end of input");
|
179
|
+
}
|
180
|
+
|
181
|
+
static void print_grammar_char(FILE * file, uint32_t c) {
|
182
|
+
if (0x20 <= c && c <= 0x7f) {
|
183
|
+
fprintf(file, "%c", static_cast<char>(c));
|
184
|
+
} else {
|
185
|
+
// cop out of encoding UTF-8
|
186
|
+
fprintf(file, "<U+%04X>", c);
|
187
|
+
}
|
188
|
+
}
|
189
|
+
|
190
|
+
static bool is_char_element(llama_grammar_element elem) {
|
191
|
+
switch (elem.type) {
|
192
|
+
case LLAMA_GRETYPE_CHAR: return true;
|
193
|
+
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
194
|
+
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
195
|
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
196
|
+
case LLAMA_GRETYPE_CHAR_ANY: return true;
|
197
|
+
default: return false;
|
198
|
+
}
|
199
|
+
}
|
200
|
+
|
201
|
+
static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
202
|
+
for (auto elem : rule) {
|
203
|
+
switch (elem.type) {
|
204
|
+
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
205
|
+
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
206
|
+
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
207
|
+
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
208
|
+
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
209
|
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
210
|
+
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
211
|
+
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
212
|
+
}
|
213
|
+
switch (elem.type) {
|
214
|
+
case LLAMA_GRETYPE_END:
|
215
|
+
case LLAMA_GRETYPE_ALT:
|
216
|
+
case LLAMA_GRETYPE_RULE_REF:
|
217
|
+
fprintf(file, "(%u) ", elem.value);
|
218
|
+
break;
|
219
|
+
case LLAMA_GRETYPE_CHAR:
|
220
|
+
case LLAMA_GRETYPE_CHAR_NOT:
|
221
|
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
222
|
+
case LLAMA_GRETYPE_CHAR_ALT:
|
223
|
+
case LLAMA_GRETYPE_CHAR_ANY:
|
224
|
+
fprintf(file, "(\"");
|
225
|
+
print_grammar_char(file, elem.value);
|
226
|
+
fprintf(file, "\") ");
|
227
|
+
break;
|
228
|
+
}
|
229
|
+
}
|
230
|
+
fprintf(file, "\n");
|
231
|
+
}
|
232
|
+
|
233
|
+
static void print_rule(
|
234
|
+
FILE * file,
|
235
|
+
uint32_t rule_id,
|
236
|
+
const llama_grammar_rule & rule,
|
237
|
+
const std::map<uint32_t, std::string> & symbol_id_names) {
|
238
|
+
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
239
|
+
throw std::runtime_error(
|
240
|
+
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
241
|
+
}
|
242
|
+
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
243
|
+
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
244
|
+
llama_grammar_element elem = rule[i];
|
245
|
+
switch (elem.type) {
|
246
|
+
case LLAMA_GRETYPE_END:
|
247
|
+
throw std::runtime_error(
|
248
|
+
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
249
|
+
std::to_string(i));
|
250
|
+
case LLAMA_GRETYPE_ALT:
|
251
|
+
fprintf(file, "| ");
|
252
|
+
break;
|
253
|
+
case LLAMA_GRETYPE_RULE_REF:
|
254
|
+
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
255
|
+
break;
|
256
|
+
case LLAMA_GRETYPE_CHAR:
|
257
|
+
fprintf(file, "[");
|
258
|
+
print_grammar_char(file, elem.value);
|
259
|
+
break;
|
260
|
+
case LLAMA_GRETYPE_CHAR_NOT:
|
261
|
+
fprintf(file, "[^");
|
262
|
+
print_grammar_char(file, elem.value);
|
263
|
+
break;
|
264
|
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
265
|
+
if (i == 0 || !is_char_element(rule[i - 1])) {
|
266
|
+
throw std::runtime_error(
|
267
|
+
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
268
|
+
std::to_string(rule_id) + "," + std::to_string(i));
|
269
|
+
}
|
270
|
+
fprintf(file, "-");
|
271
|
+
print_grammar_char(file, elem.value);
|
272
|
+
break;
|
273
|
+
case LLAMA_GRETYPE_CHAR_ALT:
|
274
|
+
if (i == 0 || !is_char_element(rule[i - 1])) {
|
275
|
+
throw std::runtime_error(
|
276
|
+
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
277
|
+
std::to_string(rule_id) + "," + std::to_string(i));
|
278
|
+
}
|
279
|
+
print_grammar_char(file, elem.value);
|
280
|
+
break;
|
281
|
+
case LLAMA_GRETYPE_CHAR_ANY:
|
282
|
+
fprintf(file, ".");
|
283
|
+
break;
|
284
|
+
}
|
285
|
+
if (is_char_element(elem)) {
|
286
|
+
switch (rule[i + 1].type) {
|
287
|
+
case LLAMA_GRETYPE_CHAR_ALT:
|
288
|
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
289
|
+
case LLAMA_GRETYPE_CHAR_ANY:
|
290
|
+
break;
|
291
|
+
default:
|
292
|
+
fprintf(file, "] ");
|
293
|
+
}
|
294
|
+
}
|
295
|
+
}
|
296
|
+
fprintf(file, "\n");
|
297
|
+
}
|
298
|
+
|
299
|
+
//
|
300
|
+
// implementation
|
301
|
+
//
|
302
|
+
|
303
|
+
uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
|
304
|
+
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
305
|
+
auto result = symbol_ids.emplace(std::string(src, len), next_id);
|
306
|
+
return result.first->second;
|
307
|
+
}
|
308
|
+
|
309
|
+
uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
|
310
|
+
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
311
|
+
symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
312
|
+
return next_id;
|
313
|
+
}
|
314
|
+
|
315
|
+
void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
|
316
|
+
if (rules.size() <= rule_id) {
|
317
|
+
rules.resize(rule_id + 1);
|
318
|
+
}
|
319
|
+
rules[rule_id] = rule;
|
320
|
+
}
|
321
|
+
|
322
|
+
const char * llama_grammar_parser::parse_alternates(
|
323
|
+
const char * src,
|
324
|
+
const std::string & rule_name,
|
325
|
+
uint32_t rule_id,
|
326
|
+
bool is_nested) {
|
327
|
+
llama_grammar_rule rule;
|
328
|
+
const char * pos = parse_sequence(src, rule_name, rule, is_nested);
|
329
|
+
while (*pos == '|') {
|
330
|
+
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
331
|
+
pos = parse_space(pos + 1, true);
|
332
|
+
pos = parse_sequence(pos, rule_name, rule, is_nested);
|
333
|
+
}
|
334
|
+
rule.push_back({LLAMA_GRETYPE_END, 0});
|
335
|
+
add_rule(rule_id, rule);
|
336
|
+
return pos;
|
337
|
+
}
|
338
|
+
|
339
|
+
const char * llama_grammar_parser::parse_sequence(
|
340
|
+
const char * src,
|
341
|
+
const std::string & rule_name,
|
342
|
+
llama_grammar_rule & rule,
|
343
|
+
bool is_nested) {
|
344
|
+
size_t last_sym_start = rule.size();
|
345
|
+
const char * pos = src;
|
346
|
+
|
347
|
+
auto handle_repetitions = [&](int min_times, int max_times) {
|
348
|
+
|
349
|
+
if (last_sym_start == rule.size()) {
|
350
|
+
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
351
|
+
}
|
352
|
+
|
353
|
+
// apply transformation to previous symbol (last_sym_start to end) according to
|
354
|
+
// the following rewrite rules:
|
355
|
+
// S{m,n} --> S S S (m times) S'(n-m)
|
356
|
+
// S'(x) ::= S S'(x-1) |
|
357
|
+
// (... n-m definitions of these S' rules ...)
|
358
|
+
// S'(1) ::= S |
|
359
|
+
// S{m,} --> S S S (m times) S'
|
360
|
+
// S' ::= S S' |
|
361
|
+
// S* --> S{0,}
|
362
|
+
// --> S' ::= S S' |
|
363
|
+
// S+ --> S{1,}
|
364
|
+
// --> S S'
|
365
|
+
// S' ::= S S' |
|
366
|
+
// S? --> S{0,1}
|
367
|
+
// --> S'
|
368
|
+
// S' ::= S |
|
369
|
+
|
370
|
+
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
371
|
+
if (min_times == 0) {
|
372
|
+
rule.resize(last_sym_start);
|
373
|
+
} else {
|
374
|
+
// Repeat the previous elements (min_times - 1) times
|
375
|
+
for (int i = 1; i < min_times; i++) {
|
376
|
+
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
377
|
+
}
|
378
|
+
}
|
379
|
+
|
380
|
+
uint32_t last_rec_rule_id = 0;
|
381
|
+
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
382
|
+
|
383
|
+
llama_grammar_rule rec_rule(prev_rule);
|
384
|
+
for (int i = 0; i < n_opt; i++) {
|
385
|
+
rec_rule.resize(prev_rule.size());
|
386
|
+
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
387
|
+
if (i > 0 || max_times < 0) {
|
388
|
+
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
389
|
+
}
|
390
|
+
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
391
|
+
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
392
|
+
add_rule( rec_rule_id, rec_rule);
|
393
|
+
last_rec_rule_id = rec_rule_id;
|
394
|
+
}
|
395
|
+
if (n_opt > 0) {
|
396
|
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
397
|
+
}
|
398
|
+
};
|
399
|
+
|
400
|
+
while (*pos) {
|
401
|
+
if (*pos == '"') { // literal string
|
402
|
+
pos++;
|
403
|
+
last_sym_start = rule.size();
|
404
|
+
while (*pos != '"') {
|
405
|
+
if (!*pos) {
|
406
|
+
throw std::runtime_error("unexpected end of input");
|
407
|
+
}
|
408
|
+
auto char_pair = parse_char(pos);
|
409
|
+
pos = char_pair.second;
|
410
|
+
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
411
|
+
}
|
412
|
+
pos = parse_space(pos + 1, is_nested);
|
413
|
+
} else if (*pos == '[') { // char range(s)
|
414
|
+
pos++;
|
415
|
+
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
416
|
+
if (*pos == '^') {
|
417
|
+
pos++;
|
418
|
+
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
419
|
+
}
|
420
|
+
last_sym_start = rule.size();
|
421
|
+
while (*pos != ']') {
|
422
|
+
if (!*pos) {
|
423
|
+
throw std::runtime_error("unexpected end of input");
|
424
|
+
}
|
425
|
+
auto char_pair = parse_char(pos);
|
426
|
+
pos = char_pair.second;
|
427
|
+
enum llama_gretype type = last_sym_start < rule.size()
|
428
|
+
? LLAMA_GRETYPE_CHAR_ALT
|
429
|
+
: start_type;
|
430
|
+
|
431
|
+
rule.push_back({type, char_pair.first});
|
432
|
+
if (pos[0] == '-' && pos[1] != ']') {
|
433
|
+
if (!pos[1]) {
|
434
|
+
throw std::runtime_error("unexpected end of input");
|
435
|
+
}
|
436
|
+
auto endchar_pair = parse_char(pos + 1);
|
437
|
+
pos = endchar_pair.second;
|
438
|
+
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
439
|
+
}
|
440
|
+
}
|
441
|
+
pos = parse_space(pos + 1, is_nested);
|
442
|
+
} else if (is_word_char(*pos)) { // rule reference
|
443
|
+
const char * name_end = parse_name(pos);
|
444
|
+
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
445
|
+
pos = parse_space(name_end, is_nested);
|
446
|
+
last_sym_start = rule.size();
|
447
|
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
448
|
+
} else if (*pos == '(') { // grouping
|
449
|
+
// parse nested alternates into synthesized rule
|
450
|
+
pos = parse_space(pos + 1, true);
|
451
|
+
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
452
|
+
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
453
|
+
last_sym_start = rule.size();
|
454
|
+
// output reference to synthesized rule
|
455
|
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
456
|
+
if (*pos != ')') {
|
457
|
+
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
458
|
+
}
|
459
|
+
pos = parse_space(pos + 1, is_nested);
|
460
|
+
} else if (*pos == '.') { // any char
|
461
|
+
last_sym_start = rule.size();
|
462
|
+
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
463
|
+
pos = parse_space(pos + 1, is_nested);
|
464
|
+
} else if (*pos == '*') {
|
465
|
+
pos = parse_space(pos + 1, is_nested);
|
466
|
+
handle_repetitions(0, -1);
|
467
|
+
} else if (*pos == '+') {
|
468
|
+
pos = parse_space(pos + 1, is_nested);
|
469
|
+
handle_repetitions(1, -1);
|
470
|
+
} else if (*pos == '?') {
|
471
|
+
pos = parse_space(pos + 1, is_nested);
|
472
|
+
handle_repetitions(0, 1);
|
473
|
+
} else if (*pos == '{') {
|
474
|
+
pos = parse_space(pos + 1, is_nested);
|
475
|
+
|
476
|
+
if (!is_digit_char(*pos)) {
|
477
|
+
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
478
|
+
}
|
479
|
+
const char * int_end = parse_int(pos);
|
480
|
+
int min_times = std::stoul(std::string(pos, int_end - pos));
|
481
|
+
pos = parse_space(int_end, is_nested);
|
482
|
+
|
483
|
+
int max_times = -1;
|
484
|
+
|
485
|
+
if (*pos == '}') {
|
486
|
+
max_times = min_times;
|
487
|
+
pos = parse_space(pos + 1, is_nested);
|
488
|
+
} else if (*pos == ',') {
|
489
|
+
pos = parse_space(pos + 1, is_nested);
|
490
|
+
|
491
|
+
if (is_digit_char(*pos)) {
|
492
|
+
const char * int_end = parse_int(pos);
|
493
|
+
max_times = std::stoul(std::string(pos, int_end - pos));
|
494
|
+
pos = parse_space(int_end, is_nested);
|
495
|
+
}
|
496
|
+
|
497
|
+
if (*pos != '}') {
|
498
|
+
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
499
|
+
}
|
500
|
+
pos = parse_space(pos + 1, is_nested);
|
501
|
+
} else {
|
502
|
+
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
503
|
+
}
|
504
|
+
handle_repetitions(min_times, max_times);
|
505
|
+
} else {
|
506
|
+
break;
|
507
|
+
}
|
508
|
+
}
|
509
|
+
return pos;
|
510
|
+
}
|
511
|
+
|
512
|
+
const char * llama_grammar_parser::parse_rule(const char * src) {
|
513
|
+
const char * name_end = parse_name(src);
|
514
|
+
const char * pos = parse_space(name_end, false);
|
515
|
+
size_t name_len = name_end - src;
|
516
|
+
uint32_t rule_id = get_symbol_id(src, name_len);
|
517
|
+
const std::string name(src, name_len);
|
518
|
+
|
519
|
+
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
520
|
+
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
521
|
+
}
|
522
|
+
pos = parse_space(pos + 3, true);
|
523
|
+
|
524
|
+
pos = parse_alternates(pos, name, rule_id, false);
|
525
|
+
|
526
|
+
if (*pos == '\r') {
|
527
|
+
pos += pos[1] == '\n' ? 2 : 1;
|
528
|
+
} else if (*pos == '\n') {
|
529
|
+
pos++;
|
530
|
+
} else if (*pos) {
|
531
|
+
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
532
|
+
}
|
533
|
+
return parse_space(pos, true);
|
534
|
+
}
|
535
|
+
|
536
|
+
bool llama_grammar_parser::parse(const char * src) {
|
537
|
+
try {
|
538
|
+
const char * pos = parse_space(src, true);
|
539
|
+
while (*pos) {
|
540
|
+
pos = parse_rule(pos);
|
541
|
+
}
|
542
|
+
// Validate the state to ensure that all rules are defined
|
543
|
+
for (const auto & rule : rules) {
|
544
|
+
if (rule.empty()) {
|
545
|
+
throw std::runtime_error("Undefined rule");
|
546
|
+
}
|
547
|
+
for (const auto & elem : rule) {
|
548
|
+
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
549
|
+
// Ensure that the rule at that location exists
|
550
|
+
if (elem.value >= rules.size() || rules[elem.value].empty()) {
|
551
|
+
// Get the name of the rule that is missing
|
552
|
+
for (const auto & kv : symbol_ids) {
|
553
|
+
if (kv.second == elem.value) {
|
554
|
+
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
|
555
|
+
}
|
556
|
+
}
|
557
|
+
}
|
558
|
+
}
|
559
|
+
}
|
560
|
+
}
|
561
|
+
} catch (const std::exception & err) {
|
562
|
+
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
563
|
+
rules.clear();
|
564
|
+
return false;
|
565
|
+
}
|
566
|
+
|
567
|
+
return true;
|
568
|
+
}
|
569
|
+
|
570
|
+
void llama_grammar_parser::print(FILE * file) {
|
571
|
+
try {
|
572
|
+
std::map<uint32_t, std::string> symbol_id_names;
|
573
|
+
for (const auto & kv : symbol_ids) {
|
574
|
+
symbol_id_names[kv.second] = kv.first;
|
575
|
+
}
|
576
|
+
for (size_t i = 0, end = rules.size(); i < end; i++) {
|
577
|
+
// fprintf(file, "%zu: ", i);
|
578
|
+
// print_rule_binary(file, rules[i]);
|
579
|
+
print_rule(file, uint32_t(i), rules[i], symbol_id_names);
|
580
|
+
// fprintf(file, "\n");
|
581
|
+
}
|
582
|
+
} catch (const std::exception & err) {
|
583
|
+
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
584
|
+
}
|
585
|
+
}
|
586
|
+
|
587
|
+
llama_grammar_stack llama_grammar_parser::c_rules() const {
|
588
|
+
llama_grammar_stack ret;
|
589
|
+
ret.reserve(rules.size());
|
590
|
+
for (const auto & rule : rules) {
|
591
|
+
ret.push_back(rule.data());
|
592
|
+
}
|
593
|
+
return ret;
|
76
594
|
}
|
77
595
|
|
78
596
|
// returns true iff pos points to the end of one of the definitions of a rule
|
@@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
|
|
89
607
|
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
90
608
|
const llama_grammar_element * pos,
|
91
609
|
const uint32_t chr) {
|
92
|
-
|
93
610
|
bool found = false;
|
94
611
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
95
612
|
|
@@ -225,16 +742,93 @@ static void llama_grammar_advance_stack(
|
|
225
742
|
}
|
226
743
|
}
|
227
744
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
745
|
+
static llama_grammar_candidates llama_grammar_reject_candidates(
|
746
|
+
const llama_grammar_rules & rules,
|
747
|
+
const llama_grammar_stacks & stacks,
|
748
|
+
const llama_grammar_candidates & candidates) {
|
749
|
+
LM_GGML_ASSERT(!stacks.empty()); // REVIEW
|
750
|
+
|
751
|
+
if (candidates.empty()) {
|
752
|
+
return {};
|
753
|
+
}
|
754
|
+
|
755
|
+
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
756
|
+
|
757
|
+
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
758
|
+
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
759
|
+
}
|
760
|
+
|
761
|
+
return rejects;
|
762
|
+
}
|
763
|
+
|
764
|
+
static bool llama_grammar_detect_left_recursion(
|
765
|
+
const llama_grammar_rules & rules,
|
766
|
+
size_t rule_index,
|
767
|
+
std::vector<bool> * rules_visited,
|
768
|
+
std::vector<bool> * rules_in_progress,
|
769
|
+
std::vector<bool> * rules_may_be_empty) {
|
770
|
+
if ((*rules_in_progress)[rule_index]) {
|
771
|
+
return true;
|
772
|
+
}
|
773
|
+
|
774
|
+
(*rules_in_progress)[rule_index] = true;
|
775
|
+
|
776
|
+
const llama_grammar_rule & rule = rules[rule_index];
|
777
|
+
|
778
|
+
// First check if the rule might produce the empty string. This could be done combined with the second
|
779
|
+
// step but it's more readable as two steps.
|
780
|
+
bool at_rule_start = true;
|
781
|
+
for (size_t i = 0; i < rule.size(); i++) {
|
782
|
+
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
783
|
+
if (at_rule_start) {
|
784
|
+
(*rules_may_be_empty)[rule_index] = true;
|
785
|
+
break;
|
786
|
+
}
|
787
|
+
at_rule_start = true;
|
788
|
+
} else {
|
789
|
+
at_rule_start = false;
|
790
|
+
}
|
791
|
+
}
|
792
|
+
|
793
|
+
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
794
|
+
// be empty)
|
795
|
+
bool recurse_into_nonterminal = true;
|
796
|
+
for (size_t i = 0; i < rule.size(); i++) {
|
797
|
+
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
798
|
+
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
799
|
+
return true;
|
800
|
+
}
|
801
|
+
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
802
|
+
recurse_into_nonterminal = false;
|
803
|
+
}
|
804
|
+
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
805
|
+
recurse_into_nonterminal = true;
|
806
|
+
} else {
|
807
|
+
recurse_into_nonterminal = false;
|
808
|
+
}
|
809
|
+
}
|
810
|
+
|
811
|
+
(*rules_in_progress)[rule_index] = false;
|
812
|
+
(*rules_visited)[rule_index] = true;
|
813
|
+
|
814
|
+
return false;
|
815
|
+
}
|
816
|
+
|
817
|
+
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
818
|
+
return grammar->rules;
|
819
|
+
}
|
820
|
+
|
821
|
+
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
822
|
+
return grammar->stacks;
|
823
|
+
}
|
824
|
+
|
232
825
|
void llama_grammar_accept(
|
233
826
|
const llama_grammar_rules & rules,
|
234
827
|
const llama_grammar_stacks & stacks,
|
235
828
|
const uint32_t chr,
|
236
|
-
llama_grammar_stacks &
|
237
|
-
|
829
|
+
llama_grammar_stacks & stacks_new) {
|
830
|
+
stacks_new.clear();
|
831
|
+
stacks_new.reserve(stacks.size());
|
238
832
|
|
239
833
|
for (const auto & stack : stacks) {
|
240
834
|
if (stack.empty()) {
|
@@ -250,29 +844,11 @@ void llama_grammar_accept(
|
|
250
844
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
251
845
|
new_stack.push_back(pos);
|
252
846
|
}
|
253
|
-
llama_grammar_advance_stack(rules, new_stack,
|
847
|
+
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
254
848
|
}
|
255
849
|
}
|
256
850
|
}
|
257
851
|
|
258
|
-
static llama_grammar_candidates llama_grammar_reject_candidates(
|
259
|
-
const llama_grammar_rules & rules,
|
260
|
-
const llama_grammar_stacks & stacks,
|
261
|
-
const llama_grammar_candidates & candidates) {
|
262
|
-
LM_GGML_ASSERT(!stacks.empty()); // REVIEW
|
263
|
-
|
264
|
-
if (candidates.empty()) {
|
265
|
-
return {};
|
266
|
-
}
|
267
|
-
|
268
|
-
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
269
|
-
|
270
|
-
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
271
|
-
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
272
|
-
}
|
273
|
-
return rejects;
|
274
|
-
}
|
275
|
-
|
276
852
|
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
277
853
|
const llama_grammar_rules & rules,
|
278
854
|
const llama_grammar_stack & stack,
|
@@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|
328
904
|
return rejects;
|
329
905
|
}
|
330
906
|
|
331
|
-
|
332
|
-
const llama_grammar_rules & rules,
|
333
|
-
size_t rule_index,
|
334
|
-
std::vector<bool> * rules_visited,
|
335
|
-
std::vector<bool> * rules_in_progress,
|
336
|
-
std::vector<bool> * rules_may_be_empty) {
|
337
|
-
if ((*rules_in_progress)[rule_index]) {
|
338
|
-
return true;
|
339
|
-
}
|
907
|
+
////////////////////
|
340
908
|
|
341
|
-
|
909
|
+
struct llama_grammar * llama_grammar_init_impl(
|
910
|
+
const struct llama_vocab * vocab,
|
911
|
+
const llama_grammar_element ** rules,
|
912
|
+
size_t n_rules,
|
913
|
+
size_t start_rule_index) {
|
914
|
+
const llama_grammar_element * pos;
|
342
915
|
|
343
|
-
|
916
|
+
// copy rule definitions into vectors
|
917
|
+
llama_grammar_rules vec_rules(n_rules);
|
918
|
+
for (size_t i = 0; i < n_rules; i++) {
|
919
|
+
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
920
|
+
vec_rules[i].push_back(*pos);
|
921
|
+
}
|
922
|
+
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
923
|
+
}
|
344
924
|
|
345
|
-
//
|
346
|
-
|
347
|
-
bool
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
at_rule_start = false;
|
925
|
+
// Check for left recursion
|
926
|
+
std::vector<bool> rules_visited(n_rules);
|
927
|
+
std::vector<bool> rules_in_progress(n_rules);
|
928
|
+
std::vector<bool> rules_may_be_empty(n_rules);
|
929
|
+
for (size_t i = 0; i < n_rules; i++) {
|
930
|
+
if (rules_visited[i]) {
|
931
|
+
continue;
|
932
|
+
}
|
933
|
+
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
934
|
+
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
|
935
|
+
return nullptr;
|
357
936
|
}
|
358
937
|
}
|
359
938
|
|
360
|
-
//
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
939
|
+
// loop over alternates of start rule to build initial stacks
|
940
|
+
llama_grammar_stacks stacks;
|
941
|
+
pos = vec_rules[start_rule_index].data();
|
942
|
+
do {
|
943
|
+
llama_grammar_stack stack;
|
944
|
+
if (!llama_grammar_is_end_of_sequence(pos)) {
|
945
|
+
// if alternate is nonempty, add to stack
|
946
|
+
stack.push_back(pos);
|
947
|
+
}
|
948
|
+
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
949
|
+
while (!llama_grammar_is_end_of_sequence(pos)) {
|
950
|
+
// scan to end of alternate def
|
951
|
+
pos++;
|
952
|
+
}
|
953
|
+
if (pos->type == LLAMA_GRETYPE_ALT) {
|
954
|
+
// there's another alternate def of this rule to process
|
955
|
+
pos++;
|
373
956
|
} else {
|
374
|
-
|
957
|
+
break;
|
375
958
|
}
|
376
|
-
}
|
959
|
+
} while (true);
|
377
960
|
|
378
|
-
|
379
|
-
|
380
|
-
|
961
|
+
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
962
|
+
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
963
|
+
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
964
|
+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
381
965
|
}
|
382
966
|
|
383
|
-
|
384
|
-
|
385
|
-
|
967
|
+
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
968
|
+
llama_grammar_parser parser;
|
969
|
+
|
970
|
+
// if there is a grammar, parse it
|
971
|
+
if (!parser.parse(grammar_str)) {
|
972
|
+
return nullptr;
|
973
|
+
}
|
974
|
+
|
975
|
+
// will be empty (default) if there are parse errors
|
976
|
+
if (parser.rules.empty()) {
|
977
|
+
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
978
|
+
return nullptr;
|
979
|
+
}
|
980
|
+
|
981
|
+
// Ensure that there is a "root" node.
|
982
|
+
if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
|
983
|
+
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
984
|
+
return nullptr;
|
985
|
+
}
|
986
|
+
|
987
|
+
std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
|
988
|
+
|
989
|
+
const size_t n_rules = grammar_rules.size();
|
990
|
+
const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
|
386
991
|
|
387
|
-
struct llama_grammar * llama_grammar_init_impl(
|
388
|
-
const llama_grammar_element ** rules,
|
389
|
-
size_t n_rules,
|
390
|
-
size_t start_rule_index) {
|
391
992
|
const llama_grammar_element * pos;
|
392
993
|
|
393
994
|
// copy rule definitions into vectors
|
394
995
|
llama_grammar_rules vec_rules(n_rules);
|
395
996
|
for (size_t i = 0; i < n_rules; i++) {
|
396
|
-
for (pos =
|
997
|
+
for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
397
998
|
vec_rules[i].push_back(*pos);
|
398
999
|
}
|
399
1000
|
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
@@ -438,22 +1039,26 @@ struct llama_grammar * llama_grammar_init_impl(
|
|
438
1039
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
439
1040
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
440
1041
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
441
|
-
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
1042
|
+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
442
1043
|
}
|
443
1044
|
|
444
1045
|
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
1046
|
+
if (grammar == nullptr) {
|
1047
|
+
return;
|
1048
|
+
}
|
1049
|
+
|
445
1050
|
delete grammar;
|
446
1051
|
}
|
447
1052
|
|
448
|
-
struct llama_grammar *
|
449
|
-
llama_grammar * result = new llama_grammar{ grammar
|
1053
|
+
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
1054
|
+
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
|
450
1055
|
|
451
1056
|
// redirect elements in stacks to point to new rules
|
452
1057
|
for (size_t is = 0; is < result->stacks.size(); is++) {
|
453
1058
|
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
454
|
-
for (size_t ir0 = 0; ir0 < grammar
|
455
|
-
for (size_t ir1 = 0; ir1 < grammar
|
456
|
-
if (grammar
|
1059
|
+
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
1060
|
+
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
1061
|
+
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
457
1062
|
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
458
1063
|
}
|
459
1064
|
}
|
@@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
|
|
464
1069
|
return result;
|
465
1070
|
}
|
466
1071
|
|
467
|
-
void
|
468
|
-
LM_GGML_ASSERT(grammar);
|
469
|
-
LM_GGML_ASSERT(vocab);
|
470
|
-
|
471
|
-
int64_t t_start_sample_us = lm_ggml_time_us();
|
1072
|
+
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
1073
|
+
LM_GGML_ASSERT(grammar.vocab != nullptr);
|
472
1074
|
|
473
1075
|
bool allow_eog = false;
|
474
|
-
for (const auto & stack : grammar
|
1076
|
+
for (const auto & stack : grammar.stacks) {
|
475
1077
|
if (stack.empty()) {
|
476
1078
|
allow_eog = true;
|
477
1079
|
break;
|
@@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
|
|
479
1081
|
}
|
480
1082
|
|
481
1083
|
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
482
|
-
candidates_decoded.reserve(
|
1084
|
+
candidates_decoded.reserve(cur_p->size);
|
483
1085
|
|
484
1086
|
llama_grammar_candidates candidates_grammar;
|
485
|
-
candidates_grammar.reserve(
|
1087
|
+
candidates_grammar.reserve(cur_p->size);
|
486
1088
|
|
487
|
-
for (size_t i = 0; i <
|
488
|
-
const llama_token id =
|
489
|
-
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
1089
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
1090
|
+
const llama_token id = cur_p->data[i].id;
|
1091
|
+
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
|
490
1092
|
|
491
|
-
if (llama_token_is_eog_impl(*vocab, id)) {
|
1093
|
+
if (llama_token_is_eog_impl(*grammar.vocab, id)) {
|
492
1094
|
if (!allow_eog) {
|
493
|
-
|
1095
|
+
cur_p->data[i].logit = -INFINITY;
|
494
1096
|
}
|
495
1097
|
} else if (piece.empty() || piece[0] == 0) {
|
496
|
-
|
1098
|
+
cur_p->data[i].logit = -INFINITY;
|
497
1099
|
} else {
|
498
|
-
candidates_decoded.push_back(decode_utf8(piece, grammar
|
1100
|
+
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
499
1101
|
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
500
1102
|
}
|
501
1103
|
}
|
502
1104
|
|
503
|
-
const auto rejects = llama_grammar_reject_candidates(grammar
|
1105
|
+
const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
504
1106
|
for (const auto & reject : rejects) {
|
505
|
-
|
1107
|
+
cur_p->data[reject.index].logit = -INFINITY;
|
506
1108
|
}
|
507
|
-
|
508
|
-
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
509
1109
|
}
|
510
1110
|
|
511
|
-
void
|
512
|
-
|
1111
|
+
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
1112
|
+
LM_GGML_ASSERT(grammar.vocab != nullptr);
|
513
1113
|
|
514
|
-
if (llama_token_is_eog_impl(*vocab, token)) {
|
515
|
-
for (const auto & stack : grammar
|
1114
|
+
if (llama_token_is_eog_impl(*grammar.vocab, token)) {
|
1115
|
+
for (const auto & stack : grammar.stacks) {
|
516
1116
|
if (stack.empty()) {
|
517
1117
|
return;
|
518
1118
|
}
|
@@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|
520
1120
|
LM_GGML_ABORT("fatal error");
|
521
1121
|
}
|
522
1122
|
|
523
|
-
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
1123
|
+
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
524
1124
|
|
525
1125
|
// Note terminating 0 in decoded string
|
526
|
-
const auto decoded = decode_utf8(piece, grammar
|
1126
|
+
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
527
1127
|
const auto & code_points = decoded.first;
|
528
1128
|
|
529
|
-
llama_grammar_stacks
|
1129
|
+
llama_grammar_stacks stacks_new;
|
1130
|
+
|
530
1131
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
531
|
-
llama_grammar_accept(grammar
|
532
|
-
grammar
|
1132
|
+
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
|
1133
|
+
grammar.stacks = std::move(stacks_new);
|
533
1134
|
}
|
534
1135
|
|
535
|
-
grammar
|
536
|
-
LM_GGML_ASSERT(!grammar
|
537
|
-
|
538
|
-
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
1136
|
+
grammar.partial_utf8 = decoded.second;
|
1137
|
+
LM_GGML_ASSERT(!grammar.stacks.empty());
|
539
1138
|
}
|