@fugood/llama.node 0.3.7 → 0.3.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -2
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +8 -0
- package/lib/index.js +16 -1
- package/lib/index.ts +16 -0
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +4 -3
- package/src/LlamaCompletionWorker.cpp +4 -2
- package/src/LlamaContext.cpp +156 -6
- package/src/LlamaContext.h +5 -0
- package/src/common.hpp +6 -11
- package/src/llama.cpp/.github/workflows/build.yml +19 -17
- package/src/llama.cpp/.github/workflows/docker.yml +77 -30
- package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +22 -3
- package/src/llama.cpp/CMakeLists.txt +49 -24
- package/src/llama.cpp/common/arg.cpp +82 -26
- package/src/llama.cpp/common/arg.h +3 -0
- package/src/llama.cpp/common/common.cpp +192 -72
- package/src/llama.cpp/common/common.h +51 -18
- package/src/llama.cpp/common/ngram-cache.cpp +12 -12
- package/src/llama.cpp/common/ngram-cache.h +2 -2
- package/src/llama.cpp/common/sampling.cpp +11 -6
- package/src/llama.cpp/common/speculative.cpp +18 -15
- package/src/llama.cpp/docs/build.md +2 -0
- package/src/llama.cpp/examples/batched/batched.cpp +9 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
- package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
- package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
- package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
- package/src/llama.cpp/examples/infill/infill.cpp +23 -24
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
- package/src/llama.cpp/examples/llava/clip.cpp +4 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
- package/src/llama.cpp/examples/llava/llava.cpp +2 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
- package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
- package/src/llama.cpp/examples/main/main.cpp +51 -29
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
- package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
- package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
- package/src/llama.cpp/examples/run/run.cpp +175 -61
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
- package/src/llama.cpp/examples/server/httplib.h +1295 -409
- package/src/llama.cpp/examples/server/server.cpp +387 -181
- package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
- package/src/llama.cpp/examples/server/utils.hpp +170 -58
- package/src/llama.cpp/examples/simple/simple.cpp +9 -8
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
- package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
- package/src/llama.cpp/examples/tts/tts.cpp +64 -23
- package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
- package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +36 -145
- package/src/llama.cpp/ggml/include/gguf.h +202 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
- package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml.c +117 -1327
- package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
- package/src/llama.cpp/include/llama-cpp.h +6 -1
- package/src/llama.cpp/include/llama.h +138 -75
- package/src/llama.cpp/src/CMakeLists.txt +13 -1
- package/src/llama.cpp/src/llama-adapter.cpp +347 -0
- package/src/llama.cpp/src/llama-adapter.h +74 -0
- package/src/llama.cpp/src/llama-arch.cpp +1487 -0
- package/src/llama.cpp/src/llama-arch.h +400 -0
- package/src/llama.cpp/src/llama-batch.cpp +368 -0
- package/src/llama.cpp/src/llama-batch.h +88 -0
- package/src/llama.cpp/src/llama-chat.cpp +578 -0
- package/src/llama.cpp/src/llama-chat.h +52 -0
- package/src/llama.cpp/src/llama-context.cpp +1775 -0
- package/src/llama.cpp/src/llama-context.h +128 -0
- package/src/llama.cpp/src/llama-cparams.cpp +1 -0
- package/src/llama.cpp/src/llama-cparams.h +37 -0
- package/src/llama.cpp/src/llama-grammar.cpp +5 -4
- package/src/llama.cpp/src/llama-grammar.h +3 -1
- package/src/llama.cpp/src/llama-hparams.cpp +71 -0
- package/src/llama.cpp/src/llama-hparams.h +139 -0
- package/src/llama.cpp/src/llama-impl.cpp +167 -0
- package/src/llama.cpp/src/llama-impl.h +16 -136
- package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
- package/src/llama.cpp/src/llama-kv-cache.h +218 -0
- package/src/llama.cpp/src/llama-mmap.cpp +589 -0
- package/src/llama.cpp/src/llama-mmap.h +67 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
- package/src/llama.cpp/src/llama-model-loader.h +167 -0
- package/src/llama.cpp/src/llama-model.cpp +3953 -0
- package/src/llama.cpp/src/llama-model.h +370 -0
- package/src/llama.cpp/src/llama-quant.cpp +934 -0
- package/src/llama.cpp/src/llama-quant.h +1 -0
- package/src/llama.cpp/src/llama-sampling.cpp +147 -32
- package/src/llama.cpp/src/llama-sampling.h +3 -19
- package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
- package/src/llama.cpp/src/llama-vocab.h +97 -142
- package/src/llama.cpp/src/llama.cpp +7160 -20314
- package/src/llama.cpp/src/unicode.cpp +8 -3
- package/src/llama.cpp/tests/CMakeLists.txt +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
- package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
- package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
- package/src/llama.cpp/tests/test-gguf.cpp +222 -187
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +0 -1
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#include "common.h"
|
|
4
4
|
#include "log.h"
|
|
5
5
|
#include "llama.h"
|
|
6
|
+
#include "common/base64.hpp"
|
|
6
7
|
|
|
7
8
|
#ifndef NDEBUG
|
|
8
9
|
// crash the server in debug mode, otherwise send an http 500 error
|
|
@@ -56,6 +57,8 @@ static T json_value(const json & body, const std::string & key, const T & defaul
|
|
|
56
57
|
}
|
|
57
58
|
}
|
|
58
59
|
|
|
60
|
+
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
|
|
61
|
+
|
|
59
62
|
//
|
|
60
63
|
// tokenizer and input processing utils
|
|
61
64
|
//
|
|
@@ -88,12 +91,34 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
|
|
88
91
|
return false;
|
|
89
92
|
}
|
|
90
93
|
|
|
94
|
+
// get value by path(key1 / key2)
|
|
95
|
+
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
|
|
96
|
+
json result = json::object();
|
|
97
|
+
|
|
98
|
+
for (const std::string & path : paths) {
|
|
99
|
+
json current = js;
|
|
100
|
+
const auto keys = string_split<std::string>(path, /*separator*/ '/');
|
|
101
|
+
bool valid_path = true;
|
|
102
|
+
for (const std::string & k : keys) {
|
|
103
|
+
if (valid_path && current.is_object() && current.contains(k)) {
|
|
104
|
+
current = current[k];
|
|
105
|
+
} else {
|
|
106
|
+
valid_path = false;
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
if (valid_path) {
|
|
110
|
+
result[path] = current;
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
return result;
|
|
114
|
+
}
|
|
115
|
+
|
|
91
116
|
/**
|
|
92
117
|
* this handles 2 cases:
|
|
93
118
|
* - only string, example: "string"
|
|
94
119
|
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
|
95
120
|
*/
|
|
96
|
-
static llama_tokens tokenize_mixed(const
|
|
121
|
+
static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
|
|
97
122
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
98
123
|
// or the first element of the json_prompt array is a string.
|
|
99
124
|
llama_tokens prompt_tokens;
|
|
@@ -106,10 +131,10 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
|
|
|
106
131
|
|
|
107
132
|
llama_tokens p;
|
|
108
133
|
if (first) {
|
|
109
|
-
p = common_tokenize(
|
|
134
|
+
p = common_tokenize(vocab, s, add_special, parse_special);
|
|
110
135
|
first = false;
|
|
111
136
|
} else {
|
|
112
|
-
p = common_tokenize(
|
|
137
|
+
p = common_tokenize(vocab, s, false, parse_special);
|
|
113
138
|
}
|
|
114
139
|
|
|
115
140
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
@@ -123,7 +148,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
|
|
|
123
148
|
}
|
|
124
149
|
} else {
|
|
125
150
|
auto s = json_prompt.template get<std::string>();
|
|
126
|
-
prompt_tokens = common_tokenize(
|
|
151
|
+
prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
|
|
127
152
|
}
|
|
128
153
|
|
|
129
154
|
return prompt_tokens;
|
|
@@ -141,11 +166,11 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
|
|
|
141
166
|
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
|
142
167
|
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
|
143
168
|
*/
|
|
144
|
-
static std::vector<llama_tokens> tokenize_input_prompts(
|
|
169
|
+
static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
|
|
145
170
|
std::vector<llama_tokens> result;
|
|
146
171
|
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
|
147
172
|
// string or mixed
|
|
148
|
-
result.push_back(tokenize_mixed(
|
|
173
|
+
result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
|
|
149
174
|
} else if (json_is_array_of_numbers(json_prompt)) {
|
|
150
175
|
// array of tokens
|
|
151
176
|
result.push_back(json_prompt.get<llama_tokens>());
|
|
@@ -154,7 +179,7 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
|
|
|
154
179
|
result.reserve(json_prompt.size());
|
|
155
180
|
for (const auto & p : json_prompt) {
|
|
156
181
|
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
|
|
157
|
-
result.push_back(tokenize_mixed(
|
|
182
|
+
result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
|
|
158
183
|
} else if (json_is_array_of_numbers(p)) {
|
|
159
184
|
// array of tokens
|
|
160
185
|
result.push_back(p.get<llama_tokens>());
|
|
@@ -206,21 +231,23 @@ static size_t validate_utf8(const std::string& text) {
|
|
|
206
231
|
//
|
|
207
232
|
|
|
208
233
|
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
|
|
209
|
-
static llama_tokens format_rerank(const struct
|
|
234
|
+
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
|
|
210
235
|
llama_tokens result;
|
|
236
|
+
|
|
211
237
|
result.reserve(doc.size() + query.size() + 4);
|
|
212
|
-
result.push_back(
|
|
238
|
+
result.push_back(llama_vocab_bos(vocab));
|
|
213
239
|
result.insert(result.end(), query.begin(), query.end());
|
|
214
|
-
result.push_back(
|
|
215
|
-
result.push_back(
|
|
240
|
+
result.push_back(llama_vocab_eos(vocab));
|
|
241
|
+
result.push_back(llama_vocab_sep(vocab));
|
|
216
242
|
result.insert(result.end(), doc.begin(), doc.end());
|
|
217
|
-
result.push_back(
|
|
243
|
+
result.push_back(llama_vocab_eos(vocab));
|
|
244
|
+
|
|
218
245
|
return result;
|
|
219
246
|
}
|
|
220
247
|
|
|
221
248
|
// format infill task
|
|
222
249
|
static llama_tokens format_infill(
|
|
223
|
-
const
|
|
250
|
+
const llama_vocab * vocab,
|
|
224
251
|
const json & input_prefix,
|
|
225
252
|
const json & input_suffix,
|
|
226
253
|
const json & input_extra,
|
|
@@ -247,15 +274,14 @@ static llama_tokens format_infill(
|
|
|
247
274
|
llama_tokens extra_tokens;
|
|
248
275
|
extra_tokens.reserve(n_ctx);
|
|
249
276
|
|
|
250
|
-
auto
|
|
251
|
-
auto
|
|
252
|
-
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
|
277
|
+
auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
|
|
278
|
+
auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
|
|
253
279
|
|
|
254
|
-
if (
|
|
280
|
+
if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
|
|
255
281
|
// TODO: make project name an input
|
|
256
|
-
static const auto k_fim_repo = common_tokenize(
|
|
282
|
+
static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
|
|
257
283
|
|
|
258
|
-
extra_tokens.push_back(
|
|
284
|
+
extra_tokens.push_back(llama_vocab_fim_rep(vocab));
|
|
259
285
|
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
|
260
286
|
}
|
|
261
287
|
for (const auto & chunk : input_extra) {
|
|
@@ -263,28 +289,28 @@ static llama_tokens format_infill(
|
|
|
263
289
|
const std::string text = json_value(chunk, "text", std::string());
|
|
264
290
|
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
|
|
265
291
|
|
|
266
|
-
if (
|
|
267
|
-
const auto k_fim_file = common_tokenize(
|
|
292
|
+
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
|
|
293
|
+
const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
|
|
268
294
|
|
|
269
|
-
extra_tokens.insert(extra_tokens.end(),
|
|
295
|
+
extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
|
|
270
296
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
|
271
297
|
} else {
|
|
272
298
|
// chunk separator in binary form to avoid confusing the AI
|
|
273
299
|
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
|
274
|
-
static const auto k_chunk_prefix_tokens = common_tokenize(
|
|
300
|
+
static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
|
|
275
301
|
|
|
276
302
|
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
|
277
303
|
}
|
|
278
304
|
|
|
279
|
-
const auto chunk_tokens = common_tokenize(
|
|
305
|
+
const auto chunk_tokens = common_tokenize(vocab, text, false, false);
|
|
280
306
|
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
|
281
307
|
}
|
|
282
308
|
|
|
283
|
-
if (
|
|
309
|
+
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
|
|
284
310
|
// TODO: current filename
|
|
285
|
-
static const auto k_fim_file = common_tokenize(
|
|
311
|
+
static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
|
|
286
312
|
|
|
287
|
-
extra_tokens.insert(extra_tokens.end(),
|
|
313
|
+
extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
|
|
288
314
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
|
289
315
|
}
|
|
290
316
|
|
|
@@ -300,15 +326,15 @@ static llama_tokens format_infill(
|
|
|
300
326
|
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
|
301
327
|
tokens_suffix.resize(n_suffix_take);
|
|
302
328
|
|
|
303
|
-
tokens_prefix.insert(tokens_prefix.begin(),
|
|
329
|
+
tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
|
|
304
330
|
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
|
305
|
-
tokens_suffix.insert(tokens_suffix.begin(),
|
|
331
|
+
tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
|
|
306
332
|
|
|
307
333
|
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
|
|
308
334
|
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
|
|
309
335
|
|
|
310
|
-
if (
|
|
311
|
-
embd_inp.insert(embd_inp.begin(),
|
|
336
|
+
if (llama_vocab_get_add_bos(vocab)) {
|
|
337
|
+
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
|
|
312
338
|
}
|
|
313
339
|
|
|
314
340
|
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
|
@@ -317,7 +343,7 @@ static llama_tokens format_infill(
|
|
|
317
343
|
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
|
318
344
|
|
|
319
345
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
320
|
-
embd_inp.push_back(
|
|
346
|
+
embd_inp.push_back(llama_vocab_fim_mid(vocab));
|
|
321
347
|
|
|
322
348
|
return embd_inp;
|
|
323
349
|
}
|
|
@@ -357,19 +383,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
|
|
357
383
|
return formatted_chat;
|
|
358
384
|
}
|
|
359
385
|
|
|
360
|
-
static std::string llama_get_chat_template(const struct llama_model * model) {
|
|
361
|
-
std::string template_key = "tokenizer.chat_template";
|
|
362
|
-
// call with NULL buffer to get the total size of the string
|
|
363
|
-
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
|
364
|
-
if (res < 2) {
|
|
365
|
-
return "";
|
|
366
|
-
} else {
|
|
367
|
-
std::vector<char> model_template(res + 1, 0);
|
|
368
|
-
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
|
369
|
-
return std::string(model_template.data(), model_template.size() - 1);
|
|
370
|
-
}
|
|
371
|
-
}
|
|
372
|
-
|
|
373
386
|
//
|
|
374
387
|
// base64 utils (TODO: move to common in the future)
|
|
375
388
|
//
|
|
@@ -495,7 +508,7 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
|
|
495
508
|
|
|
496
509
|
// format incomplete utf-8 multibyte character for output
|
|
497
510
|
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
|
|
498
|
-
std::string out = token ==
|
|
511
|
+
std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token);
|
|
499
512
|
|
|
500
513
|
// if the size is 1 and first bit is 1, meaning it's a partial character
|
|
501
514
|
// (size > 1 meaning it's already a known token)
|
|
@@ -524,10 +537,49 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
|
|
524
537
|
// OAI utils
|
|
525
538
|
//
|
|
526
539
|
|
|
527
|
-
static json oaicompat_completion_params_parse(
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
540
|
+
static json oaicompat_completion_params_parse(const json & body) {
|
|
541
|
+
json llama_params;
|
|
542
|
+
|
|
543
|
+
if (!body.contains("prompt")) {
|
|
544
|
+
throw std::runtime_error("\"prompt\" is required");
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
// Handle "stop" field
|
|
548
|
+
if (body.contains("stop") && body.at("stop").is_string()) {
|
|
549
|
+
llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
|
|
550
|
+
} else {
|
|
551
|
+
llama_params["stop"] = json_value(body, "stop", json::array());
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
// Handle "n" field
|
|
555
|
+
int n_choices = json_value(body, "n", 1);
|
|
556
|
+
if (n_choices != 1) {
|
|
557
|
+
throw std::runtime_error("Only one completion choice is allowed");
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
// Params supported by OAI but unsupported by llama.cpp
|
|
561
|
+
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
|
|
562
|
+
for (const auto & param : unsupported_params) {
|
|
563
|
+
if (body.contains(param)) {
|
|
564
|
+
throw std::runtime_error("Unsupported param: " + param);
|
|
565
|
+
}
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
// Copy remaining properties to llama_params
|
|
569
|
+
for (const auto & item : body.items()) {
|
|
570
|
+
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
|
571
|
+
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
|
|
572
|
+
llama_params[item.key()] = item.value();
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
return llama_params;
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
static json oaicompat_chat_completion_params_parse(
|
|
580
|
+
const struct llama_model * model,
|
|
581
|
+
const json & body, /* openai api json semantics */
|
|
582
|
+
const std::string & chat_template) {
|
|
531
583
|
json llama_params;
|
|
532
584
|
|
|
533
585
|
// Apply chat template to the list of messages
|
|
@@ -589,16 +641,31 @@ static json oaicompat_completion_params_parse(
|
|
|
589
641
|
return llama_params;
|
|
590
642
|
}
|
|
591
643
|
|
|
592
|
-
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
|
644
|
+
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
|
|
593
645
|
json data = json::array();
|
|
594
646
|
int32_t n_tokens = 0;
|
|
595
647
|
int i = 0;
|
|
596
648
|
for (const auto & elem : embeddings) {
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
649
|
+
json embedding_obj;
|
|
650
|
+
|
|
651
|
+
if (use_base64) {
|
|
652
|
+
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
|
|
653
|
+
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
|
|
654
|
+
size_t data_size = vec.size() * sizeof(float);
|
|
655
|
+
embedding_obj = {
|
|
656
|
+
{"embedding", base64::encode(data_ptr, data_size)},
|
|
657
|
+
{"index", i++},
|
|
658
|
+
{"object", "embedding"},
|
|
659
|
+
{"encoding_format", "base64"}
|
|
660
|
+
};
|
|
661
|
+
} else {
|
|
662
|
+
embedding_obj = {
|
|
663
|
+
{"embedding", json_value(elem, "embedding", json::array())},
|
|
664
|
+
{"index", i++},
|
|
665
|
+
{"object", "embedding"}
|
|
666
|
+
};
|
|
667
|
+
}
|
|
668
|
+
data.push_back(embedding_obj);
|
|
602
669
|
|
|
603
670
|
n_tokens += json_value(elem, "tokens_evaluated", 0);
|
|
604
671
|
}
|
|
@@ -698,14 +765,18 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
|
|
|
698
765
|
return data;
|
|
699
766
|
}
|
|
700
767
|
|
|
701
|
-
static std::string safe_json_to_str(json data) {
|
|
768
|
+
static std::string safe_json_to_str(const json & data) {
|
|
702
769
|
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
|
703
770
|
}
|
|
704
771
|
|
|
705
772
|
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
|
706
773
|
std::vector<llama_token_data> cur;
|
|
707
774
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
708
|
-
|
|
775
|
+
|
|
776
|
+
const llama_model * model = llama_get_model(ctx);
|
|
777
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
778
|
+
|
|
779
|
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
709
780
|
|
|
710
781
|
cur.resize(n_vocab);
|
|
711
782
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
@@ -731,3 +802,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
|
|
|
731
802
|
|
|
732
803
|
return cur;
|
|
733
804
|
}
|
|
805
|
+
|
|
806
|
+
static bool are_lora_equal(
|
|
807
|
+
const std::vector<common_adapter_lora_info> & l1,
|
|
808
|
+
const std::vector<common_adapter_lora_info> & l2) {
|
|
809
|
+
if (l1.size() != l2.size()) {
|
|
810
|
+
return false;
|
|
811
|
+
}
|
|
812
|
+
for (size_t i = 0; i < l1.size(); ++i) {
|
|
813
|
+
// we don't check lora.path to reduce the time complexity
|
|
814
|
+
if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
|
|
815
|
+
return false;
|
|
816
|
+
}
|
|
817
|
+
}
|
|
818
|
+
return true;
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
// parse lora config from JSON request, returned a copy of lora_base with updated scale
|
|
822
|
+
static std::vector<common_adapter_lora_info> parse_lora_request(
|
|
823
|
+
const std::vector<common_adapter_lora_info> & lora_base,
|
|
824
|
+
const json & data) {
|
|
825
|
+
std::vector<common_adapter_lora_info> lora(lora_base);
|
|
826
|
+
int max_idx = lora.size();
|
|
827
|
+
|
|
828
|
+
// clear existing value
|
|
829
|
+
for (auto & entry : lora) {
|
|
830
|
+
entry.scale = 0.0f;
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
// set value
|
|
834
|
+
for (const auto & entry : data) {
|
|
835
|
+
int id = json_value(entry, "id", -1);
|
|
836
|
+
float scale = json_value(entry, "scale", 0.0f);
|
|
837
|
+
if (0 <= id && id < max_idx) {
|
|
838
|
+
lora[id].scale = scale;
|
|
839
|
+
} else {
|
|
840
|
+
throw std::runtime_error("invalid adapter id");
|
|
841
|
+
}
|
|
842
|
+
}
|
|
843
|
+
|
|
844
|
+
return lora;
|
|
845
|
+
}
|
|
@@ -83,7 +83,8 @@ int main(int argc, char ** argv) {
|
|
|
83
83
|
llama_model_params model_params = llama_model_default_params();
|
|
84
84
|
model_params.n_gpu_layers = ngl;
|
|
85
85
|
|
|
86
|
-
llama_model * model =
|
|
86
|
+
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
|
|
87
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
87
88
|
|
|
88
89
|
if (model == NULL) {
|
|
89
90
|
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
|
@@ -93,11 +94,11 @@ int main(int argc, char ** argv) {
|
|
|
93
94
|
// tokenize the prompt
|
|
94
95
|
|
|
95
96
|
// find the number of tokens in the prompt
|
|
96
|
-
const int n_prompt = -llama_tokenize(
|
|
97
|
+
const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
|
97
98
|
|
|
98
99
|
// allocate space for the tokens and tokenize the prompt
|
|
99
100
|
std::vector<llama_token> prompt_tokens(n_prompt);
|
|
100
|
-
if (llama_tokenize(
|
|
101
|
+
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
|
|
101
102
|
fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
|
|
102
103
|
return 1;
|
|
103
104
|
}
|
|
@@ -112,7 +113,7 @@ int main(int argc, char ** argv) {
|
|
|
112
113
|
// enable performance counters
|
|
113
114
|
ctx_params.no_perf = false;
|
|
114
115
|
|
|
115
|
-
llama_context * ctx =
|
|
116
|
+
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
|
116
117
|
|
|
117
118
|
if (ctx == NULL) {
|
|
118
119
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
|
@@ -131,7 +132,7 @@ int main(int argc, char ** argv) {
|
|
|
131
132
|
|
|
132
133
|
for (auto id : prompt_tokens) {
|
|
133
134
|
char buf[128];
|
|
134
|
-
int n = llama_token_to_piece(
|
|
135
|
+
int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
|
|
135
136
|
if (n < 0) {
|
|
136
137
|
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
|
|
137
138
|
return 1;
|
|
@@ -164,12 +165,12 @@ int main(int argc, char ** argv) {
|
|
|
164
165
|
new_token_id = llama_sampler_sample(smpl, ctx, -1);
|
|
165
166
|
|
|
166
167
|
// is it an end of generation?
|
|
167
|
-
if (
|
|
168
|
+
if (llama_vocab_is_eog(vocab, new_token_id)) {
|
|
168
169
|
break;
|
|
169
170
|
}
|
|
170
171
|
|
|
171
172
|
char buf[128];
|
|
172
|
-
int n = llama_token_to_piece(
|
|
173
|
+
int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
|
|
173
174
|
if (n < 0) {
|
|
174
175
|
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
|
|
175
176
|
return 1;
|
|
@@ -199,7 +200,7 @@ int main(int argc, char ** argv) {
|
|
|
199
200
|
|
|
200
201
|
llama_sampler_free(smpl);
|
|
201
202
|
llama_free(ctx);
|
|
202
|
-
|
|
203
|
+
llama_model_free(model);
|
|
203
204
|
|
|
204
205
|
return 0;
|
|
205
206
|
}
|
|
@@ -69,18 +69,20 @@ int main(int argc, char ** argv) {
|
|
|
69
69
|
llama_model_params model_params = llama_model_default_params();
|
|
70
70
|
model_params.n_gpu_layers = ngl;
|
|
71
71
|
|
|
72
|
-
llama_model * model =
|
|
72
|
+
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
|
|
73
73
|
if (!model) {
|
|
74
74
|
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
|
75
75
|
return 1;
|
|
76
76
|
}
|
|
77
77
|
|
|
78
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
79
|
+
|
|
78
80
|
// initialize the context
|
|
79
81
|
llama_context_params ctx_params = llama_context_default_params();
|
|
80
82
|
ctx_params.n_ctx = n_ctx;
|
|
81
83
|
ctx_params.n_batch = n_ctx;
|
|
82
84
|
|
|
83
|
-
llama_context * ctx =
|
|
85
|
+
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
|
84
86
|
if (!ctx) {
|
|
85
87
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
|
86
88
|
return 1;
|
|
@@ -93,13 +95,13 @@ int main(int argc, char ** argv) {
|
|
|
93
95
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
|
|
94
96
|
|
|
95
97
|
// helper function to evaluate a prompt and generate a response
|
|
96
|
-
auto generate = [&](const std::string & prompt) {
|
|
98
|
+
auto generate = [&](const std::string & prompt, bool is_first) {
|
|
97
99
|
std::string response;
|
|
98
100
|
|
|
99
101
|
// tokenize the prompt
|
|
100
|
-
const int n_prompt_tokens = -llama_tokenize(
|
|
102
|
+
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
|
101
103
|
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
|
|
102
|
-
if (llama_tokenize(
|
|
104
|
+
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
|
|
103
105
|
GGML_ABORT("failed to tokenize the prompt\n");
|
|
104
106
|
}
|
|
105
107
|
|
|
@@ -124,13 +126,13 @@ int main(int argc, char ** argv) {
|
|
|
124
126
|
new_token_id = llama_sampler_sample(smpl, ctx, -1);
|
|
125
127
|
|
|
126
128
|
// is it an end of generation?
|
|
127
|
-
if (
|
|
129
|
+
if (llama_vocab_is_eog(vocab, new_token_id)) {
|
|
128
130
|
break;
|
|
129
131
|
}
|
|
130
132
|
|
|
131
133
|
// convert the token to a string, print it and add it to the response
|
|
132
134
|
char buf[256];
|
|
133
|
-
int n = llama_token_to_piece(
|
|
135
|
+
int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
|
|
134
136
|
if (n < 0) {
|
|
135
137
|
GGML_ABORT("failed to convert token to piece\n");
|
|
136
138
|
}
|
|
@@ -159,12 +161,14 @@ int main(int argc, char ** argv) {
|
|
|
159
161
|
break;
|
|
160
162
|
}
|
|
161
163
|
|
|
164
|
+
const char * tmpl = llama_model_chat_template(model);
|
|
165
|
+
|
|
162
166
|
// add the user input to the message list and format it
|
|
163
167
|
messages.push_back({"user", strdup(user.c_str())});
|
|
164
|
-
int new_len = llama_chat_apply_template(
|
|
168
|
+
int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
|
165
169
|
if (new_len > (int)formatted.size()) {
|
|
166
170
|
formatted.resize(new_len);
|
|
167
|
-
new_len = llama_chat_apply_template(
|
|
171
|
+
new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
|
168
172
|
}
|
|
169
173
|
if (new_len < 0) {
|
|
170
174
|
fprintf(stderr, "failed to apply the chat template\n");
|
|
@@ -176,12 +180,12 @@ int main(int argc, char ** argv) {
|
|
|
176
180
|
|
|
177
181
|
// generate a response
|
|
178
182
|
printf("\033[33m");
|
|
179
|
-
std::string response = generate(prompt);
|
|
183
|
+
std::string response = generate(prompt, prev_len == 0);
|
|
180
184
|
printf("\n\033[0m");
|
|
181
185
|
|
|
182
186
|
// add the response to the messages
|
|
183
187
|
messages.push_back({"assistant", strdup(response.c_str())});
|
|
184
|
-
prev_len = llama_chat_apply_template(
|
|
188
|
+
prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
|
|
185
189
|
if (prev_len < 0) {
|
|
186
190
|
fprintf(stderr, "failed to apply the chat template\n");
|
|
187
191
|
return 1;
|
|
@@ -194,7 +198,7 @@ int main(int argc, char ** argv) {
|
|
|
194
198
|
}
|
|
195
199
|
llama_sampler_free(smpl);
|
|
196
200
|
llama_free(ctx);
|
|
197
|
-
|
|
201
|
+
llama_model_free(model);
|
|
198
202
|
|
|
199
203
|
return 0;
|
|
200
204
|
}
|
|
@@ -72,8 +72,9 @@ int main(int argc, char ** argv) {
|
|
|
72
72
|
|
|
73
73
|
// load the target model
|
|
74
74
|
common_init_result llama_init_tgt = common_init_from_params(params);
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
|
|
76
|
+
model_tgt = llama_init_tgt.model.get();
|
|
77
|
+
ctx_tgt = llama_init_tgt.context.get();
|
|
77
78
|
|
|
78
79
|
// load the draft model
|
|
79
80
|
params.devices = params.speculative.devices;
|
|
@@ -85,13 +86,17 @@ int main(int argc, char ** argv) {
|
|
|
85
86
|
|
|
86
87
|
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
|
87
88
|
common_init_result llama_init_dft = common_init_from_params(params);
|
|
88
|
-
model_dft = llama_init_dft.model;
|
|
89
|
-
ctx_dft = llama_init_dft.context;
|
|
90
89
|
|
|
91
|
-
|
|
90
|
+
model_dft = llama_init_dft.model.get();
|
|
91
|
+
ctx_dft = llama_init_dft.context.get();
|
|
92
|
+
|
|
93
|
+
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
|
94
|
+
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
|
95
|
+
|
|
96
|
+
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
|
92
97
|
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
|
93
98
|
|
|
94
|
-
const bool vocab_type_dft = llama_vocab_type(
|
|
99
|
+
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
|
|
95
100
|
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
|
|
96
101
|
|
|
97
102
|
if (vocab_type_tgt != vocab_type_dft) {
|
|
@@ -101,18 +106,18 @@ int main(int argc, char ** argv) {
|
|
|
101
106
|
}
|
|
102
107
|
|
|
103
108
|
if (
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
109
|
+
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
|
110
|
+
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
|
111
|
+
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
|
112
|
+
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
|
108
113
|
) {
|
|
109
114
|
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
|
110
115
|
return 1;
|
|
111
116
|
}
|
|
112
117
|
|
|
113
118
|
{
|
|
114
|
-
const int n_vocab_tgt =
|
|
115
|
-
const int n_vocab_dft =
|
|
119
|
+
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
|
120
|
+
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
|
116
121
|
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
117
122
|
? n_vocab_tgt - n_vocab_dft
|
|
118
123
|
: n_vocab_dft - n_vocab_tgt;
|
|
@@ -120,13 +125,13 @@ int main(int argc, char ** argv) {
|
|
|
120
125
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
121
126
|
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
122
127
|
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
123
|
-
n_vocab_tgt,
|
|
128
|
+
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
124
129
|
return 1;
|
|
125
130
|
}
|
|
126
131
|
|
|
127
132
|
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
128
|
-
const char * token_text_tgt =
|
|
129
|
-
const char * token_text_dft =
|
|
133
|
+
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
|
134
|
+
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
|
130
135
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
131
136
|
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
|
132
137
|
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
|
@@ -168,7 +173,7 @@ int main(int argc, char ** argv) {
|
|
|
168
173
|
const auto t_enc_end = ggml_time_us();
|
|
169
174
|
|
|
170
175
|
// the 2 models should have the same vocab
|
|
171
|
-
//GGML_ASSERT(n_vocab ==
|
|
176
|
+
//GGML_ASSERT(n_vocab == llama_vocab_n_tokens(model_dft));
|
|
172
177
|
|
|
173
178
|
// how many tokens to draft each time
|
|
174
179
|
int n_draft = params.speculative.n_max;
|
|
@@ -384,7 +389,7 @@ int main(int argc, char ** argv) {
|
|
|
384
389
|
}
|
|
385
390
|
}
|
|
386
391
|
|
|
387
|
-
if (
|
|
392
|
+
if (llama_vocab_is_eog(vocab_tgt, token_id)) {
|
|
388
393
|
has_eos = true;
|
|
389
394
|
}
|
|
390
395
|
++n_predict;
|
|
@@ -631,12 +636,6 @@ int main(int argc, char ** argv) {
|
|
|
631
636
|
|
|
632
637
|
llama_batch_free(batch_dft);
|
|
633
638
|
|
|
634
|
-
llama_free(ctx_tgt);
|
|
635
|
-
llama_free_model(model_tgt);
|
|
636
|
-
|
|
637
|
-
llama_free(ctx_dft);
|
|
638
|
-
llama_free_model(model_dft);
|
|
639
|
-
|
|
640
639
|
llama_backend_free();
|
|
641
640
|
|
|
642
641
|
LOG("\n\n");
|