cui-llama.rn 1.1.2 → 1.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,11 +3,31 @@
3
3
  #include "llama-vocab.h"
4
4
  #include "llama-sampling.h"
5
5
 
6
+ #include <cmath>
6
7
  #include <algorithm>
8
+ #include <stdexcept>
7
9
 
8
- // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
9
- // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
10
- std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
10
+ //
11
+ // helpers
12
+ //
13
+
14
+ // NOTE: assumes valid utf8 (but checks for overrun)
15
+ static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
16
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
17
+ uint8_t first_byte = static_cast<uint8_t>(*src);
18
+ uint8_t highbits = first_byte >> 4;
19
+ int len = lookup[highbits];
20
+ uint8_t mask = (1 << (8 - len)) - 1;
21
+ uint32_t value = first_byte & mask;
22
+ const char * end = src + len; // may overrun!
23
+ const char * pos = src + 1;
24
+ for ( ; pos < end && *pos; pos++) {
25
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
26
+ }
27
+ return std::make_pair(value, pos);
28
+ }
29
+
30
+ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
11
31
  const std::string & src,
12
32
  llama_partial_utf8 partial_start) {
13
33
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -40,7 +60,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
40
60
  while (*pos != 0) {
41
61
  uint8_t first_byte = static_cast<uint8_t>(*pos);
42
62
  uint8_t highbits = first_byte >> 4;
43
- n_remain = lookup[highbits] - 1;
63
+ n_remain = lookup[highbits] - 1;
44
64
 
45
65
  if (n_remain < 0) {
46
66
  // invalid sequence, abort
@@ -50,7 +70,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
50
70
  }
51
71
 
52
72
  uint8_t mask = (1 << (7 - n_remain)) - 1;
53
- value = first_byte & mask;
73
+ value = first_byte & mask;
54
74
 
55
75
  ++pos;
56
76
  while (*pos != 0 && n_remain > 0) {
@@ -67,12 +87,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
67
87
  return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
68
88
  }
69
89
 
70
- const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
71
- return grammar->rules;
90
+ static bool is_digit_char(char c) {
91
+ return '0' <= c && c <= '9';
72
92
  }
73
93
 
74
- llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
75
- return grammar->stacks;
94
+ static bool is_word_char(char c) {
95
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
96
+ }
97
+
98
+ static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
99
+ const char * pos = src;
100
+ const char * end = src + size;
101
+ uint32_t value = 0;
102
+ for ( ; pos < end && *pos; pos++) {
103
+ value <<= 4;
104
+ char c = *pos;
105
+ if ('a' <= c && c <= 'f') {
106
+ value += c - 'a' + 10;
107
+ } else if ('A' <= c && c <= 'F') {
108
+ value += c - 'A' + 10;
109
+ } else if ('0' <= c && c <= '9') {
110
+ value += c - '0';
111
+ } else {
112
+ break;
113
+ }
114
+ }
115
+ if (pos != end) {
116
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
117
+ }
118
+ return std::make_pair(value, pos);
119
+ }
120
+
121
+ static const char * parse_space(const char * src, bool newline_ok) {
122
+ const char * pos = src;
123
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
124
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
125
+ if (*pos == '#') {
126
+ while (*pos && *pos != '\r' && *pos != '\n') {
127
+ pos++;
128
+ }
129
+ } else {
130
+ pos++;
131
+ }
132
+ }
133
+ return pos;
134
+ }
135
+
136
+ static const char * parse_name(const char * src) {
137
+ const char * pos = src;
138
+ while (is_word_char(*pos)) {
139
+ pos++;
140
+ }
141
+ if (pos == src) {
142
+ throw std::runtime_error(std::string("expecting name at ") + src);
143
+ }
144
+ return pos;
145
+ }
146
+
147
+ static const char * parse_int(const char * src) {
148
+ const char * pos = src;
149
+ while (is_digit_char(*pos)) {
150
+ pos++;
151
+ }
152
+ if (pos == src) {
153
+ throw std::runtime_error(std::string("expecting integer at ") + src);
154
+ }
155
+ return pos;
156
+ }
157
+
158
+ static std::pair<uint32_t, const char *> parse_char(const char * src) {
159
+ if (*src == '\\') {
160
+ switch (src[1]) {
161
+ case 'x': return parse_hex(src + 2, 2);
162
+ case 'u': return parse_hex(src + 2, 4);
163
+ case 'U': return parse_hex(src + 2, 8);
164
+ case 't': return std::make_pair('\t', src + 2);
165
+ case 'r': return std::make_pair('\r', src + 2);
166
+ case 'n': return std::make_pair('\n', src + 2);
167
+ case '\\':
168
+ case '"':
169
+ case '[':
170
+ case ']':
171
+ return std::make_pair(src[1], src + 2);
172
+ default:
173
+ throw std::runtime_error(std::string("unknown escape at ") + src);
174
+ }
175
+ } else if (*src) {
176
+ return decode_utf8(src);
177
+ }
178
+ throw std::runtime_error("unexpected end of input");
179
+ }
180
+
181
+ static void print_grammar_char(FILE * file, uint32_t c) {
182
+ if (0x20 <= c && c <= 0x7f) {
183
+ fprintf(file, "%c", static_cast<char>(c));
184
+ } else {
185
+ // cop out of encoding UTF-8
186
+ fprintf(file, "<U+%04X>", c);
187
+ }
188
+ }
189
+
190
+ static bool is_char_element(llama_grammar_element elem) {
191
+ switch (elem.type) {
192
+ case LLAMA_GRETYPE_CHAR: return true;
193
+ case LLAMA_GRETYPE_CHAR_NOT: return true;
194
+ case LLAMA_GRETYPE_CHAR_ALT: return true;
195
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
196
+ case LLAMA_GRETYPE_CHAR_ANY: return true;
197
+ default: return false;
198
+ }
199
+ }
200
+
201
+ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
202
+ for (auto elem : rule) {
203
+ switch (elem.type) {
204
+ case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
205
+ case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
206
+ case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
207
+ case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
208
+ case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
209
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
210
+ case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
211
+ case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
212
+ }
213
+ switch (elem.type) {
214
+ case LLAMA_GRETYPE_END:
215
+ case LLAMA_GRETYPE_ALT:
216
+ case LLAMA_GRETYPE_RULE_REF:
217
+ fprintf(file, "(%u) ", elem.value);
218
+ break;
219
+ case LLAMA_GRETYPE_CHAR:
220
+ case LLAMA_GRETYPE_CHAR_NOT:
221
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
222
+ case LLAMA_GRETYPE_CHAR_ALT:
223
+ case LLAMA_GRETYPE_CHAR_ANY:
224
+ fprintf(file, "(\"");
225
+ print_grammar_char(file, elem.value);
226
+ fprintf(file, "\") ");
227
+ break;
228
+ }
229
+ }
230
+ fprintf(file, "\n");
231
+ }
232
+
233
+ static void print_rule(
234
+ FILE * file,
235
+ uint32_t rule_id,
236
+ const llama_grammar_rule & rule,
237
+ const std::map<uint32_t, std::string> & symbol_id_names) {
238
+ if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
239
+ throw std::runtime_error(
240
+ "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
241
+ }
242
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
243
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
244
+ llama_grammar_element elem = rule[i];
245
+ switch (elem.type) {
246
+ case LLAMA_GRETYPE_END:
247
+ throw std::runtime_error(
248
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
249
+ std::to_string(i));
250
+ case LLAMA_GRETYPE_ALT:
251
+ fprintf(file, "| ");
252
+ break;
253
+ case LLAMA_GRETYPE_RULE_REF:
254
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
255
+ break;
256
+ case LLAMA_GRETYPE_CHAR:
257
+ fprintf(file, "[");
258
+ print_grammar_char(file, elem.value);
259
+ break;
260
+ case LLAMA_GRETYPE_CHAR_NOT:
261
+ fprintf(file, "[^");
262
+ print_grammar_char(file, elem.value);
263
+ break;
264
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
265
+ if (i == 0 || !is_char_element(rule[i - 1])) {
266
+ throw std::runtime_error(
267
+ "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
268
+ std::to_string(rule_id) + "," + std::to_string(i));
269
+ }
270
+ fprintf(file, "-");
271
+ print_grammar_char(file, elem.value);
272
+ break;
273
+ case LLAMA_GRETYPE_CHAR_ALT:
274
+ if (i == 0 || !is_char_element(rule[i - 1])) {
275
+ throw std::runtime_error(
276
+ "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
277
+ std::to_string(rule_id) + "," + std::to_string(i));
278
+ }
279
+ print_grammar_char(file, elem.value);
280
+ break;
281
+ case LLAMA_GRETYPE_CHAR_ANY:
282
+ fprintf(file, ".");
283
+ break;
284
+ }
285
+ if (is_char_element(elem)) {
286
+ switch (rule[i + 1].type) {
287
+ case LLAMA_GRETYPE_CHAR_ALT:
288
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
289
+ case LLAMA_GRETYPE_CHAR_ANY:
290
+ break;
291
+ default:
292
+ fprintf(file, "] ");
293
+ }
294
+ }
295
+ }
296
+ fprintf(file, "\n");
297
+ }
298
+
299
+ //
300
+ // implementation
301
+ //
302
+
303
+ uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
304
+ uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
305
+ auto result = symbol_ids.emplace(std::string(src, len), next_id);
306
+ return result.first->second;
307
+ }
308
+
309
+ uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
310
+ uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
311
+ symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
312
+ return next_id;
313
+ }
314
+
315
+ void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
316
+ if (rules.size() <= rule_id) {
317
+ rules.resize(rule_id + 1);
318
+ }
319
+ rules[rule_id] = rule;
320
+ }
321
+
322
+ const char * llama_grammar_parser::parse_alternates(
323
+ const char * src,
324
+ const std::string & rule_name,
325
+ uint32_t rule_id,
326
+ bool is_nested) {
327
+ llama_grammar_rule rule;
328
+ const char * pos = parse_sequence(src, rule_name, rule, is_nested);
329
+ while (*pos == '|') {
330
+ rule.push_back({LLAMA_GRETYPE_ALT, 0});
331
+ pos = parse_space(pos + 1, true);
332
+ pos = parse_sequence(pos, rule_name, rule, is_nested);
333
+ }
334
+ rule.push_back({LLAMA_GRETYPE_END, 0});
335
+ add_rule(rule_id, rule);
336
+ return pos;
337
+ }
338
+
339
+ const char * llama_grammar_parser::parse_sequence(
340
+ const char * src,
341
+ const std::string & rule_name,
342
+ llama_grammar_rule & rule,
343
+ bool is_nested) {
344
+ size_t last_sym_start = rule.size();
345
+ const char * pos = src;
346
+
347
+ auto handle_repetitions = [&](int min_times, int max_times) {
348
+
349
+ if (last_sym_start == rule.size()) {
350
+ throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
351
+ }
352
+
353
+ // apply transformation to previous symbol (last_sym_start to end) according to
354
+ // the following rewrite rules:
355
+ // S{m,n} --> S S S (m times) S'(n-m)
356
+ // S'(x) ::= S S'(x-1) |
357
+ // (... n-m definitions of these S' rules ...)
358
+ // S'(1) ::= S |
359
+ // S{m,} --> S S S (m times) S'
360
+ // S' ::= S S' |
361
+ // S* --> S{0,}
362
+ // --> S' ::= S S' |
363
+ // S+ --> S{1,}
364
+ // --> S S'
365
+ // S' ::= S S' |
366
+ // S? --> S{0,1}
367
+ // --> S'
368
+ // S' ::= S |
369
+
370
+ llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
371
+ if (min_times == 0) {
372
+ rule.resize(last_sym_start);
373
+ } else {
374
+ // Repeat the previous elements (min_times - 1) times
375
+ for (int i = 1; i < min_times; i++) {
376
+ rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
377
+ }
378
+ }
379
+
380
+ uint32_t last_rec_rule_id = 0;
381
+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
382
+
383
+ llama_grammar_rule rec_rule(prev_rule);
384
+ for (int i = 0; i < n_opt; i++) {
385
+ rec_rule.resize(prev_rule.size());
386
+ uint32_t rec_rule_id = generate_symbol_id( rule_name);
387
+ if (i > 0 || max_times < 0) {
388
+ rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
389
+ }
390
+ rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
391
+ rec_rule.push_back({LLAMA_GRETYPE_END, 0});
392
+ add_rule( rec_rule_id, rec_rule);
393
+ last_rec_rule_id = rec_rule_id;
394
+ }
395
+ if (n_opt > 0) {
396
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
397
+ }
398
+ };
399
+
400
+ while (*pos) {
401
+ if (*pos == '"') { // literal string
402
+ pos++;
403
+ last_sym_start = rule.size();
404
+ while (*pos != '"') {
405
+ if (!*pos) {
406
+ throw std::runtime_error("unexpected end of input");
407
+ }
408
+ auto char_pair = parse_char(pos);
409
+ pos = char_pair.second;
410
+ rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
411
+ }
412
+ pos = parse_space(pos + 1, is_nested);
413
+ } else if (*pos == '[') { // char range(s)
414
+ pos++;
415
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
416
+ if (*pos == '^') {
417
+ pos++;
418
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
419
+ }
420
+ last_sym_start = rule.size();
421
+ while (*pos != ']') {
422
+ if (!*pos) {
423
+ throw std::runtime_error("unexpected end of input");
424
+ }
425
+ auto char_pair = parse_char(pos);
426
+ pos = char_pair.second;
427
+ enum llama_gretype type = last_sym_start < rule.size()
428
+ ? LLAMA_GRETYPE_CHAR_ALT
429
+ : start_type;
430
+
431
+ rule.push_back({type, char_pair.first});
432
+ if (pos[0] == '-' && pos[1] != ']') {
433
+ if (!pos[1]) {
434
+ throw std::runtime_error("unexpected end of input");
435
+ }
436
+ auto endchar_pair = parse_char(pos + 1);
437
+ pos = endchar_pair.second;
438
+ rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
439
+ }
440
+ }
441
+ pos = parse_space(pos + 1, is_nested);
442
+ } else if (is_word_char(*pos)) { // rule reference
443
+ const char * name_end = parse_name(pos);
444
+ uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
445
+ pos = parse_space(name_end, is_nested);
446
+ last_sym_start = rule.size();
447
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
448
+ } else if (*pos == '(') { // grouping
449
+ // parse nested alternates into synthesized rule
450
+ pos = parse_space(pos + 1, true);
451
+ uint32_t sub_rule_id = generate_symbol_id(rule_name);
452
+ pos = parse_alternates(pos, rule_name, sub_rule_id, true);
453
+ last_sym_start = rule.size();
454
+ // output reference to synthesized rule
455
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
456
+ if (*pos != ')') {
457
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
458
+ }
459
+ pos = parse_space(pos + 1, is_nested);
460
+ } else if (*pos == '.') { // any char
461
+ last_sym_start = rule.size();
462
+ rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
463
+ pos = parse_space(pos + 1, is_nested);
464
+ } else if (*pos == '*') {
465
+ pos = parse_space(pos + 1, is_nested);
466
+ handle_repetitions(0, -1);
467
+ } else if (*pos == '+') {
468
+ pos = parse_space(pos + 1, is_nested);
469
+ handle_repetitions(1, -1);
470
+ } else if (*pos == '?') {
471
+ pos = parse_space(pos + 1, is_nested);
472
+ handle_repetitions(0, 1);
473
+ } else if (*pos == '{') {
474
+ pos = parse_space(pos + 1, is_nested);
475
+
476
+ if (!is_digit_char(*pos)) {
477
+ throw std::runtime_error(std::string("expecting an int at ") + pos);
478
+ }
479
+ const char * int_end = parse_int(pos);
480
+ int min_times = std::stoul(std::string(pos, int_end - pos));
481
+ pos = parse_space(int_end, is_nested);
482
+
483
+ int max_times = -1;
484
+
485
+ if (*pos == '}') {
486
+ max_times = min_times;
487
+ pos = parse_space(pos + 1, is_nested);
488
+ } else if (*pos == ',') {
489
+ pos = parse_space(pos + 1, is_nested);
490
+
491
+ if (is_digit_char(*pos)) {
492
+ const char * int_end = parse_int(pos);
493
+ max_times = std::stoul(std::string(pos, int_end - pos));
494
+ pos = parse_space(int_end, is_nested);
495
+ }
496
+
497
+ if (*pos != '}') {
498
+ throw std::runtime_error(std::string("expecting '}' at ") + pos);
499
+ }
500
+ pos = parse_space(pos + 1, is_nested);
501
+ } else {
502
+ throw std::runtime_error(std::string("expecting ',' at ") + pos);
503
+ }
504
+ handle_repetitions(min_times, max_times);
505
+ } else {
506
+ break;
507
+ }
508
+ }
509
+ return pos;
510
+ }
511
+
512
+ const char * llama_grammar_parser::parse_rule(const char * src) {
513
+ const char * name_end = parse_name(src);
514
+ const char * pos = parse_space(name_end, false);
515
+ size_t name_len = name_end - src;
516
+ uint32_t rule_id = get_symbol_id(src, name_len);
517
+ const std::string name(src, name_len);
518
+
519
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
520
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
521
+ }
522
+ pos = parse_space(pos + 3, true);
523
+
524
+ pos = parse_alternates(pos, name, rule_id, false);
525
+
526
+ if (*pos == '\r') {
527
+ pos += pos[1] == '\n' ? 2 : 1;
528
+ } else if (*pos == '\n') {
529
+ pos++;
530
+ } else if (*pos) {
531
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
532
+ }
533
+ return parse_space(pos, true);
534
+ }
535
+
536
+ bool llama_grammar_parser::parse(const char * src) {
537
+ try {
538
+ const char * pos = parse_space(src, true);
539
+ while (*pos) {
540
+ pos = parse_rule(pos);
541
+ }
542
+ // Validate the state to ensure that all rules are defined
543
+ for (const auto & rule : rules) {
544
+ if (rule.empty()) {
545
+ throw std::runtime_error("Undefined rule");
546
+ }
547
+ for (const auto & elem : rule) {
548
+ if (elem.type == LLAMA_GRETYPE_RULE_REF) {
549
+ // Ensure that the rule at that location exists
550
+ if (elem.value >= rules.size() || rules[elem.value].empty()) {
551
+ // Get the name of the rule that is missing
552
+ for (const auto & kv : symbol_ids) {
553
+ if (kv.second == elem.value) {
554
+ throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
555
+ }
556
+ }
557
+ }
558
+ }
559
+ }
560
+ }
561
+ } catch (const std::exception & err) {
562
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
563
+ rules.clear();
564
+ return false;
565
+ }
566
+
567
+ return true;
568
+ }
569
+
570
+ void llama_grammar_parser::print(FILE * file) {
571
+ try {
572
+ std::map<uint32_t, std::string> symbol_id_names;
573
+ for (const auto & kv : symbol_ids) {
574
+ symbol_id_names[kv.second] = kv.first;
575
+ }
576
+ for (size_t i = 0, end = rules.size(); i < end; i++) {
577
+ // fprintf(file, "%zu: ", i);
578
+ // print_rule_binary(file, rules[i]);
579
+ print_rule(file, uint32_t(i), rules[i], symbol_id_names);
580
+ // fprintf(file, "\n");
581
+ }
582
+ } catch (const std::exception & err) {
583
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
584
+ }
585
+ }
586
+
587
+ llama_grammar_stack llama_grammar_parser::c_rules() const {
588
+ llama_grammar_stack ret;
589
+ ret.reserve(rules.size());
590
+ for (const auto & rule : rules) {
591
+ ret.push_back(rule.data());
592
+ }
593
+ return ret;
76
594
  }
