@fugood/llama.node 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +2 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +8 -9
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +43 -9
- package/src/llama.cpp/.github/workflows/docker.yml +3 -0
- package/src/llama.cpp/CMakeLists.txt +7 -4
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +0 -2
- package/src/llama.cpp/common/arg.cpp +642 -607
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +79 -281
- package/src/llama.cpp/common/common.h +130 -100
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +116 -108
- package/src/llama.cpp/common/sampling.h +20 -20
- package/src/llama.cpp/docs/build.md +37 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +14 -14
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
- package/src/llama.cpp/examples/infill/infill.cpp +40 -86
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/clip.cpp +1 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +37 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
- package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
- package/src/llama.cpp/examples/main/main.cpp +64 -109
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
- package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
- package/src/llama.cpp/examples/server/server.cpp +553 -691
- package/src/llama.cpp/examples/server/utils.hpp +312 -25
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +128 -96
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +53 -393
- package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
- package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
- package/src/llama.cpp/include/llama.h +67 -33
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +745 -105
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +49 -9
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +2636 -2406
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/tests/CMakeLists.txt +1 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
- package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +1 -0
- package/src/llama.cpp/tests/test-sampling.cpp +162 -137
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
- /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
|
@@ -33,8 +33,8 @@
|
|
|
33
33
|
|
|
34
34
|
static llama_context ** g_ctx;
|
|
35
35
|
static llama_model ** g_model;
|
|
36
|
-
static
|
|
37
|
-
static
|
|
36
|
+
static common_sampler ** g_smpl;
|
|
37
|
+
static common_params * g_params;
|
|
38
38
|
static std::vector<llama_token> * g_input_tokens;
|
|
39
39
|
static std::ostringstream * g_output_ss;
|
|
40
40
|
static std::vector<llama_token> * g_output_tokens;
|
|
@@ -62,49 +62,6 @@ static bool file_is_empty(const std::string & path) {
|
|
|
62
62
|
return f.tellg() == 0;
|
|
63
63
|
}
|
|
64
64
|
|
|
65
|
-
static void write_logfile(
|
|
66
|
-
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
|
67
|
-
const std::vector<llama_token> & input_tokens, const std::string & output,
|
|
68
|
-
const std::vector<llama_token> & output_tokens
|
|
69
|
-
) {
|
|
70
|
-
if (params.logdir.empty()) {
|
|
71
|
-
return;
|
|
72
|
-
}
|
|
73
|
-
|
|
74
|
-
const std::string timestamp = string_get_sortable_timestamp();
|
|
75
|
-
|
|
76
|
-
const bool success = fs_create_directory_with_parents(params.logdir);
|
|
77
|
-
if (!success) {
|
|
78
|
-
LOG_ERR("%s: failed to create logdir %s, cannot write logfile\n", __func__, params.logdir.c_str());
|
|
79
|
-
return;
|
|
80
|
-
}
|
|
81
|
-
|
|
82
|
-
const std::string logfile_path = params.logdir + timestamp + ".yml";
|
|
83
|
-
FILE * logfile = fopen(logfile_path.c_str(), "w");
|
|
84
|
-
|
|
85
|
-
if (logfile == NULL) {
|
|
86
|
-
LOG_ERR("%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
|
|
87
|
-
return;
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
fprintf(logfile, "binary: main\n");
|
|
91
|
-
char model_desc[128];
|
|
92
|
-
llama_model_desc(model, model_desc, sizeof(model_desc));
|
|
93
|
-
yaml_dump_non_result_info(logfile, params, ctx, timestamp, input_tokens, model_desc);
|
|
94
|
-
|
|
95
|
-
fprintf(logfile, "\n");
|
|
96
|
-
fprintf(logfile, "######################\n");
|
|
97
|
-
fprintf(logfile, "# Generation Results #\n");
|
|
98
|
-
fprintf(logfile, "######################\n");
|
|
99
|
-
fprintf(logfile, "\n");
|
|
100
|
-
|
|
101
|
-
yaml_dump_string_multiline(logfile, "output", output.c_str());
|
|
102
|
-
yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
|
|
103
|
-
|
|
104
|
-
llama_perf_dump_yaml(logfile, ctx);
|
|
105
|
-
fclose(logfile);
|
|
106
|
-
}
|
|
107
|
-
|
|
108
65
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
|
109
66
|
static void sigint_handler(int signo) {
|
|
110
67
|
if (signo == SIGINT) {
|
|
@@ -114,12 +71,11 @@ static void sigint_handler(int signo) {
|
|
|
114
71
|
} else {
|
|
115
72
|
console::cleanup();
|
|
116
73
|
LOG("\n");
|
|
117
|
-
|
|
118
|
-
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
|
74
|
+
common_perf_print(*g_ctx, *g_smpl);
|
|
119
75
|
|
|
120
76
|
// make sure all logs are flushed
|
|
121
77
|
LOG("Interrupted by user\n");
|
|
122
|
-
|
|
78
|
+
common_log_pause(common_log_main());
|
|
123
79
|
|
|
124
80
|
_exit(130);
|
|
125
81
|
}
|
|
@@ -127,22 +83,22 @@ static void sigint_handler(int signo) {
|
|
|
127
83
|
}
|
|
128
84
|
#endif
|
|
129
85
|
|
|
130
|
-
static std::string chat_add_and_format(struct llama_model * model, std::vector<
|
|
131
|
-
|
|
132
|
-
auto formatted =
|
|
86
|
+
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
|
|
87
|
+
common_chat_msg new_msg{role, content};
|
|
88
|
+
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
|
133
89
|
chat_msgs.push_back({role, content});
|
|
134
90
|
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
|
135
91
|
return formatted;
|
|
136
92
|
}
|
|
137
93
|
|
|
138
94
|
int main(int argc, char ** argv) {
|
|
139
|
-
|
|
95
|
+
common_params params;
|
|
140
96
|
g_params = ¶ms;
|
|
141
|
-
if (!
|
|
97
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
|
|
142
98
|
return 1;
|
|
143
99
|
}
|
|
144
100
|
|
|
145
|
-
|
|
101
|
+
common_init();
|
|
146
102
|
|
|
147
103
|
auto & sparams = params.sparams;
|
|
148
104
|
|
|
@@ -187,9 +143,9 @@ int main(int argc, char ** argv) {
|
|
|
187
143
|
|
|
188
144
|
llama_model * model = nullptr;
|
|
189
145
|
llama_context * ctx = nullptr;
|
|
190
|
-
|
|
146
|
+
common_sampler * smpl = nullptr;
|
|
191
147
|
|
|
192
|
-
std::vector<
|
|
148
|
+
std::vector<common_chat_msg> chat_msgs;
|
|
193
149
|
|
|
194
150
|
g_model = &model;
|
|
195
151
|
g_ctx = &ctx;
|
|
@@ -197,7 +153,7 @@ int main(int argc, char ** argv) {
|
|
|
197
153
|
|
|
198
154
|
// load the model and apply lora adapter, if any
|
|
199
155
|
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
|
|
200
|
-
|
|
156
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
201
157
|
|
|
202
158
|
model = llama_init.model;
|
|
203
159
|
ctx = llama_init.context;
|
|
@@ -246,7 +202,7 @@ int main(int argc, char ** argv) {
|
|
|
246
202
|
// print chat template example in conversation mode
|
|
247
203
|
if (params.conversation) {
|
|
248
204
|
if (params.enable_chat_template) {
|
|
249
|
-
LOG_INF("%s: chat template example:\n%s\n", __func__,
|
|
205
|
+
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
|
|
250
206
|
} else {
|
|
251
207
|
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
|
252
208
|
}
|
|
@@ -255,7 +211,7 @@ int main(int argc, char ** argv) {
|
|
|
255
211
|
// print system information
|
|
256
212
|
{
|
|
257
213
|
LOG_INF("\n");
|
|
258
|
-
LOG_INF("%s\n",
|
|
214
|
+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
|
259
215
|
LOG_INF("\n");
|
|
260
216
|
}
|
|
261
217
|
|
|
@@ -296,7 +252,7 @@ int main(int argc, char ** argv) {
|
|
|
296
252
|
: params.prompt;
|
|
297
253
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
|
298
254
|
LOG_DBG("tokenize the prompt\n");
|
|
299
|
-
embd_inp =
|
|
255
|
+
embd_inp = common_tokenize(ctx, prompt, true, true);
|
|
300
256
|
} else {
|
|
301
257
|
LOG_DBG("use session tokens\n");
|
|
302
258
|
embd_inp = session_tokens;
|
|
@@ -379,13 +335,13 @@ int main(int argc, char ** argv) {
|
|
|
379
335
|
LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
|
380
336
|
LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
|
381
337
|
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
|
382
|
-
LOG_INF("%6d -> '%s'\n", embd_inp[i],
|
|
338
|
+
LOG_INF("%6d -> '%s'\n", embd_inp[i], common_token_to_piece(ctx, embd_inp[i]).c_str());
|
|
383
339
|
}
|
|
384
340
|
|
|
385
341
|
if (params.n_keep > add_bos) {
|
|
386
342
|
LOG_INF("%s: static prompt based on n_keep: '", __func__);
|
|
387
343
|
for (int i = 0; i < params.n_keep; i++) {
|
|
388
|
-
LOG_CNT("%s",
|
|
344
|
+
LOG_CNT("%s", common_token_to_piece(ctx, embd_inp[i]).c_str());
|
|
389
345
|
}
|
|
390
346
|
LOG_CNT("'\n");
|
|
391
347
|
}
|
|
@@ -415,9 +371,9 @@ int main(int argc, char ** argv) {
|
|
|
415
371
|
for (const auto & antiprompt : params.antiprompt) {
|
|
416
372
|
LOG_INF("Reverse prompt: '%s'\n", antiprompt.c_str());
|
|
417
373
|
if (params.verbose_prompt) {
|
|
418
|
-
auto tmp =
|
|
374
|
+
auto tmp = common_tokenize(ctx, antiprompt, false, true);
|
|
419
375
|
for (int i = 0; i < (int) tmp.size(); i++) {
|
|
420
|
-
LOG_INF("%6d -> '%s'\n", tmp[i],
|
|
376
|
+
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
|
|
421
377
|
}
|
|
422
378
|
}
|
|
423
379
|
}
|
|
@@ -430,9 +386,9 @@ int main(int argc, char ** argv) {
|
|
|
430
386
|
if (!params.input_prefix.empty()) {
|
|
431
387
|
LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str());
|
|
432
388
|
if (params.verbose_prompt) {
|
|
433
|
-
auto tmp =
|
|
389
|
+
auto tmp = common_tokenize(ctx, params.input_prefix, true, true);
|
|
434
390
|
for (int i = 0; i < (int) tmp.size(); i++) {
|
|
435
|
-
LOG_INF("%6d -> '%s'\n", tmp[i],
|
|
391
|
+
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
|
|
436
392
|
}
|
|
437
393
|
}
|
|
438
394
|
}
|
|
@@ -440,23 +396,23 @@ int main(int argc, char ** argv) {
|
|
|
440
396
|
if (!params.input_suffix.empty()) {
|
|
441
397
|
LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str());
|
|
442
398
|
if (params.verbose_prompt) {
|
|
443
|
-
auto tmp =
|
|
399
|
+
auto tmp = common_tokenize(ctx, params.input_suffix, false, true);
|
|
444
400
|
for (int i = 0; i < (int) tmp.size(); i++) {
|
|
445
|
-
LOG_INF("%6d -> '%s'\n", tmp[i],
|
|
401
|
+
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
|
|
446
402
|
}
|
|
447
403
|
}
|
|
448
404
|
}
|
|
449
405
|
}
|
|
450
406
|
|
|
451
|
-
smpl =
|
|
407
|
+
smpl = common_sampler_init(model, sparams);
|
|
452
408
|
if (!smpl) {
|
|
453
409
|
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
|
454
410
|
return 1;
|
|
455
411
|
}
|
|
456
412
|
|
|
457
|
-
LOG_INF("sampler seed: %u\n",
|
|
413
|
+
LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl));
|
|
458
414
|
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
|
|
459
|
-
LOG_INF("sampler chain: %s\n",
|
|
415
|
+
LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str());
|
|
460
416
|
|
|
461
417
|
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
|
462
418
|
|
|
@@ -521,14 +477,14 @@ int main(int argc, char ** argv) {
|
|
|
521
477
|
|
|
522
478
|
antiprompt_ids.reserve(params.antiprompt.size());
|
|
523
479
|
for (const std::string & antiprompt : params.antiprompt) {
|
|
524
|
-
antiprompt_ids.emplace_back(::
|
|
480
|
+
antiprompt_ids.emplace_back(::common_tokenize(ctx, antiprompt, false, true));
|
|
525
481
|
}
|
|
526
482
|
|
|
527
483
|
if (llama_model_has_encoder(model)) {
|
|
528
484
|
int enc_input_size = embd_inp.size();
|
|
529
485
|
llama_token * enc_input_buf = embd_inp.data();
|
|
530
486
|
|
|
531
|
-
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size
|
|
487
|
+
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
|
|
532
488
|
LOG_ERR("%s : failed to eval\n", __func__);
|
|
533
489
|
return 1;
|
|
534
490
|
}
|
|
@@ -569,30 +525,30 @@ int main(int argc, char ** argv) {
|
|
|
569
525
|
if (!params.ctx_shift){
|
|
570
526
|
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
|
571
527
|
break;
|
|
572
|
-
}
|
|
573
|
-
if (params.n_predict == -2) {
|
|
574
|
-
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
|
575
|
-
break;
|
|
576
|
-
}
|
|
528
|
+
}
|
|
577
529
|
|
|
578
|
-
|
|
579
|
-
|
|
530
|
+
if (params.n_predict == -2) {
|
|
531
|
+
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
|
532
|
+
break;
|
|
533
|
+
}
|
|
580
534
|
|
|
581
|
-
|
|
582
|
-
|
|
535
|
+
const int n_left = n_past - params.n_keep;
|
|
536
|
+
const int n_discard = n_left/2;
|
|
583
537
|
|
|
584
|
-
|
|
585
|
-
|
|
538
|
+
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
|
539
|
+
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
|
586
540
|
|
|
587
|
-
|
|
541
|
+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
|
542
|
+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
|
588
543
|
|
|
589
|
-
|
|
544
|
+
n_past -= n_discard;
|
|
590
545
|
|
|
591
|
-
|
|
546
|
+
LOG_DBG("after swap: n_past = %d\n", n_past);
|
|
592
547
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
548
|
+
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
|
|
549
|
+
|
|
550
|
+
LOG_DBG("clear session path\n");
|
|
551
|
+
path_session.clear();
|
|
596
552
|
}
|
|
597
553
|
} else {
|
|
598
554
|
// context extension via Self-Extend
|
|
@@ -648,7 +604,7 @@ int main(int argc, char ** argv) {
|
|
|
648
604
|
|
|
649
605
|
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
|
650
606
|
|
|
651
|
-
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval
|
|
607
|
+
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
|
652
608
|
LOG_ERR("%s : failed to eval\n", __func__);
|
|
653
609
|
return 1;
|
|
654
610
|
}
|
|
@@ -679,9 +635,9 @@ int main(int argc, char ** argv) {
|
|
|
679
635
|
LOG_DBG("saved session to %s\n", path_session.c_str());
|
|
680
636
|
}
|
|
681
637
|
|
|
682
|
-
const llama_token id =
|
|
638
|
+
const llama_token id = common_sampler_sample(smpl, ctx, -1);
|
|
683
639
|
|
|
684
|
-
|
|
640
|
+
common_sampler_accept(smpl, id, /* accept_grammar= */ true);
|
|
685
641
|
|
|
686
642
|
// LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
|
|
687
643
|
|
|
@@ -702,7 +658,7 @@ int main(int argc, char ** argv) {
|
|
|
702
658
|
|
|
703
659
|
// push the prompt in the sampling context in order to apply repetition penalties later
|
|
704
660
|
// for the prompt, we don't apply grammar rules
|
|
705
|
-
|
|
661
|
+
common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
|
|
706
662
|
|
|
707
663
|
++n_consumed;
|
|
708
664
|
if ((int) embd.size() >= params.n_batch) {
|
|
@@ -714,7 +670,7 @@ int main(int argc, char ** argv) {
|
|
|
714
670
|
// display text
|
|
715
671
|
if (input_echo && display) {
|
|
716
672
|
for (auto id : embd) {
|
|
717
|
-
const std::string token_str =
|
|
673
|
+
const std::string token_str = common_token_to_piece(ctx, id, params.special);
|
|
718
674
|
|
|
719
675
|
// Console/Stream Output
|
|
720
676
|
LOG("%s", token_str.c_str());
|
|
@@ -743,7 +699,7 @@ int main(int argc, char ** argv) {
|
|
|
743
699
|
// check for reverse prompt in the last n_prev tokens
|
|
744
700
|
if (!params.antiprompt.empty()) {
|
|
745
701
|
const int n_prev = 32;
|
|
746
|
-
const std::string last_output =
|
|
702
|
+
const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev);
|
|
747
703
|
|
|
748
704
|
is_antiprompt = false;
|
|
749
705
|
// Check if each of the reverse prompts appears at the end of the output.
|
|
@@ -765,7 +721,7 @@ int main(int argc, char ** argv) {
|
|
|
765
721
|
}
|
|
766
722
|
|
|
767
723
|
// check for reverse prompt using special tokens
|
|
768
|
-
llama_token last_token =
|
|
724
|
+
llama_token last_token = common_sampler_last(smpl);
|
|
769
725
|
for (std::vector<llama_token> ids : antiprompt_ids) {
|
|
770
726
|
if (ids.size() == 1 && last_token == ids[0]) {
|
|
771
727
|
if (params.interactive) {
|
|
@@ -782,13 +738,13 @@ int main(int argc, char ** argv) {
|
|
|
782
738
|
}
|
|
783
739
|
|
|
784
740
|
// deal with end of generation tokens in interactive mode
|
|
785
|
-
if (llama_token_is_eog(model,
|
|
741
|
+
if (llama_token_is_eog(model, common_sampler_last(smpl))) {
|
|
786
742
|
LOG_DBG("found an EOG token\n");
|
|
787
743
|
|
|
788
744
|
if (params.interactive) {
|
|
789
745
|
if (!params.antiprompt.empty()) {
|
|
790
746
|
// tokenize and inject first reverse prompt
|
|
791
|
-
const auto first_antiprompt =
|
|
747
|
+
const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true);
|
|
792
748
|
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
|
793
749
|
is_antiprompt = true;
|
|
794
750
|
}
|
|
@@ -803,8 +759,8 @@ int main(int argc, char ** argv) {
|
|
|
803
759
|
|
|
804
760
|
// if current token is not EOG, we add it to current assistant message
|
|
805
761
|
if (params.conversation) {
|
|
806
|
-
const auto id =
|
|
807
|
-
assistant_ss <<
|
|
762
|
+
const auto id = common_sampler_last(smpl);
|
|
763
|
+
assistant_ss << common_token_to_piece(ctx, id, false);
|
|
808
764
|
}
|
|
809
765
|
|
|
810
766
|
if (n_past > 0 && is_interacting) {
|
|
@@ -862,9 +818,9 @@ int main(int argc, char ** argv) {
|
|
|
862
818
|
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
|
863
819
|
: std::move(buffer);
|
|
864
820
|
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
|
865
|
-
const auto line_pfx =
|
|
866
|
-
const auto line_inp =
|
|
867
|
-
const auto line_sfx =
|
|
821
|
+
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
|
822
|
+
const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat);
|
|
823
|
+
const auto line_sfx = common_tokenize(ctx, params.input_suffix, false, true);
|
|
868
824
|
|
|
869
825
|
LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str());
|
|
870
826
|
|
|
@@ -882,7 +838,7 @@ int main(int argc, char ** argv) {
|
|
|
882
838
|
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
|
883
839
|
const llama_token token = embd_inp[i];
|
|
884
840
|
output_tokens.push_back(token);
|
|
885
|
-
output_ss <<
|
|
841
|
+
output_ss << common_token_to_piece(ctx, token);
|
|
886
842
|
}
|
|
887
843
|
|
|
888
844
|
// reset assistant message
|
|
@@ -899,7 +855,7 @@ int main(int argc, char ** argv) {
|
|
|
899
855
|
|
|
900
856
|
if (n_past > 0) {
|
|
901
857
|
if (is_interacting) {
|
|
902
|
-
|
|
858
|
+
common_sampler_reset(smpl);
|
|
903
859
|
}
|
|
904
860
|
is_interacting = false;
|
|
905
861
|
}
|
|
@@ -925,10 +881,9 @@ int main(int argc, char ** argv) {
|
|
|
925
881
|
}
|
|
926
882
|
|
|
927
883
|
LOG("\n\n");
|
|
928
|
-
|
|
929
|
-
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
|
884
|
+
common_perf_print(ctx, smpl);
|
|
930
885
|
|
|
931
|
-
|
|
886
|
+
common_sampler_free(smpl);
|
|
932
887
|
|
|
933
888
|
llama_free(ctx);
|
|
934
889
|
llama_free_model(model);
|
|
@@ -54,7 +54,7 @@ static std::vector<std::string> k_prompts = {
|
|
|
54
54
|
struct client {
|
|
55
55
|
~client() {
|
|
56
56
|
if (smpl) {
|
|
57
|
-
|
|
57
|
+
common_sampler_free(smpl);
|
|
58
58
|
}
|
|
59
59
|
}
|
|
60
60
|
|
|
@@ -75,7 +75,7 @@ struct client {
|
|
|
75
75
|
std::string prompt;
|
|
76
76
|
std::string response;
|
|
77
77
|
|
|
78
|
-
struct
|
|
78
|
+
struct common_sampler * smpl = nullptr;
|
|
79
79
|
};
|
|
80
80
|
|
|
81
81
|
static void print_date_time() {
|
|
@@ -103,13 +103,13 @@ static std::vector<std::string> split_string(const std::string& input, char deli
|
|
|
103
103
|
int main(int argc, char ** argv) {
|
|
104
104
|
srand(1234);
|
|
105
105
|
|
|
106
|
-
|
|
106
|
+
common_params params;
|
|
107
107
|
|
|
108
|
-
if (!
|
|
108
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
|
109
109
|
return 1;
|
|
110
110
|
}
|
|
111
111
|
|
|
112
|
-
|
|
112
|
+
common_init();
|
|
113
113
|
|
|
114
114
|
// number of simultaneous "clients" to simulate
|
|
115
115
|
const int32_t n_clients = params.n_parallel;
|
|
@@ -130,7 +130,7 @@ int main(int argc, char ** argv) {
|
|
|
130
130
|
llama_numa_init(params.numa);
|
|
131
131
|
|
|
132
132
|
// load the target model
|
|
133
|
-
|
|
133
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
134
134
|
|
|
135
135
|
llama_model * model = llama_init.model;
|
|
136
136
|
llama_context * ctx = llama_init.context;
|
|
@@ -160,11 +160,11 @@ int main(int argc, char ** argv) {
|
|
|
160
160
|
for (size_t i = 0; i < clients.size(); ++i) {
|
|
161
161
|
auto & client = clients[i];
|
|
162
162
|
client.id = i;
|
|
163
|
-
client.smpl =
|
|
163
|
+
client.smpl = common_sampler_init(model, params.sparams);
|
|
164
164
|
}
|
|
165
165
|
|
|
166
166
|
std::vector<llama_token> tokens_system;
|
|
167
|
-
tokens_system =
|
|
167
|
+
tokens_system = common_tokenize(ctx, k_system, true);
|
|
168
168
|
const int32_t n_tokens_system = tokens_system.size();
|
|
169
169
|
|
|
170
170
|
llama_seq_id g_seq_id = 0;
|
|
@@ -189,7 +189,7 @@ int main(int argc, char ** argv) {
|
|
|
189
189
|
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
|
|
190
190
|
|
|
191
191
|
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
|
192
|
-
|
|
192
|
+
common_batch_add(batch, tokens_system[i], i, { 0 }, false);
|
|
193
193
|
}
|
|
194
194
|
|
|
195
195
|
if (llama_decode(ctx, batch) != 0) {
|
|
@@ -210,10 +210,10 @@ int main(int argc, char ** argv) {
|
|
|
210
210
|
while (true) {
|
|
211
211
|
if (dump_kv_cache) {
|
|
212
212
|
llama_kv_cache_view_update(ctx, &kvc_view);
|
|
213
|
-
|
|
213
|
+
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
|
214
214
|
}
|
|
215
215
|
|
|
216
|
-
|
|
216
|
+
common_batch_clear(batch);
|
|
217
217
|
|
|
218
218
|
// decode any currently ongoing sequences
|
|
219
219
|
for (auto & client : clients) {
|
|
@@ -223,7 +223,7 @@ int main(int argc, char ** argv) {
|
|
|
223
223
|
|
|
224
224
|
client.i_batch = batch.n_tokens;
|
|
225
225
|
|
|
226
|
-
|
|
226
|
+
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
|
227
227
|
|
|
228
228
|
client.n_decoded += 1;
|
|
229
229
|
}
|
|
@@ -252,14 +252,14 @@ int main(int argc, char ** argv) {
|
|
|
252
252
|
client.prompt = client.input + "\nAssistant:";
|
|
253
253
|
client.response = "";
|
|
254
254
|
|
|
255
|
-
|
|
255
|
+
common_sampler_reset(client.smpl);
|
|
256
256
|
|
|
257
257
|
// do not prepend BOS because we have a system prompt!
|
|
258
258
|
std::vector<llama_token> tokens_prompt;
|
|
259
|
-
tokens_prompt =
|
|
259
|
+
tokens_prompt = common_tokenize(ctx, client.prompt, false);
|
|
260
260
|
|
|
261
261
|
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
|
262
|
-
|
|
262
|
+
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
|
263
263
|
}
|
|
264
264
|
|
|
265
265
|
// extract the logits only for the last token
|
|
@@ -308,7 +308,6 @@ int main(int argc, char ** argv) {
|
|
|
308
308
|
batch.n_seq_id + i,
|
|
309
309
|
batch.seq_id + i,
|
|
310
310
|
batch.logits + i,
|
|
311
|
-
0, 0, 0, // unused
|
|
312
311
|
};
|
|
313
312
|
|
|
314
313
|
const int ret = llama_decode(ctx, batch_view);
|
|
@@ -340,9 +339,9 @@ int main(int argc, char ** argv) {
|
|
|
340
339
|
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
|
341
340
|
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
|
342
341
|
|
|
343
|
-
const llama_token id =
|
|
342
|
+
const llama_token id = common_sampler_sample(client.smpl, ctx, client.i_batch - i);
|
|
344
343
|
|
|
345
|
-
|
|
344
|
+
common_sampler_accept(client.smpl, id, true);
|
|
346
345
|
|
|
347
346
|
if (client.n_decoded == 1) {
|
|
348
347
|
// start measuring generation time after the first token to make sure all concurrent clients
|
|
@@ -350,7 +349,7 @@ int main(int argc, char ** argv) {
|
|
|
350
349
|
client.t_start_gen = ggml_time_us();
|
|
351
350
|
}
|
|
352
351
|
|
|
353
|
-
const std::string token_str =
|
|
352
|
+
const std::string token_str = common_token_to_piece(ctx, id);
|
|
354
353
|
|
|
355
354
|
client.response += token_str;
|
|
356
355
|
client.sampled = id;
|
|
@@ -15,17 +15,17 @@ static void print_usage(int, char ** argv) {
|
|
|
15
15
|
}
|
|
16
16
|
|
|
17
17
|
int main(int argc, char ** argv) {
|
|
18
|
-
|
|
18
|
+
common_params params;
|
|
19
19
|
|
|
20
20
|
params.n_junk = 250;
|
|
21
21
|
params.n_keep = 32;
|
|
22
22
|
params.i_pos = -1;
|
|
23
23
|
|
|
24
|
-
if (!
|
|
24
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) {
|
|
25
25
|
return 1;
|
|
26
26
|
}
|
|
27
27
|
|
|
28
|
-
|
|
28
|
+
common_init();
|
|
29
29
|
|
|
30
30
|
int n_junk = params.n_junk;
|
|
31
31
|
int n_keep = params.n_keep;
|
|
@@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
|
|
|
61
61
|
|
|
62
62
|
// initialize the model
|
|
63
63
|
|
|
64
|
-
llama_model_params model_params =
|
|
64
|
+
llama_model_params model_params = common_model_params_to_llama(params);
|
|
65
65
|
|
|
66
66
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
|
67
67
|
|
|
@@ -72,7 +72,7 @@ int main(int argc, char ** argv) {
|
|
|
72
72
|
|
|
73
73
|
// initialize the context
|
|
74
74
|
|
|
75
|
-
llama_context_params ctx_params =
|
|
75
|
+
llama_context_params ctx_params = common_context_params_to_llama(params);
|
|
76
76
|
|
|
77
77
|
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
|
|
78
78
|
|
|
@@ -92,10 +92,10 @@ int main(int argc, char ** argv) {
|
|
|
92
92
|
|
|
93
93
|
// tokenize the prompt
|
|
94
94
|
std::vector<llama_token> tokens_list;
|
|
95
|
-
tokens_list =
|
|
95
|
+
tokens_list = common_tokenize(ctx, params.prompt, true);
|
|
96
96
|
|
|
97
97
|
// tokenize the prefix and use it as a sink
|
|
98
|
-
const int n_tokens_prefix =
|
|
98
|
+
const int n_tokens_prefix = common_tokenize(ctx, prompt_prefix, true).size();
|
|
99
99
|
|
|
100
100
|
const int n_tokens_all = tokens_list.size();
|
|
101
101
|
|
|
@@ -137,10 +137,10 @@ int main(int argc, char ** argv) {
|
|
|
137
137
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
|
138
138
|
}
|
|
139
139
|
|
|
140
|
-
|
|
140
|
+
common_batch_clear(batch);
|
|
141
141
|
|
|
142
142
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
|
143
|
-
|
|
143
|
+
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
|
144
144
|
}
|
|
145
145
|
|
|
146
146
|
if (i + n_batch >= n_tokens_all) {
|
|
@@ -171,10 +171,10 @@ int main(int argc, char ** argv) {
|
|
|
171
171
|
|
|
172
172
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
|
173
173
|
|
|
174
|
-
|
|
174
|
+
common_batch_clear(batch);
|
|
175
175
|
|
|
176
176
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
|
177
|
-
|
|
177
|
+
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
|
178
178
|
}
|
|
179
179
|
|
|
180
180
|
if (i + n_batch >= n_tokens_all) {
|
|
@@ -229,15 +229,15 @@ int main(int argc, char ** argv) {
|
|
|
229
229
|
break;
|
|
230
230
|
}
|
|
231
231
|
|
|
232
|
-
LOG("%s",
|
|
232
|
+
LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
|
|
233
233
|
|
|
234
234
|
n_decode += 1;
|
|
235
235
|
|
|
236
236
|
// prepare the next batch
|
|
237
|
-
|
|
237
|
+
common_batch_clear(batch);
|
|
238
238
|
|
|
239
239
|
// push this new token for next evaluation
|
|
240
|
-
|
|
240
|
+
common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
|
|
241
241
|
}
|
|
242
242
|
|
|
243
243
|
n_cur += 1;
|