whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,14 +1,16 @@
|
|
1
1
|
#include "llama-context.h"
|
2
2
|
|
3
3
|
#include "llama-impl.h"
|
4
|
+
#include "llama-batch.h"
|
4
5
|
#include "llama-io.h"
|
6
|
+
#include "llama-memory.h"
|
5
7
|
#include "llama-mmap.h"
|
6
8
|
#include "llama-model.h"
|
7
|
-
#include "llama-kv-cache.h"
|
8
9
|
|
10
|
+
#include <cinttypes>
|
9
11
|
#include <cstring>
|
12
|
+
#include <limits>
|
10
13
|
#include <stdexcept>
|
11
|
-
#include <cinttypes>
|
12
14
|
|
13
15
|
//
|
14
16
|
// llama_context
|
@@ -17,7 +19,8 @@
|
|
17
19
|
llama_context::llama_context(
|
18
20
|
const llama_model & model,
|
19
21
|
llama_context_params params) :
|
20
|
-
model(model)
|
22
|
+
model(model),
|
23
|
+
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
21
24
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
22
25
|
|
23
26
|
t_start_us = model.t_start_us;
|
@@ -26,8 +29,8 @@ llama_context::llama_context(
|
|
26
29
|
const auto & hparams = model.hparams;
|
27
30
|
|
28
31
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
29
|
-
if (cparams.n_seq_max >
|
30
|
-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(
|
32
|
+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
33
|
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
31
34
|
}
|
32
35
|
|
33
36
|
cparams.n_threads = params.n_threads;
|
@@ -122,6 +125,11 @@ llama_context::llama_context(
|
|
122
125
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
123
126
|
}
|
124
127
|
|
128
|
+
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
|
129
|
+
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
130
|
+
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
131
|
+
}
|
132
|
+
|
125
133
|
if (!hparams.vocab_only) {
|
126
134
|
// GPU backends
|
127
135
|
for (auto * dev : model.devices) {
|
@@ -259,15 +267,9 @@ llama_context::llama_context(
|
|
259
267
|
|
260
268
|
// reserve worst-case graph
|
261
269
|
if (!hparams.vocab_only && memory) {
|
262
|
-
const uint32_t n_seqs =
|
270
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
263
271
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
264
272
|
|
265
|
-
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
266
|
-
|
267
|
-
// restore later
|
268
|
-
// TODO: something cleaner
|
269
|
-
const auto n_outputs_save = n_outputs;
|
270
|
-
|
271
273
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
272
274
|
|
273
275
|
int n_splits_pp = -1;
|
@@ -277,25 +279,18 @@ llama_context::llama_context(
|
|
277
279
|
int n_nodes_tg = -1;
|
278
280
|
|
279
281
|
// simulate full KV cache
|
280
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
281
282
|
|
282
|
-
|
283
|
+
const auto mctx = memory->init_full();
|
284
|
+
if (!mctx) {
|
285
|
+
throw std::runtime_error("failed to initialize KV cache");
|
286
|
+
}
|
283
287
|
|
284
288
|
cross.v_embd.clear();
|
285
289
|
|
286
290
|
// reserve pp graph first so that buffers are only allocated once
|
287
291
|
{
|
288
|
-
|
289
|
-
|
290
|
-
// max number of outputs
|
291
|
-
n_outputs = ubatch_pp.n_tokens;
|
292
|
-
|
293
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
294
|
-
|
295
|
-
auto * gf = graph_init();
|
296
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
297
|
-
|
298
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
292
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
293
|
+
if (!gf) {
|
299
294
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
300
295
|
}
|
301
296
|
|
@@ -305,16 +300,8 @@ llama_context::llama_context(
|
|
305
300
|
|
306
301
|
// reserve with tg graph to get the number of splits and nodes
|
307
302
|
{
|
308
|
-
|
309
|
-
|
310
|
-
n_outputs = ubatch_tg.n_tokens;
|
311
|
-
|
312
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
313
|
-
|
314
|
-
auto * gf = graph_init();
|
315
|
-
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
316
|
-
|
317
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
303
|
+
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
304
|
+
if (!gf) {
|
318
305
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
319
306
|
}
|
320
307
|
|
@@ -324,22 +311,12 @@ llama_context::llama_context(
|
|
324
311
|
|
325
312
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
326
313
|
{
|
327
|
-
|
328
|
-
|
329
|
-
n_outputs = ubatch_pp.n_tokens;
|
330
|
-
|
331
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
332
|
-
|
333
|
-
auto * gf = graph_init();
|
334
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
335
|
-
|
336
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
314
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
315
|
+
if (!gf) {
|
337
316
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
338
317
|
}
|
339
318
|
}
|
340
319
|
|
341
|
-
n_outputs = n_outputs_save;
|
342
|
-
|
343
320
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
344
321
|
ggml_backend_t backend = backend_ptrs[i];
|
345
322
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
@@ -443,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
|
|
443
420
|
return cparams.n_threads_batch;
|
444
421
|
}
|
445
422
|
|
446
|
-
|
447
|
-
|
448
|
-
return kv_self;
|
423
|
+
llama_memory_t llama_context::get_memory() const {
|
424
|
+
return memory.get();
|
449
425
|
}
|
450
426
|
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
void llama_context::kv_self_update() {
|
457
|
-
bool need_reserve = false;
|
427
|
+
// deprecated
|
428
|
+
void llama_context::kv_self_defrag_sched() {
|
429
|
+
if (!memory) {
|
430
|
+
return;
|
431
|
+
}
|
458
432
|
|
459
|
-
|
433
|
+
memory_force_optimize = true;
|
434
|
+
}
|
460
435
|
|
461
|
-
|
436
|
+
// deprecated
|
437
|
+
bool llama_context::kv_self_update(bool optimize) {
|
438
|
+
if (!memory) {
|
439
|
+
return false;
|
440
|
+
}
|
462
441
|
|
463
|
-
|
464
|
-
|
465
|
-
|
442
|
+
{
|
443
|
+
// TODO: remove in the future
|
444
|
+
optimize |= memory_force_optimize;
|
445
|
+
memory_force_optimize = false;
|
466
446
|
|
467
|
-
|
468
|
-
|
469
|
-
|
447
|
+
const auto mctx = memory->init_update(this, optimize);
|
448
|
+
switch (mctx->get_status()) {
|
449
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
450
|
+
{
|
451
|
+
// noop
|
452
|
+
} break;
|
453
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
454
|
+
{
|
455
|
+
// no updates need to be performed
|
456
|
+
return false;
|
457
|
+
}
|
458
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
459
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
460
|
+
{
|
461
|
+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
462
|
+
return false;
|
463
|
+
}
|
464
|
+
}
|
470
465
|
|
471
|
-
|
472
|
-
|
466
|
+
if (!mctx->apply()) {
|
467
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
468
|
+
}
|
469
|
+
}
|
473
470
|
|
474
|
-
|
475
|
-
|
471
|
+
// if the memory module did any computation, we have to reserve a new worst-case graph
|
472
|
+
{
|
473
|
+
const auto mctx = memory->init_full();
|
474
|
+
if (!mctx) {
|
475
|
+
throw std::runtime_error("failed to initialize memory context");
|
476
|
+
}
|
476
477
|
|
477
|
-
|
478
|
-
|
478
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
479
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
479
480
|
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
481
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
482
|
+
if (!gf) {
|
483
|
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
484
484
|
}
|
485
485
|
}
|
486
|
+
|
487
|
+
return true;
|
486
488
|
}
|
487
489
|
|
488
490
|
enum llama_pooling_type llama_context::pooling_type() const {
|
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
|
|
494
496
|
}
|
495
497
|
|
496
498
|
float * llama_context::get_logits_ith(int32_t i) {
|
497
|
-
|
499
|
+
int64_t j = -1;
|
498
500
|
|
499
501
|
try {
|
500
502
|
if (logits == nullptr) {
|
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
517
519
|
}
|
518
520
|
if (j >= n_outputs) {
|
519
521
|
// This should not happen
|
520
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
522
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
521
523
|
}
|
522
524
|
|
523
525
|
return logits + j*model.vocab.n_tokens();
|
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
|
|
536
538
|
}
|
537
539
|
|
538
540
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
539
|
-
|
541
|
+
int64_t j = -1;
|
540
542
|
|
541
543
|
try {
|
542
544
|
if (embd == nullptr) {
|
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
559
561
|
}
|
560
562
|
if (j >= n_outputs) {
|
561
563
|
// This should not happen
|
562
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
564
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
563
565
|
}
|
564
566
|
|
565
567
|
return embd + j*model.hparams.n_embd;
|
@@ -676,69 +678,95 @@ bool llama_context::apply_adapter_cvec(
|
|
676
678
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
677
679
|
}
|
678
680
|
|
679
|
-
|
680
|
-
if (
|
681
|
-
LLAMA_LOG_ERROR("%s:
|
682
|
-
|
681
|
+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
682
|
+
if (mctx && !mctx->apply()) {
|
683
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
684
|
+
ret = GGML_STATUS_FAILED;
|
685
|
+
return nullptr;
|
683
686
|
}
|
684
687
|
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
+
auto * gf = graph_init();
|
689
|
+
if (!gf) {
|
690
|
+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
691
|
+
ret = GGML_STATUS_FAILED;
|
692
|
+
return nullptr;
|
693
|
+
}
|
688
694
|
|
689
|
-
|
690
|
-
|
695
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
696
|
+
if (!res) {
|
697
|
+
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
698
|
+
ret = GGML_STATUS_FAILED;
|
699
|
+
return nullptr;
|
700
|
+
}
|
691
701
|
|
692
|
-
|
702
|
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
693
703
|
|
694
|
-
|
704
|
+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
705
|
+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
706
|
+
ret = GGML_STATUS_ALLOC_FAILED;
|
707
|
+
return nullptr;
|
708
|
+
}
|
695
709
|
|
696
|
-
|
697
|
-
if (batch.token) {
|
698
|
-
for (int32_t i = 0; i < n_tokens; ++i) {
|
699
|
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
700
|
-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
701
|
-
return -1;
|
702
|
-
}
|
710
|
+
res->set_inputs(&ubatch);
|
703
711
|
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
712
|
+
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
713
|
+
if (status != GGML_STATUS_SUCCESS) {
|
714
|
+
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
715
|
+
ret = status;
|
716
|
+
return nullptr;
|
709
717
|
}
|
710
718
|
|
719
|
+
ret = GGML_STATUS_SUCCESS;
|
720
|
+
|
721
|
+
return res;
|
722
|
+
}
|
723
|
+
|
724
|
+
int llama_context::encode(const llama_batch & batch_inp) {
|
725
|
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
726
|
+
|
727
|
+
if (batch_inp.n_tokens == 0) {
|
728
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
729
|
+
return -1;
|
730
|
+
}
|
731
|
+
|
732
|
+
const auto & hparams = model.hparams;
|
733
|
+
|
734
|
+
const int64_t n_embd = hparams.n_embd;
|
735
|
+
|
736
|
+
// note: during encode, we always pass the full sequence starting from pos = 0
|
737
|
+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
738
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
739
|
+
return -1;
|
740
|
+
}
|
741
|
+
|
742
|
+
const uint32_t n_tokens = balloc->get_n_tokens();
|
743
|
+
|
744
|
+
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
745
|
+
|
711
746
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
712
|
-
GGML_ASSERT(cparams.n_ubatch >=
|
747
|
+
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
713
748
|
|
714
749
|
if (t_compute_start_us == 0) {
|
715
750
|
t_compute_start_us = ggml_time_us();
|
716
751
|
}
|
717
752
|
|
753
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
718
754
|
embd_seq.clear();
|
719
755
|
|
720
756
|
n_queued_tokens += n_tokens;
|
721
757
|
|
722
|
-
const int64_t n_embd = hparams.n_embd;
|
723
|
-
|
724
|
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
725
|
-
|
726
|
-
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
727
|
-
|
728
758
|
// reserve output buffer
|
729
759
|
if (output_reserve(n_tokens) < n_tokens) {
|
730
760
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
731
761
|
return -2;
|
732
762
|
};
|
733
763
|
|
734
|
-
for (
|
764
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
735
765
|
output_ids[i] = i;
|
736
766
|
}
|
737
767
|
|
738
768
|
n_outputs = n_tokens;
|
739
769
|
|
740
|
-
//batch_manager->prepare(ubatch);
|
741
|
-
|
742
770
|
ggml_backend_sched_reset(sched.get());
|
743
771
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
744
772
|
|
@@ -749,26 +777,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
749
777
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
750
778
|
cparams.causal_attn = false;
|
751
779
|
|
752
|
-
|
753
|
-
auto res =
|
754
|
-
|
755
|
-
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
756
|
-
|
757
|
-
res->set_inputs(&ubatch);
|
780
|
+
ggml_status status;
|
781
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
758
782
|
|
759
783
|
cparams.causal_attn = causal_attn_org;
|
760
784
|
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
return -2;
|
769
|
-
case GGML_STATUS_FAILED:
|
770
|
-
default:
|
771
|
-
return -3;
|
785
|
+
if (!res) {
|
786
|
+
switch (status) {
|
787
|
+
case GGML_STATUS_ABORTED: return 2;
|
788
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
789
|
+
case GGML_STATUS_FAILED: return -3;
|
790
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
791
|
+
}
|
772
792
|
}
|
773
793
|
|
774
794
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
@@ -793,31 +813,28 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
793
813
|
{
|
794
814
|
// extract sequence embeddings
|
795
815
|
auto & embd_seq_out = embd_seq;
|
796
|
-
embd_seq_out.clear();
|
797
816
|
|
798
|
-
|
817
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
818
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
819
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
799
820
|
|
800
|
-
for (int32_t i = 0; i < n_tokens; i++) {
|
801
|
-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
802
|
-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
803
|
-
continue;
|
804
|
-
}
|
805
821
|
embd_seq_out[seq_id].resize(n_embd);
|
806
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
822
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
807
823
|
}
|
808
824
|
} break;
|
809
825
|
case LLAMA_POOLING_TYPE_RANK:
|
810
826
|
{
|
811
|
-
// extract the rerank score -
|
827
|
+
// extract the rerank score - n_cls_out floats per sequence
|
812
828
|
auto & embd_seq_out = embd_seq;
|
813
829
|
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
830
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
831
|
+
|
832
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
833
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
834
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
835
|
+
|
836
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
837
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
821
838
|
}
|
822
839
|
} break;
|
823
840
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
@@ -842,12 +859,16 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
842
859
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
843
860
|
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
844
861
|
|
862
|
+
const auto & batch = balloc->get_batch();
|
863
|
+
|
845
864
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
846
865
|
cross.seq_ids_enc.resize(n_tokens);
|
847
|
-
for (
|
866
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
848
867
|
cross.seq_ids_enc[i].clear();
|
849
|
-
|
850
|
-
|
868
|
+
|
869
|
+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
870
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
871
|
+
|
851
872
|
cross.seq_ids_enc[i].insert(seq_id);
|
852
873
|
}
|
853
874
|
}
|
@@ -856,55 +877,42 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
856
877
|
return 0;
|
857
878
|
}
|
858
879
|
|
859
|
-
int llama_context::decode(llama_batch &
|
880
|
+
int llama_context::decode(const llama_batch & batch_inp) {
|
881
|
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
882
|
+
|
860
883
|
if (!memory) {
|
861
884
|
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
862
|
-
return encode(
|
885
|
+
return encode(batch_inp);
|
863
886
|
}
|
864
887
|
|
865
|
-
if (
|
888
|
+
if (batch_inp.n_tokens == 0) {
|
866
889
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
867
890
|
return -1;
|
868
891
|
}
|
869
892
|
|
870
|
-
if (!inp_batch.pos) {
|
871
|
-
if (inp_batch.seq_id) {
|
872
|
-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
873
|
-
return -1;
|
874
|
-
}
|
875
|
-
}
|
876
|
-
|
877
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
878
|
-
|
879
|
-
// temporary allocate memory for the input batch if needed
|
880
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
881
|
-
|
882
|
-
const llama_batch & batch = batch_allocr.batch;
|
883
|
-
|
884
893
|
const auto & vocab = model.vocab;
|
885
894
|
const auto & hparams = model.hparams;
|
886
895
|
|
887
896
|
const int32_t n_vocab = vocab.n_tokens();
|
897
|
+
const int64_t n_embd = hparams.n_embd;
|
888
898
|
|
889
|
-
|
890
|
-
const
|
899
|
+
// when computing embeddings, all tokens are output
|
900
|
+
const bool output_all = cparams.embeddings;
|
891
901
|
|
892
|
-
|
902
|
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
903
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
904
|
+
return -1;
|
905
|
+
}
|
893
906
|
|
894
|
-
|
907
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
908
|
+
const uint32_t n_outputs_all = balloc->get_n_outputs();
|
895
909
|
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
}
|
903
|
-
|
904
|
-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
905
|
-
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
906
|
-
return -1;
|
907
|
-
}
|
910
|
+
if (output_all) {
|
911
|
+
// require that all tokens are output
|
912
|
+
if (n_outputs_all != n_tokens_all) {
|
913
|
+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
914
|
+
__func__, n_outputs_all, n_tokens_all);
|
915
|
+
return -1;
|
908
916
|
}
|
909
917
|
}
|
910
918
|
|
@@ -917,49 +925,77 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
917
925
|
}
|
918
926
|
n_queued_tokens += n_tokens_all;
|
919
927
|
|
920
|
-
// this
|
921
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
922
|
-
|
928
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
923
929
|
embd_seq.clear();
|
924
930
|
|
925
|
-
|
931
|
+
bool did_optimize = false;
|
932
|
+
|
933
|
+
// handle any pending defrags/shifts
|
934
|
+
kv_self_update(false);
|
926
935
|
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
936
|
+
llama_memory_context_ptr mctx;
|
937
|
+
|
938
|
+
while (true) {
|
939
|
+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
940
|
+
if (!mctx) {
|
941
|
+
return -2;
|
931
942
|
}
|
932
|
-
} else if (embd_pooled) {
|
933
|
-
n_outputs_all = n_tokens_all;
|
934
|
-
} else {
|
935
|
-
// keep last output only
|
936
|
-
n_outputs_all = 1;
|
937
|
-
}
|
938
943
|
|
939
|
-
|
944
|
+
switch (mctx->get_status()) {
|
945
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
946
|
+
{
|
947
|
+
} break;
|
948
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
949
|
+
{
|
950
|
+
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
951
|
+
|
952
|
+
return -2;
|
953
|
+
}
|
954
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
955
|
+
{
|
956
|
+
if (!did_optimize) {
|
957
|
+
did_optimize = true;
|
958
|
+
|
959
|
+
if (kv_self_update(true)) {
|
960
|
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
961
|
+
|
962
|
+
continue;
|
963
|
+
}
|
964
|
+
}
|
965
|
+
|
966
|
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
|
967
|
+
|
968
|
+
return 1;
|
969
|
+
}
|
970
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
971
|
+
{
|
972
|
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
|
973
|
+
|
974
|
+
return -2;
|
975
|
+
}
|
976
|
+
}
|
977
|
+
|
978
|
+
break;
|
979
|
+
}
|
940
980
|
|
941
981
|
// reserve output buffer
|
942
982
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
943
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
983
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
944
984
|
return -2;
|
945
985
|
};
|
946
986
|
|
947
|
-
// handle any pending defrags/shifts
|
948
|
-
kv_self_update();
|
949
|
-
|
950
987
|
int64_t n_outputs_prev = 0;
|
951
988
|
|
952
|
-
|
953
|
-
|
989
|
+
do {
|
990
|
+
const auto & ubatch = mctx->get_ubatch();
|
954
991
|
|
955
|
-
// count the outputs in this
|
992
|
+
// count the outputs in this ubatch
|
956
993
|
{
|
957
994
|
int32_t n_outputs_new = 0;
|
958
995
|
|
959
996
|
if (n_outputs_all == n_tokens_all) {
|
960
997
|
n_outputs_new = ubatch.n_tokens;
|
961
998
|
} else {
|
962
|
-
GGML_ASSERT(ubatch.output);
|
963
999
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
964
1000
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
965
1001
|
}
|
@@ -969,33 +1005,40 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
969
1005
|
n_outputs = n_outputs_new;
|
970
1006
|
}
|
971
1007
|
|
972
|
-
// find KV slot
|
973
|
-
if (!kv_self->find_slot(ubatch)) {
|
974
|
-
return 1;
|
975
|
-
}
|
976
|
-
|
977
1008
|
ggml_backend_sched_reset(sched.get());
|
978
1009
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
979
1010
|
|
980
|
-
|
981
|
-
auto res =
|
1011
|
+
ggml_status status;
|
1012
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
982
1013
|
|
983
|
-
|
1014
|
+
if (!res) {
|
1015
|
+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
1016
|
+
llama_pos pos_min[LLAMA_MAX_SEQ];
|
1017
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
1018
|
+
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
1019
|
+
}
|
984
1020
|
|
985
|
-
|
1021
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
1022
|
+
const auto & seq_id = ubatch.seq_id[i][0];
|
986
1023
|
|
987
|
-
|
1024
|
+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
1025
|
+
}
|
988
1026
|
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
1027
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
1028
|
+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
1029
|
+
continue;
|
1030
|
+
}
|
1031
|
+
|
1032
|
+
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
1033
|
+
|
1034
|
+
memory->seq_rm(s, pos_min[s], -1);
|
1035
|
+
}
|
1036
|
+
|
1037
|
+
switch (status) {
|
1038
|
+
case GGML_STATUS_ABORTED: return 2;
|
1039
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
1040
|
+
case GGML_STATUS_FAILED: return -3;
|
1041
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
999
1042
|
}
|
1000
1043
|
}
|
1001
1044
|
|
@@ -1004,7 +1047,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1004
1047
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
1005
1048
|
//}
|
1006
1049
|
|
1007
|
-
auto * t_logits =
|
1050
|
+
auto * t_logits = res->get_logits();
|
1008
1051
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
1009
1052
|
|
1010
1053
|
if (t_embd && res->get_embd_pooled()) {
|
@@ -1051,27 +1094,27 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1051
1094
|
// extract sequence embeddings (cleared before processing each batch)
|
1052
1095
|
auto & embd_seq_out = embd_seq;
|
1053
1096
|
|
1054
|
-
for (uint32_t s = 0; s < ubatch.
|
1055
|
-
const llama_seq_id seq_id
|
1056
|
-
|
1057
|
-
|
1058
|
-
}
|
1097
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
1098
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
1099
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
1100
|
+
|
1059
1101
|
embd_seq_out[seq_id].resize(n_embd);
|
1060
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
1102
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
1061
1103
|
}
|
1062
1104
|
} break;
|
1063
1105
|
case LLAMA_POOLING_TYPE_RANK:
|
1064
1106
|
{
|
1065
|
-
// extract the rerank score -
|
1107
|
+
// extract the rerank score - n_cls_out floats per sequence
|
1066
1108
|
auto & embd_seq_out = embd_seq;
|
1067
1109
|
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1110
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
1111
|
+
|
1112
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
1113
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
1114
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
1115
|
+
|
1116
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
1117
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
1075
1118
|
}
|
1076
1119
|
} break;
|
1077
1120
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
@@ -1082,23 +1125,20 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1082
1125
|
}
|
1083
1126
|
|
1084
1127
|
n_outputs_prev += n_outputs;
|
1085
|
-
}
|
1086
|
-
|
1087
|
-
// finalize the batch processing
|
1088
|
-
kv_guard.commit();
|
1128
|
+
} while (mctx->next());
|
1089
1129
|
|
1090
1130
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
1091
1131
|
n_outputs = n_outputs_all;
|
1092
1132
|
|
1093
1133
|
// set output mappings
|
1094
|
-
{
|
1134
|
+
if (n_outputs > 0) {
|
1095
1135
|
bool sorted_output = true;
|
1096
1136
|
|
1097
|
-
auto & out_ids =
|
1137
|
+
auto & out_ids = balloc->get_out_ids();
|
1098
1138
|
|
1099
|
-
GGML_ASSERT(out_ids.size() == (size_t)
|
1139
|
+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
1100
1140
|
|
1101
|
-
for (int64_t i = 0; i <
|
1141
|
+
for (int64_t i = 0; i < n_outputs; ++i) {
|
1102
1142
|
int64_t out_id = out_ids[i];
|
1103
1143
|
output_ids[out_id] = i;
|
1104
1144
|
if (out_id != i) {
|
@@ -1110,20 +1150,22 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1110
1150
|
// note: this is mostly relevant for recurrent models atm
|
1111
1151
|
if (!sorted_output) {
|
1112
1152
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
1113
|
-
const
|
1153
|
+
const uint64_t n_embd = model.hparams.n_embd;
|
1114
1154
|
|
1115
1155
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
1116
1156
|
|
1117
1157
|
// TODO: is there something more efficient which also minimizes swaps?
|
1118
1158
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
1119
|
-
for (
|
1120
|
-
|
1121
|
-
for (
|
1159
|
+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
1160
|
+
uint32_t j_min = i;
|
1161
|
+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
1122
1162
|
if (out_ids[j] < out_ids[j_min]) {
|
1123
1163
|
j_min = j;
|
1124
1164
|
}
|
1125
1165
|
}
|
1126
|
-
if (j_min == i) {
|
1166
|
+
if (j_min == i) {
|
1167
|
+
continue;
|
1168
|
+
}
|
1127
1169
|
std::swap(out_ids[i], out_ids[j_min]);
|
1128
1170
|
if (logits_size > 0) {
|
1129
1171
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
@@ -1136,8 +1178,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1136
1178
|
}
|
1137
1179
|
}
|
1138
1180
|
}
|
1181
|
+
|
1139
1182
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1140
|
-
|
1183
|
+
|
1184
|
+
for (uint32_t i = 0; i < n_outputs; ++i) {
|
1141
1185
|
output_ids[out_ids[i]] = i;
|
1142
1186
|
}
|
1143
1187
|
}
|
@@ -1146,11 +1190,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1146
1190
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
1147
1191
|
//synchronize();
|
1148
1192
|
|
1149
|
-
// decide if we need to defrag the kv cache
|
1150
|
-
if (cparams.defrag_thold > 0.0f) {
|
1151
|
-
kv_self->defrag_sched(cparams.defrag_thold);
|
1152
|
-
}
|
1153
|
-
|
1154
1193
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
1155
1194
|
// overlap with device computation.
|
1156
1195
|
ggml_backend_sched_reset(sched.get());
|
@@ -1162,7 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1162
1201
|
// output
|
1163
1202
|
//
|
1164
1203
|
|
1165
|
-
|
1204
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
1166
1205
|
const auto & hparams = model.hparams;
|
1167
1206
|
const auto & vocab = model.vocab;
|
1168
1207
|
|
@@ -1172,9 +1211,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1172
1211
|
const auto n_vocab = vocab.n_tokens();
|
1173
1212
|
const auto n_embd = hparams.n_embd;
|
1174
1213
|
|
1175
|
-
|
1176
|
-
bool
|
1177
|
-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
1214
|
+
bool has_logits = true;
|
1215
|
+
bool has_embd = cparams.embeddings;
|
1178
1216
|
|
1179
1217
|
// TODO: hacky enc-dec support
|
1180
1218
|
if (model.arch == LLM_ARCH_T5) {
|
@@ -1228,8 +1266,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1228
1266
|
// set all ids as invalid (negative)
|
1229
1267
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1230
1268
|
|
1231
|
-
this->n_outputs
|
1232
|
-
this->n_outputs_max = n_outputs_max;
|
1269
|
+
this->n_outputs = 0;
|
1233
1270
|
|
1234
1271
|
return n_outputs_max;
|
1235
1272
|
}
|
@@ -1254,11 +1291,52 @@ ggml_cgraph * llama_context::graph_init() {
|
|
1254
1291
|
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
1255
1292
|
}
|
1256
1293
|
|
1294
|
+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
1295
|
+
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
1296
|
+
|
1297
|
+
if (n_tokens % n_seqs != 0) {
|
1298
|
+
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
1299
|
+
n_outputs = std::min(n_outputs, n_tokens);
|
1300
|
+
|
1301
|
+
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
1302
|
+
}
|
1303
|
+
|
1304
|
+
// store the n_outputs as it is, and restore it afterwards
|
1305
|
+
// TODO: not sure if needed, might simplify in the future by removing this
|
1306
|
+
const auto save_n_outputs = this->n_outputs;
|
1307
|
+
|
1308
|
+
this->n_outputs = n_outputs;
|
1309
|
+
|
1310
|
+
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
1311
|
+
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
1312
|
+
|
1313
|
+
auto * gf = graph_init();
|
1314
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
1315
|
+
|
1316
|
+
this->n_outputs = save_n_outputs;
|
1317
|
+
|
1318
|
+
if (!res) {
|
1319
|
+
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
1320
|
+
return nullptr;
|
1321
|
+
}
|
1322
|
+
|
1323
|
+
ggml_backend_sched_reset(sched.get());
|
1324
|
+
|
1325
|
+
// initialize scheduler with the specified graph
|
1326
|
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
1327
|
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
1328
|
+
return nullptr;
|
1329
|
+
}
|
1330
|
+
|
1331
|
+
return gf;
|
1332
|
+
}
|
1333
|
+
|
1257
1334
|
llm_graph_result_ptr llama_context::graph_build(
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1335
|
+
ggml_context * ctx,
|
1336
|
+
ggml_cgraph * gf,
|
1337
|
+
const llama_ubatch & ubatch,
|
1338
|
+
llm_graph_type gtype,
|
1339
|
+
const llama_memory_context_i * mctx) {
|
1262
1340
|
return model.build_graph(
|
1263
1341
|
{
|
1264
1342
|
/*.ctx =*/ ctx,
|
@@ -1270,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
1270
1348
|
/*.backend_cpu =*/ backend_cpu,
|
1271
1349
|
/*.cvec =*/ &cvec,
|
1272
1350
|
/*.loras =*/ &loras,
|
1273
|
-
/*.
|
1351
|
+
/*.mctx =*/ mctx,
|
1274
1352
|
/*.cross =*/ &cross,
|
1275
1353
|
/*.n_outputs =*/ n_outputs,
|
1276
1354
|
/*.cb =*/ graph_get_cb(),
|
@@ -1679,14 +1757,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
1679
1757
|
|
1680
1758
|
std::vector<int32_t> w_output_pos;
|
1681
1759
|
|
1682
|
-
GGML_ASSERT(n_outputs <= n_outputs_max);
|
1683
|
-
|
1684
1760
|
w_output_pos.resize(n_outputs);
|
1685
1761
|
|
1686
1762
|
// build a more compact representation of the output ids
|
1687
1763
|
for (size_t i = 0; i < n_batch(); ++i) {
|
1688
1764
|
// map an output id to a position in the batch
|
1689
|
-
|
1765
|
+
int64_t pos = output_ids[i];
|
1690
1766
|
if (pos >= 0) {
|
1691
1767
|
GGML_ASSERT(pos < n_outputs);
|
1692
1768
|
w_output_pos[pos] = i;
|
@@ -1726,11 +1802,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
1726
1802
|
}
|
1727
1803
|
}
|
1728
1804
|
|
1729
|
-
|
1730
|
-
|
1731
|
-
if (kv_self != nullptr) {
|
1805
|
+
if (memory != nullptr) {
|
1732
1806
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
1733
|
-
|
1807
|
+
memory->state_write(io);
|
1734
1808
|
}
|
1735
1809
|
|
1736
1810
|
return io.n_bytes();
|
@@ -1817,9 +1891,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
1817
1891
|
if (memory) {
|
1818
1892
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
1819
1893
|
|
1820
|
-
|
1821
|
-
|
1822
|
-
kv_self->state_read(io);
|
1894
|
+
memory->state_read(io);
|
1823
1895
|
}
|
1824
1896
|
|
1825
1897
|
return io.n_bytes();
|
@@ -1829,9 +1901,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
1829
1901
|
GGML_UNUSED(seq_id);
|
1830
1902
|
|
1831
1903
|
if (memory) {
|
1832
|
-
|
1833
|
-
|
1834
|
-
kv_self->state_write(io, seq_id);
|
1904
|
+
memory->state_write(io, seq_id);
|
1835
1905
|
}
|
1836
1906
|
|
1837
1907
|
return io.n_bytes();
|
@@ -1841,9 +1911,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
|
1841
1911
|
GGML_UNUSED(seq_id);
|
1842
1912
|
|
1843
1913
|
if (memory) {
|
1844
|
-
|
1845
|
-
|
1846
|
-
kv_self->state_read(io, seq_id);
|
1914
|
+
memory->state_read(io, seq_id);
|
1847
1915
|
}
|
1848
1916
|
|
1849
1917
|
return io.n_bytes();
|
@@ -1948,10 +2016,7 @@ void llama_context::opt_epoch_iter(
|
|
1948
2016
|
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
1949
2017
|
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
1950
2018
|
|
1951
|
-
|
1952
|
-
|
1953
|
-
kv_self->clear();
|
1954
|
-
llama_kv_cache_guard kv_guard(kv_self);
|
2019
|
+
memory->clear(true);
|
1955
2020
|
|
1956
2021
|
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
1957
2022
|
batch.n_tokens = n_batch;
|
@@ -1963,39 +2028,44 @@ void llama_context::opt_epoch_iter(
|
|
1963
2028
|
batch.logits [pos_batch] = true;
|
1964
2029
|
}
|
1965
2030
|
|
1966
|
-
|
2031
|
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
2032
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
2033
|
+
return;
|
2034
|
+
}
|
2035
|
+
|
2036
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
1967
2037
|
|
1968
2038
|
n_queued_tokens += n_tokens_all;
|
1969
2039
|
|
1970
|
-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
1971
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
1972
|
-
|
1973
2040
|
embd_seq.clear();
|
1974
2041
|
|
1975
|
-
|
2042
|
+
uint32_t n_outputs_all = n_tokens_all;
|
1976
2043
|
|
1977
|
-
|
2044
|
+
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
2045
|
+
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
2046
|
+
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
2047
|
+
break;
|
2048
|
+
}
|
1978
2049
|
|
1979
2050
|
// reserve output buffer
|
1980
2051
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
1981
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
2052
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
1982
2053
|
GGML_ABORT("TODO: handle this error");
|
1983
2054
|
};
|
1984
2055
|
|
1985
|
-
|
1986
|
-
|
2056
|
+
uint32_t pos_batch = 0;
|
2057
|
+
do {
|
2058
|
+
const auto & ubatch = mctx->get_ubatch();
|
1987
2059
|
|
1988
2060
|
n_outputs = ubatch.n_tokens;
|
1989
2061
|
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
GGML_ABORT("TODO: handle this error");
|
2062
|
+
if (!mctx->apply()) {
|
2063
|
+
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
2064
|
+
break;
|
1995
2065
|
}
|
1996
2066
|
|
1997
2067
|
auto * gf = graph_init();
|
1998
|
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
2068
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
1999
2069
|
|
2000
2070
|
struct ggml_context * ctx_compute_opt;
|
2001
2071
|
{
|
@@ -2010,6 +2080,7 @@ void llama_context::opt_epoch_iter(
|
|
2010
2080
|
}
|
2011
2081
|
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
2012
2082
|
ggml_opt_alloc(opt_ctx, train);
|
2083
|
+
|
2013
2084
|
res->set_inputs(&ubatch);
|
2014
2085
|
{
|
2015
2086
|
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
@@ -2027,10 +2098,10 @@ void llama_context::opt_epoch_iter(
|
|
2027
2098
|
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
2028
2099
|
}
|
2029
2100
|
ggml_free(ctx_compute_opt);
|
2030
|
-
}
|
2031
|
-
}
|
2032
2101
|
|
2033
|
-
|
2102
|
+
pos_batch += ubatch.n_tokens;
|
2103
|
+
} while (mctx->next());
|
2104
|
+
}
|
2034
2105
|
}
|
2035
2106
|
|
2036
2107
|
void llama_context::opt_epoch(
|
@@ -2190,12 +2261,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|
2190
2261
|
return &ctx->get_model();
|
2191
2262
|
}
|
2192
2263
|
|
2264
|
+
// deprecated
|
2193
2265
|
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
2194
|
-
return ctx->
|
2266
|
+
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
2195
2267
|
}
|
2196
2268
|
|
2269
|
+
// deprecated
|
2197
2270
|
void llama_kv_self_update(llama_context * ctx) {
|
2198
|
-
ctx->kv_self_update();
|
2271
|
+
ctx->kv_self_update(false);
|
2199
2272
|
}
|
2200
2273
|
|
2201
2274
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
@@ -2310,13 +2383,118 @@ int32_t llama_apply_adapter_cvec(
|
|
2310
2383
|
return res ? 0 : -1;
|
2311
2384
|
}
|
2312
2385
|
|
2386
|
+
//
|
2387
|
+
// memory
|
2388
|
+
//
|
2389
|
+
|
2390
|
+
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
2391
|
+
return ctx->get_memory();
|
2392
|
+
}
|
2393
|
+
|
2394
|
+
void llama_memory_clear(llama_memory_t mem, bool data) {
|
2395
|
+
if (!mem) {
|
2396
|
+
return;
|
2397
|
+
}
|
2398
|
+
|
2399
|
+
mem->clear(data);
|
2400
|
+
}
|
2401
|
+
|
2402
|
+
bool llama_memory_seq_rm(
|
2403
|
+
llama_memory_t mem,
|
2404
|
+
llama_seq_id seq_id,
|
2405
|
+
llama_pos p0,
|
2406
|
+
llama_pos p1) {
|
2407
|
+
if (!mem) {
|
2408
|
+
return true;
|
2409
|
+
}
|
2410
|
+
|
2411
|
+
return mem->seq_rm(seq_id, p0, p1);
|
2412
|
+
}
|
2413
|
+
|
2414
|
+
void llama_memory_seq_cp(
|
2415
|
+
llama_memory_t mem,
|
2416
|
+
llama_seq_id seq_id_src,
|
2417
|
+
llama_seq_id seq_id_dst,
|
2418
|
+
llama_pos p0,
|
2419
|
+
llama_pos p1) {
|
2420
|
+
if (!mem) {
|
2421
|
+
return;
|
2422
|
+
}
|
2423
|
+
|
2424
|
+
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
2425
|
+
}
|
2426
|
+
|
2427
|
+
void llama_memory_seq_keep(
|
2428
|
+
llama_memory_t mem,
|
2429
|
+
llama_seq_id seq_id) {
|
2430
|
+
if (!mem) {
|
2431
|
+
return;
|
2432
|
+
}
|
2433
|
+
|
2434
|
+
mem->seq_keep(seq_id);
|
2435
|
+
}
|
2436
|
+
|
2437
|
+
void llama_memory_seq_add(
|
2438
|
+
llama_memory_t mem,
|
2439
|
+
llama_seq_id seq_id,
|
2440
|
+
llama_pos p0,
|
2441
|
+
llama_pos p1,
|
2442
|
+
llama_pos delta) {
|
2443
|
+
if (!mem) {
|
2444
|
+
return;
|
2445
|
+
}
|
2446
|
+
|
2447
|
+
mem->seq_add(seq_id, p0, p1, delta);
|
2448
|
+
}
|
2449
|
+
|
2450
|
+
void llama_memory_seq_div(
|
2451
|
+
llama_memory_t mem,
|
2452
|
+
llama_seq_id seq_id,
|
2453
|
+
llama_pos p0,
|
2454
|
+
llama_pos p1,
|
2455
|
+
int d) {
|
2456
|
+
if (!mem) {
|
2457
|
+
return;
|
2458
|
+
}
|
2459
|
+
|
2460
|
+
mem->seq_div(seq_id, p0, p1, d);
|
2461
|
+
}
|
2462
|
+
|
2463
|
+
llama_pos llama_memory_seq_pos_min(
|
2464
|
+
llama_memory_t mem,
|
2465
|
+
llama_seq_id seq_id) {
|
2466
|
+
if (!mem) {
|
2467
|
+
return -1;
|
2468
|
+
}
|
2469
|
+
|
2470
|
+
return mem->seq_pos_min(seq_id);
|
2471
|
+
}
|
2472
|
+
|
2473
|
+
llama_pos llama_memory_seq_pos_max(
|
2474
|
+
llama_memory_t mem,
|
2475
|
+
llama_seq_id seq_id) {
|
2476
|
+
if (!mem) {
|
2477
|
+
return -1;
|
2478
|
+
}
|
2479
|
+
|
2480
|
+
return mem->seq_pos_max(seq_id);
|
2481
|
+
}
|
2482
|
+
|
2483
|
+
bool llama_memory_can_shift(llama_memory_t mem) {
|
2484
|
+
if (!mem) {
|
2485
|
+
return false;
|
2486
|
+
}
|
2487
|
+
|
2488
|
+
return mem->get_can_shift();
|
2489
|
+
}
|
2490
|
+
|
2313
2491
|
//
|
2314
2492
|
// kv cache
|
2315
2493
|
//
|
2316
2494
|
|
2317
2495
|
// deprecated
|
2318
2496
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
2319
|
-
const auto * kv = ctx
|
2497
|
+
const auto * kv = llama_get_memory(ctx);
|
2320
2498
|
if (!kv) {
|
2321
2499
|
return 0;
|
2322
2500
|
}
|
@@ -2338,7 +2516,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
2338
2516
|
// deprecated
|
2339
2517
|
// note: this is the same as above - will be removed anyway, so it's ok
|
2340
2518
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
2341
|
-
const auto * kv = ctx
|
2519
|
+
const auto * kv = llama_get_memory(ctx);
|
2342
2520
|
if (!kv) {
|
2343
2521
|
return 0;
|
2344
2522
|
}
|
@@ -2357,114 +2535,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
2357
2535
|
return res;
|
2358
2536
|
}
|
2359
2537
|
|
2538
|
+
// deprecated
|
2360
2539
|
void llama_kv_self_clear(llama_context * ctx) {
|
2361
|
-
auto * kv = ctx
|
2540
|
+
auto * kv = llama_get_memory(ctx);
|
2362
2541
|
if (!kv) {
|
2363
2542
|
return;
|
2364
2543
|
}
|
2365
2544
|
|
2366
|
-
kv
|
2545
|
+
llama_memory_clear(kv, true);
|
2367
2546
|
}
|
2368
2547
|
|
2548
|
+
// deprecated
|
2369
2549
|
bool llama_kv_self_seq_rm(
|
2370
2550
|
llama_context * ctx,
|
2371
2551
|
llama_seq_id seq_id,
|
2372
2552
|
llama_pos p0,
|
2373
2553
|
llama_pos p1) {
|
2374
|
-
auto * kv = ctx
|
2554
|
+
auto * kv = llama_get_memory(ctx);
|
2375
2555
|
if (!kv) {
|
2376
2556
|
return true;
|
2377
2557
|
}
|
2378
2558
|
|
2379
|
-
return kv
|
2559
|
+
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
2380
2560
|
}
|
2381
2561
|
|
2562
|
+
// deprecated
|
2382
2563
|
void llama_kv_self_seq_cp(
|
2383
2564
|
llama_context * ctx,
|
2384
2565
|
llama_seq_id seq_id_src,
|
2385
2566
|
llama_seq_id seq_id_dst,
|
2386
2567
|
llama_pos p0,
|
2387
2568
|
llama_pos p1) {
|
2388
|
-
auto * kv = ctx
|
2569
|
+
auto * kv = llama_get_memory(ctx);
|
2389
2570
|
if (!kv) {
|
2390
2571
|
return;
|
2391
2572
|
}
|
2392
2573
|
|
2393
|
-
kv
|
2574
|
+
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
2394
2575
|
}
|
2395
2576
|
|
2577
|
+
// deprecated
|
2396
2578
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
2397
|
-
auto * kv = ctx
|
2579
|
+
auto * kv = llama_get_memory(ctx);
|
2398
2580
|
if (!kv) {
|
2399
2581
|
return;
|
2400
2582
|
}
|
2401
2583
|
|
2402
|
-
kv
|
2584
|
+
llama_memory_seq_keep(kv, seq_id);
|
2403
2585
|
}
|
2404
2586
|
|
2587
|
+
// deprecated
|
2405
2588
|
void llama_kv_self_seq_add(
|
2406
2589
|
llama_context * ctx,
|
2407
2590
|
llama_seq_id seq_id,
|
2408
2591
|
llama_pos p0,
|
2409
2592
|
llama_pos p1,
|
2410
2593
|
llama_pos delta) {
|
2411
|
-
auto * kv = ctx
|
2594
|
+
auto * kv = llama_get_memory(ctx);
|
2412
2595
|
if (!kv) {
|
2413
2596
|
return;
|
2414
2597
|
}
|
2415
2598
|
|
2416
|
-
kv
|
2599
|
+
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
2417
2600
|
}
|
2418
2601
|
|
2602
|
+
// deprecated
|
2419
2603
|
void llama_kv_self_seq_div(
|
2420
2604
|
llama_context * ctx,
|
2421
2605
|
llama_seq_id seq_id,
|
2422
2606
|
llama_pos p0,
|
2423
2607
|
llama_pos p1,
|
2424
2608
|
int d) {
|
2425
|
-
auto * kv = ctx
|
2609
|
+
auto * kv = llama_get_memory(ctx);
|
2426
2610
|
if (!kv) {
|
2427
2611
|
return;
|
2428
2612
|
}
|
2429
2613
|
|
2430
|
-
kv
|
2614
|
+
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
2431
2615
|
}
|
2432
2616
|
|
2617
|
+
// deprecated
|
2433
2618
|
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
2434
|
-
|
2619
|
+
auto * kv = llama_get_memory(ctx);
|
2435
2620
|
if (!kv) {
|
2436
2621
|
return -1;
|
2437
2622
|
}
|
2438
2623
|
|
2439
|
-
return kv
|
2624
|
+
return llama_memory_seq_pos_min(kv, seq_id);
|
2440
2625
|
}
|
2441
2626
|
|
2627
|
+
// deprecated
|
2442
2628
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
2443
|
-
|
2629
|
+
auto * kv = llama_get_memory(ctx);
|
2444
2630
|
if (!kv) {
|
2445
2631
|
return -1;
|
2446
2632
|
}
|
2447
2633
|
|
2448
|
-
return kv
|
2634
|
+
return llama_memory_seq_pos_max(kv, seq_id);
|
2449
2635
|
}
|
2450
2636
|
|
2637
|
+
// deprecated
|
2451
2638
|
void llama_kv_self_defrag(llama_context * ctx) {
|
2452
|
-
auto * kv = ctx->get_kv_self();
|
2453
|
-
if (!kv) {
|
2454
|
-
return;
|
2455
|
-
}
|
2456
|
-
|
2457
2639
|
// force defrag
|
2458
|
-
|
2640
|
+
ctx->kv_self_defrag_sched();
|
2459
2641
|
}
|
2460
2642
|
|
2643
|
+
// deprecated
|
2461
2644
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
2462
|
-
|
2645
|
+
auto * kv = llama_get_memory(ctx);
|
2463
2646
|
if (!kv) {
|
2464
2647
|
return false;
|
2465
2648
|
}
|
2466
2649
|
|
2467
|
-
return kv
|
2650
|
+
return llama_memory_can_shift(kv);
|
2468
2651
|
}
|
2469
2652
|
|
2470
2653
|
// llama state API
|
@@ -2589,22 +2772,8 @@ int32_t llama_encode(
|
|
2589
2772
|
int32_t llama_decode(
|
2590
2773
|
llama_context * ctx,
|
2591
2774
|
llama_batch batch) {
|
2592
|
-
int ret = ctx->decode(batch);
|
2593
|
-
|
2594
|
-
// defrag and try again
|
2595
|
-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
2596
|
-
if (ret == 1) {
|
2597
|
-
llama_kv_self_defrag(ctx);
|
2598
|
-
ret = ctx->decode(batch);
|
2599
|
-
|
2600
|
-
if (ret == 1) {
|
2601
|
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
2602
|
-
|
2603
|
-
return ret;
|
2604
|
-
}
|
2605
|
-
}
|
2606
|
-
|
2607
|
-
if (ret != 0) {
|
2775
|
+
const int ret = ctx->decode(batch);
|
2776
|
+
if (ret != 0 && ret != 1) {
|
2608
2777
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
2609
2778
|
}
|
2610
2779
|
|