cui-llama.rn 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +20 -0
- package/README.md +330 -0
- package/android/build.gradle +107 -0
- package/android/gradle.properties +5 -0
- package/android/src/main/AndroidManifest.xml +4 -0
- package/android/src/main/CMakeLists.txt +69 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
- package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
- package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
- package/android/src/main/jni.cpp +635 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
- package/cpp/README.md +4 -0
- package/cpp/common.cpp +3237 -0
- package/cpp/common.h +467 -0
- package/cpp/ggml-aarch64.c +2193 -0
- package/cpp/ggml-aarch64.h +39 -0
- package/cpp/ggml-alloc.c +1041 -0
- package/cpp/ggml-alloc.h +76 -0
- package/cpp/ggml-backend-impl.h +153 -0
- package/cpp/ggml-backend.c +2225 -0
- package/cpp/ggml-backend.h +236 -0
- package/cpp/ggml-common.h +1829 -0
- package/cpp/ggml-impl.h +655 -0
- package/cpp/ggml-metal.h +65 -0
- package/cpp/ggml-metal.m +3273 -0
- package/cpp/ggml-quants.c +15022 -0
- package/cpp/ggml-quants.h +132 -0
- package/cpp/ggml.c +22034 -0
- package/cpp/ggml.h +2444 -0
- package/cpp/grammar-parser.cpp +536 -0
- package/cpp/grammar-parser.h +29 -0
- package/cpp/json-schema-to-grammar.cpp +1045 -0
- package/cpp/json-schema-to-grammar.h +8 -0
- package/cpp/json.hpp +24766 -0
- package/cpp/llama.cpp +21789 -0
- package/cpp/llama.h +1201 -0
- package/cpp/log.h +737 -0
- package/cpp/rn-llama.hpp +630 -0
- package/cpp/sampling.cpp +460 -0
- package/cpp/sampling.h +160 -0
- package/cpp/sgemm.cpp +1027 -0
- package/cpp/sgemm.h +14 -0
- package/cpp/unicode-data.cpp +7032 -0
- package/cpp/unicode-data.h +20 -0
- package/cpp/unicode.cpp +812 -0
- package/cpp/unicode.h +64 -0
- package/ios/RNLlama.h +11 -0
- package/ios/RNLlama.mm +302 -0
- package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
- package/ios/RNLlamaContext.h +39 -0
- package/ios/RNLlamaContext.mm +426 -0
- package/jest/mock.js +169 -0
- package/lib/commonjs/NativeRNLlama.js +10 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -0
- package/lib/commonjs/grammar.js +574 -0
- package/lib/commonjs/grammar.js.map +1 -0
- package/lib/commonjs/index.js +151 -0
- package/lib/commonjs/index.js.map +1 -0
- package/lib/module/NativeRNLlama.js +3 -0
- package/lib/module/NativeRNLlama.js.map +1 -0
- package/lib/module/grammar.js +566 -0
- package/lib/module/grammar.js.map +1 -0
- package/lib/module/index.js +129 -0
- package/lib/module/index.js.map +1 -0
- package/lib/typescript/NativeRNLlama.d.ts +107 -0
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
- package/lib/typescript/grammar.d.ts +38 -0
- package/lib/typescript/grammar.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +46 -0
- package/lib/typescript/index.d.ts.map +1 -0
- package/llama-rn.podspec +56 -0
- package/package.json +230 -0
- package/src/NativeRNLlama.ts +132 -0
- package/src/grammar.ts +849 -0
- package/src/index.ts +182 -0
package/cpp/rn-llama.hpp
ADDED
@@ -0,0 +1,630 @@
|
|
1
|
+
#ifndef RNLLAMA_H
|
2
|
+
#define RNLLAMA_H
|
3
|
+
|
4
|
+
#include <sstream>
|
5
|
+
#include <iostream>
|
6
|
+
#include "common.h"
|
7
|
+
#include "llama.h"
|
8
|
+
|
9
|
+
namespace rnllama {
|
10
|
+
|
11
|
+
static void llama_batch_clear(llama_batch *batch) {
|
12
|
+
batch->n_tokens = 0;
|
13
|
+
}
|
14
|
+
|
15
|
+
static void llama_batch_add(llama_batch *batch, llama_token id, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
|
16
|
+
batch->token [batch->n_tokens] = id;
|
17
|
+
batch->pos [batch->n_tokens] = pos;
|
18
|
+
batch->n_seq_id[batch->n_tokens] = seq_ids.size();
|
19
|
+
for (size_t i = 0; i < seq_ids.size(); i++) {
|
20
|
+
batch->seq_id[batch->n_tokens][i] = seq_ids[i];
|
21
|
+
}
|
22
|
+
batch->logits [batch->n_tokens] = logits ? 1 : 0;
|
23
|
+
batch->n_tokens += 1;
|
24
|
+
}
|
25
|
+
|
26
|
+
// NOTE: Edit from https://github.com/ggerganov/llama.cpp/blob/master/examples/server/server.cpp
|
27
|
+
|
28
|
+
static void log(const char *level, const char *function, int line,
|
29
|
+
const char *format, ...)
|
30
|
+
{
|
31
|
+
printf("[%s] %s:%d ", level, function, line);
|
32
|
+
|
33
|
+
va_list args;
|
34
|
+
va_start(args, format);
|
35
|
+
vprintf(format, args);
|
36
|
+
va_end(args);
|
37
|
+
|
38
|
+
printf("\n");
|
39
|
+
}
|
40
|
+
|
41
|
+
static bool rnllama_verbose = false;
|
42
|
+
|
43
|
+
#if RNLLAMA_VERBOSE != 1
|
44
|
+
#define LOG_VERBOSE(MSG, ...)
|
45
|
+
#else
|
46
|
+
#define LOG_VERBOSE(MSG, ...) \
|
47
|
+
do \
|
48
|
+
{ \
|
49
|
+
if (rnllama_verbose) \
|
50
|
+
{ \
|
51
|
+
log("VERBOSE", __func__, __LINE__, MSG, ##__VA_ARGS__); \
|
52
|
+
} \
|
53
|
+
} while (0)
|
54
|
+
#endif
|
55
|
+
|
56
|
+
#define LOG_ERROR(MSG, ...) log("ERROR", __func__, __LINE__, MSG, ##__VA_ARGS__)
|
57
|
+
#define LOG_WARNING(MSG, ...) log("WARNING", __func__, __LINE__, MSG, ##__VA_ARGS__)
|
58
|
+
#define LOG_INFO(MSG, ...) log("INFO", __func__, __LINE__, MSG, ##__VA_ARGS__)
|
59
|
+
|
60
|
+
enum stop_type
|
61
|
+
{
|
62
|
+
STOP_FULL,
|
63
|
+
STOP_PARTIAL,
|
64
|
+
};
|
65
|
+
|
66
|
+
// completion token output with probabilities
|
67
|
+
struct completion_token_output
|
68
|
+
{
|
69
|
+
struct token_prob
|
70
|
+
{
|
71
|
+
llama_token tok;
|
72
|
+
float prob;
|
73
|
+
};
|
74
|
+
|
75
|
+
std::vector<token_prob> probs;
|
76
|
+
llama_token tok;
|
77
|
+
};
|
78
|
+
|
79
|
+
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
80
|
+
{
|
81
|
+
size_t i;
|
82
|
+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++)
|
83
|
+
{
|
84
|
+
}
|
85
|
+
return i;
|
86
|
+
}
|
87
|
+
|
88
|
+
static bool ends_with(const std::string &str, const std::string &suffix)
|
89
|
+
{
|
90
|
+
return str.size() >= suffix.size() &&
|
91
|
+
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
92
|
+
}
|
93
|
+
|
94
|
+
static size_t find_partial_stop_string(const std::string &stop,
|
95
|
+
const std::string &text)
|
96
|
+
{
|
97
|
+
if (!text.empty() && !stop.empty())
|
98
|
+
{
|
99
|
+
const char text_last_char = text.back();
|
100
|
+
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
|
101
|
+
{
|
102
|
+
if (stop[char_index] == text_last_char)
|
103
|
+
{
|
104
|
+
const std::string current_partial = stop.substr(0, char_index + 1);
|
105
|
+
if (ends_with(text, current_partial))
|
106
|
+
{
|
107
|
+
return text.size() - char_index - 1;
|
108
|
+
}
|
109
|
+
}
|
110
|
+
}
|
111
|
+
}
|
112
|
+
return std::string::npos;
|
113
|
+
}
|
114
|
+
|
115
|
+
// format incomplete utf-8 multibyte character for output
|
116
|
+
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
|
117
|
+
{
|
118
|
+
std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
|
119
|
+
// if the size is 1 and first bit is 1, meaning it's a partial character
|
120
|
+
// (size > 1 meaning it's already a known token)
|
121
|
+
if (out.size() == 1 && (out[0] & 0x80) == 0x80)
|
122
|
+
{
|
123
|
+
std::stringstream ss;
|
124
|
+
ss << std::hex << (out[0] & 0xff);
|
125
|
+
std::string res(ss.str());
|
126
|
+
out = "byte: \\x" + res;
|
127
|
+
}
|
128
|
+
return out;
|
129
|
+
}
|
130
|
+
|
131
|
+
template <class Iter>
|
132
|
+
static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
|
133
|
+
{
|
134
|
+
std::string ret;
|
135
|
+
for (; begin != end; ++begin)
|
136
|
+
{
|
137
|
+
ret += llama_token_to_piece(ctx, *begin);
|
138
|
+
}
|
139
|
+
return ret;
|
140
|
+
}
|
141
|
+
|
142
|
+
struct llama_rn_context
|
143
|
+
{
|
144
|
+
bool is_predicting = false;
|
145
|
+
bool is_interrupted = false;
|
146
|
+
bool has_next_token = false;
|
147
|
+
std::string generated_text;
|
148
|
+
std::vector<completion_token_output> generated_token_probs;
|
149
|
+
|
150
|
+
size_t num_prompt_tokens = 0;
|
151
|
+
size_t num_tokens_predicted = 0;
|
152
|
+
size_t n_past = 0;
|
153
|
+
size_t n_remain = 0;
|
154
|
+
|
155
|
+
std::vector<llama_token> embd;
|
156
|
+
|
157
|
+
gpt_params params;
|
158
|
+
|
159
|
+
llama_model *model = nullptr;
|
160
|
+
llama_context *ctx = nullptr;
|
161
|
+
llama_sampling_context *ctx_sampling = nullptr;
|
162
|
+
|
163
|
+
int n_ctx;
|
164
|
+
|
165
|
+
bool truncated = false;
|
166
|
+
bool stopped_eos = false;
|
167
|
+
bool stopped_word = false;
|
168
|
+
bool stopped_limit = false;
|
169
|
+
std::string stopping_word;
|
170
|
+
int32_t multibyte_pending = 0;
|
171
|
+
|
172
|
+
~llama_rn_context()
|
173
|
+
{
|
174
|
+
if (ctx)
|
175
|
+
{
|
176
|
+
llama_free(ctx);
|
177
|
+
ctx = nullptr;
|
178
|
+
}
|
179
|
+
if (model)
|
180
|
+
{
|
181
|
+
llama_free_model(model);
|
182
|
+
model = nullptr;
|
183
|
+
}
|
184
|
+
if (ctx_sampling != nullptr)
|
185
|
+
{
|
186
|
+
llama_sampling_free(ctx_sampling);
|
187
|
+
}
|
188
|
+
}
|
189
|
+
|
190
|
+
void rewind()
|
191
|
+
{
|
192
|
+
is_interrupted = false;
|
193
|
+
params.antiprompt.clear();
|
194
|
+
params.sparams.grammar.clear();
|
195
|
+
num_prompt_tokens = 0;
|
196
|
+
num_tokens_predicted = 0;
|
197
|
+
generated_text = "";
|
198
|
+
generated_text.reserve(params.n_ctx);
|
199
|
+
generated_token_probs.clear();
|
200
|
+
truncated = false;
|
201
|
+
stopped_eos = false;
|
202
|
+
stopped_word = false;
|
203
|
+
stopped_limit = false;
|
204
|
+
stopping_word = "";
|
205
|
+
multibyte_pending = 0;
|
206
|
+
n_remain = 0;
|
207
|
+
n_past = 0;
|
208
|
+
params.sparams.n_prev = n_ctx;
|
209
|
+
}
|
210
|
+
|
211
|
+
bool initSampling() {
|
212
|
+
if (ctx_sampling != nullptr) {
|
213
|
+
llama_sampling_free(ctx_sampling);
|
214
|
+
}
|
215
|
+
ctx_sampling = llama_sampling_init(params.sparams);
|
216
|
+
return ctx_sampling != nullptr;
|
217
|
+
}
|
218
|
+
|
219
|
+
bool loadModel(gpt_params ¶ms_)
|
220
|
+
{
|
221
|
+
params = params_;
|
222
|
+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
223
|
+
if (model == nullptr)
|
224
|
+
{
|
225
|
+
LOG_ERROR("unable to load model: %s", params_.model.c_str());
|
226
|
+
return false;
|
227
|
+
}
|
228
|
+
n_ctx = llama_n_ctx(ctx);
|
229
|
+
return true;
|
230
|
+
}
|
231
|
+
|
232
|
+
void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
|
233
|
+
const int n_left = n_ctx - params.n_keep;
|
234
|
+
const int n_block_size = n_left / 2;
|
235
|
+
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size;
|
236
|
+
|
237
|
+
// Keep n_keep tokens at start of prompt (at most n_ctx - 4)
|
238
|
+
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
239
|
+
|
240
|
+
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
241
|
+
|
242
|
+
LOG_VERBOSE("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, new_tokens: %s, num_prompt_tokens: %d",
|
243
|
+
n_ctx,
|
244
|
+
params.n_keep,
|
245
|
+
n_left,
|
246
|
+
tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()).c_str(),
|
247
|
+
new_tokens.size()
|
248
|
+
);
|
249
|
+
|
250
|
+
truncated = true;
|
251
|
+
prompt_tokens = new_tokens;
|
252
|
+
}
|
253
|
+
|
254
|
+
void loadPrompt()
|
255
|
+
{
|
256
|
+
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
|
257
|
+
num_prompt_tokens = prompt_tokens.size();
|
258
|
+
|
259
|
+
if (params.n_keep < 0)
|
260
|
+
{
|
261
|
+
params.n_keep = (int)num_prompt_tokens;
|
262
|
+
}
|
263
|
+
params.n_keep = std::min(n_ctx - 4, params.n_keep);
|
264
|
+
|
265
|
+
// if input prompt is too big, truncate like normal
|
266
|
+
if (num_prompt_tokens >= (size_t) n_ctx)
|
267
|
+
{
|
268
|
+
truncatePrompt(prompt_tokens);
|
269
|
+
num_prompt_tokens = prompt_tokens.size();
|
270
|
+
|
271
|
+
LM_GGML_ASSERT(num_prompt_tokens < (size_t) n_ctx);
|
272
|
+
}
|
273
|
+
// push the prompt into the sampling context (do not apply grammar)
|
274
|
+
for (auto & token : prompt_tokens)
|
275
|
+
{
|
276
|
+
llama_sampling_accept(ctx_sampling, ctx, token, false);
|
277
|
+
}
|
278
|
+
|
279
|
+
// compare the evaluated prompt with the new prompt
|
280
|
+
n_past = common_part(embd, prompt_tokens);
|
281
|
+
|
282
|
+
embd = prompt_tokens;
|
283
|
+
if (n_past == num_prompt_tokens)
|
284
|
+
{
|
285
|
+
// we have to evaluate at least 1 token to generate logits.
|
286
|
+
n_past--;
|
287
|
+
}
|
288
|
+
|
289
|
+
// since #3228 we now have to manually manage the KV cache
|
290
|
+
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
291
|
+
|
292
|
+
LOG_VERBOSE("prompt ingested, n_past: %d, cached: %s, to_eval: %s",
|
293
|
+
n_past,
|
294
|
+
tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past).c_str(),
|
295
|
+
tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()).c_str()
|
296
|
+
);
|
297
|
+
|
298
|
+
has_next_token = true;
|
299
|
+
}
|
300
|
+
|
301
|
+
void beginCompletion()
|
302
|
+
{
|
303
|
+
// number of tokens to keep when resetting context
|
304
|
+
n_remain = params.n_predict;
|
305
|
+
llama_set_rng_seed(ctx, params.seed);
|
306
|
+
|
307
|
+
is_predicting = true;
|
308
|
+
}
|
309
|
+
|
310
|
+
completion_token_output nextToken()
|
311
|
+
{
|
312
|
+
completion_token_output result;
|
313
|
+
result.tok = -1;
|
314
|
+
|
315
|
+
if (embd.size() >= (size_t)params.n_ctx)
|
316
|
+
{
|
317
|
+
// Shift context
|
318
|
+
|
319
|
+
const int n_left = n_past - params.n_keep - 1;
|
320
|
+
const int n_discard = n_left/2;
|
321
|
+
|
322
|
+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
323
|
+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
324
|
+
|
325
|
+
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
|
326
|
+
{
|
327
|
+
embd[i - n_discard] = embd[i];
|
328
|
+
}
|
329
|
+
embd.resize(embd.size() - n_discard);
|
330
|
+
|
331
|
+
n_past -= n_discard;
|
332
|
+
|
333
|
+
LOG_VERBOSE("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, new_tokens: %s",
|
334
|
+
params.n_ctx,
|
335
|
+
params.n_keep,
|
336
|
+
n_left
|
337
|
+
);
|
338
|
+
}
|
339
|
+
|
340
|
+
bool tg = true;
|
341
|
+
while (n_past < embd.size())
|
342
|
+
{
|
343
|
+
int n_eval = (int)embd.size() - n_past;
|
344
|
+
tg = n_eval == 1;
|
345
|
+
if (n_eval > params.n_batch)
|
346
|
+
{
|
347
|
+
n_eval = params.n_batch;
|
348
|
+
}
|
349
|
+
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
|
350
|
+
{
|
351
|
+
LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s",
|
352
|
+
n_eval,
|
353
|
+
n_past,
|
354
|
+
params.n_threads,
|
355
|
+
tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()).c_str()
|
356
|
+
);
|
357
|
+
has_next_token = false;
|
358
|
+
return result;
|
359
|
+
}
|
360
|
+
n_past += n_eval;
|
361
|
+
}
|
362
|
+
|
363
|
+
if (params.n_predict == 0)
|
364
|
+
{
|
365
|
+
has_next_token = false;
|
366
|
+
result.tok = llama_token_eos(model);
|
367
|
+
return result;
|
368
|
+
}
|
369
|
+
|
370
|
+
{
|
371
|
+
// out of user input, sample next token
|
372
|
+
std::vector<llama_token_data> candidates;
|
373
|
+
candidates.reserve(llama_n_vocab(model));
|
374
|
+
|
375
|
+
result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL);
|
376
|
+
|
377
|
+
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
|
378
|
+
|
379
|
+
const int32_t n_probs = params.sparams.n_probs;
|
380
|
+
if (params.sparams.temp <= 0 && n_probs > 0)
|
381
|
+
{
|
382
|
+
// For llama_sample_token_greedy we need to sort candidates
|
383
|
+
llama_sample_softmax(ctx, &cur_p);
|
384
|
+
}
|
385
|
+
|
386
|
+
for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
|
387
|
+
{
|
388
|
+
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
|
389
|
+
}
|
390
|
+
llama_sampling_accept(ctx_sampling, ctx, result.tok, true);
|
391
|
+
if (tg) {
|
392
|
+
num_tokens_predicted++;
|
393
|
+
}
|
394
|
+
}
|
395
|
+
|
396
|
+
// add it to the context
|
397
|
+
embd.push_back(result.tok);
|
398
|
+
// decrement remaining sampling budget
|
399
|
+
--n_remain;
|
400
|
+
|
401
|
+
if (!embd.empty() && embd.back() == llama_token_eos(model))
|
402
|
+
{
|
403
|
+
// stopping_word = llama_token_to_piece(ctx, embd.back());
|
404
|
+
has_next_token = false;
|
405
|
+
stopped_eos = true;
|
406
|
+
LOG_VERBOSE("eos token found", "");
|
407
|
+
return result;
|
408
|
+
}
|
409
|
+
|
410
|
+
has_next_token = params.n_predict == -1 || n_remain != 0;
|
411
|
+
return result;
|
412
|
+
}
|
413
|
+
|
414
|
+
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
|
415
|
+
const stop_type type)
|
416
|
+
{
|
417
|
+
size_t stop_pos = std::string::npos;
|
418
|
+
for (const std::string &word : params.antiprompt)
|
419
|
+
{
|
420
|
+
size_t pos;
|
421
|
+
if (type == STOP_FULL)
|
422
|
+
{
|
423
|
+
const size_t tmp = word.size() + last_token_size;
|
424
|
+
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
425
|
+
pos = text.find(word, from_pos);
|
426
|
+
}
|
427
|
+
else
|
428
|
+
{
|
429
|
+
pos = find_partial_stop_string(word, text);
|
430
|
+
}
|
431
|
+
if (pos != std::string::npos &&
|
432
|
+
(stop_pos == std::string::npos || pos < stop_pos))
|
433
|
+
{
|
434
|
+
if (type == STOP_FULL)
|
435
|
+
{
|
436
|
+
stopping_word = word;
|
437
|
+
stopped_word = true;
|
438
|
+
has_next_token = false;
|
439
|
+
}
|
440
|
+
stop_pos = pos;
|
441
|
+
}
|
442
|
+
}
|
443
|
+
return stop_pos;
|
444
|
+
}
|
445
|
+
|
446
|
+
completion_token_output doCompletion()
|
447
|
+
{
|
448
|
+
const completion_token_output token_with_probs = nextToken();
|
449
|
+
|
450
|
+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
|
451
|
+
generated_text += token_text;
|
452
|
+
|
453
|
+
if (params.sparams.n_probs > 0)
|
454
|
+
{
|
455
|
+
generated_token_probs.push_back(token_with_probs);
|
456
|
+
}
|
457
|
+
|
458
|
+
if (multibyte_pending > 0)
|
459
|
+
{
|
460
|
+
multibyte_pending -= token_text.size();
|
461
|
+
}
|
462
|
+
else if (token_text.size() == 1)
|
463
|
+
{
|
464
|
+
const char c = token_text[0];
|
465
|
+
// 2-byte characters: 110xxxxx 10xxxxxx
|
466
|
+
if ((c & 0xE0) == 0xC0)
|
467
|
+
{
|
468
|
+
multibyte_pending = 1;
|
469
|
+
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
470
|
+
}
|
471
|
+
else if ((c & 0xF0) == 0xE0)
|
472
|
+
{
|
473
|
+
multibyte_pending = 2;
|
474
|
+
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
475
|
+
}
|
476
|
+
else if ((c & 0xF8) == 0xF0)
|
477
|
+
{
|
478
|
+
multibyte_pending = 3;
|
479
|
+
}
|
480
|
+
else
|
481
|
+
{
|
482
|
+
multibyte_pending = 0;
|
483
|
+
}
|
484
|
+
}
|
485
|
+
|
486
|
+
if (multibyte_pending > 0 && !has_next_token)
|
487
|
+
{
|
488
|
+
has_next_token = true;
|
489
|
+
n_remain++;
|
490
|
+
}
|
491
|
+
|
492
|
+
if (!has_next_token && n_remain == 0)
|
493
|
+
{
|
494
|
+
stopped_limit = true;
|
495
|
+
}
|
496
|
+
|
497
|
+
LOG_VERBOSE("next token, token: %s, token_text: %s, has_next_token: %d, n_remain: %d, num_tokens_predicted: %d, stopped_eos: %d, stopped_word: %d, stopped_limit: %d, stopping_word: %s",
|
498
|
+
llama_token_to_piece(ctx, token_with_probs.tok),
|
499
|
+
tokens_to_output_formatted_string(ctx, token_with_probs.tok).c_str(),
|
500
|
+
has_next_token,
|
501
|
+
n_remain,
|
502
|
+
num_tokens_predicted,
|
503
|
+
stopped_eos,
|
504
|
+
stopped_word,
|
505
|
+
stopped_limit,
|
506
|
+
stopping_word.c_str()
|
507
|
+
);
|
508
|
+
return token_with_probs;
|
509
|
+
}
|
510
|
+
|
511
|
+
std::vector<float> getEmbedding()
|
512
|
+
{
|
513
|
+
static const int n_embd = llama_n_embd(llama_get_model(ctx));
|
514
|
+
if (!params.embedding)
|
515
|
+
{
|
516
|
+
LOG_WARNING("embedding disabled, embedding: %s", params.embedding);
|
517
|
+
return std::vector<float>(n_embd, 0.0f);
|
518
|
+
}
|
519
|
+
const float *data = llama_get_embeddings(ctx);
|
520
|
+
std::vector<float> embedding(data, data + n_embd);
|
521
|
+
return embedding;
|
522
|
+
}
|
523
|
+
|
524
|
+
std::string bench(int pp, int tg, int pl, int nr)
|
525
|
+
{
|
526
|
+
if (is_predicting) {
|
527
|
+
LOG_ERROR("cannot benchmark while predicting", "");
|
528
|
+
return std::string("[]");
|
529
|
+
}
|
530
|
+
|
531
|
+
is_predicting = true;
|
532
|
+
|
533
|
+
double pp_avg = 0;
|
534
|
+
double tg_avg = 0;
|
535
|
+
|
536
|
+
double pp_std = 0;
|
537
|
+
double tg_std = 0;
|
538
|
+
|
539
|
+
// TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30)
|
540
|
+
llama_batch batch = llama_batch_init(512, 0, 1);
|
541
|
+
|
542
|
+
for (int i = 0; i < nr; i++)
|
543
|
+
{
|
544
|
+
llama_batch_clear(&batch);
|
545
|
+
|
546
|
+
const int n_tokens = pp;
|
547
|
+
|
548
|
+
for (int i = 0; i < n_tokens; i++)
|
549
|
+
{
|
550
|
+
llama_batch_add(&batch, 0, i, {0}, false);
|
551
|
+
}
|
552
|
+
batch.logits[batch.n_tokens - 1] = 1; // true
|
553
|
+
|
554
|
+
llama_kv_cache_clear(ctx);
|
555
|
+
|
556
|
+
const int64_t t_pp_start = llama_time_us();
|
557
|
+
if (llama_decode(ctx, batch) != 0)
|
558
|
+
{
|
559
|
+
LOG_ERROR("llama_decode() failed during prompt", "");
|
560
|
+
}
|
561
|
+
const int64_t t_pp_end = llama_time_us();
|
562
|
+
llama_kv_cache_clear(ctx);
|
563
|
+
|
564
|
+
if (is_interrupted) break;
|
565
|
+
|
566
|
+
const int64_t t_tg_start = llama_time_us();
|
567
|
+
|
568
|
+
for (int i = 0; i < tg; i++)
|
569
|
+
{
|
570
|
+
llama_batch_clear(&batch);
|
571
|
+
|
572
|
+
for (int j = 0; j < pl; j++)
|
573
|
+
{
|
574
|
+
llama_batch_add(&batch, 0, i, {j}, true);
|
575
|
+
}
|
576
|
+
|
577
|
+
if (llama_decode(ctx, batch) != 0)
|
578
|
+
{
|
579
|
+
LOG_ERROR("llama_decode() failed during text generation", "");
|
580
|
+
}
|
581
|
+
if (is_interrupted) break;
|
582
|
+
}
|
583
|
+
|
584
|
+
const int64_t t_tg_end = llama_time_us();
|
585
|
+
|
586
|
+
llama_kv_cache_clear(ctx);
|
587
|
+
|
588
|
+
const double t_pp = (t_pp_end - t_pp_start) / 1000000.0;
|
589
|
+
const double t_tg = (t_tg_end - t_tg_start) / 1000000.0;
|
590
|
+
|
591
|
+
const double speed_pp = pp / t_pp;
|
592
|
+
const double speed_tg = (pl * tg) / t_tg;
|
593
|
+
|
594
|
+
pp_avg += speed_pp;
|
595
|
+
tg_avg += speed_tg;
|
596
|
+
|
597
|
+
pp_std += speed_pp * speed_pp;
|
598
|
+
tg_std += speed_tg * speed_tg;
|
599
|
+
}
|
600
|
+
|
601
|
+
pp_avg /= nr;
|
602
|
+
tg_avg /= nr;
|
603
|
+
|
604
|
+
if (nr > 1) {
|
605
|
+
pp_std = sqrt(pp_std / (nr - 1) - pp_avg * pp_avg * nr / (nr - 1));
|
606
|
+
tg_std = sqrt(tg_std / (nr - 1) - tg_avg * tg_avg * nr / (nr - 1));
|
607
|
+
} else {
|
608
|
+
pp_std = 0;
|
609
|
+
tg_std = 0;
|
610
|
+
}
|
611
|
+
|
612
|
+
if (is_interrupted) llama_kv_cache_clear(ctx);
|
613
|
+
is_predicting = false;
|
614
|
+
|
615
|
+
char model_desc[128];
|
616
|
+
llama_model_desc(model, model_desc, sizeof(model_desc));
|
617
|
+
return std::string("[\"") + model_desc + std::string("\",") +
|
618
|
+
std::to_string(llama_model_size(model)) + std::string(",") +
|
619
|
+
std::to_string(llama_model_n_params(model)) + std::string(",") +
|
620
|
+
std::to_string(pp_avg) + std::string(",") +
|
621
|
+
std::to_string(pp_std) + std::string(",") +
|
622
|
+
std::to_string(tg_avg) + std::string(",") +
|
623
|
+
std::to_string(tg_std) +
|
624
|
+
std::string("]");
|
625
|
+
}
|
626
|
+
};
|
627
|
+
|
628
|
+
}
|
629
|
+
|
630
|
+
#endif /* LLAMA_H */
|