cui-llama.rn 1.0.3 → 1.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -39
- package/android/src/main/CMakeLists.txt +12 -2
- package/android/src/main/java/com/rnllama/LlamaContext.java +29 -9
- package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
- package/android/src/main/jni.cpp +62 -8
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
- package/cpp/common.cpp +3237 -3231
- package/cpp/common.h +469 -468
- package/cpp/ggml-aarch64.c +2193 -2193
- package/cpp/ggml-aarch64.h +39 -39
- package/cpp/ggml-alloc.c +1036 -1042
- package/cpp/ggml-backend-impl.h +153 -153
- package/cpp/ggml-backend.c +2240 -2234
- package/cpp/ggml-backend.h +238 -238
- package/cpp/ggml-common.h +1833 -1829
- package/cpp/ggml-impl.h +755 -655
- package/cpp/ggml-metal.h +65 -65
- package/cpp/ggml-metal.m +3269 -3269
- package/cpp/ggml-quants.c +14872 -14860
- package/cpp/ggml-quants.h +132 -132
- package/cpp/ggml.c +22055 -22044
- package/cpp/ggml.h +2453 -2447
- package/cpp/llama-grammar.cpp +539 -0
- package/cpp/llama-grammar.h +39 -0
- package/cpp/llama-impl.h +26 -0
- package/cpp/llama-sampling.cpp +635 -0
- package/cpp/llama-sampling.h +56 -0
- package/cpp/llama-vocab.cpp +1721 -0
- package/cpp/llama-vocab.h +130 -0
- package/cpp/llama.cpp +19171 -21892
- package/cpp/llama.h +1240 -1217
- package/cpp/log.h +737 -737
- package/cpp/rn-llama.hpp +207 -29
- package/cpp/sampling.cpp +460 -460
- package/cpp/sgemm.cpp +1027 -1027
- package/cpp/sgemm.h +14 -14
- package/cpp/unicode.cpp +6 -0
- package/cpp/unicode.h +3 -0
- package/ios/RNLlama.mm +15 -6
- package/ios/RNLlamaContext.h +2 -8
- package/ios/RNLlamaContext.mm +41 -34
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/chat.js +37 -0
- package/lib/commonjs/chat.js.map +1 -0
- package/lib/commonjs/index.js +14 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/chat.js +31 -0
- package/lib/module/chat.js.map +1 -0
- package/lib/module/index.js +14 -1
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +5 -1
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/chat.d.ts +10 -0
- package/lib/typescript/chat.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +9 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +10 -1
- package/src/chat.ts +44 -0
- package/src/index.ts +31 -4
@@ -0,0 +1,539 @@
|
|
1
|
+
#include "llama-grammar.h"
|
2
|
+
|
3
|
+
#include "llama-vocab.h"
|
4
|
+
#include "llama-sampling.h"
|
5
|
+
|
6
|
+
#include <algorithm>
|
7
|
+
|
8
|
+
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
9
|
+
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
10
|
+
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
11
|
+
const std::string & src,
|
12
|
+
llama_partial_utf8 partial_start) {
|
13
|
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
14
|
+
const char * pos = src.c_str();
|
15
|
+
std::vector<uint32_t> code_points;
|
16
|
+
|
17
|
+
// common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
|
18
|
+
code_points.reserve(src.size() + 1);
|
19
|
+
uint32_t value = partial_start.value;
|
20
|
+
int n_remain = partial_start.n_remain;
|
21
|
+
|
22
|
+
// continue previous decode, if applicable
|
23
|
+
while (*pos != 0 && n_remain > 0) {
|
24
|
+
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
25
|
+
if ((next_byte >> 6) != 2) {
|
26
|
+
// invalid sequence, abort
|
27
|
+
code_points.push_back(0);
|
28
|
+
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
|
29
|
+
}
|
30
|
+
value = (value << 6) + (next_byte & 0x3F);
|
31
|
+
++pos;
|
32
|
+
--n_remain;
|
33
|
+
}
|
34
|
+
|
35
|
+
if (partial_start.n_remain > 0 && n_remain == 0) {
|
36
|
+
code_points.push_back(value);
|
37
|
+
}
|
38
|
+
|
39
|
+
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
40
|
+
while (*pos != 0) {
|
41
|
+
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
42
|
+
uint8_t highbits = first_byte >> 4;
|
43
|
+
n_remain = lookup[highbits] - 1;
|
44
|
+
|
45
|
+
if (n_remain < 0) {
|
46
|
+
// invalid sequence, abort
|
47
|
+
code_points.clear();
|
48
|
+
code_points.push_back(0);
|
49
|
+
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
|
50
|
+
}
|
51
|
+
|
52
|
+
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
53
|
+
value = first_byte & mask;
|
54
|
+
|
55
|
+
++pos;
|
56
|
+
while (*pos != 0 && n_remain > 0) {
|
57
|
+
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
58
|
+
++pos;
|
59
|
+
--n_remain;
|
60
|
+
}
|
61
|
+
if (n_remain == 0) {
|
62
|
+
code_points.push_back(value);
|
63
|
+
}
|
64
|
+
}
|
65
|
+
code_points.push_back(0);
|
66
|
+
|
67
|
+
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
68
|
+
}
|
69
|
+
|
70
|
+
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
71
|
+
return grammar->rules;
|
72
|
+
}
|
73
|
+
|
74
|
+
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
75
|
+
return grammar->stacks;
|
76
|
+
}
|
77
|
+
|
78
|
+
// returns true iff pos points to the end of one of the definitions of a rule
|
79
|
+
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
|
80
|
+
switch (pos->type) {
|
81
|
+
case LLAMA_GRETYPE_END: return true; // NOLINT
|
82
|
+
case LLAMA_GRETYPE_ALT: return true; // NOLINT
|
83
|
+
default: return false;
|
84
|
+
}
|
85
|
+
}
|
86
|
+
|
87
|
+
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
88
|
+
// asserts that pos is pointing to a char range element
|
89
|
+
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
90
|
+
const llama_grammar_element * pos,
|
91
|
+
const uint32_t chr) {
|
92
|
+
|
93
|
+
bool found = false;
|
94
|
+
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
95
|
+
|
96
|
+
LM_GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
|
97
|
+
|
98
|
+
do {
|
99
|
+
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
100
|
+
// inclusive range, e.g. [a-z]
|
101
|
+
found = found || (pos->value <= chr && chr <= pos[1].value);
|
102
|
+
pos += 2;
|
103
|
+
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
|
104
|
+
// Any character matches "."
|
105
|
+
found = true;
|
106
|
+
pos += 1;
|
107
|
+
} else {
|
108
|
+
// exact char match, e.g. [a] or "a"
|
109
|
+
found = found || pos->value == chr;
|
110
|
+
pos += 1;
|
111
|
+
}
|
112
|
+
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
113
|
+
|
114
|
+
return std::make_pair(found == is_positive_char, pos);
|
115
|
+
}
|
116
|
+
|
117
|
+
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
|
118
|
+
// range at pos (regular or inverse range)
|
119
|
+
// asserts that pos is pointing to a char range element
|
120
|
+
static bool llama_grammar_match_partial_char(
|
121
|
+
const llama_grammar_element * pos,
|
122
|
+
const llama_partial_utf8 partial_utf8) {
|
123
|
+
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
124
|
+
LM_GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
|
125
|
+
|
126
|
+
uint32_t partial_value = partial_utf8.value;
|
127
|
+
int n_remain = partial_utf8.n_remain;
|
128
|
+
|
129
|
+
// invalid sequence or 7-bit char split across 2 bytes (overlong)
|
130
|
+
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
|
131
|
+
return false;
|
132
|
+
}
|
133
|
+
|
134
|
+
// range of possible code points this partial UTF-8 sequence could complete to
|
135
|
+
uint32_t low = partial_value << (n_remain * 6);
|
136
|
+
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
|
137
|
+
|
138
|
+
if (low == 0) {
|
139
|
+
if (n_remain == 2) {
|
140
|
+
low = 1 << 11;
|
141
|
+
} else if (n_remain == 3) {
|
142
|
+
low = 1 << 16;
|
143
|
+
}
|
144
|
+
}
|
145
|
+
|
146
|
+
do {
|
147
|
+
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
148
|
+
// inclusive range, e.g. [a-z]
|
149
|
+
if (pos->value <= high && low <= pos[1].value) {
|
150
|
+
return is_positive_char;
|
151
|
+
}
|
152
|
+
pos += 2;
|
153
|
+
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
|
154
|
+
// Any character matches "."
|
155
|
+
return true;
|
156
|
+
} else {
|
157
|
+
// exact char match, e.g. [a] or "a"
|
158
|
+
if (low <= pos->value && pos->value <= high) {
|
159
|
+
return is_positive_char;
|
160
|
+
}
|
161
|
+
pos += 1;
|
162
|
+
}
|
163
|
+
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
164
|
+
|
165
|
+
return !is_positive_char;
|
166
|
+
}
|
167
|
+
|
168
|
+
// transforms a grammar pushdown stack into N possible stacks, all ending
|
169
|
+
// at a character range (terminal element)
|
170
|
+
static void llama_grammar_advance_stack(
|
171
|
+
const llama_grammar_rules & rules,
|
172
|
+
const llama_grammar_stack & stack,
|
173
|
+
llama_grammar_stacks & new_stacks) {
|
174
|
+
if (stack.empty()) {
|
175
|
+
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
176
|
+
new_stacks.emplace_back(stack);
|
177
|
+
}
|
178
|
+
return;
|
179
|
+
}
|
180
|
+
|
181
|
+
const llama_grammar_element * pos = stack.back();
|
182
|
+
|
183
|
+
switch (pos->type) {
|
184
|
+
case LLAMA_GRETYPE_RULE_REF: {
|
185
|
+
const size_t rule_id = static_cast<size_t>(pos->value);
|
186
|
+
const llama_grammar_element * subpos = rules[rule_id].data();
|
187
|
+
do {
|
188
|
+
// init new stack without the top (pos)
|
189
|
+
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
190
|
+
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
191
|
+
// if this rule ref is followed by another element, add that to stack
|
192
|
+
new_stack.push_back(pos + 1);
|
193
|
+
}
|
194
|
+
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
195
|
+
// if alternate is nonempty, add to stack
|
196
|
+
new_stack.push_back(subpos);
|
197
|
+
}
|
198
|
+
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
199
|
+
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
200
|
+
// scan to end of alternate def
|
201
|
+
subpos++;
|
202
|
+
}
|
203
|
+
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
204
|
+
// there's another alternate def of this rule to process
|
205
|
+
subpos++;
|
206
|
+
} else {
|
207
|
+
break;
|
208
|
+
}
|
209
|
+
} while (true);
|
210
|
+
break;
|
211
|
+
}
|
212
|
+
case LLAMA_GRETYPE_CHAR:
|
213
|
+
case LLAMA_GRETYPE_CHAR_NOT:
|
214
|
+
case LLAMA_GRETYPE_CHAR_ANY:
|
215
|
+
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
216
|
+
// only add the stack if it's not a duplicate of one we already have
|
217
|
+
new_stacks.emplace_back(stack);
|
218
|
+
}
|
219
|
+
break;
|
220
|
+
default:
|
221
|
+
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
222
|
+
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
223
|
+
// those
|
224
|
+
LM_GGML_ABORT("fatal error");
|
225
|
+
}
|
226
|
+
}
|
227
|
+
|
228
|
+
// takes a set of possible pushdown stacks on a grammar, which are required to
|
229
|
+
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
230
|
+
// produces the N possible stacks if the given char is accepted at those
|
231
|
+
// positions
|
232
|
+
void llama_grammar_accept(
|
233
|
+
const llama_grammar_rules & rules,
|
234
|
+
const llama_grammar_stacks & stacks,
|
235
|
+
const uint32_t chr,
|
236
|
+
llama_grammar_stacks & new_stacks) {
|
237
|
+
new_stacks.clear();
|
238
|
+
|
239
|
+
for (const auto & stack : stacks) {
|
240
|
+
if (stack.empty()) {
|
241
|
+
continue;
|
242
|
+
}
|
243
|
+
|
244
|
+
auto match = llama_grammar_match_char(stack.back(), chr);
|
245
|
+
if (match.first) {
|
246
|
+
const llama_grammar_element * pos = match.second;
|
247
|
+
|
248
|
+
// update top of stack to next element, if any
|
249
|
+
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
250
|
+
if (!llama_grammar_is_end_of_sequence(pos)) {
|
251
|
+
new_stack.push_back(pos);
|
252
|
+
}
|
253
|
+
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
254
|
+
}
|
255
|
+
}
|
256
|
+
}
|
257
|
+
|
258
|
+
static llama_grammar_candidates llama_grammar_reject_candidates(
|
259
|
+
const llama_grammar_rules & rules,
|
260
|
+
const llama_grammar_stacks & stacks,
|
261
|
+
const llama_grammar_candidates & candidates) {
|
262
|
+
LM_GGML_ASSERT(!stacks.empty()); // REVIEW
|
263
|
+
|
264
|
+
if (candidates.empty()) {
|
265
|
+
return {};
|
266
|
+
}
|
267
|
+
|
268
|
+
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
269
|
+
|
270
|
+
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
271
|
+
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
272
|
+
}
|
273
|
+
return rejects;
|
274
|
+
}
|
275
|
+
|
276
|
+
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
277
|
+
const llama_grammar_rules & rules,
|
278
|
+
const llama_grammar_stack & stack,
|
279
|
+
const llama_grammar_candidates & candidates) {
|
280
|
+
|
281
|
+
llama_grammar_candidates rejects;
|
282
|
+
rejects.reserve(candidates.size());
|
283
|
+
|
284
|
+
if (stack.empty()) {
|
285
|
+
for (const auto & tok : candidates) {
|
286
|
+
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
|
287
|
+
rejects.push_back(tok);
|
288
|
+
}
|
289
|
+
}
|
290
|
+
return rejects;
|
291
|
+
}
|
292
|
+
|
293
|
+
const llama_grammar_element * stack_pos = stack.back();
|
294
|
+
|
295
|
+
llama_grammar_candidates next_candidates;
|
296
|
+
next_candidates.reserve(candidates.size());
|
297
|
+
|
298
|
+
for (const auto & tok : candidates) {
|
299
|
+
if (*tok.code_points == 0) {
|
300
|
+
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
301
|
+
// that cannot satisfy this position in grammar
|
302
|
+
if (tok.partial_utf8.n_remain != 0 &&
|
303
|
+
!llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
304
|
+
rejects.push_back(tok);
|
305
|
+
}
|
306
|
+
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
307
|
+
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
|
308
|
+
} else {
|
309
|
+
rejects.push_back(tok);
|
310
|
+
}
|
311
|
+
}
|
312
|
+
|
313
|
+
const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
|
314
|
+
|
315
|
+
// update top of stack to next element, if any
|
316
|
+
llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
|
317
|
+
if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
|
318
|
+
stack_after.push_back(stack_pos_after);
|
319
|
+
}
|
320
|
+
llama_grammar_stacks next_stacks;
|
321
|
+
llama_grammar_advance_stack(rules, stack_after, next_stacks);
|
322
|
+
|
323
|
+
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
324
|
+
for (const auto & tok : next_rejects) {
|
325
|
+
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
326
|
+
}
|
327
|
+
|
328
|
+
return rejects;
|
329
|
+
}
|
330
|
+
|
331
|
+
static bool llama_grammar_detect_left_recursion(
|
332
|
+
const llama_grammar_rules & rules,
|
333
|
+
size_t rule_index,
|
334
|
+
std::vector<bool> * rules_visited,
|
335
|
+
std::vector<bool> * rules_in_progress,
|
336
|
+
std::vector<bool> * rules_may_be_empty) {
|
337
|
+
if ((*rules_in_progress)[rule_index]) {
|
338
|
+
return true;
|
339
|
+
}
|
340
|
+
|
341
|
+
(*rules_in_progress)[rule_index] = true;
|
342
|
+
|
343
|
+
const llama_grammar_rule & rule = rules[rule_index];
|
344
|
+
|
345
|
+
// First check if the rule might produce the empty string. This could be done combined with the second
|
346
|
+
// step but it's more readable as two steps.
|
347
|
+
bool at_rule_start = true;
|
348
|
+
for (size_t i = 0; i < rule.size(); i++) {
|
349
|
+
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
350
|
+
if (at_rule_start) {
|
351
|
+
(*rules_may_be_empty)[rule_index] = true;
|
352
|
+
break;
|
353
|
+
}
|
354
|
+
at_rule_start = true;
|
355
|
+
} else {
|
356
|
+
at_rule_start = false;
|
357
|
+
}
|
358
|
+
}
|
359
|
+
|
360
|
+
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
361
|
+
// be empty)
|
362
|
+
bool recurse_into_nonterminal = true;
|
363
|
+
for (size_t i = 0; i < rule.size(); i++) {
|
364
|
+
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
365
|
+
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
366
|
+
return true;
|
367
|
+
}
|
368
|
+
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
369
|
+
recurse_into_nonterminal = false;
|
370
|
+
}
|
371
|
+
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
372
|
+
recurse_into_nonterminal = true;
|
373
|
+
} else {
|
374
|
+
recurse_into_nonterminal = false;
|
375
|
+
}
|
376
|
+
}
|
377
|
+
|
378
|
+
(*rules_in_progress)[rule_index] = false;
|
379
|
+
(*rules_visited)[rule_index] = true;
|
380
|
+
return false;
|
381
|
+
}
|
382
|
+
|
383
|
+
//
|
384
|
+
// grammar - external
|
385
|
+
//
|
386
|
+
|
387
|
+
struct llama_grammar * llama_grammar_init_impl(
|
388
|
+
const llama_grammar_element ** rules,
|
389
|
+
size_t n_rules,
|
390
|
+
size_t start_rule_index) {
|
391
|
+
const llama_grammar_element * pos;
|
392
|
+
|
393
|
+
// copy rule definitions into vectors
|
394
|
+
llama_grammar_rules vec_rules(n_rules);
|
395
|
+
for (size_t i = 0; i < n_rules; i++) {
|
396
|
+
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
397
|
+
vec_rules[i].push_back(*pos);
|
398
|
+
}
|
399
|
+
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
400
|
+
}
|
401
|
+
|
402
|
+
// Check for left recursion
|
403
|
+
std::vector<bool> rules_visited(n_rules);
|
404
|
+
std::vector<bool> rules_in_progress(n_rules);
|
405
|
+
std::vector<bool> rules_may_be_empty(n_rules);
|
406
|
+
for (size_t i = 0; i < n_rules; i++) {
|
407
|
+
if (rules_visited[i]) {
|
408
|
+
continue;
|
409
|
+
}
|
410
|
+
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
411
|
+
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
|
412
|
+
return nullptr;
|
413
|
+
}
|
414
|
+
}
|
415
|
+
|
416
|
+
// loop over alternates of start rule to build initial stacks
|
417
|
+
llama_grammar_stacks stacks;
|
418
|
+
pos = vec_rules[start_rule_index].data();
|
419
|
+
do {
|
420
|
+
llama_grammar_stack stack;
|
421
|
+
if (!llama_grammar_is_end_of_sequence(pos)) {
|
422
|
+
// if alternate is nonempty, add to stack
|
423
|
+
stack.push_back(pos);
|
424
|
+
}
|
425
|
+
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
426
|
+
while (!llama_grammar_is_end_of_sequence(pos)) {
|
427
|
+
// scan to end of alternate def
|
428
|
+
pos++;
|
429
|
+
}
|
430
|
+
if (pos->type == LLAMA_GRETYPE_ALT) {
|
431
|
+
// there's another alternate def of this rule to process
|
432
|
+
pos++;
|
433
|
+
} else {
|
434
|
+
break;
|
435
|
+
}
|
436
|
+
} while (true);
|
437
|
+
|
438
|
+
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
439
|
+
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
440
|
+
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
441
|
+
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
442
|
+
}
|
443
|
+
|
444
|
+
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
445
|
+
delete grammar;
|
446
|
+
}
|
447
|
+
|
448
|
+
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
|
449
|
+
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
|
450
|
+
|
451
|
+
// redirect elements in stacks to point to new rules
|
452
|
+
for (size_t is = 0; is < result->stacks.size(); is++) {
|
453
|
+
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
454
|
+
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
|
455
|
+
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
|
456
|
+
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
|
457
|
+
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
458
|
+
}
|
459
|
+
}
|
460
|
+
}
|
461
|
+
}
|
462
|
+
}
|
463
|
+
|
464
|
+
return result;
|
465
|
+
}
|
466
|
+
|
467
|
+
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
468
|
+
LM_GGML_ASSERT(grammar);
|
469
|
+
LM_GGML_ASSERT(vocab);
|
470
|
+
|
471
|
+
int64_t t_start_sample_us = lm_ggml_time_us();
|
472
|
+
|
473
|
+
bool allow_eog = false;
|
474
|
+
for (const auto & stack : grammar->stacks) {
|
475
|
+
if (stack.empty()) {
|
476
|
+
allow_eog = true;
|
477
|
+
break;
|
478
|
+
}
|
479
|
+
}
|
480
|
+
|
481
|
+
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
482
|
+
candidates_decoded.reserve(candidates->size);
|
483
|
+
|
484
|
+
llama_grammar_candidates candidates_grammar;
|
485
|
+
candidates_grammar.reserve(candidates->size);
|
486
|
+
|
487
|
+
for (size_t i = 0; i < candidates->size; ++i) {
|
488
|
+
const llama_token id = candidates->data[i].id;
|
489
|
+
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
490
|
+
|
491
|
+
if (llama_token_is_eog_impl(*vocab, id)) {
|
492
|
+
if (!allow_eog) {
|
493
|
+
candidates->data[i].logit = -INFINITY;
|
494
|
+
}
|
495
|
+
} else if (piece.empty() || piece[0] == 0) {
|
496
|
+
candidates->data[i].logit = -INFINITY;
|
497
|
+
} else {
|
498
|
+
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
|
499
|
+
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
500
|
+
}
|
501
|
+
}
|
502
|
+
|
503
|
+
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
|
504
|
+
for (const auto & reject : rejects) {
|
505
|
+
candidates->data[reject.index].logit = -INFINITY;
|
506
|
+
}
|
507
|
+
|
508
|
+
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
509
|
+
}
|
510
|
+
|
511
|
+
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
|
512
|
+
const int64_t t_start_sample_us = lm_ggml_time_us();
|
513
|
+
|
514
|
+
if (llama_token_is_eog_impl(*vocab, token)) {
|
515
|
+
for (const auto & stack : grammar->stacks) {
|
516
|
+
if (stack.empty()) {
|
517
|
+
return;
|
518
|
+
}
|
519
|
+
}
|
520
|
+
LM_GGML_ABORT("fatal error");
|
521
|
+
}
|
522
|
+
|
523
|
+
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
524
|
+
|
525
|
+
// Note terminating 0 in decoded string
|
526
|
+
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
527
|
+
const auto & code_points = decoded.first;
|
528
|
+
|
529
|
+
llama_grammar_stacks tmp_new_stacks;
|
530
|
+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
531
|
+
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
|
532
|
+
grammar->stacks = tmp_new_stacks;
|
533
|
+
}
|
534
|
+
|
535
|
+
grammar->partial_utf8 = decoded.second;
|
536
|
+
LM_GGML_ASSERT(!grammar->stacks.empty());
|
537
|
+
|
538
|
+
smpl->t_sample_us += lm_ggml_time_us() - t_start_sample_us;
|
539
|
+
}
|
@@ -0,0 +1,39 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include "llama-impl.h"
|
4
|
+
|
5
|
+
struct llama_vocab;
|
6
|
+
struct llama_sampling;
|
7
|
+
|
8
|
+
struct llama_grammar {
|
9
|
+
const llama_grammar_rules rules;
|
10
|
+
llama_grammar_stacks stacks;
|
11
|
+
|
12
|
+
// buffer for partially generated UTF-8 sequence from accepted tokens
|
13
|
+
llama_partial_utf8 partial_utf8;
|
14
|
+
};
|
15
|
+
|
16
|
+
//
|
17
|
+
// internal API
|
18
|
+
//
|
19
|
+
|
20
|
+
struct llama_grammar * llama_grammar_init_impl(
|
21
|
+
const llama_grammar_element ** rules,
|
22
|
+
size_t n_rules,
|
23
|
+
size_t start_rule_index);
|
24
|
+
|
25
|
+
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
26
|
+
|
27
|
+
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
|
28
|
+
|
29
|
+
void llama_grammar_sample_impl(
|
30
|
+
const struct llama_grammar * grammar,
|
31
|
+
const struct llama_vocab * vocab,
|
32
|
+
const struct llama_sampling * smpl,
|
33
|
+
llama_token_data_array * candidates);
|
34
|
+
|
35
|
+
void llama_grammar_accept_token_impl(
|
36
|
+
struct llama_grammar * grammar,
|
37
|
+
const struct llama_vocab * vocab,
|
38
|
+
const struct llama_sampling * smpl,
|
39
|
+
llama_token token);
|
package/cpp/llama-impl.h
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#define LLAMA_API_INTERNAL
|
4
|
+
#include "llama.h"
|
5
|
+
|
6
|
+
#ifdef __GNUC__
|
7
|
+
#ifdef __MINGW32__
|
8
|
+
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
9
|
+
#else
|
10
|
+
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
11
|
+
#endif
|
12
|
+
#else
|
13
|
+
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
14
|
+
#endif
|
15
|
+
|
16
|
+
//
|
17
|
+
// logging
|
18
|
+
//
|
19
|
+
|
20
|
+
LLAMA_ATTRIBUTE_FORMAT(2, 3)
|
21
|
+
void llama_log_internal (lm_ggml_log_level level, const char * format, ...);
|
22
|
+
void llama_log_callback_default(lm_ggml_log_level level, const char * text, void * user_data);
|
23
|
+
|
24
|
+
#define LLAMA_LOG_INFO(...) llama_log_internal(LM_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
25
|
+
#define LLAMA_LOG_WARN(...) llama_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
26
|
+
#define LLAMA_LOG_ERROR(...) llama_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|