@fugood/llama.node 0.3.3 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +29 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +17 -1
- package/src/LlamaContext.cpp +86 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -20,11 +20,11 @@
|
|
|
20
20
|
#include <sstream>
|
|
21
21
|
#include <string>
|
|
22
22
|
#include <vector>
|
|
23
|
+
#include <memory>
|
|
23
24
|
|
|
24
|
-
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo
|
|
25
|
+
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
|
|
25
26
|
|
|
26
27
|
using json = nlohmann::ordered_json;
|
|
27
|
-
using llama_tokens = std::vector<llama_token>;
|
|
28
28
|
|
|
29
29
|
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
30
30
|
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
@@ -41,17 +41,6 @@ using llama_tokens = std::vector<llama_token>;
|
|
|
41
41
|
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
42
42
|
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
43
43
|
|
|
44
|
-
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
|
45
|
-
enum error_type {
|
|
46
|
-
ERROR_TYPE_INVALID_REQUEST,
|
|
47
|
-
ERROR_TYPE_AUTHENTICATION,
|
|
48
|
-
ERROR_TYPE_SERVER,
|
|
49
|
-
ERROR_TYPE_NOT_FOUND,
|
|
50
|
-
ERROR_TYPE_PERMISSION,
|
|
51
|
-
ERROR_TYPE_UNAVAILABLE, // custom error
|
|
52
|
-
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
|
53
|
-
};
|
|
54
|
-
|
|
55
44
|
template <typename T>
|
|
56
45
|
static T json_value(const json & body, const std::string & key, const T & default_value) {
|
|
57
46
|
// Fallback null to default value
|
|
@@ -149,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
|
|
|
149
138
|
* and multiple prompts (multi-tasks):
|
|
150
139
|
* - "prompt": ["string1", "string2"]
|
|
151
140
|
* - "prompt": ["string1", [12, 34, 56]]
|
|
141
|
+
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
|
152
142
|
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
|
153
143
|
*/
|
|
154
144
|
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
|
@@ -175,9 +165,42 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
|
|
|
175
165
|
} else {
|
|
176
166
|
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
|
177
167
|
}
|
|
168
|
+
if (result.empty()) {
|
|
169
|
+
throw std::runtime_error("\"prompt\" must not be empty");
|
|
170
|
+
}
|
|
178
171
|
return result;
|
|
179
172
|
}
|
|
180
173
|
|
|
174
|
+
// return the last index of character that can form a valid string
|
|
175
|
+
// if the last character is potentially cut in half, return the index before the cut
|
|
176
|
+
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
|
177
|
+
static size_t validate_utf8(const std::string& text) {
|
|
178
|
+
size_t len = text.size();
|
|
179
|
+
if (len == 0) return 0;
|
|
180
|
+
|
|
181
|
+
// Check the last few bytes to see if a multi-byte character is cut off
|
|
182
|
+
for (size_t i = 1; i <= 4 && i <= len; ++i) {
|
|
183
|
+
unsigned char c = text[len - i];
|
|
184
|
+
// Check for start of a multi-byte sequence from the end
|
|
185
|
+
if ((c & 0xE0) == 0xC0) {
|
|
186
|
+
// 2-byte character start: 110xxxxx
|
|
187
|
+
// Needs at least 2 bytes
|
|
188
|
+
if (i < 2) return len - i;
|
|
189
|
+
} else if ((c & 0xF0) == 0xE0) {
|
|
190
|
+
// 3-byte character start: 1110xxxx
|
|
191
|
+
// Needs at least 3 bytes
|
|
192
|
+
if (i < 3) return len - i;
|
|
193
|
+
} else if ((c & 0xF8) == 0xF0) {
|
|
194
|
+
// 4-byte character start: 11110xxx
|
|
195
|
+
// Needs at least 4 bytes
|
|
196
|
+
if (i < 4) return len - i;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
// If no cut-off multi-byte character is found, return full length
|
|
201
|
+
return len;
|
|
202
|
+
}
|
|
203
|
+
|
|
181
204
|
//
|
|
182
205
|
// template utils
|
|
183
206
|
//
|
|
@@ -338,12 +361,12 @@ static std::string llama_get_chat_template(const struct llama_model * model) {
|
|
|
338
361
|
std::string template_key = "tokenizer.chat_template";
|
|
339
362
|
// call with NULL buffer to get the total size of the string
|
|
340
363
|
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
|
341
|
-
if (res <
|
|
364
|
+
if (res < 2) {
|
|
342
365
|
return "";
|
|
343
366
|
} else {
|
|
344
|
-
std::vector<char> model_template(res, 0);
|
|
367
|
+
std::vector<char> model_template(res + 1, 0);
|
|
345
368
|
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
|
346
|
-
return std::string(model_template.data(), model_template.size());
|
|
369
|
+
return std::string(model_template.data(), model_template.size() - 1);
|
|
347
370
|
}
|
|
348
371
|
}
|
|
349
372
|
|
|
@@ -439,62 +462,6 @@ static std::string gen_chatcmplid() {
|
|
|
439
462
|
// other common utils
|
|
440
463
|
//
|
|
441
464
|
|
|
442
|
-
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
|
|
443
|
-
size_t i;
|
|
444
|
-
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
|
445
|
-
|
|
446
|
-
return i;
|
|
447
|
-
}
|
|
448
|
-
|
|
449
|
-
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
|
|
450
|
-
// check for empty sequences
|
|
451
|
-
if (a.empty() || b.empty()) {
|
|
452
|
-
return 0;
|
|
453
|
-
}
|
|
454
|
-
|
|
455
|
-
// get the lengths of the input sequences
|
|
456
|
-
size_t a_len = a.size();
|
|
457
|
-
size_t b_len = b.size();
|
|
458
|
-
|
|
459
|
-
// initialize the maximum length of the longest common subsequence (LCS)
|
|
460
|
-
size_t max_length = 0;
|
|
461
|
-
|
|
462
|
-
// use two rows instead of a 2D matrix to optimize space
|
|
463
|
-
std::vector<size_t> prev_row(b_len + 1, 0);
|
|
464
|
-
std::vector<size_t> curr_row(b_len + 1, 0);
|
|
465
|
-
|
|
466
|
-
// iterate through the elements of a
|
|
467
|
-
for (size_t i = 1; i <= a_len; i++) {
|
|
468
|
-
// iterate through the elements of b
|
|
469
|
-
for (size_t j = 1; j <= b_len; j++) {
|
|
470
|
-
// if elements at the current positions match
|
|
471
|
-
if (a[i - 1] == b[j - 1]) {
|
|
472
|
-
// if it's the first element of either sequences, set LCS length to 1
|
|
473
|
-
if (i == 1 || j == 1) {
|
|
474
|
-
curr_row[j] = 1;
|
|
475
|
-
} else {
|
|
476
|
-
// increment LCS length by 1 compared to the previous element
|
|
477
|
-
curr_row[j] = prev_row[j - 1] + 1;
|
|
478
|
-
}
|
|
479
|
-
|
|
480
|
-
// update max_length if necessary
|
|
481
|
-
if (curr_row[j] > max_length) {
|
|
482
|
-
max_length = curr_row[j];
|
|
483
|
-
}
|
|
484
|
-
} else {
|
|
485
|
-
// reset LCS length if elements don't match
|
|
486
|
-
curr_row[j] = 0;
|
|
487
|
-
}
|
|
488
|
-
}
|
|
489
|
-
|
|
490
|
-
// update the previous row for the next iteration
|
|
491
|
-
prev_row = curr_row;
|
|
492
|
-
}
|
|
493
|
-
|
|
494
|
-
// return the maximum length of the LCS
|
|
495
|
-
return max_length;
|
|
496
|
-
}
|
|
497
|
-
|
|
498
465
|
static bool ends_with(const std::string & str, const std::string & suffix) {
|
|
499
466
|
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
|
500
467
|
}
|
|
@@ -542,48 +509,11 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
|
|
|
542
509
|
return out;
|
|
543
510
|
}
|
|
544
511
|
|
|
545
|
-
struct completion_token_output {
|
|
546
|
-
llama_token tok;
|
|
547
|
-
std::string text_to_send;
|
|
548
|
-
|
|
549
|
-
struct token_prob {
|
|
550
|
-
llama_token tok;
|
|
551
|
-
float prob;
|
|
552
|
-
};
|
|
553
|
-
|
|
554
|
-
std::vector<token_prob> probs;
|
|
555
|
-
};
|
|
556
|
-
|
|
557
|
-
// convert a vector of completion_token_output to json
|
|
558
|
-
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
|
|
559
|
-
json out = json::array();
|
|
560
|
-
|
|
561
|
-
for (const auto & prob : probs) {
|
|
562
|
-
json probs_for_token = json::array();
|
|
563
|
-
|
|
564
|
-
for (const auto & p : prob.probs) {
|
|
565
|
-
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
|
|
566
|
-
probs_for_token.push_back(json {
|
|
567
|
-
{"tok_str", tok_str},
|
|
568
|
-
{"prob", p.prob},
|
|
569
|
-
});
|
|
570
|
-
}
|
|
571
|
-
|
|
572
|
-
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
|
573
|
-
out.push_back(json {
|
|
574
|
-
{"content", tok_str},
|
|
575
|
-
{"probs", probs_for_token},
|
|
576
|
-
});
|
|
577
|
-
}
|
|
578
|
-
|
|
579
|
-
return out;
|
|
580
|
-
}
|
|
581
|
-
|
|
582
512
|
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
|
|
583
513
|
const std::string str =
|
|
584
514
|
std::string(event) + ": " +
|
|
585
515
|
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
586
|
-
"\n\n"; //
|
|
516
|
+
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
|
587
517
|
|
|
588
518
|
LOG_DBG("data stream, to_send: %s", str.c_str());
|
|
589
519
|
|
|
@@ -600,8 +530,6 @@ static json oaicompat_completion_params_parse(
|
|
|
600
530
|
const std::string & chat_template) {
|
|
601
531
|
json llama_params;
|
|
602
532
|
|
|
603
|
-
llama_params["__oaicompat"] = true;
|
|
604
|
-
|
|
605
533
|
// Apply chat template to the list of messages
|
|
606
534
|
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
|
607
535
|
|
|
@@ -661,157 +589,9 @@ static json oaicompat_completion_params_parse(
|
|
|
661
589
|
return llama_params;
|
|
662
590
|
}
|
|
663
591
|
|
|
664
|
-
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
|
|
665
|
-
bool stopped_word = result.count("stopped_word") != 0;
|
|
666
|
-
bool stopped_eos = json_value(result, "stopped_eos", false);
|
|
667
|
-
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
|
668
|
-
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
|
669
|
-
std::string content = json_value(result, "content", std::string(""));
|
|
670
|
-
|
|
671
|
-
std::string finish_reason = "length";
|
|
672
|
-
if (stopped_word || stopped_eos) {
|
|
673
|
-
finish_reason = "stop";
|
|
674
|
-
}
|
|
675
|
-
|
|
676
|
-
json choices =
|
|
677
|
-
streaming ? json::array({json{{"finish_reason", finish_reason},
|
|
678
|
-
{"index", 0},
|
|
679
|
-
{"delta", json::object()}}})
|
|
680
|
-
: json::array({json{{"finish_reason", finish_reason},
|
|
681
|
-
{"index", 0},
|
|
682
|
-
{"message", json{{"content", content},
|
|
683
|
-
{"role", "assistant"}}}}});
|
|
684
|
-
|
|
685
|
-
std::time_t t = std::time(0);
|
|
686
|
-
|
|
687
|
-
json res = json {
|
|
688
|
-
{"choices", choices},
|
|
689
|
-
{"created", t},
|
|
690
|
-
{"model",
|
|
691
|
-
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
692
|
-
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
|
693
|
-
{"usage", json {
|
|
694
|
-
{"completion_tokens", num_tokens_predicted},
|
|
695
|
-
{"prompt_tokens", num_prompt_tokens},
|
|
696
|
-
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
|
697
|
-
}},
|
|
698
|
-
{"id", completion_id}
|
|
699
|
-
};
|
|
700
|
-
|
|
701
|
-
// extra fields for debugging purposes
|
|
702
|
-
if (verbose) {
|
|
703
|
-
res["__verbose"] = result;
|
|
704
|
-
}
|
|
705
|
-
|
|
706
|
-
if (result.contains("completion_probabilities")) {
|
|
707
|
-
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
|
708
|
-
}
|
|
709
|
-
|
|
710
|
-
return res;
|
|
711
|
-
}
|
|
712
|
-
|
|
713
|
-
// return value is vector as there is one case where we might need to generate two responses
|
|
714
|
-
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
|
|
715
|
-
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
|
716
|
-
return std::vector<json>({result});
|
|
717
|
-
}
|
|
718
|
-
|
|
719
|
-
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
|
720
|
-
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
|
721
|
-
|
|
722
|
-
bool stopped_word = json_value(result, "stopped_word", false);
|
|
723
|
-
bool stopped_eos = json_value(result, "stopped_eos", false);
|
|
724
|
-
bool stopped_limit = json_value(result, "stopped_limit", false);
|
|
725
|
-
std::string content = json_value(result, "content", std::string(""));
|
|
726
|
-
|
|
727
|
-
std::string finish_reason;
|
|
728
|
-
if (stopped_word || stopped_eos) {
|
|
729
|
-
finish_reason = "stop";
|
|
730
|
-
}
|
|
731
|
-
if (stopped_limit) {
|
|
732
|
-
finish_reason = "length";
|
|
733
|
-
}
|
|
734
|
-
|
|
735
|
-
std::time_t t = std::time(0);
|
|
736
|
-
|
|
737
|
-
json choices;
|
|
738
|
-
|
|
739
|
-
if (!finish_reason.empty()) {
|
|
740
|
-
choices = json::array({json{{"finish_reason", finish_reason},
|
|
741
|
-
{"index", 0},
|
|
742
|
-
{"delta", json::object()}}});
|
|
743
|
-
} else {
|
|
744
|
-
if (first) {
|
|
745
|
-
if (content.empty()) {
|
|
746
|
-
choices = json::array({json{{"finish_reason", nullptr},
|
|
747
|
-
{"index", 0},
|
|
748
|
-
{"delta", json{{"role", "assistant"}}}}});
|
|
749
|
-
} else {
|
|
750
|
-
// We have to send this as two updates to conform to openai behavior
|
|
751
|
-
json initial_ret = json{{"choices", json::array({json{
|
|
752
|
-
{"finish_reason", nullptr},
|
|
753
|
-
{"index", 0},
|
|
754
|
-
{"delta", json{
|
|
755
|
-
{"role", "assistant"}
|
|
756
|
-
}}}})},
|
|
757
|
-
{"created", t},
|
|
758
|
-
{"id", completion_id},
|
|
759
|
-
{"model", modelname},
|
|
760
|
-
{"object", "chat.completion.chunk"}};
|
|
761
|
-
|
|
762
|
-
json second_ret = json{
|
|
763
|
-
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
764
|
-
{"index", 0},
|
|
765
|
-
{"delta", json{
|
|
766
|
-
{"content", content}}}
|
|
767
|
-
}})},
|
|
768
|
-
{"created", t},
|
|
769
|
-
{"id", completion_id},
|
|
770
|
-
{"model", modelname},
|
|
771
|
-
{"object", "chat.completion.chunk"}};
|
|
772
|
-
|
|
773
|
-
return std::vector<json>({initial_ret, second_ret});
|
|
774
|
-
}
|
|
775
|
-
} else {
|
|
776
|
-
// Some idiosyncrasy in task processing logic makes several trailing calls
|
|
777
|
-
// with empty content, we ignore these at the calee site.
|
|
778
|
-
if (content.empty()) {
|
|
779
|
-
return std::vector<json>({json::object()});
|
|
780
|
-
}
|
|
781
|
-
|
|
782
|
-
choices = json::array({json{
|
|
783
|
-
{"finish_reason", nullptr},
|
|
784
|
-
{"index", 0},
|
|
785
|
-
{"delta",
|
|
786
|
-
json{
|
|
787
|
-
{"content", content},
|
|
788
|
-
}},
|
|
789
|
-
}});
|
|
790
|
-
}
|
|
791
|
-
}
|
|
792
|
-
|
|
793
|
-
json ret = json {
|
|
794
|
-
{"choices", choices},
|
|
795
|
-
{"created", t},
|
|
796
|
-
{"id", completion_id},
|
|
797
|
-
{"model", modelname},
|
|
798
|
-
{"object", "chat.completion.chunk"}
|
|
799
|
-
};
|
|
800
|
-
if (!finish_reason.empty()) {
|
|
801
|
-
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
|
802
|
-
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
|
803
|
-
ret.push_back({"usage", json {
|
|
804
|
-
{"completion_tokens", num_tokens_predicted},
|
|
805
|
-
{"prompt_tokens", num_prompt_tokens},
|
|
806
|
-
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
|
807
|
-
}});
|
|
808
|
-
}
|
|
809
|
-
|
|
810
|
-
return std::vector<json>({ret});
|
|
811
|
-
}
|
|
812
|
-
|
|
813
592
|
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
|
814
593
|
json data = json::array();
|
|
594
|
+
int32_t n_tokens = 0;
|
|
815
595
|
int i = 0;
|
|
816
596
|
for (const auto & elem : embeddings) {
|
|
817
597
|
data.push_back(json{
|
|
@@ -819,14 +599,16 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|
|
819
599
|
{"index", i++},
|
|
820
600
|
{"object", "embedding"}
|
|
821
601
|
});
|
|
602
|
+
|
|
603
|
+
n_tokens += json_value(elem, "tokens_evaluated", 0);
|
|
822
604
|
}
|
|
823
605
|
|
|
824
606
|
json res = json {
|
|
825
607
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
826
608
|
{"object", "list"},
|
|
827
|
-
{"usage", json {
|
|
828
|
-
{"prompt_tokens",
|
|
829
|
-
{"total_tokens",
|
|
609
|
+
{"usage", json {
|
|
610
|
+
{"prompt_tokens", n_tokens},
|
|
611
|
+
{"total_tokens", n_tokens}
|
|
830
612
|
}},
|
|
831
613
|
{"data", data}
|
|
832
614
|
};
|
|
@@ -836,20 +618,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|
|
836
618
|
|
|
837
619
|
static json format_response_rerank(const json & request, const json & ranks) {
|
|
838
620
|
json data = json::array();
|
|
621
|
+
int32_t n_tokens = 0;
|
|
839
622
|
int i = 0;
|
|
840
623
|
for (const auto & rank : ranks) {
|
|
841
624
|
data.push_back(json{
|
|
842
625
|
{"index", i++},
|
|
843
626
|
{"relevance_score", json_value(rank, "score", 0.0)},
|
|
844
627
|
});
|
|
628
|
+
|
|
629
|
+
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
|
845
630
|
}
|
|
846
631
|
|
|
847
632
|
json res = json {
|
|
848
633
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
849
634
|
{"object", "list"},
|
|
850
|
-
{"usage", json {
|
|
851
|
-
{"prompt_tokens",
|
|
852
|
-
{"total_tokens",
|
|
635
|
+
{"usage", json {
|
|
636
|
+
{"prompt_tokens", n_tokens},
|
|
637
|
+
{"total_tokens", n_tokens}
|
|
853
638
|
}},
|
|
854
639
|
{"results", data}
|
|
855
640
|
};
|
|
@@ -902,42 +687,47 @@ static json format_detokenized_response(const std::string & content) {
|
|
|
902
687
|
};
|
|
903
688
|
}
|
|
904
689
|
|
|
905
|
-
static json
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
break;
|
|
913
|
-
case ERROR_TYPE_AUTHENTICATION:
|
|
914
|
-
type_str = "authentication_error";
|
|
915
|
-
code = 401;
|
|
916
|
-
break;
|
|
917
|
-
case ERROR_TYPE_NOT_FOUND:
|
|
918
|
-
type_str = "not_found_error";
|
|
919
|
-
code = 404;
|
|
920
|
-
break;
|
|
921
|
-
case ERROR_TYPE_SERVER:
|
|
922
|
-
type_str = "server_error";
|
|
923
|
-
code = 500;
|
|
924
|
-
break;
|
|
925
|
-
case ERROR_TYPE_PERMISSION:
|
|
926
|
-
type_str = "permission_error";
|
|
927
|
-
code = 403;
|
|
928
|
-
break;
|
|
929
|
-
case ERROR_TYPE_NOT_SUPPORTED:
|
|
930
|
-
type_str = "not_supported_error";
|
|
931
|
-
code = 501;
|
|
932
|
-
break;
|
|
933
|
-
case ERROR_TYPE_UNAVAILABLE:
|
|
934
|
-
type_str = "unavailable_error";
|
|
935
|
-
code = 503;
|
|
936
|
-
break;
|
|
690
|
+
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
|
|
691
|
+
json data = json::array();
|
|
692
|
+
for (const auto & lb : logit_bias) {
|
|
693
|
+
data.push_back(json{
|
|
694
|
+
{"bias", lb.bias},
|
|
695
|
+
{"token", lb.token},
|
|
696
|
+
});
|
|
937
697
|
}
|
|
938
|
-
return
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
698
|
+
return data;
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
static std::string safe_json_to_str(json data) {
|
|
702
|
+
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
|
706
|
+
std::vector<llama_token_data> cur;
|
|
707
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
708
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
709
|
+
|
|
710
|
+
cur.resize(n_vocab);
|
|
711
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
712
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
// sort tokens by logits
|
|
716
|
+
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
|
717
|
+
return a.logit > b.logit;
|
|
718
|
+
});
|
|
719
|
+
|
|
720
|
+
// apply softmax
|
|
721
|
+
float max_l = cur[0].logit;
|
|
722
|
+
float cum_sum = 0.0f;
|
|
723
|
+
for (size_t i = 0; i < cur.size(); ++i) {
|
|
724
|
+
float p = expf(cur[i].logit - max_l);
|
|
725
|
+
cur[i].p = p;
|
|
726
|
+
cum_sum += p;
|
|
727
|
+
}
|
|
728
|
+
for (size_t i = 0; i < cur.size(); ++i) {
|
|
729
|
+
cur[i].p /= cum_sum;
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
return cur;
|
|
943
733
|
}
|
|
@@ -2,4 +2,4 @@ set(TARGET llama-simple)
|
|
|
2
2
|
add_executable(${TARGET} simple.cpp)
|
|
3
3
|
install(TARGETS ${TARGET} RUNTIME)
|
|
4
4
|
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
|
|
5
|
-
target_compile_features(${TARGET} PRIVATE
|
|
5
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
|
@@ -2,4 +2,4 @@ set(TARGET llama-simple-chat)
|
|
|
2
2
|
add_executable(${TARGET} simple-chat.cpp)
|
|
3
3
|
install(TARGETS ${TARGET} RUNTIME)
|
|
4
4
|
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
|
|
5
|
-
target_compile_features(${TARGET} PRIVATE
|
|
5
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
|
@@ -2,4 +2,4 @@ set(TARGET llama-speculative)
|
|
|
2
2
|
add_executable(${TARGET} speculative.cpp)
|
|
3
3
|
install(TARGETS ${TARGET} RUNTIME)
|
|
4
4
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
5
|
-
target_compile_features(${TARGET} PRIVATE
|
|
5
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
#include <string>
|
|
13
13
|
#include <vector>
|
|
14
14
|
|
|
15
|
-
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE
|
|
15
|
+
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
|
16
16
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
17
17
|
|
|
18
18
|
struct seq_draft {
|
|
@@ -33,7 +33,7 @@ int main(int argc, char ** argv) {
|
|
|
33
33
|
common_params params;
|
|
34
34
|
|
|
35
35
|
// needed to get candidate probs even for temp <= 0.0
|
|
36
|
-
params.
|
|
36
|
+
params.sampling.n_probs = 128;
|
|
37
37
|
|
|
38
38
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
|
39
39
|
return 1;
|
|
@@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
|
|
|
46
46
|
|
|
47
47
|
common_init();
|
|
48
48
|
|
|
49
|
-
if (params.
|
|
49
|
+
if (params.speculative.model.empty()) {
|
|
50
50
|
LOG_ERR("%s: --model-draft is required\n", __func__);
|
|
51
51
|
return 1;
|
|
52
52
|
}
|
|
@@ -55,9 +55,9 @@ int main(int argc, char ** argv) {
|
|
|
55
55
|
const int n_seq_dft = params.n_parallel;
|
|
56
56
|
|
|
57
57
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
|
58
|
-
const float
|
|
58
|
+
const float p_draft_split = params.speculative.p_split;
|
|
59
59
|
|
|
60
|
-
std::default_random_engine rng(params.
|
|
60
|
+
std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
|
|
61
61
|
std::uniform_real_distribution<> u_dist;
|
|
62
62
|
|
|
63
63
|
// init llama.cpp
|
|
@@ -76,13 +76,14 @@ int main(int argc, char ** argv) {
|
|
|
76
76
|
ctx_tgt = llama_init_tgt.context;
|
|
77
77
|
|
|
78
78
|
// load the draft model
|
|
79
|
-
params.
|
|
80
|
-
params.
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
params.devices = params.speculative.devices;
|
|
80
|
+
params.model = params.speculative.model;
|
|
81
|
+
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
|
82
|
+
if (params.speculative.cpuparams.n_threads > 0) {
|
|
83
|
+
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
|
83
84
|
}
|
|
84
85
|
|
|
85
|
-
params.cpuparams_batch.n_threads = params.
|
|
86
|
+
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
|
86
87
|
common_init_result llama_init_dft = common_init_from_params(params);
|
|
87
88
|
model_dft = llama_init_dft.model;
|
|
88
89
|
ctx_dft = llama_init_dft.context;
|
|
@@ -170,7 +171,7 @@ int main(int argc, char ** argv) {
|
|
|
170
171
|
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
|
171
172
|
|
|
172
173
|
// how many tokens to draft each time
|
|
173
|
-
int n_draft = params.
|
|
174
|
+
int n_draft = params.speculative.n_max;
|
|
174
175
|
|
|
175
176
|
int n_predict = 0;
|
|
176
177
|
int n_drafted = 0;
|
|
@@ -183,14 +184,14 @@ int main(int argc, char ** argv) {
|
|
|
183
184
|
bool has_eos = false;
|
|
184
185
|
|
|
185
186
|
// target model sampling context (reuse the llama_context's sampling instance)
|
|
186
|
-
struct common_sampler * smpl = common_sampler_init(model_tgt, params.
|
|
187
|
+
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
|
187
188
|
|
|
188
189
|
// draft sequence data
|
|
189
190
|
std::vector<seq_draft> drafts(n_seq_dft);
|
|
190
191
|
|
|
191
192
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
192
193
|
// allocate llama_sampler for each draft sequence
|
|
193
|
-
drafts[s].smpl = common_sampler_init(model_dft, params.
|
|
194
|
+
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
|
|
194
195
|
}
|
|
195
196
|
|
|
196
197
|
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
|
@@ -230,7 +231,7 @@ int main(int argc, char ** argv) {
|
|
|
230
231
|
// for stochastic sampling, attempt to match the token with the drafted tokens
|
|
231
232
|
{
|
|
232
233
|
bool accept = false;
|
|
233
|
-
if (params.
|
|
234
|
+
if (params.sampling.temp > 0) {
|
|
234
235
|
// stochastic verification
|
|
235
236
|
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
|
236
237
|
|
|
@@ -494,7 +495,7 @@ int main(int argc, char ** argv) {
|
|
|
494
495
|
|
|
495
496
|
// attempt to split the branch if the probability is high enough
|
|
496
497
|
for (int f = 1; f < 8; ++f) {
|
|
497
|
-
if (n_seq_cur < n_seq_dft && cur_p->data[f].p >
|
|
498
|
+
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
|
|
498
499
|
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
|
499
500
|
|
|
500
501
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|