77
595
 
78
596
  // returns true iff pos points to the end of one of the definitions of a rule
@@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
89
607
  static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
90
608
  const llama_grammar_element * pos,
91
609
  const uint32_t chr) {
92
-
93
610
  bool found = false;
94
611
  bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
95
612
 
@@ -225,16 +742,93 @@ static void llama_grammar_advance_stack(
225
742
  }
226
743
  }
227
744
 
228
- // takes a set of possible pushdown stacks on a grammar, which are required to
229
- // be positioned at a character range (see `llama_grammar_advance_stack`), and
230
- // produces the N possible stacks if the given char is accepted at those
231
- // positions
745
+ static llama_grammar_candidates llama_grammar_reject_candidates(
746
+ const llama_grammar_rules & rules,
747
+ const llama_grammar_stacks & stacks,
748
+ const llama_grammar_candidates & candidates) {
749
+ LM_GGML_ASSERT(!stacks.empty()); // REVIEW
750
+
751
+ if (candidates.empty()) {
752
+ return {};
753
+ }
754
+
755
+ auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
756
+
757
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
758
+ rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
759
+ }
760
+
761
+ return rejects;
762
+ }
763
+
764
+ static bool llama_grammar_detect_left_recursion(
765
+ const llama_grammar_rules & rules,
766
+ size_t rule_index,
767
+ std::vector<bool> * rules_visited,
768
+ std::vector<bool> * rules_in_progress,
769
+ std::vector<bool> * rules_may_be_empty) {
770
+ if ((*rules_in_progress)[rule_index]) {
771
+ return true;
772
+ }
773
+
774
+ (*rules_in_progress)[rule_index] = true;
775
+
776
+ const llama_grammar_rule & rule = rules[rule_index];
777
+
778
+ // First check if the rule might produce the empty string. This could be done combined with the second
779
+ // step but it's more readable as two steps.
780
+ bool at_rule_start = true;
781
+ for (size_t i = 0; i < rule.size(); i++) {
782
+ if (llama_grammar_is_end_of_sequence(&rule[i])) {
783
+ if (at_rule_start) {
784
+ (*rules_may_be_empty)[rule_index] = true;
785
+ break;
786
+ }
787
+ at_rule_start = true;
788
+ } else {
789
+ at_rule_start = false;
790
+ }
791
+ }
792
+
793
+ // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
794
+ // be empty)
795
+ bool recurse_into_nonterminal = true;
796
+ for (size_t i = 0; i < rule.size(); i++) {
797
+ if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
798
+ if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
799
+ return true;
800
+ }
801
+ if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
802
+ recurse_into_nonterminal = false;
803
+ }
804
+ } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
805
+ recurse_into_nonterminal = true;
806
+ } else {
807
+ recurse_into_nonterminal = false;
808
+ }
809
+ }
810
+
811
+ (*rules_in_progress)[rule_index] = false;
812
+ (*rules_visited)[rule_index] = true;
813
+
814
+ return false;
815
+ }
816
+
817
+ const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
818
+ return grammar->rules;
819
+ }
820
+
821
+ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
822
+ return grammar->stacks;
823
+ }
824
+
232
825
  void llama_grammar_accept(
233
826
  const llama_grammar_rules & rules,
234
827
  const llama_grammar_stacks & stacks,
235
828
  const uint32_t chr,
236
- llama_grammar_stacks & new_stacks) {
237
- new_stacks.clear();
829
+ llama_grammar_stacks & stacks_new) {
830
+ stacks_new.clear();
831
+ stacks_new.reserve(stacks.size());
238
832
 
239
833
  for (const auto & stack : stacks) {
240
834
  if (stack.empty()) {
@@ -250,29 +844,11 @@ void llama_grammar_accept(
250
844
  if (!llama_grammar_is_end_of_sequence(pos)) {
251
845
  new_stack.push_back(pos);
252
846
  }
253
- llama_grammar_advance_stack(rules, new_stack, new_stacks);
847
+ llama_grammar_advance_stack(rules, new_stack, stacks_new);
254
848
  }
255
849
  }
256
850
  }
257
851
 
258
- static llama_grammar_candidates llama_grammar_reject_candidates(
259
- const llama_grammar_rules & rules,
260
- const llama_grammar_stacks & stacks,
261
- const llama_grammar_candidates & candidates) {
262
- LM_GGML_ASSERT(!stacks.empty()); // REVIEW
263
-
264
- if (candidates.empty()) {
265
- return {};
266
- }
267
-
268
- auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
269
-
270
- for (size_t i = 1, size = stacks.size(); i < size; ++i) {
271
- rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
272
- }
273
- return rejects;
274
- }
275
-
276
852
  llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
277
853
  const llama_grammar_rules & rules,
278
854
  const llama_grammar_stack & stack,
@@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
328
904
  return rejects;
329
905
  }
