cui-llama.rn 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/LICENSE +20 -0
  2. package/README.md +330 -0
  3. package/android/build.gradle +107 -0
  4. package/android/gradle.properties +5 -0
  5. package/android/src/main/AndroidManifest.xml +4 -0
  6. package/android/src/main/CMakeLists.txt +69 -0
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
  8. package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
  10. package/android/src/main/jni.cpp +635 -0
  11. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
  12. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
  13. package/cpp/README.md +4 -0
  14. package/cpp/common.cpp +3237 -0
  15. package/cpp/common.h +467 -0
  16. package/cpp/ggml-aarch64.c +2193 -0
  17. package/cpp/ggml-aarch64.h +39 -0
  18. package/cpp/ggml-alloc.c +1041 -0
  19. package/cpp/ggml-alloc.h +76 -0
  20. package/cpp/ggml-backend-impl.h +153 -0
  21. package/cpp/ggml-backend.c +2225 -0
  22. package/cpp/ggml-backend.h +236 -0
  23. package/cpp/ggml-common.h +1829 -0
  24. package/cpp/ggml-impl.h +655 -0
  25. package/cpp/ggml-metal.h +65 -0
  26. package/cpp/ggml-metal.m +3273 -0
  27. package/cpp/ggml-quants.c +15022 -0
  28. package/cpp/ggml-quants.h +132 -0
  29. package/cpp/ggml.c +22034 -0
  30. package/cpp/ggml.h +2444 -0
  31. package/cpp/grammar-parser.cpp +536 -0
  32. package/cpp/grammar-parser.h +29 -0
  33. package/cpp/json-schema-to-grammar.cpp +1045 -0
  34. package/cpp/json-schema-to-grammar.h +8 -0
  35. package/cpp/json.hpp +24766 -0
  36. package/cpp/llama.cpp +21789 -0
  37. package/cpp/llama.h +1201 -0
  38. package/cpp/log.h +737 -0
  39. package/cpp/rn-llama.hpp +630 -0
  40. package/cpp/sampling.cpp +460 -0
  41. package/cpp/sampling.h +160 -0
  42. package/cpp/sgemm.cpp +1027 -0
  43. package/cpp/sgemm.h +14 -0
  44. package/cpp/unicode-data.cpp +7032 -0
  45. package/cpp/unicode-data.h +20 -0
  46. package/cpp/unicode.cpp +812 -0
  47. package/cpp/unicode.h +64 -0
  48. package/ios/RNLlama.h +11 -0
  49. package/ios/RNLlama.mm +302 -0
  50. package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
  51. package/ios/RNLlamaContext.h +39 -0
  52. package/ios/RNLlamaContext.mm +426 -0
  53. package/jest/mock.js +169 -0
  54. package/lib/commonjs/NativeRNLlama.js +10 -0
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -0
  56. package/lib/commonjs/grammar.js +574 -0
  57. package/lib/commonjs/grammar.js.map +1 -0
  58. package/lib/commonjs/index.js +151 -0
  59. package/lib/commonjs/index.js.map +1 -0
  60. package/lib/module/NativeRNLlama.js +3 -0
  61. package/lib/module/NativeRNLlama.js.map +1 -0
  62. package/lib/module/grammar.js +566 -0
  63. package/lib/module/grammar.js.map +1 -0
  64. package/lib/module/index.js +129 -0
  65. package/lib/module/index.js.map +1 -0
  66. package/lib/typescript/NativeRNLlama.d.ts +107 -0
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
  68. package/lib/typescript/grammar.d.ts +38 -0
  69. package/lib/typescript/grammar.d.ts.map +1 -0
  70. package/lib/typescript/index.d.ts +46 -0
  71. package/lib/typescript/index.d.ts.map +1 -0
  72. package/llama-rn.podspec +56 -0
  73. package/package.json +230 -0
  74. package/src/NativeRNLlama.ts +132 -0
  75. package/src/grammar.ts +849 -0
  76. package/src/index.ts +182 -0
