cui-llama.rn 1.0.3 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/README.md +35 -39
  2. package/android/src/main/CMakeLists.txt +12 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +29 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
  5. package/android/src/main/jni.cpp +62 -8
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  8. package/cpp/common.cpp +3237 -3231
  9. package/cpp/common.h +469 -468
  10. package/cpp/ggml-aarch64.c +2193 -2193
  11. package/cpp/ggml-aarch64.h +39 -39
  12. package/cpp/ggml-alloc.c +1036 -1042
  13. package/cpp/ggml-backend-impl.h +153 -153
  14. package/cpp/ggml-backend.c +2240 -2234
  15. package/cpp/ggml-backend.h +238 -238
  16. package/cpp/ggml-common.h +1833 -1829
  17. package/cpp/ggml-impl.h +755 -655
  18. package/cpp/ggml-metal.h +65 -65
  19. package/cpp/ggml-metal.m +3269 -3269
  20. package/cpp/ggml-quants.c +14872 -14860
  21. package/cpp/ggml-quants.h +132 -132
  22. package/cpp/ggml.c +22055 -22044
  23. package/cpp/ggml.h +2453 -2447
  24. package/cpp/llama-grammar.cpp +539 -0
  25. package/cpp/llama-grammar.h +39 -0
  26. package/cpp/llama-impl.h +26 -0
  27. package/cpp/llama-sampling.cpp +635 -0
  28. package/cpp/llama-sampling.h +56 -0
  29. package/cpp/llama-vocab.cpp +1721 -0
  30. package/cpp/llama-vocab.h +130 -0
  31. package/cpp/llama.cpp +19171 -21892
  32. package/cpp/llama.h +1240 -1217
  33. package/cpp/log.h +737 -737
  34. package/cpp/rn-llama.hpp +207 -29
  35. package/cpp/sampling.cpp +460 -460
  36. package/cpp/sgemm.cpp +1027 -1027
  37. package/cpp/sgemm.h +14 -14
  38. package/cpp/unicode.cpp +6 -0
  39. package/cpp/unicode.h +3 -0
  40. package/ios/RNLlama.mm +15 -6
  41. package/ios/RNLlamaContext.h +2 -8
  42. package/ios/RNLlamaContext.mm +41 -34
  43. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  44. package/lib/commonjs/chat.js +37 -0
  45. package/lib/commonjs/chat.js.map +1 -0
  46. package/lib/commonjs/index.js +14 -1
  47. package/lib/commonjs/index.js.map +1 -1
  48. package/lib/module/NativeRNLlama.js.map +1 -1
  49. package/lib/module/chat.js +31 -0
  50. package/lib/module/chat.js.map +1 -0
  51. package/lib/module/index.js +14 -1
  52. package/lib/module/index.js.map +1 -1
  53. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  54. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  55. package/lib/typescript/chat.d.ts +10 -0
  56. package/lib/typescript/chat.d.ts.map +1 -0
  57. package/lib/typescript/index.d.ts +9 -2
  58. package/lib/typescript/index.d.ts.map +1 -1
  59. package/package.json +1 -1
  60. package/src/NativeRNLlama.ts +10 -1
  61. package/src/chat.ts +44 -0
  62. package/src/index.ts +31 -4