330
906
 
331
- static bool llama_grammar_detect_left_recursion(
332
- const llama_grammar_rules & rules,
333
- size_t rule_index,
334
- std::vector<bool> * rules_visited,
335
- std::vector<bool> * rules_in_progress,
336
- std::vector<bool> * rules_may_be_empty) {
337
- if ((*rules_in_progress)[rule_index]) {
338
- return true;
339
- }
907
+ ////////////////////
340
908
 
341
- (*rules_in_progress)[rule_index] = true;
909
+ struct llama_grammar * llama_grammar_init_impl(
910
+ const struct llama_vocab * vocab,
911
+ const llama_grammar_element ** rules,
912
+ size_t n_rules,
913
+ size_t start_rule_index) {
914
+ const llama_grammar_element * pos;
342
915
 
343
- const llama_grammar_rule & rule = rules[rule_index];
916
+ // copy rule definitions into vectors
917
+ llama_grammar_rules vec_rules(n_rules);
918
+ for (size_t i = 0; i < n_rules; i++) {
919
+ for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
920
+ vec_rules[i].push_back(*pos);
921
+ }
922
+ vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
923
+ }
344
924
 
345
- // First check if the rule might produce the empty string. This could be done combined with the second
346
- // step but it's more readable as two steps.
347
- bool at_rule_start = true;
348
- for (size_t i = 0; i < rule.size(); i++) {
349
- if (llama_grammar_is_end_of_sequence(&rule[i])) {
350
- if (at_rule_start) {
351
- (*rules_may_be_empty)[rule_index] = true;
352
- break;
353
- }
354
- at_rule_start = true;
355
- } else {
356
- at_rule_start = false;
925
+ // Check for left recursion
926
+ std::vector<bool> rules_visited(n_rules);
927
+ std::vector<bool> rules_in_progress(n_rules);
928
+ std::vector<bool> rules_may_be_empty(n_rules);
929
+ for (size_t i = 0; i < n_rules; i++) {
930
+ if (rules_visited[i]) {
931
+ continue;
932
+ }
933
+ if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
934
+ LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
935
+ return nullptr;
357
936
  }
358
937
  }
359
938
 
360
- // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
361
- // be empty)
362
- bool recurse_into_nonterminal = true;
363
- for (size_t i = 0; i < rule.size(); i++) {
364
- if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
365
- if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
366
- return true;
367
- }
368
- if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
369
- recurse_into_nonterminal = false;
370
- }
371
- } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
372
- recurse_into_nonterminal = true;
939
+ // loop over alternates of start rule to build initial stacks
940
+ llama_grammar_stacks stacks;
941
+ pos = vec_rules[start_rule_index].data();
942
+ do {
943
+ llama_grammar_stack stack;
944
+ if (!llama_grammar_is_end_of_sequence(pos)) {
945
+ // if alternate is nonempty, add to stack
946
+ stack.push_back(pos);
947
+ }
948
+ llama_grammar_advance_stack(vec_rules, stack, stacks);
949
+ while (!llama_grammar_is_end_of_sequence(pos)) {
950
+ // scan to end of alternate def
951
+ pos++;
952
+ }
953
+ if (pos->type == LLAMA_GRETYPE_ALT) {
954
+ // there's another alternate def of this rule to process
955
+ pos++;
373
956
  } else {
374
- recurse_into_nonterminal = false;
957
+ break;
375
958
  }
376
- }
959
+ } while (true);
377
960
 
