@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -1,50 +1,112 @@
1
- #include "arg.h"
2
- #include "common.h"
3
- #include "log.h"
4
1
  #include "llama.h"
5
-
2
+ #include <cstdio>
3
+ #include <cstring>
4
+ #include <string>
6
5
  #include <vector>
7
6
 
8
7
  static void print_usage(int, char ** argv) {
9
- LOG("\nexample usage:\n");
10
- LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32\n", argv[0]);
11
- LOG("\n");
8
+ printf("\nexample usage:\n");
9
+ printf("\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [prompt]\n", argv[0]);
10
+ printf("\n");
12
11
  }
13
12
 
14
13
  int main(int argc, char ** argv) {
15
- gpt_params params;
16
-
17
- params.prompt = "Hello my name is";
18
- params.n_predict = 32;
19
-
20
- if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
21
- return 1;
14
+ // path to the model gguf file
15
+ std::string model_path;
16
+ // prompt to generate text from
17
+ std::string prompt = "Hello my name is";
18
+ // number of layers to offload to the GPU
19
+ int ngl = 99;
20
+ // number of tokens to predict
21
+ int n_predict = 32;
22
+
23
+ // parse command line arguments
24
+
25
+ {
26
+ int i = 1;
27
+ for (; i < argc; i++) {
28
+ if (strcmp(argv[i], "-m") == 0) {
29
+ if (i + 1 < argc) {
30
+ model_path = argv[++i];
31
+ } else {
32
+ print_usage(argc, argv);
33
+ return 1;
34
+ }
35
+ } else if (strcmp(argv[i], "-n") == 0) {
36
+ if (i + 1 < argc) {
37
+ try {
38
+ n_predict = std::stoi(argv[++i]);
39
+ } catch (...) {
40
+ print_usage(argc, argv);
41
+ return 1;
42
+ }
43
+ } else {
44
+ print_usage(argc, argv);
45
+ return 1;
46
+ }
47
+ } else if (strcmp(argv[i], "-ngl") == 0) {
48
+ if (i + 1 < argc) {
49
+ try {
50
+ ngl = std::stoi(argv[++i]);
51
+ } catch (...) {
52
+ print_usage(argc, argv);
53
+ return 1;
54
+ }
55
+ } else {
56
+ print_usage(argc, argv);
57
+ return 1;
58
+ }
59
+ } else {
60
+ // prompt starts here
61
+ break;
62
+ }
63
+ }
64
+ if (model_path.empty()) {
65
+ print_usage(argc, argv);
66
+ return 1;
67
+ }
68
+ if (i < argc) {
69
+ prompt = argv[i++];
70
+ for (; i < argc; i++) {
71
+ prompt += " ";
72
+ prompt += argv[i];
73
+ }
74
+ }
22
75
  }
23
76
 
24
- gpt_init();
25
-
26
- // total length of the sequence including the prompt
27
- const int n_predict = params.n_predict;
28
-
29
- // init LLM
30
-
31
- llama_backend_init();
32
- llama_numa_init(params.numa);
33
-
34
77
  // initialize the model
35
78
 
36
- llama_model_params model_params = llama_model_params_from_gpt_params(params);
79
+ llama_model_params model_params = llama_model_default_params();
80
+ model_params.n_gpu_layers = ngl;
37
81
 
38
- llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
82
+ llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);
39
83
 
40
84
  if (model == NULL) {
41
85
  fprintf(stderr , "%s: error: unable to load model\n" , __func__);
42
86
  return 1;
43
87
  }
44
88
 
89
+ // tokenize the prompt
90
+
91
+ // find the number of tokens in the prompt
92
+ const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
93
+
94
+ // allocate space for the tokens and tokenize the prompt
95
+ std::vector<llama_token> prompt_tokens(n_prompt);
96
+ if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
97
+ fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
98
+ return 1;
99
+ }
100
+
45
101
  // initialize the context
46
102
 
