llama-rb 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,460 @@
1
+ #include "common.h"
2
+ #include "llama.h"
3
+
4
+ #include <cassert>
5
+ #include <cinttypes>
6
+ #include <cmath>
7
+ #include <cstdio>
8
+ #include <cstring>
9
+ #include <fstream>
10
+ #include <iostream>
11
+ #include <string>
12
+ #include <vector>
13
+
14
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
15
+ #include <signal.h>
16
+ #include <unistd.h>
17
+ #elif defined (_WIN32)
18
+ #include <signal.h>
19
+ #endif
20
+
21
+ static console_state con_st;
22
+
23
+ static bool is_interacting = false;
24
+
25
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
26
+ void sigint_handler(int signo) {
27
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
28
+ printf("\n"); // this also force flush stdout.
29
+ if (signo == SIGINT) {
30
+ if (!is_interacting) {
31
+ is_interacting=true;
32
+ } else {
33
+ _exit(130);
34
+ }
35
+ }
36
+ }
37
+ #endif
38
+
39
+ int main(int argc, char ** argv) {
40
+ gpt_params params;
41
+ params.model = "models/llama-7B/ggml-model.bin";
42
+
43
+ if (gpt_params_parse(argc, argv, params) == false) {
44
+ return 1;
45
+ }
46
+
47
+ // save choice to use color for later
48
+ // (note for later: this is a slightly awkward choice)
49
+ con_st.use_color = params.use_color;
50
+
51
+ #if defined (_WIN32)
52
+ win32_console_init(params.use_color);
53
+ #endif
54
+
55
+ if (params.perplexity) {
56
+ printf("\n************\n");
57
+ printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
58
+ printf("************\n\n");
59
+
60
+ return 0;
61
+ }
62
+
63
+ if (params.embedding) {
64
+ printf("\n************\n");
65
+ printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
66
+ printf("************\n\n");
67
+
68
+ return 0;
69
+ }
70
+
71
+ if (params.n_ctx > 2048) {
72
+ fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
73
+ "expect poor results\n", __func__, params.n_ctx);
74
+ }
75
+
76
+ if (params.seed <= 0) {
77
+ params.seed = time(NULL);
78
+ }
79
+
80
+ fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
81
+
82
+ std::mt19937 rng(params.seed);
83
+ if (params.random_prompt) {
84
+ params.prompt = gpt_random_prompt(rng);
85
+ }
86
+
87
+ // params.prompt = R"(// this function checks if the number n is prime
88
+ //bool is_prime(int n) {)";
89
+
90
+ llama_context * ctx;
91
+
92
+ // load the model
93
+ {
94
+ auto lparams = llama_context_default_params();
95
+
96
+ lparams.n_ctx = params.n_ctx;
97
+ lparams.n_parts = params.n_parts;
98
+ lparams.seed = params.seed;
99
+ lparams.f16_kv = params.memory_f16;
100
+ lparams.use_mlock = params.use_mlock;
101
+
102
+ ctx = llama_init_from_file(params.model.c_str(), lparams);
103
+
104
+ if (ctx == NULL) {
105
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
106
+ return 1;
107
+ }
108
+ }
109
+
110
+ // print system information
111
+ {
112
+ fprintf(stderr, "\n");
113
+ fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
114
+ params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
115
+ }
116
+
117
+ // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
118
+ // uncomment the "used_mem" line in llama.cpp to see the results
119
+ if (params.mem_test) {
120
+ {
121
+ const std::vector<llama_token> tmp(params.n_batch, 0);
122
+ llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
123
+ }
124
+
125
+ {
126
+ const std::vector<llama_token> tmp = { 0, };
127
+ llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
128
+ }
129
+
130
+ llama_print_timings(ctx);
131
+ llama_free(ctx);
132
+
133
+ return 0;
134
+ }
135
+
136
+ // Add a space in front of the first character to match OG llama tokenizer behavior
137
+ params.prompt.insert(0, 1, ' ');
138
+
139
+ // tokenize the prompt
140
+ auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
141
+
142
+ const int n_ctx = llama_n_ctx(ctx);
143
+
144
+ if ((int) embd_inp.size() > n_ctx - 4) {
145
+ fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
146
+ return 1;
147
+ }
148
+
149
+ // number of tokens to keep when resetting context
150
+ if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
151
+ params.n_keep = (int)embd_inp.size();
152
+ }
153
+
154
+ // prefix & suffix for instruct mode
155
+ const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
156
+ const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
157
+
158
+ // in instruct mode, we inject a prefix and a suffix to each input by the user
159
+ if (params.instruct) {
160
+ params.interactive_start = true;
161
+ params.antiprompt.push_back("### Instruction:\n\n");
162
+ }
163
+
164
+ // enable interactive mode if reverse prompt or interactive start is specified
165
+ if (params.antiprompt.size() != 0 || params.interactive_start) {
166
+ params.interactive = true;
167
+ }
168
+
169
+ // determine newline token
170
+ auto llama_token_newline = ::llama_tokenize(ctx, "\n", false);
171
+
172
+ if (params.verbose_prompt) {
173
+ fprintf(stderr, "\n");
174
+ fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
175
+ fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
176
+ for (int i = 0; i < (int) embd_inp.size(); i++) {
177
+ fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
178
+ }
179
+ if (params.n_keep > 0) {
180
+ fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
181
+ for (int i = 0; i < params.n_keep; i++) {
182
+ fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]));
183
+ }
184
+ fprintf(stderr, "'\n");
185
+ }
186
+ fprintf(stderr, "\n");
187
+ }
188
+
189
+ if (params.interactive) {
190
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
191
+ struct sigaction sigint_action;
192
+ sigint_action.sa_handler = sigint_handler;
193
+ sigemptyset (&sigint_action.sa_mask);
194
+ sigint_action.sa_flags = 0;
195
+ sigaction(SIGINT, &sigint_action, NULL);
196
+ #elif defined (_WIN32)
197
+ signal(SIGINT, sigint_handler);
198
+ #endif
199
+
200
+ fprintf(stderr, "%s: interactive mode on.\n", __func__);
201
+
202
+ if (params.antiprompt.size()) {
203
+ for (auto antiprompt : params.antiprompt) {
204
+ fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
205
+ }
206
+ }
207
+
208
+ if (!params.input_prefix.empty()) {
209
+ fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
210
+ }
211
+ }
212
+ fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
213
+ params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
214
+ fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
215
+ fprintf(stderr, "\n\n");
216
+
217
+ // TODO: replace with ring-buffer
218
+ std::vector<llama_token> last_n_tokens(n_ctx);
219
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
220
+
221
+ if (params.interactive) {
222
+ fprintf(stderr, "== Running in interactive mode. ==\n"
223
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
224
+ " - Press Ctrl+C to interject at any time.\n"
225
+ #endif
226
+ " - Press Return to return control to LLaMa.\n"
227
+ " - If you want to submit another line, end your input in '\\'.\n\n");
228
+ is_interacting = params.interactive_start;
229
+ }
230
+
231
+ bool is_antiprompt = false;
232
+ bool input_noecho = false;
233
+
234
+ int n_past = 0;
235
+ int n_remain = params.n_predict;
236
+ int n_consumed = 0;
237
+
238
+ // the first thing we will do is to output the prompt, so set color accordingly
239
+ set_console_color(con_st, CONSOLE_COLOR_PROMPT);
240
+
241
+ std::vector<llama_token> embd;
242
+
243
+ while (n_remain != 0 || params.interactive) {
244
+ // predict
245
+ if (embd.size() > 0) {
246
+ // infinite text generation via context swapping
247
+ // if we run out of context:
248
+ // - take the n_keep first tokens from the original prompt (via n_past)
249
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
250
+ if (n_past + (int) embd.size() > n_ctx) {
251
+ const int n_left = n_past - params.n_keep;
252
+
253
+ n_past = params.n_keep;
254
+
255
+ // insert n_left/2 tokens at the start of embd from last_n_tokens
256
+ embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
257
+
258
+ //printf("\n---\n");
259
+ //printf("resetting: '");
260
+ //for (int i = 0; i < (int) embd.size(); i++) {
261
+ // printf("%s", llama_token_to_str(ctx, embd[i]));
262
+ //}
263
+ //printf("'\n");
264
+ //printf("\n---\n");
265
+ }
266
+
267
+ if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
268
+ fprintf(stderr, "%s : failed to eval\n", __func__);
269
+ return 1;
270
+ }
271
+ }
272
+
273
+ n_past += embd.size();
274
+ embd.clear();
275
+
276
+ if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
277
+ // out of user input, sample next token
278
+ const int32_t top_k = params.top_k;
279
+ const float top_p = params.top_p;
280
+ const float temp = params.temp;
281
+ const float repeat_penalty = params.repeat_penalty;
282
+
283
+ llama_token id = 0;
284
+
285
+ {
286
+ auto logits = llama_get_logits(ctx);
287
+
288
+ if (params.ignore_eos) {
289
+ logits[llama_token_eos()] = 0;
290
+ }
291
+
292
+ id = llama_sample_top_p_top_k(ctx,
293
+ last_n_tokens.data() + n_ctx - params.repeat_last_n,
294
+ params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
295
+
296
+ last_n_tokens.erase(last_n_tokens.begin());
297
+ last_n_tokens.push_back(id);
298
+ }
299
+
300
+ // replace end of text token with newline token when in interactive mode
301
+ if (id == llama_token_eos() && params.interactive && !params.instruct) {
302
+ id = llama_token_newline.front();
303
+ if (params.antiprompt.size() != 0) {
304
+ // tokenize and inject first reverse prompt
305
+ const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
306
+ embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
307
+ }
308
+ }
309
+
310
+ // add it to the context
311
+ embd.push_back(id);
312
+
313
+ // echo this to console
314
+ input_noecho = false;
315
+
316
+ // decrement remaining sampling budget
317
+ --n_remain;
318
+ } else {
319
+ // some user input remains from prompt or interaction, forward it to processing
320
+ while ((int) embd_inp.size() > n_consumed) {
321
+ embd.push_back(embd_inp[n_consumed]);
322
+ last_n_tokens.erase(last_n_tokens.begin());
323
+ last_n_tokens.push_back(embd_inp[n_consumed]);
324
+ ++n_consumed;
325
+ if ((int) embd.size() >= params.n_batch) {
326
+ break;
327
+ }
328
+ }
329
+ }
330
+
331
+ // display text
332
+ if (!input_noecho) {
333
+ for (auto id : embd) {
334
+ printf("%s", llama_token_to_str(ctx, id));
335
+ }
336
+ fflush(stdout);
337
+ }
338
+ // reset color to default if we there is no pending user input
339
+ if (!input_noecho && (int)embd_inp.size() == n_consumed) {
340
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
341
+ }
342
+
343
+ // in interactive mode, and not currently processing queued inputs;
344
+ // check if we should prompt the user for more
345
+ if (params.interactive && (int) embd_inp.size() <= n_consumed) {
346
+
347
+ // check for reverse prompt
348
+ if (params.antiprompt.size()) {
349
+ std::string last_output;
350
+ for (auto id : last_n_tokens) {
351
+ last_output += llama_token_to_str(ctx, id);
352
+ }
353
+
354
+ is_antiprompt = false;
355
+ // Check if each of the reverse prompts appears at the end of the output.
356
+ for (std::string & antiprompt : params.antiprompt) {
357
+ if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
358
+ is_interacting = true;
359
+ is_antiprompt = true;
360
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
361
+ fflush(stdout);
362
+ break;
363
+ }
364
+ }
365
+ }
366
+
367
+ if (n_past > 0 && is_interacting) {
368
+ // potentially set color to indicate we are taking user input
369
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
370
+
371
+ #if defined (_WIN32)
372
+ // Windows: must reactivate sigint handler after each signal
373
+ signal(SIGINT, sigint_handler);
374
+ #endif
375
+
376
+ if (params.instruct) {
377
+ printf("\n> ");
378
+ }
379
+
380
+ std::string buffer;
381
+ if (!params.input_prefix.empty()) {
382
+ buffer += params.input_prefix;
383
+ printf("%s", buffer.c_str());
384
+ }
385
+
386
+ std::string line;
387
+ bool another_line = true;
388
+ do {
389
+ if (!std::getline(std::cin, line)) {
390
+ // input stream is bad or EOF received
391
+ return 0;
392
+ }
393
+ if (line.empty() || line.back() != '\\') {
394
+ another_line = false;
395
+ } else {
396
+ line.pop_back(); // Remove the continue character
397
+ }
398
+ buffer += line + '\n'; // Append the line to the result
399
+ } while (another_line);
400
+
401
+ // done taking input, reset color
402
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
403
+
404
+ // Add tokens to embd only if the input buffer is non-empty
405
+ // Entering a empty line lets the user pass control back
406
+ if (buffer.length() > 1) {
407
+
408
+ // instruct mode: insert instruction prefix
409
+ if (params.instruct && !is_antiprompt) {
410
+ n_consumed = embd_inp.size();
411
+ embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
412
+ }
413
+
414
+ auto line_inp = ::llama_tokenize(ctx, buffer, false);
415
+ embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
416
+
417
+ // instruct mode: insert response suffix
418
+ if (params.instruct) {
419
+ embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
420
+ }
421
+
422
+ n_remain -= line_inp.size();
423
+ }
424
+
425
+ input_noecho = true; // do not echo this again
426
+ }
427
+
428
+ if (n_past > 0) {
429
+ is_interacting = false;
430
+ }
431
+ }
432
+
433
+ // end of text token
434
+ if (embd.back() == llama_token_eos()) {
435
+ if (params.instruct) {
436
+ is_interacting = true;
437
+ } else {
438
+ fprintf(stderr, " [end of text]\n");
439
+ break;
440
+ }
441
+ }
442
+
443
+ // In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
444
+ if (params.interactive && n_remain <= 0 && params.n_predict != -1) {
445
+ n_remain = params.n_predict;
446
+ is_interacting = true;
447
+ }
448
+ }
449
+
450
+ #if defined (_WIN32)
451
+ signal(SIGINT, SIG_DFL);
452
+ #endif
453
+
454
+ llama_print_timings(ctx);
455
+ llama_free(ctx);
456
+
457
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
458
+
459
+ return 0;
460
+ }