378
- (*rules_in_progress)[rule_index] = false;
379
- (*rules_visited)[rule_index] = true;
380
- return false;
961
+ // Important: vec_rules has to be moved here, not copied, because stacks contains
962
+ // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
963
+ // then the pointers would be invalidated when the local vec_rules goes out of scope.
964
+ return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
381
965
  }
382
966
 
383
- //
384
- // grammar - external
385
- //
967
+ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
968
+ llama_grammar_parser parser;
969
+
970
+ // if there is a grammar, parse it
971
+ if (!parser.parse(grammar_str)) {
972
+ return nullptr;
973
+ }
974
+
975
+ // will be empty (default) if there are parse errors
976
+ if (parser.rules.empty()) {
977
+ fprintf(stderr, "%s: failed to parse grammar\n", __func__);
978
+ return nullptr;
979
+ }
980
+
981
+ // Ensure that there is a "root" node.
982
+ if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
983
+ fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
984
+ return nullptr;
985
+ }
986
+
987
+ std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
988
+
989
+ const size_t n_rules = grammar_rules.size();
990
+ const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
386
991
 
387
- struct llama_grammar * llama_grammar_init_impl(
388
- const llama_grammar_element ** rules,
389
- size_t n_rules,
390
- size_t start_rule_index) {
391
992
  const llama_grammar_element * pos;
392
993
 
393
994
  // copy rule definitions into vectors
394
995
  llama_grammar_rules vec_rules(n_rules);
395
996
  for (size_t i = 0; i < n_rules; i++) {
396
- for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
997
+ for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
397
998
  vec_rules[i].push_back(*pos);
398
999
  }
399
1000
  vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
@@ -438,22 +1039,26 @@ struct llama_grammar * llama_grammar_init_impl(
438
1039
  // Important: vec_rules has to be moved here, not copied, because stacks contains
439
1040
  // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
440
1041
  // then the pointers would be invalidated when the local vec_rules goes out of scope.
441
- return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
1042
+ return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
442
1043
  }
443
1044
 
444
1045
  void llama_grammar_free_impl(struct llama_grammar * grammar) {
1046
+ if (grammar == nullptr) {
1047
+ return;
1048
+ }
1049
+
445
1050
  delete grammar;
446
1051
  }
447
1052
 
448
- struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
449
- llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
1053
+ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1054
+ llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
450
1055
 
451
1056
  // redirect elements in stacks to point to new rules
452
1057
  for (size_t is = 0; is < result->stacks.size(); is++) {
453
1058
  for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
454
- for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
455
- for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
456
- if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
1059
+ for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1060
+ for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1061
+ if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
457
1062
  result->stacks[is][ie] = &result->rules[ir0][ir1];
458
1063
  }
459
1064
  }