@@ -0,0 +1,539 @@
1
+ #include "llama-grammar.h"
2
+
3
+ #include "llama-vocab.h"
4
+ #include "llama-sampling.h"
5
+
6
+ #include <algorithm>
7
+
8
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
9
+ // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
10
+ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
11
+ const std::string & src,
12
+ llama_partial_utf8 partial_start) {
13
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
14
+ const char * pos = src.c_str();
15
+ std::vector<uint32_t> code_points;
16
+
17
+ // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
18
+ code_points.reserve(src.size() + 1);
19
+ uint32_t value = partial_start.value;
20
+ int n_remain = partial_start.n_remain;
21
+
22
+ // continue previous decode, if applicable
23
+ while (*pos != 0 && n_remain > 0) {
24
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
25
+ if ((next_byte >> 6) != 2) {
26
+ // invalid sequence, abort
27
+ code_points.push_back(0);
28
+ return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
29
+ }
30
+ value = (value << 6) + (next_byte & 0x3F);
31
+ ++pos;
32
+ --n_remain;
33
+ }
34
+
35
+ if (partial_start.n_remain > 0 && n_remain == 0) {
36
+ code_points.push_back(value);
37
+ }
38
+
39
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
40
+ while (*pos != 0) {
41
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
42
+ uint8_t highbits = first_byte >> 4;
43
+ n_remain = lookup[highbits] - 1;
44
+
45
+ if (n_remain < 0) {
46
+ // invalid sequence, abort
47
+ code_points.clear();
48
+ code_points.push_back(0);
49
+ return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
50
+ }
51
+
52
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
53
+ value = first_byte & mask;
54
+
55
+ ++pos;
56
+ while (*pos != 0 && n_remain > 0) {
57
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
58
+ ++pos;
59
+ --n_remain;
60
+ }
61
+ if (n_remain == 0) {
62
+ code_points.push_back(value);
63
+ }
64
+ }
65
+ code_points.push_back(0);
66
+
67
+ return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
68
+ }
69
+
70
+ const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
71
+ return grammar->rules;
72
+ }
73
+
74
+ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
75
+ return grammar->stacks;
76
+ }
77
+
78
+ // returns true iff pos points to the end of one of the definitions of a rule
79
+ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
80
+ switch (pos->type) {
81
+ case LLAMA_GRETYPE_END: return true; // NOLINT
82
+ case LLAMA_GRETYPE_ALT: return true; // NOLINT
83
+ default: return false;
84
+ }
85
+ }
86
+
87
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
88
+ // asserts that pos is pointing to a char range element
89
+ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
90
+ const llama_grammar_element * pos,
91
+ const uint32_t chr) {
92
+
93
+ bool found = false;
94
+ bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
95
+
96
+ LM_GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
97
+
98
+ do {
99
+ if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
100
+ // inclusive range, e.g. [a-z]
101
+ found = found || (pos->value <= chr && chr <= pos[1].value);
102
+ pos += 2;
103
+ } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
104
+ // Any character matches "."
105
+ found = true;
106
+ pos += 1;
107
+ } else {
108
+ // exact char match, e.g. [a] or "a"
109
+ found = found || pos->value == chr;
110
+ pos += 1;
111
+ }
112
+ } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
113
+
114
+ return std::make_pair(found == is_positive_char, pos);
115
+ }
116
+
117
+ // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
118
+ // range at pos (regular or inverse range)
119
+ // asserts that pos is pointing to a char range element
120
+ static bool llama_grammar_match_partial_char(
121
+ const llama_grammar_element * pos,
122
+ const llama_partial_utf8 partial_utf8) {
123
+ bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
124
+ LM_GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
125
+
126
+ uint32_t partial_value = partial_utf8.value;
127
+ int n_remain = partial_utf8.n_remain;
128
+
129
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
130
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
131
+ return false;
132
+ }
133
+
134
+ // range of possible code points this partial UTF-8 sequence could complete to
135
+ uint32_t low = partial_value << (n_remain * 6);
136
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
137
+
138
+ if (low == 0) {
139
+ if (n_remain == 2) {
140
+ low = 1 << 11;
141
+ } else if (n_remain == 3) {
142
+ low = 1 << 16;
143
+ }
144
+ }
145
+
146
+ do {
147
+ if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
148
+ // inclusive range, e.g. [a-z]
149
+ if (pos->value <= high && low <= pos[1].value) {
150
+ return is_positive_char;
151
+ }
152
+ pos += 2;
153
+ } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
154
+ // Any character matches "."
155
+ return true;
156
+ } else {
157
+ // exact char match, e.g. [a] or "a"
158
+ if (low <= pos->value && pos->value <= high) {
159
+ return is_positive_char;
160
+ }
161
+ pos += 1;
162
+ }
163
+ } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
164
+
165
+ return !is_positive_char;
166
+ }
167
+
168
+ // transforms a grammar pushdown stack into N possible stacks, all ending
169
+ // at a character range (terminal element)
170
+ static void llama_grammar_advance_stack(
171
+ const llama_grammar_rules & rules,
172
+ const llama_grammar_stack & stack,
173
+ llama_grammar_stacks & new_stacks) {
174
+ if (stack.empty()) {
175
+ if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
176
+ new_stacks.emplace_back(stack);
177
+ }
178
+ return;
179
+ }
180
+
181
+ const llama_grammar_element * pos = stack.back();
182
+
183
+ switch (pos->type) {
184
+ case LLAMA_GRETYPE_RULE_REF: {
185
+ const size_t rule_id = static_cast<size_t>(pos->value);
186
+ const llama_grammar_element * subpos = rules[rule_id].data();
187
+ do {
188
+ // init new stack without the top (pos)
189
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
190
+ if (!llama_grammar_is_end_of_sequence(pos + 1)) {
191
+ // if this rule ref is followed by another element, add that to stack
192
+ new_stack.push_back(pos + 1);
193
+ }
194
+ if (!llama_grammar_is_end_of_sequence(subpos)) {
195
+ // if alternate is nonempty, add to stack
196
+ new_stack.push_back(subpos);
197
+ }
198
+ llama_grammar_advance_stack(rules, new_stack, new_stacks);
199
+ while (!llama_grammar_is_end_of_sequence(subpos)) {
200
+ // scan to end of alternate def
201
+ subpos++;
202
+ }
203
+ if (subpos->type == LLAMA_GRETYPE_ALT) {
204
+ // there's another alternate def of this rule to process
205
+ subpos++;
206
+ } else {
207
+ break;
208
+ }
209
+ } while (true);
210
+ break;
211
+ }
212
+ case LLAMA_GRETYPE_CHAR:
213
+ case LLAMA_GRETYPE_CHAR_NOT:
214
+ case LLAMA_GRETYPE_CHAR_ANY:
215
+ if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
216
+ // only add the stack if it's not a duplicate of one we already have
217
+ new_stacks.emplace_back(stack);
218
+ }
219
+ break;
220
+ default:
221
+ // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
222
+ // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
223
+ // those
224
+ LM_GGML_ABORT("fatal error");
225
+ }
226
+ }
227
+
228
+ // takes a set of possible pushdown stacks on a grammar, which are required to
229
+ // be positioned at a character range (see `llama_grammar_advance_stack`), and
230
+ // produces the N possible stacks if the given char is accepted at those
231
+ // positions
232
+ void llama_grammar_accept(
233
+ const llama_grammar_rules & rules,
234
+ const llama_grammar_stacks & stacks,
235
+ const uint32_t chr,
236
+ llama_grammar_stacks & new_stacks) {
237
+ new_stacks.clear();
238
+
239
+ for (const auto & stack : stacks) {
240
+ if (stack.empty()) {
241
+ continue;
242
+ }
243
+
244
+ auto match = llama_grammar_match_char(stack.back(), chr);
245
+ if (match.first) {
246
+ const llama_grammar_element * pos = match.second;
247
+
248
+ // update top of stack to next element, if any
249
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
250
+ if (!llama_grammar_is_end_of_sequence(pos)) {
251
+ new_stack.push_back(pos);
252
+ }
253
+ llama_grammar_advance_stack(rules, new_stack, new_stacks);
254
+ }
255
+ }
256
+ }
257
+
258
+ static llama_grammar_candidates llama_grammar_reject_candidates(
259
+ const llama_grammar_rules & rules,
260
+ const llama_grammar_stacks & stacks,
261
+ const llama_grammar_candidates & candidates) {
262
+ LM_GGML_ASSERT(!stacks.empty()); // REVIEW
263
+
264
+ if (candidates.empty()) {
265
+ return {};
266
+ }
267
+
268
+ auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
269
+
270
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
271
+ rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
272
+ }
273
+ return rejects;
274
+ }
275
+
276
+ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
277
+ const llama_grammar_rules & rules,
278
+ const llama_grammar_stack & stack,
279
+ const llama_grammar_candidates & candidates) {
280
+
281
+ llama_grammar_candidates rejects;
282
+ rejects.reserve(candidates.size());
283
+
284
+ if (stack.empty()) {
285
+ for (const auto & tok : candidates) {
286
+ if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
287
+ rejects.push_back(tok);
288
+ }
289
+ }
290
+ return rejects;
291
+ }
292
+
293
+ const llama_grammar_element * stack_pos = stack.back();
294
+
295
+ llama_grammar_candidates next_candidates;
296
+ next_candidates.reserve(candidates.size());
297
+
298
+ for (const auto & tok : candidates) {
299
+ if (*tok.code_points == 0) {
300
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
301
+ // that cannot satisfy this position in grammar
302
+ if (tok.partial_utf8.n_remain != 0 &&
303
+ !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
304
+ rejects.push_back(tok);
305
+ }
306
+ } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
307
+ next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
308
+ } else {
309
+ rejects.push_back(tok);
310
+ }
311
+ }
312
+
313
+ const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
314
+
315
+ // update top of stack to next element, if any
316
+ llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
317
+ if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
318
+ stack_after.push_back(stack_pos_after);
319
+ }
320
+ llama_grammar_stacks next_stacks;
321
+ llama_grammar_advance_stack(rules, stack_after, next_stacks);
322
+
323
+ auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
324
+ for (const auto & tok : next_rejects) {
325
+ rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
326
+ }
327
+
328
+ return rejects;
329
+ }
330
+
331
+ static bool llama_grammar_detect_left_recursion(
332
+ const llama_grammar_rules & rules,
333
+ size_t rule_index,
334
+ std::vector<bool> * rules_visited,
335
+ std::vector<bool> * rules_in_progress,
336
+ std::vector<bool> * rules_may_be_empty) {
337
+ if ((*rules_in_progress)[rule_index]) {
338
+ return true;
339
+ }
340
+
341
+ (*rules_in_progress)[rule_index] = true;
342
+
343
+ const llama_grammar_rule & rule = rules[rule_index];
344
+
345
+ // First check if the rule might produce the empty string. This could be done combined with the second
346
+ // step but it's more readable as two steps.
347
+ bool at_rule_start = true;
348
+ for (size_t i = 0; i < rule.size(); i++) {
349
+ if (llama_grammar_is_end_of_sequence(&rule[i])) {
350
+ if (at_rule_start) {
351
+ (*rules_may_be_empty)[rule_index] = true;
352
+ break;
353
+ }
354
+ at_rule_start = true;
355
+ } else {
356
+ at_rule_start = false;
357
+ }
358
+ }
359
+
360
+ // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
361
+ // be empty)
362
+ bool recurse_into_nonterminal = true;
363
+ for (size_t i = 0; i < rule.size(); i++) {
364
+ if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
365
+ if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
366
+ return true;
367
+ }
368
+ if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
369
+ recurse_into_nonterminal = false;
370
+ }
371
+ } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
372
+ recurse_into_nonterminal = true;
373
+ } else {
374
+ recurse_into_nonterminal = false;
375
+ }
376
+ }
377
+
378
+ (*rules_in_progress)[rule_index] = false;
379
+ (*rules_visited)[rule_index] = true;
380
+ return false;
381
+ }
382
+
383
+ //
384
+ // grammar - external
385
+ //
386
+
387
+ struct llama_grammar * llama_grammar_init_impl(
388
+ const llama_grammar_element ** rules,
389
+ size_t n_rules,
390
+ size_t start_rule_index) {
391
+ const llama_grammar_element * pos;
392
+
393
+ // copy rule definitions into vectors
394
+ llama_grammar_rules vec_rules(n_rules);
395
+ for (size_t i = 0; i < n_rules; i++) {
396
+ for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
397
+ vec_rules[i].push_back(*pos);
398
+ }
399
+ vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
400
+ }
401
+
402
+ // Check for left recursion
403
+ std::vector<bool> rules_visited(n_rules);
404
+ std::vector<bool> rules_in_progress(n_rules);
405
+ std::vector<bool> rules_may_be_empty(n_rules);
406
+ for (size_t i = 0; i < n_rules; i++) {
407
+ if (rules_visited[i]) {
408
+ continue;
409
+ }
410
+ if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
411
+ LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
412
+ return nullptr;
413
+ }
414
+ }
415
+
416
+ // loop over alternates of start rule to build initial stacks
417
+ llama_grammar_stacks stacks;
418
+ pos = vec_rules[start_rule_index].data();
419
+ do {
420
+ llama_grammar_stack stack;
421
+ if (!llama_grammar_is_end_of_sequence(pos)) {
422
+ // if alternate is nonempty, add to stack
423
+ stack.push_back(pos);
424
+ }
425
+ llama_grammar_advance_stack(vec_rules, stack, stacks);
426
+ while (!llama_grammar_is_end_of_sequence(pos)) {
427
+ // scan to end of alternate def
428
+ pos++;
429
+ }
430
+ if (pos->type == LLAMA_GRETYPE_ALT) {
431
+ // there's another alternate def of this rule to process
432
+ pos++;
433
+ } else {
434
+ break;
435
+ }
436
+ } while (true);
437
+
438
+ // Important: vec_rules has to be moved here, not copied, because stacks contains
439
+ // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
440
+ // then the pointers would be invalidated when the local vec_rules goes out of scope.
441
+ return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
442
+ }
443
+
444
+ void llama_grammar_free_impl(struct llama_grammar * grammar) {
445
+ delete grammar;
446
+ }
447
+
448
+ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
449
+ llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
450
+
451
+ // redirect elements in stacks to point to new rules
452
+ for (size_t is = 0; is < result->stacks.size(); is++) {
453
+ for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
454
+ for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
455
+ for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
456
+ if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
457
+ result->stacks[is][ie] = &result->rules[ir0][ir1];
458
+ }
459
+ }
460
+ }
461
+ }
462
+ }
463
+
464
+ return result;
465
+ }
466
+
467
+ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468
+ LM_GGML_ASSERT(grammar);
469
+ LM_GGML_ASSERT(vocab);
470
+
471
+ int64_t t_start_sample_us = lm_ggml_time_us();
472
+
473
+ bool allow_eog = false;
474
+ for (const auto & stack : grammar->stacks) {
475
+ if (stack.empty()) {
476
+ allow_eog = true;
477
+ break;
478
+ }
479
+ }
480
+
481
+ std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
482
+ candidates_decoded.reserve(candidates->size);
483
+
484
+ llama_grammar_candidates candidates_grammar;
485
+ candidates_grammar.reserve(candidates->size);
486
+
487
+ for (size_t i = 0; i < candidates->size; ++i) {
488
+ const llama_token id = candidates->data[i].id;
489
+ const std::string & piece = vocab->cache_token_to_piece.at(id);
490
+
491
+ if (llama_token_is_eog_impl(*vocab, id)) {
492
+ if (!allow_eog) {
493
+ candidates->data[i].logit = -INFINITY;
494
+ }
495
+ } else if (piece.empty() || piece[0] == 0) {
496
+ candidates->data[i].logit = -INFINITY;
497
+ } else {
498
+ candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
499
+ candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
500
+ }
501
+ }
502
+
503
+ const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
504
+ for (const auto & reject : rejects) {
505
+ candidates->data[reject.index].logit = -INFINITY;
506
+ }
507
+
508
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
509
+ }
510
+
511
+ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
512
+ const int64_t t_start_sample_us = lm_ggml_time_us();
513
+
514
+ if (llama_token_is_eog_impl(*vocab, token)) {
515
+ for (const auto & stack : grammar->stacks) {
516
+ if (stack.empty()) {
517
+ return;
518
+ }
519
+ }
520
+ LM_GGML_ABORT("fatal error");
521
+ }
522
+
523
+ const std::string & piece = vocab->cache_token_to_piece.at(token);
524
+
525
+ // Note terminating 0 in decoded string
526
+ const auto decoded = decode_utf8(piece, grammar->partial_utf8);
527
+ const auto & code_points = decoded.first;
528
+
529
+ llama_grammar_stacks tmp_new_stacks;
530
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
531
+ llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
532
+ grammar->stacks = tmp_new_stacks;
533
+ }
534
+
535
+ grammar->partial_utf8 = decoded.second;
536
+ LM_GGML_ASSERT(!grammar->stacks.empty());
537
+
538
+ smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
539
+ }
@@ -0,0 +1,39 @@
1
+ #pragma once
2
+
3
+ #include "llama-impl.h"
4
+
5
+ struct llama_vocab;
6
+ struct llama_sampling;
7
+
8
+ struct llama_grammar {
9
+ const llama_grammar_rules rules;
10
+ llama_grammar_stacks stacks;
11
+
12
+ // buffer for partially generated UTF-8 sequence from accepted tokens
13
+ llama_partial_utf8 partial_utf8;
14
+ };
15
+
16
+ //
17
+ // internal API
18
+ //
19
+
20
+ struct llama_grammar * llama_grammar_init_impl(
21
+ const llama_grammar_element ** rules,
22
+ size_t n_rules,
23
+ size_t start_rule_index);
24
+
25
+ void llama_grammar_free_impl(struct llama_grammar * grammar);
26
+
27
+ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
28
+
29
+ void llama_grammar_sample_impl(
30
+ const struct llama_grammar * grammar,
31
+ const struct llama_vocab * vocab,
32
+ const struct llama_sampling * smpl,
33
+ llama_token_data_array * candidates);
34
+
35
+ void llama_grammar_accept_token_impl(
36
+ struct llama_grammar * grammar,
37
+ const struct llama_vocab * vocab,
38
+ const struct llama_sampling * smpl,
39
+ llama_token token);
@@ -0,0 +1,26 @@
1
+ #pragma once
2
+
3
+ #define LLAMA_API_INTERNAL
4
+ #include "llama.h"
5
+
6
+ #ifdef __GNUC__
7
+ #ifdef __MINGW32__
8
+ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
9
+ #else
10
+ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
11
+ #endif
12
+ #else
13
+ #define LLAMA_ATTRIBUTE_FORMAT(...)
14
+ #endif
15
+
16
+ //
17
+ // logging
18
+ //
19
+
20
+ LLAMA_ATTRIBUTE_FORMAT(2, 3)
21
+ void llama_log_internal (lm_ggml_log_level level, const char * format, ...);
22
+ void llama_log_callback_default(lm_ggml_log_level level, const char * text, void * user_data);
23
+
24
+ #define LLAMA_LOG_INFO(...) llama_log_internal(LM_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
25
+ #define LLAMA_LOG_WARN(...) llama_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
26
+ #define LLAMA_LOG_ERROR(...) llama_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)