llama-rb 0.1.0 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,460 @@
1
+ #include "common.h"
2
+ #include "llama.h"
3
+
4
+ #include <cassert>
5
+ #include <cinttypes>
6
+ #include <cmath>
7
+ #include <cstdio>
8
+ #include <cstring>
9
+ #include <fstream>
10
+ #include <iostream>
11
+ #include <string>
12
+ #include <vector>
13
+
14
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
15
+ #include <signal.h>
16
+ #include <unistd.h>
17
+ #elif defined (_WIN32)
18
+ #include <signal.h>
19
+ #endif
20
+
21
+ static console_state con_st;
22
+
23
+ static bool is_interacting = false;
24
+
25
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
26
+ void sigint_handler(int signo) {
27
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
28
+ printf("\n"); // this also force flush stdout.
29
+ if (signo == SIGINT) {
30
+ if (!is_interacting) {
31
+ is_interacting=true;
32
+ } else {
33
+ _exit(130);
34
+ }
35
+ }
36
+ }
37
+ #endif
38
+
39
+ int main(int argc, char ** argv) {
40
+ gpt_params params;
41
+ params.model = "models/llama-7B/ggml-model.bin";
42
+
43
+ if (gpt_params_parse(argc, argv, params) == false) {
44
+ return 1;
45
+ }
46
+
47
+ // save choice to use color for later
48
+ // (note for later: this is a slightly awkward choice)
49
+ con_st.use_color = params.use_color;
50
+
51
+ #if defined (_WIN32)
52
+ win32_console_init(params.use_color);
53
+ #endif
54
+
55
+ if (params.perplexity) {
56
+ printf("\n************\n");
57
+ printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
58
+ printf("************\n\n");
59
+
60
+ return 0;
61
+ }
62
+
63
+ if (params.embedding) {
64
+ printf("\n************\n");
65
+ printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
66
+ printf("************\n\n");
67
+
68
+ return 0;
69
+ }
70
+
71
+ if (params.n_ctx > 2048) {
72
+ fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
73
+ "expect poor results\n", __func__, params.n_ctx);
74
+ }
75
+
76
+ if (params.seed <= 0) {
77
+ params.seed = time(NULL);
78
+ }
79
+
80
+ fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
81
+
82
+ std::mt19937 rng(params.seed);
83
+ if (params.random_prompt) {
84
+ params.prompt = gpt_random_prompt(rng);
85
+ }
86
+
87
+ // params.prompt = R"(// this function checks if the number n is prime
88
+ //bool is_prime(int n) {)";
89
+
90
+ llama_context * ctx;
91
+
92
+ // load the model
93
+ {
94
+ auto lparams = llama_context_default_params();
95
+
96
+ lparams.n_ctx = params.n_ctx;
97
+ lparams.n_parts = params.n_parts;
98
+ lparams.seed = params.seed;
99
+ lparams.f16_kv = params.memory_f16;
100
+ lparams.use_mlock = params.use_mlock;
101
+
102
+ ctx = llama_init_from_file(params.model.c_str(), lparams);
103
+
104
+ if (ctx == NULL) {
105
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
106
+ return 1;
107
+ }
108
+ }
109
+
110
+ // print system information
111
+ {
112
+ fprintf(stderr, "\n");
113
+ fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
114
+ params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
115
+ }
116
+
117
+ // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
118
+ // uncomment the "used_mem" line in llama.cpp to see the results
119
+ if (params.mem_test) {
120
+ {
121
+ const std::vector<llama_token> tmp(params.n_batch, 0);
122
+ llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
123
+ }
124
+
125
+ {
126
+ const std::vector<llama_token> tmp = { 0, };
127
+ llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
128
+ }
129
+
130
+ llama_print_timings(ctx);
131
+ llama_free(ctx);
132
+
133
+ return 0;
134
+ }
135
+
136
+ // Add a space in front of the first character to match OG llama tokenizer behavior
137
+ params.prompt.insert(0, 1, ' ');
138
+
139
+ // tokenize the prompt
140
+ auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
141
+
142
+ const int n_ctx = llama_n_ctx(ctx);
143
+
144
+ if ((int) embd_inp.size() > n_ctx - 4) {
145
+ fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
146
+ return 1;
147
+ }
148
+
149
+ // number of tokens to keep when resetting context
150
+ if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
151
+ params.n_keep = (int)embd_inp.size();
152
+ }
153
+
154
+ // prefix & suffix for instruct mode
155
+ const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
156
+ const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
157
+
158
+ // in instruct mode, we inject a prefix and a suffix to each input by the user
159
+ if (params.instruct) {
160
+ params.interactive_start = true;
161
+ params.antiprompt.push_back("### Instruction:\n\n");
162
+ }
163
+
164
+ // enable interactive mode if reverse prompt or interactive start is specified
165
+ if (params.antiprompt.size() != 0 || params.interactive_start) {
166
+ params.interactive = true;
167
+ }
168
+
169
+ // determine newline token
170
+ auto llama_token_newline = ::llama_tokenize(ctx, "\n", false);
171
+
172
+ if (params.verbose_prompt) {
173
+ fprintf(stderr, "\n");
174
+ fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
175
+ fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
176
+ for (int i = 0; i < (int) embd_inp.size(); i++) {
177
+ fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
178
+ }
179
+ if (params.n_keep > 0) {
180
+ fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
181
+ for (int i = 0; i < params.n_keep; i++) {
182
+ fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]));
183
+ }
184
+ fprintf(stderr, "'\n");
185
+ }
186
+ fprintf(stderr, "\n");
187
+ }
188
+
189
+ if (params.interactive) {
190
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
191
+ struct sigaction sigint_action;
192
+ sigint_action.sa_handler = sigint_handler;
193
+ sigemptyset (&sigint_action.sa_mask);
194
+ sigint_action.sa_flags = 0;
195
+ sigaction(SIGINT, &sigint_action, NULL);
196
+ #elif defined (_WIN32)
197
+ signal(SIGINT, sigint_handler);
198
+ #endif
199
+
200
+ fprintf(stderr, "%s: interactive mode on.\n", __func__);
201
+
202
+ if (params.antiprompt.size()) {
203
+ for (auto antiprompt : params.antiprompt) {
204
+ fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
205
+ }
206
+ }
207
+
208
+ if (!params.input_prefix.empty()) {
209
+ fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
210
+ }
211
+ }
212
+ fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
213
+ params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
214
+ fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
215
+ fprintf(stderr, "\n\n");
216
+
217
+ // TODO: replace with ring-buffer
218
+ std::vector<llama_token> last_n_tokens(n_ctx);
219
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
220
+
221
+ if (params.interactive) {
222
+ fprintf(stderr, "== Running in interactive mode. ==\n"
223
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
224
+ " - Press Ctrl+C to interject at any time.\n"
225
+ #endif
226
+ " - Press Return to return control to LLaMa.\n"
227
+ " - If you want to submit another line, end your input in '\\'.\n\n");
228
+ is_interacting = params.interactive_start;
229
+ }
230
+
231
+ bool is_antiprompt = false;
232
+ bool input_noecho = false;
233
+
234
+ int n_past = 0;
235
+ int n_remain = params.n_predict;
236
+ int n_consumed = 0;
237
+
238
+ // the first thing we will do is to output the prompt, so set color accordingly
239
+ set_console_color(con_st, CONSOLE_COLOR_PROMPT);
240
+
241
+ std::vector<llama_token> embd;
242
+
243
+ while (n_remain != 0 || params.interactive) {
244
+ // predict
245
+ if (embd.size() > 0) {
246
+ // infinite text generation via context swapping
247
+ // if we run out of context:
248
+ // - take the n_keep first tokens from the original prompt (via n_past)
249
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
250
+ if (n_past + (int) embd.size() > n_ctx) {
251
+ const int n_left = n_past - params.n_keep;
252
+
253
+ n_past = params.n_keep;
254
+
255
+ // insert n_left/2 tokens at the start of embd from last_n_tokens
256
+ embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
257
+
258
+ //printf("\n---\n");
259
+ //printf("resetting: '");
260
+ //for (int i = 0; i < (int) embd.size(); i++) {
261
+ // printf("%s", llama_token_to_str(ctx, embd[i]));
262
+ //}
263
+ //printf("'\n");
264
+ //printf("\n---\n");
265
+ }
266
+
267
+ if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
268
+ fprintf(stderr, "%s : failed to eval\n", __func__);
269
+ return 1;
270
+ }
271
+ }
272
+
273
+ n_past += embd.size();
274
+ embd.clear();
275
+
276
+ if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
277
+ // out of user input, sample next token
278
+ const int32_t top_k = params.top_k;
279
+ const float top_p = params.top_p;
280
+ const float temp = params.temp;
281
+ const float repeat_penalty = params.repeat_penalty;
282
+
283
+ llama_token id = 0;
284
+
285
+ {
286
+ auto logits = llama_get_logits(ctx);
287
+
288
+ if (params.ignore_eos) {
289
+ logits[llama_token_eos()] = 0;
290
+ }
291
+
292
+ id = llama_sample_top_p_top_k(ctx,
293
+ last_n_tokens.data() + n_ctx - params.repeat_last_n,
294
+ params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
295
+
296
+ last_n_tokens.erase(last_n_tokens.begin());
297
+ last_n_tokens.push_back(id);
298
+ }
299
+
300
+ // replace end of text token with newline token when in interactive mode
301
+ if (id == llama_token_eos() && params.interactive && !params.instruct) {
302
+ id = llama_token_newline.front();
303
+ if (params.antiprompt.size() != 0) {
304
+ // tokenize and inject first reverse prompt
305
+ const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
306
+ embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
307
+ }
308
+ }
309
+
310
+ // add it to the context
311
+ embd.push_back(id);
312
+
313
+ // echo this to console
314
+ input_noecho = false;
315
+
316
+ // decrement remaining sampling budget
317
+ --n_remain;
318
+ } else {
319
+ // some user input remains from prompt or interaction, forward it to processing
320
+ while ((int) embd_inp.size() > n_consumed) {
321
+ embd.push_back(embd_inp[n_consumed]);
322
+ last_n_tokens.erase(last_n_tokens.begin());
323
+ last_n_tokens.push_back(embd_inp[n_consumed]);
324
+ ++n_consumed;
325
+ if ((int) embd.size() >= params.n_batch) {
326
+ break;
327
+ }
328
+ }
329
+ }
330
+
331
+ // display text
332
+ if (!input_noecho) {
333
+ for (auto id : embd) {
334
+ printf("%s", llama_token_to_str(ctx, id));
335
+ }
336
+ fflush(stdout);
337
+ }
338
+ // reset color to default if we there is no pending user input
339
+ if (!input_noecho && (int)embd_inp.size() == n_consumed) {
340
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
341
+ }
342
+
343
+ // in interactive mode, and not currently processing queued inputs;
344
+ // check if we should prompt the user for more
345
+ if (params.interactive && (int) embd_inp.size() <= n_consumed) {
346
+
347
+ // check for reverse prompt
348
+ if (params.antiprompt.size()) {
349
+ std::string last_output;
350
+ for (auto id : last_n_tokens) {
351
+ last_output += llama_token_to_str(ctx, id);
352
+ }
353
+
354
+ is_antiprompt = false;
355
+ // Check if each of the reverse prompts appears at the end of the output.
356
+ for (std::string & antiprompt : params.antiprompt) {
357
+ if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
358
+ is_interacting = true;
359
+ is_antiprompt = true;
360
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
361
+ fflush(stdout);
362
+ break;
363
+ }
364
+ }
365
+ }
366
+
367
+ if (n_past > 0 && is_interacting) {
368
+ // potentially set color to indicate we are taking user input
369
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
370
+
371
+ #if defined (_WIN32)
372
+ // Windows: must reactivate sigint handler after each signal
373
+ signal(SIGINT, sigint_handler);
374
+ #endif
375
+
376
+ if (params.instruct) {
377
+ printf("\n> ");
378
+ }
379
+
380
+ std::string buffer;
381
+ if (!params.input_prefix.empty()) {
382
+ buffer += params.input_prefix;
383
+ printf("%s", buffer.c_str());
384
+ }
385
+
386
+ std::string line;
387
+ bool another_line = true;
388
+ do {
389
+ if (!std::getline(std::cin, line)) {
390
+ // input stream is bad or EOF received
391
+ return 0;
392
+ }
393
+ if (line.empty() || line.back() != '\\') {
394
+ another_line = false;
395
+ } else {
396
+ line.pop_back(); // Remove the continue character
397
+ }
398
+ buffer += line + '\n'; // Append the line to the result
399
+ } while (another_line);
400
+
401
+ // done taking input, reset color
402
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
403
+
404
+ // Add tokens to embd only if the input buffer is non-empty
405
+ // Entering a empty line lets the user pass control back
406
+ if (buffer.length() > 1) {
407
+
408
+ // instruct mode: insert instruction prefix
409
+ if (params.instruct && !is_antiprompt) {
410
+ n_consumed = embd_inp.size();
411
+ embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
412
+ }
413
+
414
+ auto line_inp = ::llama_tokenize(ctx, buffer, false);
415
+ embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
416
+
417
+ // instruct mode: insert response suffix
418
+ if (params.instruct) {
419
+ embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
420
+ }
421
+
422
+ n_remain -= line_inp.size();
423
+ }
424
+
425
+ input_noecho = true; // do not echo this again
426
+ }
427
+
428
+ if (n_past > 0) {
429
+ is_interacting = false;
430
+ }
431
+ }
432
+
433
+ // end of text token
434
+ if (embd.back() == llama_token_eos()) {
435
+ if (params.instruct) {
436
+ is_interacting = true;
437
+ } else {
438
+ fprintf(stderr, " [end of text]\n");
439
+ break;
440
+ }
441
+ }
442
+
443
+ // In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
444
+ if (params.interactive && n_remain <= 0 && params.n_predict != -1) {
445
+ n_remain = params.n_predict;
446
+ is_interacting = true;
447
+ }
448
+ }
449
+
450
+ #if defined (_WIN32)
451
+ signal(SIGINT, SIG_DFL);
452
+ #endif
453
+
454
+ llama_print_timings(ctx);
455
+ llama_free(ctx);
456
+
457
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
458
+
459
+ return 0;
460
+ }