@@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
464
1069
  return result;
465
1070
  }
466
1071
 
467
- void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468
- LM_GGML_ASSERT(grammar);
469
- LM_GGML_ASSERT(vocab);
470
-
471
- int64_t t_start_sample_us = lm_ggml_time_us();
1072
+ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
1073
+ LM_GGML_ASSERT(grammar.vocab != nullptr);
472
1074
 
473
1075
  bool allow_eog = false;
474
- for (const auto & stack : grammar->stacks) {
1076
+ for (const auto & stack : grammar.stacks) {
475
1077
  if (stack.empty()) {
476
1078
  allow_eog = true;
477
1079
  break;
@@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
479
1081
  }
480
1082
 
481
1083
  std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
482
- candidates_decoded.reserve(candidates->size);
1084
+ candidates_decoded.reserve(cur_p->size);
483
1085
 
484
1086
  llama_grammar_candidates candidates_grammar;
485
- candidates_grammar.reserve(candidates->size);
1087
+ candidates_grammar.reserve(cur_p->size);
486
1088
 
487
- for (size_t i = 0; i < candidates->size; ++i) {
488
- const llama_token id = candidates->data[i].id;
489
- const std::string & piece = vocab->cache_token_to_piece.at(id);
1089
+ for (size_t i = 0; i < cur_p->size; ++i) {
1090
+ const llama_token id = cur_p->data[i].id;
1091
+ const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
490
1092
 
491
- if (llama_token_is_eog_impl(*vocab, id)) {
1093
+ if (llama_token_is_eog_impl(*grammar.vocab, id)) {
492
1094
  if (!allow_eog) {
493
- candidates->data[i].logit = -INFINITY;
1095
+ cur_p->data[i].logit = -INFINITY;
494
1096
  }
495
1097
  } else if (piece.empty() || piece[0] == 0) {
496
- candidates->data[i].logit = -INFINITY;
1098
+ cur_p->data[i].logit = -INFINITY;
497
1099
  } else {
498
- candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
1100
+ candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
499
1101
  candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
500
1102
  }
501
1103
  }
502
1104
 
503
- const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
1105
+ const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
504
1106
  for (const auto & reject : rejects) {
505
- candidates->data[reject.index].logit = -INFINITY;
1107
+ cur_p->data[reject.index].logit = -INFINITY;
506
1108
  }
507
-
508
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
509
1109
  }
510
1110
 
511
- void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
512
- const int64_t t_start_sample_us = lm_ggml_time_us();
1111
+ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
1112
+ LM_GGML_ASSERT(grammar.vocab != nullptr);
513
1113
 
514
- if (llama_token_is_eog_impl(*vocab, token)) {
515
- for (const auto & stack : grammar->stacks) {
1114
+ if (llama_token_is_eog_impl(*grammar.vocab, token)) {
1115
+ for (const auto & stack : grammar.stacks) {
516
1116
  if (stack.empty()) {
517
1117
  return;
518
1118
  }
@@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
520
1120
  LM_GGML_ABORT("fatal error");
521
1121
  }
522
1122
 
523
- const std::string & piece = vocab->cache_token_to_piece.at(token);
1123
+ const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
524
1124
 
525
1125
  // Note terminating 0 in decoded string
526
- const auto decoded = decode_utf8(piece, grammar->partial_utf8);
1126
+ const auto decoded = decode_utf8(piece, grammar.partial_utf8);
527
1127
  const auto & code_points = decoded.first;
528
1128
 
529
- llama_grammar_stacks tmp_new_stacks;
1129
+ llama_grammar_stacks stacks_new;
1130
+
530
1131
  for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
531
- llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
532
- grammar->stacks = tmp_new_stacks;
1132
+ llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
1133
+ grammar.stacks = std::move(stacks_new);
533
1134
  }
534
1135
 
535
- grammar->partial_utf8 = decoded.second;
536
- LM_GGML_ASSERT(!grammar->stacks.empty());
537
-
538
- smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
1136
+ grammar.partial_utf8 = decoded.second;
1137
+ LM_GGML_ASSERT(!grammar.stacks.empty());
539
1138
  }