47
- llama_context_params ctx_params = llama_context_params_from_gpt_params(params);
103
+ llama_context_params ctx_params = llama_context_default_params();
104
+ // n_ctx is the context size
105
+ ctx_params.n_ctx = n_prompt + n_predict - 1;
106
+ // n_batch is the maximum number of tokens that can be processed in a single call to llama_decode
107
+ ctx_params.n_batch = n_prompt;
108
+ // enable performance counters
109
+ ctx_params.no_perf = false;
48
110
 
49
111
  llama_context * ctx = llama_new_context_with_model(model, ctx_params);
50
112
 
@@ -53,117 +115,87 @@ int main(int argc, char ** argv) {
53
115
  return 1;
54
116
  }
55
117
 
56
- auto sparams = llama_sampler_chain_default_params();
118
+ // initialize the sampler
57
119
 
120
+ auto sparams = llama_sampler_chain_default_params();
58
121
  sparams.no_perf = false;
59
-
60
122
  llama_sampler * smpl = llama_sampler_chain_init(sparams);
61
123
 
62
124
  llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
63
125
 
64
- // tokenize the prompt
65
-
66
- std::vector<llama_token> tokens_list;
67
- tokens_list = ::llama_tokenize(ctx, params.prompt, true);
68
-
69
- const int n_ctx = llama_n_ctx(ctx);
70
- const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size());
71
-
72
- LOG("\n");
73
- LOG_INF("%s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, n_kv_req);
74
-
75
- // make sure the KV cache is big enough to hold all the prompt and generated tokens
76
- if (n_kv_req > n_ctx) {
77
- LOG_ERR("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
78
- LOG_ERR("%s: either reduce n_predict or increase n_ctx\n", __func__);
79
- return 1;
80
- }
81
-
82
126
  // print the prompt token-by-token
83
127
 
84
- LOG("\n");
85
-
86
- for (auto id : tokens_list) {
87
- LOG("%s", llama_token_to_piece(ctx, id).c_str());
88
- }
89
-
90
- // create a llama_batch with size 512
91
- // we use this object to submit token data for decoding
92
-
93
- llama_batch batch = llama_batch_init(512, 0, 1);
94
-
95
- // evaluate the initial prompt
96
- for (size_t i = 0; i < tokens_list.size(); i++) {
97
- llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
128
+ for (auto id : prompt_tokens) {
129
+ char buf[128];
130
+ int n = llama_token_to_piece(model, id, buf, sizeof(buf), 0, true);
131
+ if (n < 0) {
132
+ fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
133
+ return 1;
134
+ }
135
+ std::string s(buf, n);
136
+ printf("%s", s.c_str());
98
137
  }
99
138
 
100
- // llama_decode will output logits only for the last token of the prompt
101
- batch.logits[batch.n_tokens - 1] = true;
139
+ // prepare a batch for the prompt
102
140
 
103
- if (llama_decode(ctx, batch) != 0) {
104
- LOG("%s: llama_decode() failed\n", __func__);
105
- return 1;
106
- }
141
+ llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
107
142
 
108
143
  // main loop
109
144
 
110
- int n_cur = batch.n_tokens;
145
+ const auto t_main_start = ggml_time_us();
111
146
  int n_decode = 0;
147
+ llama_token new_token_id;
112
148
 
113
- const auto t_main_start = ggml_time_us();
149
+ for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
150
+ // evaluate the current batch with the transformer model
151
+ if (llama_decode(ctx, batch)) {
152
+ fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
153
+ return 1;
154
+ }
155
+
156
+ n_pos += batch.n_tokens;
114
157
 
115
- while (n_cur <= n_predict) {
116
158
  // sample the next token
117
159
  {
118
- const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1);
160
+ new_token_id = llama_sampler_sample(smpl, ctx, -1);
119
161
 
120
162
  // is it an end of generation?
121
- if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
122
- LOG("\n");
123
-
163
+ if (llama_token_is_eog(model, new_token_id)) {
124
164
  break;
125
165
  }
126
166
 
127
- LOG("%s", llama_token_to_piece(ctx, new_token_id).c_str());
167
+ char buf[128];
168
+ int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
169
+ if (n < 0) {
170
+ fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
171
+ return 1;
172
+ }
173
+ std::string s(buf, n);
174
+ printf("%s", s.c_str());
128
175
  fflush(stdout);
129
176
 
130
- // prepare the next batch
131
- llama_batch_clear(batch);
132
-
133
- // push this new token for next evaluation
134
- llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
177
+ // prepare the next batch with the sampled token
178
+ batch = llama_batch_get_one(&new_token_id, 1);
135
179
 
136
180
  n_decode += 1;
137
181
  }
138
-
139
- n_cur += 1;
140
-
141
- // evaluate the current batch with the transformer model
142
- if (llama_decode(ctx, batch)) {
143
- LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
144
- return 1;
145
- }
146
182
  }
147
183
 
148
- LOG("\n");
184
+ printf("\n");
149
185
 
150
186
  const auto t_main_end = ggml_time_us();
151
187
 
152
- LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
188
+ fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
153
189
  __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
154
190
 
155
- LOG("\n");
191
+ fprintf(stderr, "\n");
156
192
  llama_perf_sampler_print(smpl);
157
193
  llama_perf_context_print(ctx);
194
+ fprintf(stderr, "\n");
158
195
 
159
- LOG("\n");
160
-
161
- llama_batch_free(batch);
162
196
  llama_sampler_free(smpl);
163
197
  llama_free(ctx);
164
198
  llama_free_model(model);
165
199
 
166
- llama_backend_free();
167
-
168
200
  return 0;
169
201
  }
@@ -0,0 +1,5 @@
1
+ set(TARGET llama-simple-chat)
2
+ add_executable(${TARGET} simple-chat.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
@@ -0,0 +1,197 @@
1
+ #include "llama.h"
2
+ #include <cstdio>
3
+ #include <cstring>
4
+ #include <iostream>
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ static void print_usage(int, char ** argv) {
9
+ printf("\nexample usage:\n");
10
+ printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
11
+ printf("\n");
12
+ }
13
+
14
+ int main(int argc, char ** argv) {
15
+ std::string model_path;
16
+ int ngl = 99;
17
+ int n_ctx = 2048;
18
+
19
+ // parse command line arguments
20
+ for (int i = 1; i < argc; i++) {
21
+ try {
22
+ if (strcmp(argv[i], "-m") == 0) {
23
+ if (i + 1 < argc) {
24
+ model_path = argv[++i];
25
+ } else {
26
+ print_usage(argc, argv);
27
+ return 1;
28
+ }
29
+ } else if (strcmp(argv[i], "-c") == 0) {
30
+ if (i + 1 < argc) {
31
+ n_ctx = std::stoi(argv[++i]);
32
+ } else {
33
+ print_usage(argc, argv);
34
+ return 1;
35
+ }
36
+ } else if (strcmp(argv[i], "-ngl") == 0) {
37
+ if (i + 1 < argc) {
38
+ ngl = std::stoi(argv[++i]);
39
+ } else {
40
+ print_usage(argc, argv);
41
+ return 1;
42
+ }
43
+ } else {
44
+ print_usage(argc, argv);
45
+ return 1;
46
+ }
47
+ } catch (std::exception & e) {
48
+ fprintf(stderr, "error: %s\n", e.what());
49
+ print_usage(argc, argv);
50
+ return 1;
51
+ }
52
+ }
53
+ if (model_path.empty()) {
54
+ print_usage(argc, argv);
55
+ return 1;
56
+ }
57
+
58
+ // only print errors
59
+ llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) {
60
+ if (level >= GGML_LOG_LEVEL_ERROR) {
61
+ fprintf(stderr, "%s", text);
62
+ }
63
+ }, nullptr);
64
+
65
+ // initialize the model
66
+ llama_model_params model_params = llama_model_default_params();
67
+ model_params.n_gpu_layers = ngl;
68
+
69
+ llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);
70
+ if (!model) {
71
+ fprintf(stderr , "%s: error: unable to load model\n" , __func__);
72
+ return 1;
73
+ }
74
+
75
+ // initialize the context
76
+ llama_context_params ctx_params = llama_context_default_params();
77
+ ctx_params.n_ctx = n_ctx;
78
+ ctx_params.n_batch = n_ctx;
79
+
80
+ llama_context * ctx = llama_new_context_with_model(model, ctx_params);
81
+ if (!ctx) {
82
+ fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
83
+ return 1;
84
+ }
85
+
86
+ // initialize the sampler
87
+ llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
88
+ llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1));
89
+ llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
90
+ llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
91
+
92
+ // helper function to evaluate a prompt and generate a response
93
+ auto generate = [&](const std::string & prompt) {
94
+ std::string response;
95
+
96
+ // tokenize the prompt
97
+ const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
98
+ std::vector<llama_token> prompt_tokens(n_prompt_tokens);
99
+ if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
100
+ GGML_ABORT("failed to tokenize the prompt\n");
101
+ }
102
+
103
+ // prepare a batch for the prompt
104
+ llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
105
+ llama_token new_token_id;
106
+ while (true) {
107
+ // check if we have enough space in the context to evaluate this batch
108
+ int n_ctx = llama_n_ctx(ctx);
109
+ int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
110
+ if (n_ctx_used + batch.n_tokens > n_ctx) {
111
+ printf("\033[0m\n");
112
+ fprintf(stderr, "context size exceeded\n");
113
+ exit(0);
114
+ }
115
+
116
+ if (llama_decode(ctx, batch)) {
117
+ GGML_ABORT("failed to decode\n");
118
+ }
119
+
120
+ // sample the next token
121
+ new_token_id = llama_sampler_sample(smpl, ctx, -1);
122
+
123
+ // is it an end of generation?
124
+ if (llama_token_is_eog(model, new_token_id)) {
125
+ break;
126
+ }
127
+
128
+ // convert the token to a string, print it and add it to the response
129
+ char buf[256];
130
+ int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
131
+ if (n < 0) {
132
+ GGML_ABORT("failed to convert token to piece\n");
133
+ }
134
+ std::string piece(buf, n);
135
+ printf("%s", piece.c_str());
136
+ fflush(stdout);
137
+ response += piece;
138
+
139
+ // prepare the next batch with the sampled token
140
+ batch = llama_batch_get_one(&new_token_id, 1);
141
+ }
142
+
143
+ return response;
144
+ };
145
+
146
+ std::vector<llama_chat_message> messages;
147
+ std::vector<char> formatted(llama_n_ctx(ctx));
148
+ int prev_len = 0;
149
+ while (true) {
150
+ // get user input
151
+ printf("\033[32m> \033[0m");
152
+ std::string user;
153
+ std::getline(std::cin, user);
154
+
155
+ if (user.empty()) {
156
+ break;
157
+ }
158
+
159
+ // add the user input to the message list and format it
160
+ messages.push_back({"user", strdup(user.c_str())});
161
+ int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
162
+ if (new_len > (int)formatted.size()) {
163
+ formatted.resize(new_len);
164
+ new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
165
+ }
166
+ if (new_len < 0) {
167
+ fprintf(stderr, "failed to apply the chat template\n");
168
+ return 1;
169
+ }
170
+
171
+ // remove previous messages to obtain the prompt to generate the response
172
+ std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
173
+
174
+ // generate a response
175
+ printf("\033[33m");
176
+ std::string response = generate(prompt);
177
+ printf("\n\033[0m");
178
+
179
+ // add the response to the messages
180
+ messages.push_back({"assistant", strdup(response.c_str())});
181
+ prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
182
+ if (prev_len < 0) {
183
+ fprintf(stderr, "failed to apply the chat template\n");
184
+ return 1;
185
+ }
186
+ }
187
+
188
+ // free resources
189
+ for (auto & msg : messages) {
190
+ free(const_cast<char *>(msg.content));
191
+ }
192
+ llama_sampler_free(smpl);
193
+ llama_free(ctx);
194
+ llama_free_model(model);
195
+
196
+ return 0;
197
+ }