@fugood/llama.node 0.3.3 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +29 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +17 -1
- package/src/LlamaContext.cpp +86 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -99,7 +99,7 @@ struct ring_buffer {
|
|
|
99
99
|
};
|
|
100
100
|
|
|
101
101
|
struct common_sampler {
|
|
102
|
-
|
|
102
|
+
common_params_sampling params;
|
|
103
103
|
|
|
104
104
|
struct llama_sampler * grmr;
|
|
105
105
|
struct llama_sampler * chain;
|
|
@@ -125,7 +125,7 @@ struct common_sampler {
|
|
|
125
125
|
}
|
|
126
126
|
};
|
|
127
127
|
|
|
128
|
-
std::string
|
|
128
|
+
std::string common_params_sampling::print() const {
|
|
129
129
|
char result[1024];
|
|
130
130
|
|
|
131
131
|
snprintf(result, sizeof(result),
|
|
@@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
|
|
|
141
141
|
return std::string(result);
|
|
142
142
|
}
|
|
143
143
|
|
|
144
|
-
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct
|
|
144
|
+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
|
145
145
|
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
|
146
146
|
|
|
147
147
|
lparams.no_perf = params.no_perf;
|
|
@@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
161
161
|
params.logit_bias.size(),
|
|
162
162
|
params.logit_bias.data()));
|
|
163
163
|
|
|
164
|
-
llama_sampler_chain_add(result->chain,
|
|
165
|
-
llama_sampler_init_penalties(
|
|
166
|
-
llama_n_vocab (model),
|
|
167
|
-
llama_token_eos(model),
|
|
168
|
-
llama_token_nl (model),
|
|
169
|
-
params.penalty_last_n,
|
|
170
|
-
params.penalty_repeat,
|
|
171
|
-
params.penalty_freq,
|
|
172
|
-
params.penalty_present,
|
|
173
|
-
params.penalize_nl,
|
|
174
|
-
params.ignore_eos));
|
|
175
|
-
|
|
176
164
|
if (params.mirostat == 0) {
|
|
177
165
|
for (const auto & cnstr : params.samplers) {
|
|
178
166
|
switch (cnstr) {
|
|
179
|
-
|
|
167
|
+
case COMMON_SAMPLER_TYPE_DRY:
|
|
180
168
|
{
|
|
181
|
-
std::vector<const char*> c_breakers;
|
|
169
|
+
std::vector<const char *> c_breakers;
|
|
182
170
|
c_breakers.reserve(params.dry_sequence_breakers.size());
|
|
183
|
-
for (const auto& str : params.dry_sequence_breakers) {
|
|
171
|
+
for (const auto & str : params.dry_sequence_breakers) {
|
|
184
172
|
c_breakers.push_back(str.c_str());
|
|
185
173
|
}
|
|
186
174
|
|
|
187
175
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
188
176
|
}
|
|
189
|
-
|
|
177
|
+
break;
|
|
190
178
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
191
179
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
|
192
180
|
break;
|
|
@@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
208
196
|
case COMMON_SAMPLER_TYPE_INFILL:
|
|
209
197
|
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
|
210
198
|
break;
|
|
199
|
+
case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
200
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
201
|
+
break;
|
|
211
202
|
default:
|
|
212
203
|
GGML_ASSERT(false && "unknown sampler type");
|
|
213
204
|
}
|
|
@@ -320,6 +311,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|
|
320
311
|
return cur_p.data[cur_p.selected].id;
|
|
321
312
|
}
|
|
322
313
|
|
|
314
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
|
315
|
+
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
|
316
|
+
|
|
317
|
+
std::vector<llama_token> result;
|
|
318
|
+
result.reserve(idxs.size());
|
|
319
|
+
|
|
320
|
+
size_t i = 0;
|
|
321
|
+
for (; i < draft.size(); i++) {
|
|
322
|
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
323
|
+
|
|
324
|
+
common_sampler_accept(gsmpl, id, true);
|
|
325
|
+
|
|
326
|
+
result.push_back(id);
|
|
327
|
+
|
|
328
|
+
if (draft[i] != id) {
|
|
329
|
+
break;
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
if (i == draft.size()) {
|
|
334
|
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
335
|
+
|
|
336
|
+
common_sampler_accept(gsmpl, id, true);
|
|
337
|
+
|
|
338
|
+
result.push_back(id);
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
return result;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
|
345
|
+
std::vector<int> idxs(draft.size() + 1);
|
|
346
|
+
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
347
|
+
idxs[i] = i;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
|
351
|
+
}
|
|
352
|
+
|
|
323
353
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|
324
354
|
return llama_sampler_get_seed(gsmpl->chain);
|
|
325
355
|
}
|
|
@@ -376,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|
|
376
406
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
|
377
407
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
|
378
408
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
|
409
|
+
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
|
379
410
|
default : return '?';
|
|
380
411
|
}
|
|
381
412
|
}
|
|
@@ -390,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|
|
390
421
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
|
391
422
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
|
392
423
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
|
424
|
+
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
|
393
425
|
default : return "";
|
|
394
426
|
}
|
|
395
427
|
}
|
|
@@ -404,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|
|
404
436
|
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
405
437
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
|
406
438
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
|
439
|
+
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
|
407
440
|
};
|
|
408
441
|
|
|
409
442
|
// since samplers names are written multiple ways
|
|
@@ -450,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|
|
450
483
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
451
484
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
|
452
485
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
|
486
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
|
453
487
|
};
|
|
454
488
|
|
|
455
489
|
std::vector<common_sampler_type> samplers;
|
|
@@ -36,7 +36,7 @@ struct common_sampler;
|
|
|
36
36
|
|
|
37
37
|
// llama_sampler API overloads
|
|
38
38
|
|
|
39
|
-
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct
|
|
39
|
+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
|
40
40
|
|
|
41
41
|
void common_sampler_free(struct common_sampler * gsmpl);
|
|
42
42
|
|
|
@@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
|
|
60
60
|
//
|
|
61
61
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
|
62
62
|
|
|
63
|
+
// generalized version of common_sampler_sample
|
|
64
|
+
//
|
|
65
|
+
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
|
66
|
+
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
|
67
|
+
//
|
|
68
|
+
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
|
69
|
+
//
|
|
70
|
+
// is equivalent to
|
|
71
|
+
//
|
|
72
|
+
// common_sampler_sample(gsmpl, ctx, idx);
|
|
73
|
+
// common_sampler_accept(gsmpl, token, true);
|
|
74
|
+
//
|
|
75
|
+
// requires: idxs.size() == draft.size() + 1
|
|
76
|
+
//
|
|
77
|
+
// returns at least 1 token, up to idxs.size()
|
|
78
|
+
//
|
|
79
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
|
80
|
+
|
|
81
|
+
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
|
82
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
|
83
|
+
|
|
63
84
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
|
64
85
|
|
|
65
86
|
// helpers
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
#include "speculative.h"
|
|
2
|
+
|
|
3
|
+
#include "log.h"
|
|
4
|
+
#include "common.h"
|
|
5
|
+
#include "sampling.h"
|
|
6
|
+
|
|
7
|
+
#include <cstring>
|
|
8
|
+
|
|
9
|
+
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
|
10
|
+
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
11
|
+
|
|
12
|
+
struct common_speculative {
|
|
13
|
+
struct llama_context * ctx;
|
|
14
|
+
struct common_sampler * smpl;
|
|
15
|
+
|
|
16
|
+
llama_batch batch;
|
|
17
|
+
llama_tokens prompt;
|
|
18
|
+
};
|
|
19
|
+
|
|
20
|
+
struct common_speculative * common_speculative_init(
|
|
21
|
+
struct llama_context * ctx_dft) {
|
|
22
|
+
auto * result = new common_speculative {
|
|
23
|
+
/* .ctx = */ ctx_dft,
|
|
24
|
+
/* .smpl = */ nullptr,
|
|
25
|
+
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
|
26
|
+
/* .prompt = */ {},
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
// TODO: optimize or pass from outside?
|
|
30
|
+
#if 0
|
|
31
|
+
{
|
|
32
|
+
common_params_sampling params;
|
|
33
|
+
params.no_perf = false;
|
|
34
|
+
|
|
35
|
+
params.top_k = 40;
|
|
36
|
+
params.top_p = 0.9;
|
|
37
|
+
|
|
38
|
+
params.samplers = {
|
|
39
|
+
COMMON_SAMPLER_TYPE_TOP_K,
|
|
40
|
+
COMMON_SAMPLER_TYPE_TOP_P,
|
|
41
|
+
COMMON_SAMPLER_TYPE_INFILL,
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
|
45
|
+
}
|
|
46
|
+
#else
|
|
47
|
+
{
|
|
48
|
+
common_params_sampling params;
|
|
49
|
+
params.no_perf = false;
|
|
50
|
+
|
|
51
|
+
params.top_k = 10;
|
|
52
|
+
|
|
53
|
+
params.samplers = {
|
|
54
|
+
COMMON_SAMPLER_TYPE_TOP_K,
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
|
58
|
+
}
|
|
59
|
+
#endif
|
|
60
|
+
|
|
61
|
+
return result;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
void common_speculative_free(struct common_speculative * spec) {
|
|
65
|
+
if (spec == nullptr) {
|
|
66
|
+
return;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
common_sampler_free(spec->smpl);
|
|
70
|
+
|
|
71
|
+
llama_batch_free(spec->batch);
|
|
72
|
+
|
|
73
|
+
delete spec;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
bool common_speculative_are_compatible(
|
|
77
|
+
const struct llama_context * ctx_tgt,
|
|
78
|
+
const struct llama_context * ctx_dft) {
|
|
79
|
+
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
|
|
80
|
+
const struct llama_model * model_dft = llama_get_model(ctx_dft);
|
|
81
|
+
|
|
82
|
+
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
|
|
83
|
+
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
|
84
|
+
|
|
85
|
+
const bool vocab_type_dft = llama_vocab_type(model_dft);
|
|
86
|
+
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
|
87
|
+
|
|
88
|
+
if (vocab_type_tgt != vocab_type_dft) {
|
|
89
|
+
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
|
|
90
|
+
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
|
|
91
|
+
return false;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
|
|
95
|
+
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
|
|
96
|
+
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
|
|
97
|
+
llama_token_eos(model_tgt) != llama_token_eos(model_dft)) {
|
|
98
|
+
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
|
99
|
+
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt));
|
|
100
|
+
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft));
|
|
101
|
+
return false;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
{
|
|
105
|
+
const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
|
106
|
+
const int n_vocab_dft = llama_n_vocab(model_dft);
|
|
107
|
+
|
|
108
|
+
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
|
|
109
|
+
|
|
110
|
+
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
111
|
+
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
|
|
112
|
+
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
113
|
+
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
114
|
+
return false;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
118
|
+
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
|
119
|
+
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
|
120
|
+
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
121
|
+
LOG_ERR("%s: draft model vocab must match target model to use speculation but "
|
|
122
|
+
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
|
|
123
|
+
common_token_to_piece(ctx_tgt, i).c_str(),
|
|
124
|
+
common_token_to_piece(ctx_dft, i).c_str());
|
|
125
|
+
return false;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
return true;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
llama_tokens common_speculative_gen_draft(
|
|
134
|
+
struct common_speculative * spec,
|
|
135
|
+
struct common_speculative_params params,
|
|
136
|
+
const llama_tokens & prompt_tgt,
|
|
137
|
+
llama_token id_last) {
|
|
138
|
+
auto & batch = spec->batch;
|
|
139
|
+
auto & ctx = spec->ctx;
|
|
140
|
+
auto & smpl = spec->smpl;
|
|
141
|
+
auto & prompt = spec->prompt;
|
|
142
|
+
|
|
143
|
+
int reuse_i = 0;
|
|
144
|
+
int reuse_n = 0;
|
|
145
|
+
|
|
146
|
+
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
|
|
147
|
+
|
|
148
|
+
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
|
149
|
+
|
|
150
|
+
// reuse as much as possible from the old draft context
|
|
151
|
+
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
|
152
|
+
for (int i = 0; i < (int) prompt.size(); ++i) {
|
|
153
|
+
int cur = 0;
|
|
154
|
+
while (i_start + cur < (int) prompt_tgt.size() &&
|
|
155
|
+
i + cur < (int) prompt.size() &&
|
|
156
|
+
prompt_tgt[i_start + cur] == prompt[i + cur]) {
|
|
157
|
+
cur++;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
|
|
161
|
+
reuse_i = i;
|
|
162
|
+
reuse_n = cur;
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
|
|
167
|
+
|
|
168
|
+
llama_tokens result;
|
|
169
|
+
result.reserve(params.n_draft);
|
|
170
|
+
|
|
171
|
+
if (reuse_n == 0) {
|
|
172
|
+
llama_kv_cache_clear(ctx);
|
|
173
|
+
|
|
174
|
+
prompt.clear();
|
|
175
|
+
} else {
|
|
176
|
+
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
|
177
|
+
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
|
178
|
+
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
|
|
179
|
+
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
|
|
180
|
+
result.push_back(prompt[i]);
|
|
181
|
+
|
|
182
|
+
if (params.n_draft <= (int) result.size()) {
|
|
183
|
+
break;
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
return result;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
if (reuse_i > 0) {
|
|
191
|
+
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
|
|
192
|
+
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
|
|
193
|
+
|
|
194
|
+
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if (reuse_n < (int) prompt.size()) {
|
|
198
|
+
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
|
|
199
|
+
|
|
200
|
+
prompt.erase(prompt.begin() + reuse_n, prompt.end());
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
// prepare a batch to evaluate any new tokens in the prompt
|
|
205
|
+
common_batch_clear(batch);
|
|
206
|
+
|
|
207
|
+
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
|
208
|
+
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
|
209
|
+
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
|
210
|
+
|
|
211
|
+
prompt.push_back(prompt_tgt[i]);
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
// we should rarely end-up here during normal decoding
|
|
215
|
+
if (batch.n_tokens > 0) {
|
|
216
|
+
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
|
217
|
+
|
|
218
|
+
llama_decode(ctx, batch);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
const llama_pos n_past = prompt.size();
|
|
222
|
+
|
|
223
|
+
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
|
224
|
+
|
|
225
|
+
common_batch_clear(batch);
|
|
226
|
+
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
|
227
|
+
|
|
228
|
+
prompt.push_back(id_last);
|
|
229
|
+
|
|
230
|
+
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
|
|
231
|
+
|
|
232
|
+
llama_decode(ctx, batch);
|
|
233
|
+
|
|
234
|
+
common_sampler_reset(smpl);
|
|
235
|
+
|
|
236
|
+
// sample n_draft tokens from the draft model
|
|
237
|
+
for (int i = 0; i < params.n_draft; ++i) {
|
|
238
|
+
common_batch_clear(batch);
|
|
239
|
+
|
|
240
|
+
common_sampler_sample(smpl, ctx, 0, true);
|
|
241
|
+
|
|
242
|
+
const auto * cur_p = common_sampler_get_candidates(smpl);
|
|
243
|
+
|
|
244
|
+
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
|
245
|
+
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
|
246
|
+
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// add drafted token for each sequence
|
|
250
|
+
const llama_token id = cur_p->data[0].id;
|
|
251
|
+
|
|
252
|
+
// only collect very high-confidence draft tokens
|
|
253
|
+
if (cur_p->data[0].p < params.p_min) {
|
|
254
|
+
break;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
common_sampler_accept(smpl, id, true);
|
|
258
|
+
|
|
259
|
+
result.push_back(id);
|
|
260
|
+
|
|
261
|
+
if (params.n_draft <= (int) result.size()) {
|
|
262
|
+
break;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
|
266
|
+
|
|
267
|
+
// evaluate the drafted tokens on the draft model
|
|
268
|
+
llama_decode(ctx, batch);
|
|
269
|
+
|
|
270
|
+
prompt.push_back(id);
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
return result;
|
|
274
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "llama.h"
|
|
4
|
+
#include "common.h"
|
|
5
|
+
|
|
6
|
+
struct common_speculative;
|
|
7
|
+
|
|
8
|
+
struct common_speculative_params {
|
|
9
|
+
int n_draft = 16; // max drafted tokens
|
|
10
|
+
int n_reuse = 256;
|
|
11
|
+
|
|
12
|
+
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
|
16
|
+
|
|
17
|
+
void common_speculative_free(struct common_speculative * spec);
|
|
18
|
+
|
|
19
|
+
bool common_speculative_are_compatible(
|
|
20
|
+
const struct llama_context * ctx_tgt,
|
|
21
|
+
const struct llama_context * ctx_dft);
|
|
22
|
+
|
|
23
|
+
// sample up to n_draft tokens and add them to the batch using the draft model
|
|
24
|
+
llama_tokens common_speculative_gen_draft(
|
|
25
|
+
struct common_speculative * spec,
|
|
26
|
+
struct common_speculative_params params,
|
|
27
|
+
const llama_tokens & prompt,
|
|
28
|
+
llama_token id_last);
|