@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
#include "llama-context.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
4
5
|
#include "llama-io.h"
|
|
6
|
+
#include "llama-memory.h"
|
|
5
7
|
#include "llama-mmap.h"
|
|
6
8
|
#include "llama-model.h"
|
|
7
|
-
#include "llama-kv-cache.h"
|
|
8
9
|
|
|
9
10
|
#include <cinttypes>
|
|
10
11
|
#include <cstring>
|
|
@@ -18,7 +19,8 @@
|
|
|
18
19
|
llama_context::llama_context(
|
|
19
20
|
const llama_model & model,
|
|
20
21
|
llama_context_params params) :
|
|
21
|
-
model(model)
|
|
22
|
+
model(model),
|
|
23
|
+
batch_allocr(std::make_unique<llama_batch_allocr>()) {
|
|
22
24
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
|
23
25
|
|
|
24
26
|
t_start_us = model.t_start_us;
|
|
@@ -27,8 +29,8 @@ llama_context::llama_context(
|
|
|
27
29
|
const auto & hparams = model.hparams;
|
|
28
30
|
|
|
29
31
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
|
30
|
-
if (cparams.n_seq_max >
|
|
31
|
-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(
|
|
32
|
+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
|
33
|
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
|
32
34
|
}
|
|
33
35
|
|
|
34
36
|
cparams.n_threads = params.n_threads;
|
|
@@ -123,7 +125,7 @@ llama_context::llama_context(
|
|
|
123
125
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
124
126
|
}
|
|
125
127
|
|
|
126
|
-
if (!params.swa_full && cparams.n_seq_max > 1) {
|
|
128
|
+
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
|
|
127
129
|
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
|
128
130
|
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
|
129
131
|
}
|
|
@@ -277,10 +279,9 @@ llama_context::llama_context(
|
|
|
277
279
|
int n_nodes_tg = -1;
|
|
278
280
|
|
|
279
281
|
// simulate full KV cache
|
|
280
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
281
282
|
|
|
282
|
-
const auto
|
|
283
|
-
if (!
|
|
283
|
+
const auto mstate = memory->init_full();
|
|
284
|
+
if (!mstate) {
|
|
284
285
|
throw std::runtime_error("failed to initialize KV cache");
|
|
285
286
|
}
|
|
286
287
|
|
|
@@ -288,7 +289,7 @@ llama_context::llama_context(
|
|
|
288
289
|
|
|
289
290
|
// reserve pp graph first so that buffers are only allocated once
|
|
290
291
|
{
|
|
291
|
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens,
|
|
292
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
|
292
293
|
if (!gf) {
|
|
293
294
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
294
295
|
}
|
|
@@ -299,7 +300,7 @@ llama_context::llama_context(
|
|
|
299
300
|
|
|
300
301
|
// reserve with tg graph to get the number of splits and nodes
|
|
301
302
|
{
|
|
302
|
-
auto * gf = graph_reserve(1, 1, 1,
|
|
303
|
+
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
|
303
304
|
if (!gf) {
|
|
304
305
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
305
306
|
}
|
|
@@ -310,7 +311,7 @@ llama_context::llama_context(
|
|
|
310
311
|
|
|
311
312
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
312
313
|
{
|
|
313
|
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens,
|
|
314
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
|
314
315
|
if (!gf) {
|
|
315
316
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
316
317
|
}
|
|
@@ -419,40 +420,68 @@ uint32_t llama_context::n_threads_batch() const {
|
|
|
419
420
|
return cparams.n_threads_batch;
|
|
420
421
|
}
|
|
421
422
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
return kv_self;
|
|
423
|
+
llama_memory_t llama_context::get_memory() const {
|
|
424
|
+
return memory.get();
|
|
425
425
|
}
|
|
426
426
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
427
|
+
// deprecated
|
|
428
|
+
void llama_context::kv_self_defrag_sched() {
|
|
429
|
+
if (!memory) {
|
|
430
|
+
return;
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
memory_force_optimize = true;
|
|
430
434
|
}
|
|
431
435
|
|
|
432
|
-
|
|
436
|
+
// deprecated
|
|
437
|
+
bool llama_context::kv_self_update(bool optimize) {
|
|
433
438
|
if (!memory) {
|
|
434
439
|
return false;
|
|
435
440
|
}
|
|
436
441
|
|
|
437
|
-
|
|
442
|
+
{
|
|
443
|
+
// TODO: remove in the future
|
|
444
|
+
optimize |= memory_force_optimize;
|
|
445
|
+
memory_force_optimize = false;
|
|
438
446
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
447
|
+
const auto mstate = memory->init_update(this, optimize);
|
|
448
|
+
switch (mstate->get_status()) {
|
|
449
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
450
|
+
{
|
|
451
|
+
// noop
|
|
452
|
+
} break;
|
|
453
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
|
454
|
+
{
|
|
455
|
+
// no updates need to be performed
|
|
456
|
+
return false;
|
|
457
|
+
}
|
|
458
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
459
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
460
|
+
{
|
|
461
|
+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
|
462
|
+
return false;
|
|
463
|
+
}
|
|
464
|
+
}
|
|
443
465
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
throw std::runtime_error("failed to initialize KV cache");
|
|
466
|
+
if (!mstate->apply()) {
|
|
467
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
|
468
|
+
}
|
|
448
469
|
}
|
|
449
470
|
|
|
450
|
-
|
|
451
|
-
|
|
471
|
+
// if the memory module did any computation, we have to reserve a new worst-case graph
|
|
472
|
+
{
|
|
473
|
+
const auto mstate = memory->init_full();
|
|
474
|
+
if (!mstate) {
|
|
475
|
+
throw std::runtime_error("failed to initialize memory state");
|
|
476
|
+
}
|
|
452
477
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
478
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
479
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
480
|
+
|
|
481
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
|
482
|
+
if (!gf) {
|
|
483
|
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
|
484
|
+
}
|
|
456
485
|
}
|
|
457
486
|
|
|
458
487
|
return true;
|
|
@@ -467,7 +496,7 @@ float * llama_context::get_logits() {
|
|
|
467
496
|
}
|
|
468
497
|
|
|
469
498
|
float * llama_context::get_logits_ith(int32_t i) {
|
|
470
|
-
|
|
499
|
+
int64_t j = -1;
|
|
471
500
|
|
|
472
501
|
try {
|
|
473
502
|
if (logits == nullptr) {
|
|
@@ -490,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
490
519
|
}
|
|
491
520
|
if (j >= n_outputs) {
|
|
492
521
|
// This should not happen
|
|
493
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
|
522
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
494
523
|
}
|
|
495
524
|
|
|
496
525
|
return logits + j*model.vocab.n_tokens();
|
|
@@ -509,7 +538,7 @@ float * llama_context::get_embeddings() {
|
|
|
509
538
|
}
|
|
510
539
|
|
|
511
540
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
512
|
-
|
|
541
|
+
int64_t j = -1;
|
|
513
542
|
|
|
514
543
|
try {
|
|
515
544
|
if (embd == nullptr) {
|
|
@@ -532,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
|
532
561
|
}
|
|
533
562
|
if (j >= n_outputs) {
|
|
534
563
|
// This should not happen
|
|
535
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
|
564
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
536
565
|
}
|
|
537
566
|
|
|
538
567
|
return embd + j*model.hparams.n_embd;
|
|
@@ -692,52 +721,41 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
|
|
692
721
|
return res;
|
|
693
722
|
}
|
|
694
723
|
|
|
695
|
-
int llama_context::encode(llama_batch &
|
|
696
|
-
if (
|
|
724
|
+
int llama_context::encode(const llama_batch & batch_inp) {
|
|
725
|
+
if (batch_inp.n_tokens == 0) {
|
|
697
726
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
698
727
|
return -1;
|
|
699
728
|
}
|
|
700
729
|
|
|
701
|
-
// temporary allocate memory for the input batch if needed
|
|
702
730
|
// note: during encode, we always pass the full sequence starting from pos = 0
|
|
703
|
-
|
|
731
|
+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
|
|
732
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
733
|
+
return -1;
|
|
734
|
+
}
|
|
704
735
|
|
|
705
|
-
const llama_batch & batch = batch_allocr
|
|
706
|
-
const int32_t n_tokens = batch.n_tokens;
|
|
736
|
+
const llama_batch & batch = batch_allocr->get_batch();
|
|
707
737
|
|
|
708
|
-
const
|
|
738
|
+
const uint32_t n_tokens = batch.n_tokens;
|
|
709
739
|
|
|
710
740
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
711
741
|
|
|
712
|
-
// TODO: move the validation to the llama_batch_allocr
|
|
713
|
-
if (batch.token) {
|
|
714
|
-
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
715
|
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
716
|
-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
717
|
-
return -1;
|
|
718
|
-
}
|
|
719
|
-
|
|
720
|
-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
|
721
|
-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
722
|
-
throw -1;
|
|
723
|
-
}
|
|
724
|
-
}
|
|
725
|
-
}
|
|
726
|
-
|
|
727
742
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
728
|
-
GGML_ASSERT(cparams.n_ubatch >=
|
|
743
|
+
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
729
744
|
|
|
730
745
|
if (t_compute_start_us == 0) {
|
|
731
746
|
t_compute_start_us = ggml_time_us();
|
|
732
747
|
}
|
|
733
748
|
|
|
749
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
734
750
|
embd_seq.clear();
|
|
735
751
|
|
|
736
752
|
n_queued_tokens += n_tokens;
|
|
737
753
|
|
|
754
|
+
const auto & hparams = model.hparams;
|
|
755
|
+
|
|
738
756
|
const int64_t n_embd = hparams.n_embd;
|
|
739
757
|
|
|
740
|
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true
|
|
758
|
+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
|
|
741
759
|
|
|
742
760
|
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
|
743
761
|
|
|
@@ -747,7 +765,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
747
765
|
return -2;
|
|
748
766
|
};
|
|
749
767
|
|
|
750
|
-
for (
|
|
768
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
|
751
769
|
output_ids[i] = i;
|
|
752
770
|
}
|
|
753
771
|
|
|
@@ -803,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
803
821
|
|
|
804
822
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
805
823
|
|
|
806
|
-
|
|
824
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
825
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
807
826
|
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
808
827
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
809
828
|
continue;
|
|
@@ -814,16 +833,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
814
833
|
} break;
|
|
815
834
|
case LLAMA_POOLING_TYPE_RANK:
|
|
816
835
|
{
|
|
817
|
-
// extract the rerank score -
|
|
836
|
+
// extract the rerank score - n_cls_out floats per sequence
|
|
818
837
|
auto & embd_seq_out = embd_seq;
|
|
838
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
|
819
839
|
|
|
840
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
820
841
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
821
842
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
822
843
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
823
844
|
continue;
|
|
824
845
|
}
|
|
825
|
-
embd_seq_out[seq_id].resize(
|
|
826
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
|
846
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
|
847
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
|
|
827
848
|
}
|
|
828
849
|
} break;
|
|
829
850
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
@@ -850,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
850
871
|
|
|
851
872
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
|
852
873
|
cross.seq_ids_enc.resize(n_tokens);
|
|
853
|
-
for (
|
|
874
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
854
875
|
cross.seq_ids_enc[i].clear();
|
|
855
|
-
for (int s = 0; s <
|
|
856
|
-
llama_seq_id seq_id =
|
|
876
|
+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
|
877
|
+
llama_seq_id seq_id = batch.seq_id[i][s];
|
|
857
878
|
cross.seq_ids_enc[i].insert(seq_id);
|
|
858
879
|
}
|
|
859
880
|
}
|
|
@@ -862,53 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
862
883
|
return 0;
|
|
863
884
|
}
|
|
864
885
|
|
|
865
|
-
int llama_context::decode(llama_batch &
|
|
886
|
+
int llama_context::decode(const llama_batch & batch_inp) {
|
|
866
887
|
if (!memory) {
|
|
867
888
|
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
|
868
|
-
return encode(
|
|
889
|
+
return encode(batch_inp);
|
|
869
890
|
}
|
|
870
891
|
|
|
871
|
-
if (
|
|
892
|
+
if (batch_inp.n_tokens == 0) {
|
|
872
893
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
873
894
|
return -1;
|
|
874
895
|
}
|
|
875
896
|
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
|
879
|
-
return -1;
|
|
880
|
-
}
|
|
881
|
-
}
|
|
882
|
-
|
|
883
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
897
|
+
// when computing embeddings, all tokens are output
|
|
898
|
+
const bool embd_all = cparams.embeddings;
|
|
884
899
|
|
|
885
|
-
|
|
886
|
-
|
|
900
|
+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
|
|
901
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
902
|
+
return -1;
|
|
903
|
+
}
|
|
887
904
|
|
|
888
|
-
const llama_batch & batch = batch_allocr
|
|
905
|
+
const llama_batch & batch = batch_allocr->get_batch();
|
|
889
906
|
|
|
890
907
|
const auto & vocab = model.vocab;
|
|
891
908
|
const auto & hparams = model.hparams;
|
|
892
909
|
|
|
893
910
|
const int32_t n_vocab = vocab.n_tokens();
|
|
911
|
+
const int64_t n_embd = hparams.n_embd;
|
|
894
912
|
|
|
895
|
-
const
|
|
896
|
-
const int64_t n_embd = hparams.n_embd;
|
|
913
|
+
const uint32_t n_tokens_all = batch.n_tokens;
|
|
897
914
|
|
|
898
915
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
899
916
|
|
|
900
|
-
|
|
901
|
-
if (batch.token) {
|
|
902
|
-
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
|
903
|
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
904
|
-
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
|
905
|
-
return -1;
|
|
906
|
-
}
|
|
917
|
+
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
|
|
907
918
|
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
919
|
+
if (embd_all) {
|
|
920
|
+
// require that all tokens are output
|
|
921
|
+
if (n_outputs_all != n_tokens_all) {
|
|
922
|
+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
|
923
|
+
__func__, n_outputs_all, n_tokens_all);
|
|
924
|
+
return -1;
|
|
912
925
|
}
|
|
913
926
|
}
|
|
914
927
|
|
|
@@ -921,61 +934,52 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
921
934
|
}
|
|
922
935
|
n_queued_tokens += n_tokens_all;
|
|
923
936
|
|
|
924
|
-
// this
|
|
925
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
926
|
-
|
|
937
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
927
938
|
embd_seq.clear();
|
|
928
939
|
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
// count outputs
|
|
932
|
-
if (batch.logits && !embd_pooled) {
|
|
933
|
-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
934
|
-
n_outputs_all += batch.logits[i] != 0;
|
|
935
|
-
}
|
|
936
|
-
} else if (embd_pooled) {
|
|
937
|
-
n_outputs_all = n_tokens_all;
|
|
938
|
-
} else {
|
|
939
|
-
// keep last output only
|
|
940
|
-
n_outputs_all = 1;
|
|
941
|
-
}
|
|
940
|
+
bool did_optimize = false;
|
|
942
941
|
|
|
943
942
|
// handle any pending defrags/shifts
|
|
944
|
-
kv_self_update();
|
|
943
|
+
kv_self_update(false);
|
|
945
944
|
|
|
946
|
-
llama_memory_state_ptr
|
|
947
|
-
|
|
948
|
-
bool did_defrag = false;
|
|
945
|
+
llama_memory_state_ptr mstate;
|
|
949
946
|
|
|
950
947
|
while (true) {
|
|
951
|
-
|
|
952
|
-
if (!
|
|
948
|
+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
|
|
949
|
+
if (!mstate) {
|
|
953
950
|
return -2;
|
|
954
951
|
}
|
|
955
952
|
|
|
956
|
-
switch (
|
|
953
|
+
switch (mstate->get_status()) {
|
|
957
954
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
958
955
|
{
|
|
959
956
|
} break;
|
|
957
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
|
958
|
+
{
|
|
959
|
+
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
|
960
|
+
|
|
961
|
+
return -2;
|
|
962
|
+
}
|
|
960
963
|
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
961
964
|
{
|
|
962
|
-
if (!
|
|
963
|
-
|
|
965
|
+
if (!did_optimize) {
|
|
966
|
+
did_optimize = true;
|
|
964
967
|
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
|
968
|
+
if (kv_self_update(true)) {
|
|
969
|
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
|
|
968
970
|
|
|
969
971
|
continue;
|
|
970
972
|
}
|
|
971
973
|
}
|
|
972
974
|
|
|
973
|
-
LLAMA_LOG_WARN("%s: failed to find
|
|
975
|
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
|
|
974
976
|
|
|
975
977
|
return 1;
|
|
976
978
|
}
|
|
977
979
|
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
978
980
|
{
|
|
981
|
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
|
|
982
|
+
|
|
979
983
|
return -2;
|
|
980
984
|
}
|
|
981
985
|
}
|
|
@@ -985,16 +989,16 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
985
989
|
|
|
986
990
|
// reserve output buffer
|
|
987
991
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
988
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
|
992
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
989
993
|
return -2;
|
|
990
994
|
};
|
|
991
995
|
|
|
992
996
|
int64_t n_outputs_prev = 0;
|
|
993
997
|
|
|
994
998
|
do {
|
|
995
|
-
const auto & ubatch =
|
|
999
|
+
const auto & ubatch = mstate->get_ubatch();
|
|
996
1000
|
|
|
997
|
-
// count the outputs in this
|
|
1001
|
+
// count the outputs in this ubatch
|
|
998
1002
|
{
|
|
999
1003
|
int32_t n_outputs_new = 0;
|
|
1000
1004
|
|
|
@@ -1015,26 +1019,30 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1015
1019
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
1016
1020
|
|
|
1017
1021
|
ggml_status status;
|
|
1018
|
-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER,
|
|
1022
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
|
1019
1023
|
|
|
1020
1024
|
if (!res) {
|
|
1021
1025
|
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
|
1022
|
-
llama_pos pos_min[
|
|
1026
|
+
llama_pos pos_min[LLAMA_MAX_SEQ];
|
|
1027
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
1028
|
+
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
|
1029
|
+
}
|
|
1023
1030
|
|
|
1031
|
+
// TODO: fix sequence indexing
|
|
1024
1032
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
1025
1033
|
const auto & seq_id = ubatch.seq_id[i][0];
|
|
1026
1034
|
|
|
1027
1035
|
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
|
1028
1036
|
}
|
|
1029
1037
|
|
|
1030
|
-
for (int s = 0; s <
|
|
1038
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
1031
1039
|
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
|
1032
1040
|
continue;
|
|
1033
1041
|
}
|
|
1034
1042
|
|
|
1035
1043
|
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
|
1036
1044
|
|
|
1037
|
-
|
|
1045
|
+
memory->seq_rm(s, pos_min[s], -1);
|
|
1038
1046
|
}
|
|
1039
1047
|
|
|
1040
1048
|
switch (status) {
|
|
@@ -1050,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1050
1058
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
1051
1059
|
//}
|
|
1052
1060
|
|
|
1053
|
-
auto * t_logits =
|
|
1061
|
+
auto * t_logits = res->get_logits();
|
|
1054
1062
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
|
1055
1063
|
|
|
1056
1064
|
if (t_embd && res->get_embd_pooled()) {
|
|
@@ -1128,20 +1136,20 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1128
1136
|
}
|
|
1129
1137
|
|
|
1130
1138
|
n_outputs_prev += n_outputs;
|
|
1131
|
-
} while (
|
|
1139
|
+
} while (mstate->next());
|
|
1132
1140
|
|
|
1133
1141
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
1134
1142
|
n_outputs = n_outputs_all;
|
|
1135
1143
|
|
|
1136
1144
|
// set output mappings
|
|
1137
|
-
{
|
|
1145
|
+
if (n_outputs > 0) {
|
|
1138
1146
|
bool sorted_output = true;
|
|
1139
1147
|
|
|
1140
|
-
auto & out_ids =
|
|
1148
|
+
auto & out_ids = mstate->out_ids();
|
|
1141
1149
|
|
|
1142
|
-
GGML_ASSERT(out_ids.size() == (size_t)
|
|
1150
|
+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
|
1143
1151
|
|
|
1144
|
-
for (int64_t i = 0; i <
|
|
1152
|
+
for (int64_t i = 0; i < n_outputs; ++i) {
|
|
1145
1153
|
int64_t out_id = out_ids[i];
|
|
1146
1154
|
output_ids[out_id] = i;
|
|
1147
1155
|
if (out_id != i) {
|
|
@@ -1153,20 +1161,22 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1153
1161
|
// note: this is mostly relevant for recurrent models atm
|
|
1154
1162
|
if (!sorted_output) {
|
|
1155
1163
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
1156
|
-
const
|
|
1164
|
+
const uint64_t n_embd = model.hparams.n_embd;
|
|
1157
1165
|
|
|
1158
1166
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1159
1167
|
|
|
1160
1168
|
// TODO: is there something more efficient which also minimizes swaps?
|
|
1161
1169
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
1162
|
-
for (
|
|
1163
|
-
|
|
1164
|
-
for (
|
|
1170
|
+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
|
1171
|
+
uint32_t j_min = i;
|
|
1172
|
+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
|
1165
1173
|
if (out_ids[j] < out_ids[j_min]) {
|
|
1166
1174
|
j_min = j;
|
|
1167
1175
|
}
|
|
1168
1176
|
}
|
|
1169
|
-
if (j_min == i) {
|
|
1177
|
+
if (j_min == i) {
|
|
1178
|
+
continue;
|
|
1179
|
+
}
|
|
1170
1180
|
std::swap(out_ids[i], out_ids[j_min]);
|
|
1171
1181
|
if (logits_size > 0) {
|
|
1172
1182
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
@@ -1179,8 +1189,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1179
1189
|
}
|
|
1180
1190
|
}
|
|
1181
1191
|
}
|
|
1192
|
+
|
|
1182
1193
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1183
|
-
|
|
1194
|
+
|
|
1195
|
+
for (uint32_t i = 0; i < n_outputs; ++i) {
|
|
1184
1196
|
output_ids[out_ids[i]] = i;
|
|
1185
1197
|
}
|
|
1186
1198
|
}
|
|
@@ -1189,11 +1201,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1189
1201
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
|
1190
1202
|
//synchronize();
|
|
1191
1203
|
|
|
1192
|
-
// decide if we need to defrag the kv cache
|
|
1193
|
-
if (cparams.defrag_thold > 0.0f) {
|
|
1194
|
-
kv_self->defrag_sched(cparams.defrag_thold);
|
|
1195
|
-
}
|
|
1196
|
-
|
|
1197
1204
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
1198
1205
|
// overlap with device computation.
|
|
1199
1206
|
ggml_backend_sched_reset(sched.get());
|
|
@@ -1205,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1205
1212
|
// output
|
|
1206
1213
|
//
|
|
1207
1214
|
|
|
1208
|
-
|
|
1215
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1209
1216
|
const auto & hparams = model.hparams;
|
|
1210
1217
|
const auto & vocab = model.vocab;
|
|
1211
1218
|
|
|
@@ -1215,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1215
1222
|
const auto n_vocab = vocab.n_tokens();
|
|
1216
1223
|
const auto n_embd = hparams.n_embd;
|
|
1217
1224
|
|
|
1218
|
-
|
|
1219
|
-
bool
|
|
1220
|
-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
1225
|
+
bool has_logits = true;
|
|
1226
|
+
bool has_embd = cparams.embeddings;
|
|
1221
1227
|
|
|
1222
1228
|
// TODO: hacky enc-dec support
|
|
1223
1229
|
if (model.arch == LLM_ARCH_T5) {
|
|
@@ -1271,8 +1277,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1271
1277
|
// set all ids as invalid (negative)
|
|
1272
1278
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1273
1279
|
|
|
1274
|
-
this->n_outputs
|
|
1275
|
-
this->n_outputs_max = n_outputs_max;
|
|
1280
|
+
this->n_outputs = 0;
|
|
1276
1281
|
|
|
1277
1282
|
return n_outputs_max;
|
|
1278
1283
|
}
|
|
@@ -1301,7 +1306,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
|
1301
1306
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1302
1307
|
|
|
1303
1308
|
if (n_tokens % n_seqs != 0) {
|
|
1304
|
-
n_tokens = (n_tokens / n_seqs) * n_seqs;
|
|
1309
|
+
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
|
1305
1310
|
n_outputs = std::min(n_outputs, n_tokens);
|
|
1306
1311
|
|
|
1307
1312
|
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
@@ -1763,14 +1768,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1763
1768
|
|
|
1764
1769
|
std::vector<int32_t> w_output_pos;
|
|
1765
1770
|
|
|
1766
|
-
GGML_ASSERT(n_outputs <= n_outputs_max);
|
|
1767
|
-
|
|
1768
1771
|
w_output_pos.resize(n_outputs);
|
|
1769
1772
|
|
|
1770
1773
|
// build a more compact representation of the output ids
|
|
1771
1774
|
for (size_t i = 0; i < n_batch(); ++i) {
|
|
1772
1775
|
// map an output id to a position in the batch
|
|
1773
|
-
|
|
1776
|
+
int64_t pos = output_ids[i];
|
|
1774
1777
|
if (pos >= 0) {
|
|
1775
1778
|
GGML_ASSERT(pos < n_outputs);
|
|
1776
1779
|
w_output_pos[pos] = i;
|
|
@@ -1810,11 +1813,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1810
1813
|
}
|
|
1811
1814
|
}
|
|
1812
1815
|
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
if (kv_self != nullptr) {
|
|
1816
|
+
if (memory != nullptr) {
|
|
1816
1817
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
|
1817
|
-
|
|
1818
|
+
memory->state_write(io);
|
|
1818
1819
|
}
|
|
1819
1820
|
|
|
1820
1821
|
return io.n_bytes();
|
|
@@ -1901,9 +1902,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
1901
1902
|
if (memory) {
|
|
1902
1903
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
|
1903
1904
|
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
kv_self->state_read(io);
|
|
1905
|
+
memory->state_read(io);
|
|
1907
1906
|
}
|
|
1908
1907
|
|
|
1909
1908
|
return io.n_bytes();
|
|
@@ -1913,9 +1912,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
|
1913
1912
|
GGML_UNUSED(seq_id);
|
|
1914
1913
|
|
|
1915
1914
|
if (memory) {
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
kv_self->state_write(io, seq_id);
|
|
1915
|
+
memory->state_write(io, seq_id);
|
|
1919
1916
|
}
|
|
1920
1917
|
|
|
1921
1918
|
return io.n_bytes();
|
|
@@ -1925,9 +1922,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
|
|
1925
1922
|
GGML_UNUSED(seq_id);
|
|
1926
1923
|
|
|
1927
1924
|
if (memory) {
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
kv_self->state_read(io, seq_id);
|
|
1925
|
+
memory->state_read(io, seq_id);
|
|
1931
1926
|
}
|
|
1932
1927
|
|
|
1933
1928
|
return io.n_bytes();
|
|
@@ -2032,9 +2027,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2032
2027
|
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
|
2033
2028
|
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
|
2034
2029
|
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
kv_self->clear();
|
|
2030
|
+
memory->clear(true);
|
|
2038
2031
|
|
|
2039
2032
|
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
|
2040
2033
|
batch.n_tokens = n_batch;
|
|
@@ -2050,38 +2043,35 @@ void llama_context::opt_epoch_iter(
|
|
|
2050
2043
|
|
|
2051
2044
|
n_queued_tokens += n_tokens_all;
|
|
2052
2045
|
|
|
2053
|
-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
2054
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
2055
|
-
|
|
2056
2046
|
embd_seq.clear();
|
|
2057
2047
|
|
|
2058
|
-
|
|
2048
|
+
uint32_t n_outputs_all = n_tokens_all;
|
|
2059
2049
|
|
|
2060
|
-
auto
|
|
2061
|
-
if (!
|
|
2050
|
+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
|
|
2051
|
+
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
|
2062
2052
|
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
|
2063
2053
|
break;
|
|
2064
2054
|
}
|
|
2065
2055
|
|
|
2066
2056
|
// reserve output buffer
|
|
2067
2057
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
2068
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
|
2058
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
2069
2059
|
GGML_ABORT("TODO: handle this error");
|
|
2070
2060
|
};
|
|
2071
2061
|
|
|
2072
2062
|
uint32_t pos_batch = 0;
|
|
2073
2063
|
do {
|
|
2074
|
-
const auto & ubatch =
|
|
2064
|
+
const auto & ubatch = mstate->get_ubatch();
|
|
2075
2065
|
|
|
2076
2066
|
n_outputs = ubatch.n_tokens;
|
|
2077
2067
|
|
|
2078
|
-
if (!
|
|
2068
|
+
if (!mstate->apply()) {
|
|
2079
2069
|
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
|
2080
2070
|
break;
|
|
2081
2071
|
}
|
|
2082
2072
|
|
|
2083
2073
|
auto * gf = graph_init();
|
|
2084
|
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT,
|
|
2074
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
|
2085
2075
|
|
|
2086
2076
|
struct ggml_context * ctx_compute_opt;
|
|
2087
2077
|
{
|
|
@@ -2116,7 +2106,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2116
2106
|
ggml_free(ctx_compute_opt);
|
|
2117
2107
|
|
|
2118
2108
|
pos_batch += ubatch.n_tokens;
|
|
2119
|
-
} while (
|
|
2109
|
+
} while (mstate->next());
|
|
2120
2110
|
}
|
|
2121
2111
|
}
|
|
2122
2112
|
|
|
@@ -2277,13 +2267,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|
|
2277
2267
|
return &ctx->get_model();
|
|
2278
2268
|
}
|
|
2279
2269
|
|
|
2270
|
+
// deprecated
|
|
2280
2271
|
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|
2281
|
-
return ctx->
|
|
2272
|
+
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
|
2282
2273
|
}
|
|
2283
2274
|
|
|
2284
2275
|
// deprecated
|
|
2285
2276
|
void llama_kv_self_update(llama_context * ctx) {
|
|
2286
|
-
ctx->kv_self_update();
|
|
2277
|
+
ctx->kv_self_update(false);
|
|
2287
2278
|
}
|
|
2288
2279
|
|
|
2289
2280
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
|
@@ -2398,13 +2389,118 @@ int32_t llama_apply_adapter_cvec(
|
|
|
2398
2389
|
return res ? 0 : -1;
|
|
2399
2390
|
}
|
|
2400
2391
|
|
|
2392
|
+
//
|
|
2393
|
+
// memory
|
|
2394
|
+
//
|
|
2395
|
+
|
|
2396
|
+
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
|
2397
|
+
return ctx->get_memory();
|
|
2398
|
+
}
|
|
2399
|
+
|
|
2400
|
+
void llama_memory_clear(llama_memory_t mem, bool data) {
|
|
2401
|
+
if (!mem) {
|
|
2402
|
+
return;
|
|
2403
|
+
}
|
|
2404
|
+
|
|
2405
|
+
mem->clear(data);
|
|
2406
|
+
}
|
|
2407
|
+
|
|
2408
|
+
bool llama_memory_seq_rm(
|
|
2409
|
+
llama_memory_t mem,
|
|
2410
|
+
llama_seq_id seq_id,
|
|
2411
|
+
llama_pos p0,
|
|
2412
|
+
llama_pos p1) {
|
|
2413
|
+
if (!mem) {
|
|
2414
|
+
return true;
|
|
2415
|
+
}
|
|
2416
|
+
|
|
2417
|
+
return mem->seq_rm(seq_id, p0, p1);
|
|
2418
|
+
}
|
|
2419
|
+
|
|
2420
|
+
void llama_memory_seq_cp(
|
|
2421
|
+
llama_memory_t mem,
|
|
2422
|
+
llama_seq_id seq_id_src,
|
|
2423
|
+
llama_seq_id seq_id_dst,
|
|
2424
|
+
llama_pos p0,
|
|
2425
|
+
llama_pos p1) {
|
|
2426
|
+
if (!mem) {
|
|
2427
|
+
return;
|
|
2428
|
+
}
|
|
2429
|
+
|
|
2430
|
+
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
2431
|
+
}
|
|
2432
|
+
|
|
2433
|
+
void llama_memory_seq_keep(
|
|
2434
|
+
llama_memory_t mem,
|
|
2435
|
+
llama_seq_id seq_id) {
|
|
2436
|
+
if (!mem) {
|
|
2437
|
+
return;
|
|
2438
|
+
}
|
|
2439
|
+
|
|
2440
|
+
mem->seq_keep(seq_id);
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2443
|
+
void llama_memory_seq_add(
|
|
2444
|
+
llama_memory_t mem,
|
|
2445
|
+
llama_seq_id seq_id,
|
|
2446
|
+
llama_pos p0,
|
|
2447
|
+
llama_pos p1,
|
|
2448
|
+
llama_pos delta) {
|
|
2449
|
+
if (!mem) {
|
|
2450
|
+
return;
|
|
2451
|
+
}
|
|
2452
|
+
|
|
2453
|
+
mem->seq_add(seq_id, p0, p1, delta);
|
|
2454
|
+
}
|
|
2455
|
+
|
|
2456
|
+
void llama_memory_seq_div(
|
|
2457
|
+
llama_memory_t mem,
|
|
2458
|
+
llama_seq_id seq_id,
|
|
2459
|
+
llama_pos p0,
|
|
2460
|
+
llama_pos p1,
|
|
2461
|
+
int d) {
|
|
2462
|
+
if (!mem) {
|
|
2463
|
+
return;
|
|
2464
|
+
}
|
|
2465
|
+
|
|
2466
|
+
mem->seq_div(seq_id, p0, p1, d);
|
|
2467
|
+
}
|
|
2468
|
+
|
|
2469
|
+
llama_pos llama_memory_seq_pos_min(
|
|
2470
|
+
llama_memory_t mem,
|
|
2471
|
+
llama_seq_id seq_id) {
|
|
2472
|
+
if (!mem) {
|
|
2473
|
+
return -1;
|
|
2474
|
+
}
|
|
2475
|
+
|
|
2476
|
+
return mem->seq_pos_min(seq_id);
|
|
2477
|
+
}
|
|
2478
|
+
|
|
2479
|
+
llama_pos llama_memory_seq_pos_max(
|
|
2480
|
+
llama_memory_t mem,
|
|
2481
|
+
llama_seq_id seq_id) {
|
|
2482
|
+
if (!mem) {
|
|
2483
|
+
return -1;
|
|
2484
|
+
}
|
|
2485
|
+
|
|
2486
|
+
return mem->seq_pos_max(seq_id);
|
|
2487
|
+
}
|
|
2488
|
+
|
|
2489
|
+
bool llama_memory_can_shift(llama_memory_t mem) {
|
|
2490
|
+
if (!mem) {
|
|
2491
|
+
return false;
|
|
2492
|
+
}
|
|
2493
|
+
|
|
2494
|
+
return mem->get_can_shift();
|
|
2495
|
+
}
|
|
2496
|
+
|
|
2401
2497
|
//
|
|
2402
2498
|
// kv cache
|
|
2403
2499
|
//
|
|
2404
2500
|
|
|
2405
2501
|
// deprecated
|
|
2406
2502
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
2407
|
-
const auto * kv = ctx
|
|
2503
|
+
const auto * kv = llama_get_memory(ctx);
|
|
2408
2504
|
if (!kv) {
|
|
2409
2505
|
return 0;
|
|
2410
2506
|
}
|
|
@@ -2426,7 +2522,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
|
2426
2522
|
// deprecated
|
|
2427
2523
|
// note: this is the same as above - will be removed anyway, so it's ok
|
|
2428
2524
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
2429
|
-
const auto * kv = ctx
|
|
2525
|
+
const auto * kv = llama_get_memory(ctx);
|
|
2430
2526
|
if (!kv) {
|
|
2431
2527
|
return 0;
|
|
2432
2528
|
}
|
|
@@ -2445,115 +2541,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
|
2445
2541
|
return res;
|
|
2446
2542
|
}
|
|
2447
2543
|
|
|
2544
|
+
// deprecated
|
|
2448
2545
|
void llama_kv_self_clear(llama_context * ctx) {
|
|
2449
|
-
auto * kv = ctx
|
|
2546
|
+
auto * kv = llama_get_memory(ctx);
|
|
2450
2547
|
if (!kv) {
|
|
2451
2548
|
return;
|
|
2452
2549
|
}
|
|
2453
2550
|
|
|
2454
|
-
kv
|
|
2551
|
+
llama_memory_clear(kv, true);
|
|
2455
2552
|
}
|
|
2456
2553
|
|
|
2554
|
+
// deprecated
|
|
2457
2555
|
bool llama_kv_self_seq_rm(
|
|
2458
2556
|
llama_context * ctx,
|
|
2459
2557
|
llama_seq_id seq_id,
|
|
2460
2558
|
llama_pos p0,
|
|
2461
2559
|
llama_pos p1) {
|
|
2462
|
-
auto * kv = ctx
|
|
2560
|
+
auto * kv = llama_get_memory(ctx);
|
|
2463
2561
|
if (!kv) {
|
|
2464
2562
|
return true;
|
|
2465
2563
|
}
|
|
2466
2564
|
|
|
2467
|
-
return kv
|
|
2565
|
+
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
|
2468
2566
|
}
|
|
2469
2567
|
|
|
2568
|
+
// deprecated
|
|
2470
2569
|
void llama_kv_self_seq_cp(
|
|
2471
2570
|
llama_context * ctx,
|
|
2472
2571
|
llama_seq_id seq_id_src,
|
|
2473
2572
|
llama_seq_id seq_id_dst,
|
|
2474
2573
|
llama_pos p0,
|
|
2475
2574
|
llama_pos p1) {
|
|
2476
|
-
auto * kv = ctx
|
|
2575
|
+
auto * kv = llama_get_memory(ctx);
|
|
2477
2576
|
if (!kv) {
|
|
2478
2577
|
return;
|
|
2479
2578
|
}
|
|
2480
2579
|
|
|
2481
|
-
kv
|
|
2580
|
+
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
|
2482
2581
|
}
|
|
2483
2582
|
|
|
2583
|
+
// deprecated
|
|
2484
2584
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
2485
|
-
auto * kv = ctx
|
|
2585
|
+
auto * kv = llama_get_memory(ctx);
|
|
2486
2586
|
if (!kv) {
|
|
2487
2587
|
return;
|
|
2488
2588
|
}
|
|
2489
2589
|
|
|
2490
|
-
kv
|
|
2590
|
+
llama_memory_seq_keep(kv, seq_id);
|
|
2491
2591
|
}
|
|
2492
2592
|
|
|
2593
|
+
// deprecated
|
|
2493
2594
|
void llama_kv_self_seq_add(
|
|
2494
2595
|
llama_context * ctx,
|
|
2495
2596
|
llama_seq_id seq_id,
|
|
2496
2597
|
llama_pos p0,
|
|
2497
2598
|
llama_pos p1,
|
|
2498
2599
|
llama_pos delta) {
|
|
2499
|
-
auto * kv = ctx
|
|
2600
|
+
auto * kv = llama_get_memory(ctx);
|
|
2500
2601
|
if (!kv) {
|
|
2501
2602
|
return;
|
|
2502
2603
|
}
|
|
2503
2604
|
|
|
2504
|
-
kv
|
|
2605
|
+
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
|
2505
2606
|
}
|
|
2506
2607
|
|
|
2608
|
+
// deprecated
|
|
2507
2609
|
void llama_kv_self_seq_div(
|
|
2508
2610
|
llama_context * ctx,
|
|
2509
2611
|
llama_seq_id seq_id,
|
|
2510
2612
|
llama_pos p0,
|
|
2511
2613
|
llama_pos p1,
|
|
2512
2614
|
int d) {
|
|
2513
|
-
auto * kv = ctx
|
|
2615
|
+
auto * kv = llama_get_memory(ctx);
|
|
2514
2616
|
if (!kv) {
|
|
2515
2617
|
return;
|
|
2516
2618
|
}
|
|
2517
2619
|
|
|
2518
|
-
kv
|
|
2620
|
+
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
|
2519
2621
|
}
|
|
2520
2622
|
|
|
2623
|
+
// deprecated
|
|
2521
2624
|
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
|
2522
|
-
|
|
2625
|
+
auto * kv = llama_get_memory(ctx);
|
|
2523
2626
|
if (!kv) {
|
|
2524
2627
|
return -1;
|
|
2525
2628
|
}
|
|
2526
2629
|
|
|
2527
|
-
return kv
|
|
2630
|
+
return llama_memory_seq_pos_min(kv, seq_id);
|
|
2528
2631
|
}
|
|
2529
2632
|
|
|
2633
|
+
// deprecated
|
|
2530
2634
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
2531
|
-
|
|
2635
|
+
auto * kv = llama_get_memory(ctx);
|
|
2532
2636
|
if (!kv) {
|
|
2533
2637
|
return -1;
|
|
2534
2638
|
}
|
|
2535
2639
|
|
|
2536
|
-
return kv
|
|
2640
|
+
return llama_memory_seq_pos_max(kv, seq_id);
|
|
2537
2641
|
}
|
|
2538
2642
|
|
|
2539
2643
|
// deprecated
|
|
2540
2644
|
void llama_kv_self_defrag(llama_context * ctx) {
|
|
2541
|
-
auto * kv = ctx->get_kv_self();
|
|
2542
|
-
if (!kv) {
|
|
2543
|
-
return;
|
|
2544
|
-
}
|
|
2545
|
-
|
|
2546
2645
|
// force defrag
|
|
2547
|
-
|
|
2646
|
+
ctx->kv_self_defrag_sched();
|
|
2548
2647
|
}
|
|
2549
2648
|
|
|
2649
|
+
// deprecated
|
|
2550
2650
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
2551
|
-
|
|
2651
|
+
auto * kv = llama_get_memory(ctx);
|
|
2552
2652
|
if (!kv) {
|
|
2553
2653
|
return false;
|
|
2554
2654
|
}
|
|
2555
2655
|
|
|
2556
|
-
return kv
|
|
2656
|
+
return llama_memory_can_shift(kv);
|
|
2557
2657
|
}
|
|
2558
2658
|
|
|
2559
2659
|
// llama state API
|