@@ -0,0 +1,536 @@
1
+ #include "grammar-parser.h"
2
+ #include <cstdint>
3
+ #include <cwchar>
4
+ #include <string>
5
+ #include <utility>
6
+ #include <stdexcept>
7
+ #include <exception>
8
+
9
+ namespace grammar_parser {
10
+ // NOTE: assumes valid utf8 (but checks for overrun)
11
+ // copied from llama.cpp
12
+ static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
13
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
14
+ uint8_t first_byte = static_cast<uint8_t>(*src);
15
+ uint8_t highbits = first_byte >> 4;
16
+ int len = lookup[highbits];
17
+ uint8_t mask = (1 << (8 - len)) - 1;
18
+ uint32_t value = first_byte & mask;
19
+ const char * end = src + len; // may overrun!
20
+ const char * pos = src + 1;
21
+ for ( ; pos < end && *pos; pos++) {
22
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
23
+ }
24
+ return std::make_pair(value, pos);
25
+ }
26
+
27
+ static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
28
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
29
+ auto result = state.symbol_ids.emplace(std::string(src, len), next_id);
30
+ return result.first->second;
31
+ }
32
+
33
+ static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
34
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
35
+ state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
36
+ return next_id;
37
+ }
38
+
39
+ static void add_rule(
40
+ parse_state & state,
41
+ uint32_t rule_id,
42
+ const std::vector<llama_grammar_element> & rule) {
43
+ if (state.rules.size() <= rule_id) {
44
+ state.rules.resize(rule_id + 1);
45
+ }
46
+ state.rules[rule_id] = rule;
47
+ }
48
+
49
+ static bool is_digit_char(char c) {
50
+ return '0' <= c && c <= '9';
51
+ }
52
+
53
+ static bool is_word_char(char c) {
54
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
55
+ }
56
+
57
+ static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
58
+ const char * pos = src;
59
+ const char * end = src + size;
60
+ uint32_t value = 0;
61
+ for ( ; pos < end && *pos; pos++) {
62
+ value <<= 4;
63
+ char c = *pos;
64
+ if ('a' <= c && c <= 'f') {
65
+ value += c - 'a' + 10;
66
+ } else if ('A' <= c && c <= 'F') {
67
+ value += c - 'A' + 10;
68
+ } else if ('0' <= c && c <= '9') {
69
+ value += c - '0';
70
+ } else {
71
+ break;
72
+ }
73
+ }
74
+ if (pos != end) {
75
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
76
+ }
77
+ return std::make_pair(value, pos);
78
+ }
79
+
80
+ static const char * parse_space(const char * src, bool newline_ok) {
81
+ const char * pos = src;
82
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
83
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
84
+ if (*pos == '#') {
85
+ while (*pos && *pos != '\r' && *pos != '\n') {
86
+ pos++;
87
+ }
88
+ } else {
89
+ pos++;
90
+ }
91
+ }
92
+ return pos;
93
+ }
94
+
95
+ static const char * parse_name(const char * src) {
96
+ const char * pos = src;
97
+ while (is_word_char(*pos)) {
98
+ pos++;
99
+ }
100
+ if (pos == src) {
101
+ throw std::runtime_error(std::string("expecting name at ") + src);
102
+ }
103
+ return pos;
104
+ }
105
+
106
+ static const char * parse_int(const char * src) {
107
+ const char * pos = src;
108
+ while (is_digit_char(*pos)) {
109
+ pos++;
110
+ }
111
+ if (pos == src) {
112
+ throw std::runtime_error(std::string("expecting integer at ") + src);
113
+ }
114
+ return pos;
115
+ }
116
+
117
+ static std::pair<uint32_t, const char *> parse_char(const char * src) {
118
+ if (*src == '\\') {
119
+ switch (src[1]) {
120
+ case 'x': return parse_hex(src + 2, 2);
121
+ case 'u': return parse_hex(src + 2, 4);
122
+ case 'U': return parse_hex(src + 2, 8);
123
+ case 't': return std::make_pair('\t', src + 2);
124
+ case 'r': return std::make_pair('\r', src + 2);
125
+ case 'n': return std::make_pair('\n', src + 2);
126
+ case '\\':
127
+ case '"':
128
+ case '[':
129
+ case ']':
130
+ return std::make_pair(src[1], src + 2);
131
+ default:
132
+ throw std::runtime_error(std::string("unknown escape at ") + src);
133
+ }
134
+ } else if (*src) {
135
+ return decode_utf8(src);
136
+ }
137
+ throw std::runtime_error("unexpected end of input");
138
+ }
139
+
140
+ const char * parse_alternates(
141
+ parse_state & state,
142
+ const char * src,
143
+ const std::string & rule_name,
144
+ uint32_t rule_id,
145
+ bool is_nested);
146
+
147
+ static const char * parse_sequence(
148
+ parse_state & state,
149
+ const char * src,
150
+ const std::string & rule_name,
151
+ std::vector<llama_grammar_element> & out_elements,
152
+ bool is_nested) {
153
+ size_t last_sym_start = out_elements.size();
154
+ const char * pos = src;
155
+
156
+ auto handle_repetitions = [&](int min_times, int max_times) {
157
+
158
+ if (last_sym_start == out_elements.size()) {
159
+ throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
160
+ }
161
+
162
+ // apply transformation to previous symbol (last_sym_start to end) according to
163
+ // the following rewrite rules:
164
+ // S{m,n} --> S S S (m times) S'(n-m)
165
+ // S'(x) ::= S S'(x-1) |
166
+ // (... n-m definitions of these S' rules ...)
167
+ // S'(1) ::= S |
168
+ // S{m,} --> S S S (m times) S'
169
+ // S' ::= S S' |
170
+ // S* --> S{0,}
171
+ // --> S' ::= S S' |
172
+ // S+ --> S{1,}
173
+ // --> S S'
174
+ // S' ::= S S' |
175
+ // S? --> S{0,1}
176
+ // --> S'
177
+ // S' ::= S |
178
+
179
+ std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
180
+ if (min_times == 0) {
181
+ out_elements.resize(last_sym_start);
182
+ } else {
183
+ // Repeat the previous elements (min_times - 1) times
184
+ for (int i = 1; i < min_times; i++) {
185
+ out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
186
+ }
187
+ }
188
+
189
+ uint32_t last_rec_rule_id = 0;
190
+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
191
+
192
+ std::vector<llama_grammar_element> rec_rule(previous_elements);
193
+ for (int i = 0; i < n_opt; i++) {
194
+ rec_rule.resize(previous_elements.size());
195
+ uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
196
+ if (i > 0 || max_times < 0) {
197
+ rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
198
+ }
199
+ rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
200
+ rec_rule.push_back({LLAMA_GRETYPE_END, 0});
201
+ add_rule(state, rec_rule_id, rec_rule);
202
+ last_rec_rule_id = rec_rule_id;
203
+ }
204
+ if (n_opt > 0) {
205
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
206
+ }
207
+ };
208
+
209
+ while (*pos) {
210
+ if (*pos == '"') { // literal string
211
+ pos++;
212
+ last_sym_start = out_elements.size();
213
+ while (*pos != '"') {
214
+ if (!*pos) {
215
+ throw std::runtime_error("unexpected end of input");
216
+ }
217
+ auto char_pair = parse_char(pos);
218
+ pos = char_pair.second;
219
+ out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
220
+ }
221
+ pos = parse_space(pos + 1, is_nested);
222
+ } else if (*pos == '[') { // char range(s)
223
+ pos++;
224
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
225
+ if (*pos == '^') {
226
+ pos++;
227
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
228
+ }
229
+ last_sym_start = out_elements.size();
230
+ while (*pos != ']') {
231
+ if (!*pos) {
232
+ throw std::runtime_error("unexpected end of input");
233
+ }
234
+ auto char_pair = parse_char(pos);
235
+ pos = char_pair.second;
236
+ enum llama_gretype type = last_sym_start < out_elements.size()
237
+ ? LLAMA_GRETYPE_CHAR_ALT
238
+ : start_type;
239
+
240
+ out_elements.push_back({type, char_pair.first});
241
+ if (pos[0] == '-' && pos[1] != ']') {
242
+ if (!pos[1]) {
243
+ throw std::runtime_error("unexpected end of input");
244
+ }
245
+ auto endchar_pair = parse_char(pos + 1);
246
+ pos = endchar_pair.second;
247
+ out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
248
+ }
249
+ }
250
+ pos = parse_space(pos + 1, is_nested);
251
+ } else if (is_word_char(*pos)) { // rule reference
252
+ const char * name_end = parse_name(pos);
253
+ uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
254
+ pos = parse_space(name_end, is_nested);
255
+ last_sym_start = out_elements.size();
256
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
257
+ } else if (*pos == '(') { // grouping
258
+ // parse nested alternates into synthesized rule
259
+ pos = parse_space(pos + 1, true);
260
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
261
+ pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
262
+ last_sym_start = out_elements.size();
263
+ // output reference to synthesized rule
264
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
265
+ if (*pos != ')') {
266
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
267
+ }
268
+ pos = parse_space(pos + 1, is_nested);
269
+ } else if (*pos == '.') { // any char
270
+ last_sym_start = out_elements.size();
271
+ out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
272
+ pos = parse_space(pos + 1, is_nested);
273
+ } else if (*pos == '*') {
274
+ pos = parse_space(pos + 1, is_nested);
275
+ handle_repetitions(0, -1);
276
+ } else if (*pos == '+') {
277
+ pos = parse_space(pos + 1, is_nested);
278
+ handle_repetitions(1, -1);
279
+ } else if (*pos == '?') {
280
+ pos = parse_space(pos + 1, is_nested);
281
+ handle_repetitions(0, 1);
282
+ } else if (*pos == '{') {
283
+ pos = parse_space(pos + 1, is_nested);
284
+
285
+ if (!is_digit_char(*pos)) {
286
+ throw std::runtime_error(std::string("expecting an int at ") + pos);
287
+ }
288
+ const char * int_end = parse_int(pos);
289
+ int min_times = std::stoul(std::string(pos, int_end - pos));
290
+ pos = parse_space(int_end, is_nested);
291
+
292
+ int max_times = -1;
293
+
294
+ if (*pos == '}') {
295
+ max_times = min_times;
296
+ pos = parse_space(pos + 1, is_nested);
297
+ } else if (*pos == ',') {
298
+ pos = parse_space(pos + 1, is_nested);
299
+
300
+ if (is_digit_char(*pos)) {
301
+ const char * int_end = parse_int(pos);
302
+ max_times = std::stoul(std::string(pos, int_end - pos));
303
+ pos = parse_space(int_end, is_nested);
304
+ }
305
+
306
+ if (*pos != '}') {
307
+ throw std::runtime_error(std::string("expecting '}' at ") + pos);
308
+ }
309
+ pos = parse_space(pos + 1, is_nested);
310
+ } else {
311
+ throw std::runtime_error(std::string("expecting ',' at ") + pos);
312
+ }
313
+ handle_repetitions(min_times, max_times);
314
+ } else {
315
+ break;
316
+ }
317
+ }
318
+ return pos;
319
+ }
320
+
321
+ const char * parse_alternates(
322
+ parse_state & state,
323
+ const char * src,
324
+ const std::string & rule_name,
325
+ uint32_t rule_id,
326
+ bool is_nested) {
327
+ std::vector<llama_grammar_element> rule;
328
+ const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
329
+ while (*pos == '|') {
330
+ rule.push_back({LLAMA_GRETYPE_ALT, 0});
331
+ pos = parse_space(pos + 1, true);
332
+ pos = parse_sequence(state, pos, rule_name, rule, is_nested);
333
+ }
334
+ rule.push_back({LLAMA_GRETYPE_END, 0});
335
+ add_rule(state, rule_id, rule);
336
+ return pos;
337
+ }
338
+
339
+ static const char * parse_rule(parse_state & state, const char * src) {
340
+ const char * name_end = parse_name(src);
341
+ const char * pos = parse_space(name_end, false);
342
+ size_t name_len = name_end - src;
343
+ uint32_t rule_id = get_symbol_id(state, src, name_len);
344
+ const std::string name(src, name_len);
345
+
346
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
347
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
348
+ }
349
+ pos = parse_space(pos + 3, true);
350
+
351
+ pos = parse_alternates(state, pos, name, rule_id, false);
352
+
353
+ if (*pos == '\r') {
354
+ pos += pos[1] == '\n' ? 2 : 1;
355
+ } else if (*pos == '\n') {
356
+ pos++;
357
+ } else if (*pos) {
358
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
359
+ }
360
+ return parse_space(pos, true);
361
+ }
362
+
363
+ parse_state parse(const char * src) {
364
+ try {
365
+ parse_state state;
366
+ const char * pos = parse_space(src, true);
367
+ while (*pos) {
368
+ pos = parse_rule(state, pos);
369
+ }
370
+ // Validate the state to ensure that all rules are defined
371
+ for (const auto & rule : state.rules) {
372
+ for (const auto & elem : rule) {
373
+ if (elem.type == LLAMA_GRETYPE_RULE_REF) {
374
+ // Ensure that the rule at that location exists
375
+ if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
376
+ // Get the name of the rule that is missing
377
+ for (const auto & kv : state.symbol_ids) {
378
+ if (kv.second == elem.value) {
379
+ throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
380
+ }
381
+ }
382
+ }
383
+ }
384
+ }
385
+ }
386
+ return state;
387
+ } catch (const std::exception & err) {
388
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
389
+ return parse_state();
390
+ }
391
+ }
392
+
393
+ static void print_grammar_char(FILE * file, uint32_t c) {
394
+ if (0x20 <= c && c <= 0x7f) {
395
+ fprintf(file, "%c", static_cast<char>(c));
396
+ } else {
397
+ // cop out of encoding UTF-8
398
+ fprintf(file, "<U+%04X>", c);
399
+ }
400
+ }
401
+
402
+ static bool is_char_element(llama_grammar_element elem) {
403
+ switch (elem.type) {
404
+ case LLAMA_GRETYPE_CHAR: return true;
405
+ case LLAMA_GRETYPE_CHAR_NOT: return true;
406
+ case LLAMA_GRETYPE_CHAR_ALT: return true;
407
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
408
+ case LLAMA_GRETYPE_CHAR_ANY: return true;
409
+ default: return false;
410
+ }
411
+ }
412
+
413
+ static void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
414
+ for (auto elem : rule) {
415
+ switch (elem.type) {
416
+ case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
417
+ case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
418
+ case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
419
+ case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
420
+ case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
421
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
422
+ case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
423
+ case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
424
+ }
425
+ switch (elem.type) {
426
+ case LLAMA_GRETYPE_END:
427
+ case LLAMA_GRETYPE_ALT:
428
+ case LLAMA_GRETYPE_RULE_REF:
429
+ fprintf(file, "(%u) ", elem.value);
430
+ break;
431
+ case LLAMA_GRETYPE_CHAR:
432
+ case LLAMA_GRETYPE_CHAR_NOT:
433
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
434
+ case LLAMA_GRETYPE_CHAR_ALT:
435
+ case LLAMA_GRETYPE_CHAR_ANY:
436
+ fprintf(file, "(\"");
437
+ print_grammar_char(file, elem.value);
438
+ fprintf(file, "\") ");
439
+ break;
440
+ }
441
+ }
442
+ fprintf(file, "\n");
443
+ }
444
+
445
+ static void print_rule(
446
+ FILE * file,
447
+ uint32_t rule_id,
448
+ const std::vector<llama_grammar_element> & rule,
449
+ const std::map<uint32_t, std::string> & symbol_id_names) {
450
+ if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
451
+ throw std::runtime_error(
452
+ "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
453
+ }
454
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
455
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
456
+ llama_grammar_element elem = rule[i];
457
+ switch (elem.type) {
458
+ case LLAMA_GRETYPE_END:
459
+ throw std::runtime_error(
460
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
461
+ std::to_string(i));
462
+ case LLAMA_GRETYPE_ALT:
463
+ fprintf(file, "| ");
464
+ break;
465
+ case LLAMA_GRETYPE_RULE_REF:
466
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
467
+ break;
468
+ case LLAMA_GRETYPE_CHAR:
469
+ fprintf(file, "[");
470
+ print_grammar_char(file, elem.value);
471
+ break;
472
+ case LLAMA_GRETYPE_CHAR_NOT:
473
+ fprintf(file, "[^");
474
+ print_grammar_char(file, elem.value);
475
+ break;
476
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
477
+ if (i == 0 || !is_char_element(rule[i - 1])) {
478
+ throw std::runtime_error(
479
+ "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
480
+ std::to_string(rule_id) + "," + std::to_string(i));
481
+ }
482
+ fprintf(file, "-");
483
+ print_grammar_char(file, elem.value);
484
+ break;
485
+ case LLAMA_GRETYPE_CHAR_ALT:
486
+ if (i == 0 || !is_char_element(rule[i - 1])) {
487
+ throw std::runtime_error(
488
+ "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
489
+ std::to_string(rule_id) + "," + std::to_string(i));
490
+ }
491
+ print_grammar_char(file, elem.value);
492
+ break;
493
+ case LLAMA_GRETYPE_CHAR_ANY:
494
+ fprintf(file, ".");
495
+ break;
496
+ }
497
+ if (is_char_element(elem)) {
498
+ switch (rule[i + 1].type) {
499
+ case LLAMA_GRETYPE_CHAR_ALT:
500
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
501
+ case LLAMA_GRETYPE_CHAR_ANY:
502
+ break;
503
+ default:
504
+ fprintf(file, "] ");
505
+ }
506
+ }
507
+ }
508
+ fprintf(file, "\n");
509
+ }
510
+
511
+ void print_grammar(FILE * file, const parse_state & state) {
512
+ try {
513
+ std::map<uint32_t, std::string> symbol_id_names;
514
+ for (const auto & kv : state.symbol_ids) {
515
+ symbol_id_names[kv.second] = kv.first;
516
+ }
517
+ for (size_t i = 0, end = state.rules.size(); i < end; i++) {
518
+ // fprintf(file, "%zu: ", i);
519
+ // print_rule_binary(file, state.rules[i]);
520
+ print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
521
+ // fprintf(file, "\n");
522
+ }
523
+ } catch (const std::exception & err) {
524
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
525
+ }
526
+ }
527
+
528
+ std::vector<const llama_grammar_element *> parse_state::c_rules() {
529
+ std::vector<const llama_grammar_element *> ret;
530
+ ret.reserve(rules.size());
531
+ for (const auto & rule : rules) {
532
+ ret.push_back(rule.data());
533
+ }
534
+ return ret;
535
+ }
536
+ }
@@ -0,0 +1,29 @@
1
+ // Implements a parser for an extended Backus-Naur form (BNF), producing the
2
+ // binary context-free grammar format specified by llama.h. Supports character
3
+ // ranges, grouping, and repetition operators. As an example, a grammar for
4
+ // arithmetic might look like:
5
+ //
6
+ // root ::= expr
7
+ // expr ::= term ([-+*/] term)*
8
+ // term ::= num | "(" space expr ")" space
9
+ // num ::= [0-9]+ space
10
+ // space ::= [ \t\n]*
11
+
12
+ #pragma once
13
+ #include "llama.h"
14
+ #include <vector>
15
+ #include <map>
16
+ #include <cstdint>
17
+ #include <string>
18
+
19
+ namespace grammar_parser {
20
+ struct parse_state {
21
+ std::map<std::string, uint32_t> symbol_ids;
22
+ std::vector<std::vector<llama_grammar_element>> rules;
23
+
24
+ std::vector<const llama_grammar_element *> c_rules();
25
+ };
26
+
27
+ parse_state parse(const char * src);
28
+ void print_grammar(FILE * file, const parse_state & state);
29
+ }