@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
#include "llama-
|
|
1
|
+
#include "llama-memory-recurrent.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
|
+
#include "llama-io.h"
|
|
4
5
|
#include "llama-batch.h"
|
|
5
6
|
#include "llama-model.h"
|
|
6
7
|
|
|
@@ -11,27 +12,28 @@
|
|
|
11
12
|
#include <stdexcept>
|
|
12
13
|
|
|
13
14
|
//
|
|
14
|
-
//
|
|
15
|
+
// llama_memory_recurrent
|
|
15
16
|
//
|
|
16
17
|
|
|
17
|
-
|
|
18
|
-
const llama_model &
|
|
19
|
-
|
|
20
|
-
ggml_type
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
uint32_t
|
|
18
|
+
llama_memory_recurrent::llama_memory_recurrent(
|
|
19
|
+
const llama_model & model,
|
|
20
|
+
layer_filter_cb && filter,
|
|
21
|
+
ggml_type type_r,
|
|
22
|
+
ggml_type type_s,
|
|
23
|
+
bool offload,
|
|
24
|
+
uint32_t mem_size,
|
|
25
|
+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
|
24
26
|
const int32_t n_layer = hparams.n_layer;
|
|
25
27
|
|
|
26
|
-
LLAMA_LOG_INFO("%s:
|
|
27
|
-
__func__,
|
|
28
|
+
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
|
|
29
|
+
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
|
|
28
30
|
|
|
29
31
|
head = 0;
|
|
30
|
-
size =
|
|
32
|
+
size = mem_size;
|
|
31
33
|
used = 0;
|
|
32
34
|
|
|
33
35
|
cells.clear();
|
|
34
|
-
cells.resize(
|
|
36
|
+
cells.resize(mem_size);
|
|
35
37
|
|
|
36
38
|
// create a context for each buffer type
|
|
37
39
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
@@ -58,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
|
58
60
|
return it->second;
|
|
59
61
|
};
|
|
60
62
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
+
r_l.resize(n_layer);
|
|
64
|
+
s_l.resize(n_layer);
|
|
63
65
|
|
|
64
66
|
for (int i = 0; i < n_layer; i++) {
|
|
65
|
-
|
|
66
|
-
|
|
67
|
+
if (filter && !filter(i)) {
|
|
68
|
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
|
|
69
|
+
continue;
|
|
70
|
+
}
|
|
67
71
|
|
|
68
72
|
const char * dev_name = "CPU";
|
|
69
73
|
|
|
@@ -83,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
|
83
87
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
|
84
88
|
}
|
|
85
89
|
|
|
86
|
-
ggml_tensor *
|
|
87
|
-
ggml_tensor *
|
|
88
|
-
ggml_format_name(
|
|
89
|
-
ggml_format_name(
|
|
90
|
-
|
|
91
|
-
|
|
90
|
+
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
|
91
|
+
ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
|
|
92
|
+
ggml_format_name(r, "cache_r_l%d", i);
|
|
93
|
+
ggml_format_name(s, "cache_s_l%d", i);
|
|
94
|
+
r_l[i] = r;
|
|
95
|
+
s_l[i] = s;
|
|
92
96
|
}
|
|
93
97
|
|
|
94
98
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
@@ -106,32 +110,35 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
|
106
110
|
}
|
|
107
111
|
|
|
108
112
|
{
|
|
109
|
-
const size_t
|
|
110
|
-
const size_t
|
|
113
|
+
const size_t memory_size_r = size_r_bytes();
|
|
114
|
+
const size_t memory_size_s = size_s_bytes();
|
|
111
115
|
|
|
112
|
-
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB,
|
|
113
|
-
(float)(
|
|
114
|
-
ggml_type_name(
|
|
115
|
-
ggml_type_name(
|
|
116
|
+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
|
117
|
+
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
|
118
|
+
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
|
119
|
+
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
|
116
120
|
}
|
|
117
121
|
}
|
|
118
122
|
|
|
119
|
-
void
|
|
123
|
+
void llama_memory_recurrent::clear(bool data) {
|
|
120
124
|
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
|
121
125
|
cells[i].pos = -1;
|
|
122
126
|
cells[i].seq_id.clear();
|
|
123
127
|
cells[i].src = -1;
|
|
124
128
|
cells[i].tail = -1;
|
|
125
129
|
}
|
|
130
|
+
|
|
126
131
|
head = 0;
|
|
127
132
|
used = 0;
|
|
128
133
|
|
|
129
|
-
|
|
130
|
-
|
|
134
|
+
if (data) {
|
|
135
|
+
for (auto & buf : bufs) {
|
|
136
|
+
ggml_backend_buffer_clear(buf.get(), 0);
|
|
137
|
+
}
|
|
131
138
|
}
|
|
132
139
|
}
|
|
133
140
|
|
|
134
|
-
bool
|
|
141
|
+
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
135
142
|
uint32_t new_head = size;
|
|
136
143
|
|
|
137
144
|
if (p0 < 0) {
|
|
@@ -150,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|
|
150
157
|
if (0 <= seq_id) {
|
|
151
158
|
int32_t & tail_id = cells[seq_id].tail;
|
|
152
159
|
if (tail_id >= 0) {
|
|
153
|
-
const
|
|
160
|
+
const auto & cell = cells[tail_id];
|
|
154
161
|
// partial intersection is invalid
|
|
155
162
|
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
|
156
163
|
return false;
|
|
@@ -198,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|
|
198
205
|
return true;
|
|
199
206
|
}
|
|
200
207
|
|
|
201
|
-
void
|
|
208
|
+
void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
202
209
|
if (seq_id_src == seq_id_dst) {
|
|
203
210
|
return;
|
|
204
211
|
}
|
|
@@ -212,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
|
212
219
|
}
|
|
213
220
|
|
|
214
221
|
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
|
215
|
-
|
|
216
|
-
|
|
222
|
+
auto & tail_src = cells[seq_id_src];
|
|
223
|
+
auto & tail_dst = cells[seq_id_dst];
|
|
217
224
|
if (tail_dst.tail >= 0) {
|
|
218
225
|
// clear destination seq_id if it wasn't empty
|
|
219
|
-
|
|
226
|
+
auto & cell_dst = cells[tail_dst.tail];
|
|
220
227
|
|
|
221
228
|
cell_dst.seq_id.erase(seq_id_dst);
|
|
222
229
|
tail_dst.tail = -1;
|
|
@@ -227,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
|
227
234
|
}
|
|
228
235
|
}
|
|
229
236
|
if (tail_src.tail >= 0) {
|
|
230
|
-
|
|
237
|
+
auto & cell_src = cells[tail_src.tail];
|
|
231
238
|
|
|
232
239
|
cell_src.seq_id.insert(seq_id_dst);
|
|
233
240
|
tail_dst.tail = tail_src.tail;
|
|
@@ -235,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
|
235
242
|
}
|
|
236
243
|
}
|
|
237
244
|
|
|
238
|
-
void
|
|
245
|
+
void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
|
|
239
246
|
uint32_t new_head = size;
|
|
240
247
|
|
|
241
248
|
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -267,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
|
|
267
274
|
}
|
|
268
275
|
}
|
|
269
276
|
|
|
270
|
-
void
|
|
277
|
+
void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
271
278
|
if (shift == 0) {
|
|
272
279
|
return;
|
|
273
280
|
}
|
|
@@ -289,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
289
296
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
|
290
297
|
const int32_t tail_id = cells[seq_id].tail;
|
|
291
298
|
if (tail_id >= 0) {
|
|
292
|
-
|
|
299
|
+
auto & cell = cells[tail_id];
|
|
293
300
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
|
294
301
|
cell.pos += shift;
|
|
295
302
|
}
|
|
@@ -297,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
297
304
|
}
|
|
298
305
|
}
|
|
299
306
|
|
|
300
|
-
void
|
|
307
|
+
void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
301
308
|
if (d == 1) {
|
|
302
309
|
return;
|
|
303
310
|
}
|
|
@@ -319,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
319
326
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
|
320
327
|
const int32_t tail_id = cells[seq_id].tail;
|
|
321
328
|
if (tail_id >= 0) {
|
|
322
|
-
|
|
329
|
+
auto & cell = cells[tail_id];
|
|
323
330
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
|
324
331
|
cell.pos /= d;
|
|
325
332
|
}
|
|
@@ -327,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
327
334
|
}
|
|
328
335
|
}
|
|
329
336
|
|
|
330
|
-
llama_pos
|
|
337
|
+
llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
|
331
338
|
llama_pos result = std::numeric_limits<llama_pos>::max();
|
|
332
339
|
|
|
333
340
|
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -343,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
|
|
343
350
|
return result;
|
|
344
351
|
}
|
|
345
352
|
|
|
346
|
-
llama_pos
|
|
353
|
+
llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
347
354
|
llama_pos result = -1;
|
|
348
355
|
|
|
349
356
|
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -355,18 +362,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
355
362
|
return result;
|
|
356
363
|
}
|
|
357
364
|
|
|
358
|
-
llama_memory_state_ptr
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
|
|
365
|
+
llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
|
366
|
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
|
362
367
|
|
|
363
368
|
std::vector<llama_ubatch> ubatches;
|
|
364
369
|
|
|
365
370
|
while (sbatch.n_tokens > 0) {
|
|
366
371
|
llama_ubatch ubatch;
|
|
367
372
|
|
|
368
|
-
if (
|
|
369
|
-
//
|
|
373
|
+
if (embd_all) {
|
|
374
|
+
// if all tokens are output, split by sequence
|
|
370
375
|
ubatch = sbatch.split_seq(n_ubatch);
|
|
371
376
|
} else {
|
|
372
377
|
ubatch = sbatch.split_equal(n_ubatch);
|
|
@@ -376,17 +381,24 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
|
|
|
376
381
|
}
|
|
377
382
|
|
|
378
383
|
if (!prepare(ubatches)) {
|
|
379
|
-
return std::make_unique<
|
|
384
|
+
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
380
385
|
}
|
|
381
386
|
|
|
382
|
-
return std::make_unique<
|
|
387
|
+
return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
|
|
383
388
|
}
|
|
384
389
|
|
|
385
|
-
llama_memory_state_ptr
|
|
386
|
-
return std::make_unique<
|
|
390
|
+
llama_memory_state_ptr llama_memory_recurrent::init_full() {
|
|
391
|
+
return std::make_unique<llama_memory_recurrent_state>(this);
|
|
387
392
|
}
|
|
388
393
|
|
|
389
|
-
|
|
394
|
+
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
|
395
|
+
GGML_UNUSED(lctx);
|
|
396
|
+
GGML_UNUSED(optimize);
|
|
397
|
+
|
|
398
|
+
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
390
402
|
// simply remember the full state because it is very small for this type of cache
|
|
391
403
|
// TODO: optimize
|
|
392
404
|
auto org_cells = cells;
|
|
@@ -395,21 +407,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
|
395
407
|
|
|
396
408
|
bool success = true;
|
|
397
409
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
|
|
405
|
-
//
|
|
406
|
-
GGML_UNUSED(ubatches);
|
|
407
|
-
//for (const auto & ubatch : ubatches) {
|
|
408
|
-
// if (!find_slot(ubatch)) {
|
|
409
|
-
// success = false;
|
|
410
|
-
// break;
|
|
411
|
-
// }
|
|
412
|
-
//}
|
|
410
|
+
for (const auto & ubatch : ubatches) {
|
|
411
|
+
if (!find_slot(ubatch)) {
|
|
412
|
+
success = false;
|
|
413
|
+
break;
|
|
414
|
+
}
|
|
415
|
+
}
|
|
413
416
|
|
|
414
417
|
// restore the original state
|
|
415
418
|
cells = std::move(org_cells);
|
|
@@ -419,26 +422,14 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
|
419
422
|
return success;
|
|
420
423
|
}
|
|
421
424
|
|
|
422
|
-
bool
|
|
423
|
-
|
|
424
|
-
// noop
|
|
425
|
-
return false;
|
|
426
|
-
}
|
|
427
|
-
|
|
428
|
-
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
|
429
|
-
GGML_UNUSED(thold);
|
|
430
|
-
// noop
|
|
431
|
-
}
|
|
432
|
-
|
|
433
|
-
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
434
|
-
const uint32_t n_tokens = ubatch.n_tokens;
|
|
435
|
-
const uint32_t n_seqs = ubatch.n_seqs;
|
|
425
|
+
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
426
|
+
const uint32_t n_seqs = ubatch.n_seqs;
|
|
436
427
|
|
|
437
428
|
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
438
429
|
|
|
439
430
|
// if we have enough unused cells before the current head ->
|
|
440
431
|
// better to start searching from the beginning of the cache, hoping to fill it
|
|
441
|
-
if (head > used + 2*
|
|
432
|
+
if (head > used + 2*n_seqs) {
|
|
442
433
|
head = 0;
|
|
443
434
|
}
|
|
444
435
|
|
|
@@ -465,9 +456,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
465
456
|
return false;
|
|
466
457
|
}
|
|
467
458
|
if (j > 0) {
|
|
468
|
-
|
|
459
|
+
auto & seq = cells[seq_id];
|
|
469
460
|
if (seq.tail >= 0) {
|
|
470
|
-
|
|
461
|
+
auto & cell = cells[seq.tail];
|
|
471
462
|
// clear cells from seq_ids that become shared
|
|
472
463
|
// (should not normally happen, but let's handle it anyway)
|
|
473
464
|
cell.seq_id.erase(seq_id);
|
|
@@ -487,7 +478,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
487
478
|
std::vector<int32_t> tails_verif;
|
|
488
479
|
tails_verif.assign(size, -1);
|
|
489
480
|
for (uint32_t i = 0; i < size; ++i) {
|
|
490
|
-
|
|
481
|
+
auto & cell = cells[i];
|
|
491
482
|
for (llama_seq_id seq_id : cell.seq_id) {
|
|
492
483
|
if (tails_verif[seq_id] != -1) {
|
|
493
484
|
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
|
@@ -508,7 +499,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
508
499
|
|
|
509
500
|
for (uint32_t i = 0; i < size; ++i) {
|
|
510
501
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
|
511
|
-
|
|
502
|
+
auto & cell = cells[next_empty_cell];
|
|
512
503
|
if (cell.is_empty()) { break; }
|
|
513
504
|
next_empty_cell += 1;
|
|
514
505
|
}
|
|
@@ -516,34 +507,34 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
516
507
|
// find usable cell range
|
|
517
508
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
518
509
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
519
|
-
|
|
510
|
+
auto & seq_meta = cells[seq_id];
|
|
520
511
|
bool has_cell = false;
|
|
521
512
|
if (seq_meta.tail >= 0) {
|
|
522
|
-
|
|
513
|
+
auto & cell = cells[seq_meta.tail];
|
|
523
514
|
GGML_ASSERT(cell.has_seq_id(seq_id));
|
|
524
515
|
// does this seq_id "own" the cell?
|
|
525
516
|
if (cell.seq_id.size() == 1) { has_cell = true; }
|
|
526
517
|
}
|
|
527
518
|
if (!has_cell) {
|
|
528
|
-
|
|
519
|
+
auto & empty_cell = cells[next_empty_cell];
|
|
529
520
|
GGML_ASSERT(empty_cell.is_empty());
|
|
530
521
|
// copy old tail into the empty cell
|
|
531
522
|
if (seq_meta.tail >= 0) {
|
|
532
|
-
|
|
523
|
+
auto & orig_cell = cells[seq_meta.tail];
|
|
533
524
|
empty_cell.pos = orig_cell.pos;
|
|
534
525
|
empty_cell.src = orig_cell.src;
|
|
535
526
|
orig_cell.seq_id.erase(seq_id);
|
|
536
527
|
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
|
528
|
+
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
|
|
537
529
|
}
|
|
538
530
|
seq_meta.tail = next_empty_cell;
|
|
539
531
|
// find next empty cell
|
|
540
532
|
if (s + 1 < n_seqs) {
|
|
541
|
-
next_empty_cell += 1;
|
|
542
533
|
for (uint32_t i = 0; i < size; ++i) {
|
|
534
|
+
next_empty_cell += 1;
|
|
543
535
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
|
544
|
-
|
|
536
|
+
auto & cell = cells[next_empty_cell];
|
|
545
537
|
if (cell.is_empty()) { break; }
|
|
546
|
-
next_empty_cell += 1;
|
|
547
538
|
}
|
|
548
539
|
}
|
|
549
540
|
}
|
|
@@ -553,22 +544,24 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
553
544
|
|
|
554
545
|
// gather and re-order
|
|
555
546
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
556
|
-
int32_t dst_id = s + min;
|
|
557
|
-
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
|
547
|
+
const int32_t dst_id = s + min;
|
|
548
|
+
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
|
558
549
|
if (dst_id != src_id) {
|
|
559
|
-
|
|
560
|
-
|
|
550
|
+
auto & dst_cell = cells[dst_id];
|
|
551
|
+
auto & src_cell = cells[src_id];
|
|
561
552
|
|
|
562
553
|
std::swap(dst_cell.pos, src_cell.pos);
|
|
563
554
|
std::swap(dst_cell.src, src_cell.src);
|
|
564
555
|
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
|
565
556
|
|
|
566
|
-
// swap tails
|
|
567
|
-
for (
|
|
568
|
-
cells[
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
557
|
+
// swap tails
|
|
558
|
+
for (uint32_t i = 0; i < size; ++i) {
|
|
559
|
+
int32_t & tail = cells[i].tail;
|
|
560
|
+
if (tail == src_id) {
|
|
561
|
+
tail = dst_id;
|
|
562
|
+
} else if (tail == dst_id) {
|
|
563
|
+
tail = src_id;
|
|
564
|
+
}
|
|
572
565
|
}
|
|
573
566
|
}
|
|
574
567
|
}
|
|
@@ -576,8 +569,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
576
569
|
// update the pos of the used seqs
|
|
577
570
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
578
571
|
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
|
579
|
-
int32_t cell_id = s + min;
|
|
580
|
-
|
|
572
|
+
const int32_t cell_id = s + min;
|
|
573
|
+
auto & cell = cells[cell_id];
|
|
581
574
|
|
|
582
575
|
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
|
583
576
|
// What should happen when the pos backtracks or skips a value?
|
|
@@ -594,61 +587,54 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
594
587
|
}
|
|
595
588
|
}
|
|
596
589
|
|
|
590
|
+
// Find first cell without src refs, to use as the zero-ed state
|
|
591
|
+
{
|
|
592
|
+
// TODO: bake-in src refcounts in the cell metadata
|
|
593
|
+
std::vector<int32_t> refcounts(size, 0);
|
|
594
|
+
for (size_t i = 0; i < size; ++i) {
|
|
595
|
+
const int32_t src = cells[i].src;
|
|
596
|
+
if (src >= 0) {
|
|
597
|
+
refcounts[src] += 1;
|
|
598
|
+
}
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
rs_z = -1;
|
|
602
|
+
for (int i = min; i <= max; ++i) {
|
|
603
|
+
if (refcounts[i] == 0) {
|
|
604
|
+
rs_z = i;
|
|
605
|
+
break;
|
|
606
|
+
}
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
for (int i = min; i <= max; ++i) {
|
|
610
|
+
if (cells[i].src < 0) {
|
|
611
|
+
GGML_ASSERT(rs_z >= 0);
|
|
612
|
+
cells[i].src0 = rs_z;
|
|
613
|
+
} else {
|
|
614
|
+
// Stage the source ids for all used cells to allow correct seq_* behavior
|
|
615
|
+
// and still make these values available when setting the inputs
|
|
616
|
+
cells[i].src0 = cells[i].src;
|
|
617
|
+
}
|
|
618
|
+
cells[i].src = i; // avoid moving or clearing twice
|
|
619
|
+
}
|
|
620
|
+
}
|
|
621
|
+
|
|
597
622
|
// allow getting the range of used cells, from head to head + n
|
|
598
623
|
head = min;
|
|
599
624
|
n = max - min + 1;
|
|
600
625
|
used = std::count_if(cells.begin(), cells.end(),
|
|
601
|
-
[](const
|
|
626
|
+
[](const mem_cell & cell){ return !cell.is_empty(); });
|
|
602
627
|
|
|
603
628
|
// sanity check
|
|
604
629
|
return n >= n_seqs;
|
|
605
630
|
}
|
|
606
631
|
|
|
607
|
-
bool
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
|
|
612
|
-
const uint32_t cell_id = i + head;
|
|
613
|
-
|
|
614
|
-
//////////////////////////////////////////////
|
|
615
|
-
// TODO: this should not mutate the KV cache !
|
|
616
|
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
|
617
|
-
|
|
618
|
-
// prevent out-of-bound sources
|
|
619
|
-
if (cell.src < 0 || (uint32_t) cell.src >= size) {
|
|
620
|
-
cell.src = cell_id;
|
|
621
|
-
}
|
|
622
|
-
|
|
623
|
-
int32_t res = cell.src;
|
|
624
|
-
|
|
625
|
-
// TODO: do not mutate the KV cache
|
|
626
|
-
// ensure copy only happens once
|
|
627
|
-
if (cell.src != (int32_t) cell_id) {
|
|
628
|
-
cell.src = cell_id;
|
|
629
|
-
}
|
|
630
|
-
|
|
631
|
-
return res;
|
|
632
|
-
}
|
|
633
|
-
|
|
634
|
-
float llama_kv_cache_recurrent::s_mask(int i) const {
|
|
635
|
-
const uint32_t cell_id = i + head;
|
|
636
|
-
|
|
637
|
-
//////////////////////////////////////////////
|
|
638
|
-
// TODO: this should not mutate the KV cache !
|
|
639
|
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
|
640
|
-
|
|
641
|
-
float res = (float) (cell.src >= 0);
|
|
642
|
-
|
|
643
|
-
// only clear once
|
|
644
|
-
if (cell.src < 0) {
|
|
645
|
-
cell.src = cell_id;
|
|
646
|
-
}
|
|
647
|
-
|
|
648
|
-
return res;
|
|
632
|
+
bool llama_memory_recurrent::get_can_shift() const {
|
|
633
|
+
// shifting the pos is trivial for recurrent models
|
|
634
|
+
return true;
|
|
649
635
|
}
|
|
650
636
|
|
|
651
|
-
size_t
|
|
637
|
+
size_t llama_memory_recurrent::total_size() const {
|
|
652
638
|
size_t size = 0;
|
|
653
639
|
for (const auto & buf : bufs) {
|
|
654
640
|
size += ggml_backend_buffer_get_size(buf.get());
|
|
@@ -657,27 +643,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
|
|
|
657
643
|
return size;
|
|
658
644
|
}
|
|
659
645
|
|
|
660
|
-
size_t
|
|
661
|
-
size_t
|
|
646
|
+
size_t llama_memory_recurrent::size_r_bytes() const {
|
|
647
|
+
size_t size_r_bytes = 0;
|
|
662
648
|
|
|
663
|
-
for (const auto &
|
|
664
|
-
|
|
649
|
+
for (const auto & r : r_l) {
|
|
650
|
+
if (r != nullptr) {
|
|
651
|
+
size_r_bytes += ggml_nbytes(r);
|
|
652
|
+
}
|
|
665
653
|
}
|
|
666
654
|
|
|
667
|
-
return
|
|
655
|
+
return size_r_bytes;
|
|
668
656
|
}
|
|
669
657
|
|
|
670
|
-
size_t
|
|
671
|
-
size_t
|
|
658
|
+
size_t llama_memory_recurrent::size_s_bytes() const {
|
|
659
|
+
size_t size_s_bytes = 0;
|
|
672
660
|
|
|
673
|
-
for (const auto &
|
|
674
|
-
|
|
661
|
+
for (const auto & s : s_l) {
|
|
662
|
+
if (s != nullptr) {
|
|
663
|
+
size_s_bytes += ggml_nbytes(s);
|
|
664
|
+
}
|
|
675
665
|
}
|
|
676
666
|
|
|
677
|
-
return
|
|
667
|
+
return size_s_bytes;
|
|
678
668
|
}
|
|
679
669
|
|
|
680
|
-
void
|
|
670
|
+
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
681
671
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
|
682
672
|
uint32_t cell_count = 0;
|
|
683
673
|
|
|
@@ -715,7 +705,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
|
|
|
715
705
|
state_write_data(io, cell_ranges);
|
|
716
706
|
}
|
|
717
707
|
|
|
718
|
-
void
|
|
708
|
+
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
719
709
|
uint32_t cell_count;
|
|
720
710
|
io.read_to(&cell_count, sizeof(cell_count));
|
|
721
711
|
|
|
@@ -726,7 +716,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
|
726
716
|
|
|
727
717
|
if (!res) {
|
|
728
718
|
if (seq_id == -1) {
|
|
729
|
-
clear();
|
|
719
|
+
clear(true);
|
|
730
720
|
} else {
|
|
731
721
|
seq_rm(seq_id, -1, -1);
|
|
732
722
|
}
|
|
@@ -734,7 +724,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
|
734
724
|
}
|
|
735
725
|
}
|
|
736
726
|
|
|
737
|
-
void
|
|
727
|
+
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
|
738
728
|
for (const auto & range : cell_ranges) {
|
|
739
729
|
for (uint32_t i = range.first; i < range.second; ++i) {
|
|
740
730
|
const auto & cell = cells[i];
|
|
@@ -753,87 +743,85 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
|
|
|
753
743
|
}
|
|
754
744
|
}
|
|
755
745
|
|
|
756
|
-
void
|
|
757
|
-
const uint32_t
|
|
746
|
+
void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
|
747
|
+
const uint32_t s_trans = 0;
|
|
758
748
|
const uint32_t n_layer = hparams.n_layer;
|
|
759
749
|
|
|
760
|
-
io.write(&
|
|
761
|
-
io.write(&n_layer,
|
|
750
|
+
io.write(&s_trans, sizeof(s_trans));
|
|
751
|
+
io.write(&n_layer, sizeof(n_layer));
|
|
762
752
|
|
|
763
753
|
std::vector<uint8_t> tmp_buf;
|
|
764
754
|
|
|
765
755
|
// Iterate and write all the keys first, each row is a cell
|
|
766
756
|
// Get whole range at a time
|
|
767
757
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
768
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
769
758
|
|
|
770
759
|
// Write key type
|
|
771
|
-
const int32_t
|
|
772
|
-
io.write(&
|
|
760
|
+
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
|
761
|
+
io.write(&r_type_i, sizeof(r_type_i));
|
|
773
762
|
|
|
774
763
|
// Write row size of key
|
|
775
|
-
const uint64_t
|
|
776
|
-
io.write(&
|
|
764
|
+
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
|
765
|
+
io.write(&r_size_row, sizeof(r_size_row));
|
|
777
766
|
|
|
778
767
|
// Read each range of cells of k_size length each into tmp_buf and write out
|
|
779
768
|
for (const auto & range : cell_ranges) {
|
|
780
769
|
const size_t range_size = range.second - range.first;
|
|
781
|
-
const size_t buf_size = range_size *
|
|
782
|
-
io.write_tensor(
|
|
770
|
+
const size_t buf_size = range_size * r_size_row;
|
|
771
|
+
io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
|
|
783
772
|
}
|
|
784
773
|
}
|
|
785
774
|
|
|
786
|
-
if (!
|
|
775
|
+
if (!s_trans) {
|
|
787
776
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
788
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
789
777
|
|
|
790
778
|
// Write value type
|
|
791
|
-
const int32_t
|
|
792
|
-
io.write(&
|
|
779
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
780
|
+
io.write(&s_type_i, sizeof(s_type_i));
|
|
793
781
|
|
|
794
782
|
// Write row size of value
|
|
795
|
-
const uint64_t
|
|
796
|
-
io.write(&
|
|
783
|
+
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
|
784
|
+
io.write(&s_size_row, sizeof(s_size_row));
|
|
797
785
|
|
|
798
|
-
// Read each range of cells of
|
|
786
|
+
// Read each range of cells of s_size length each into tmp_buf and write out
|
|
799
787
|
for (const auto & range : cell_ranges) {
|
|
800
788
|
const size_t range_size = range.second - range.first;
|
|
801
|
-
const size_t buf_size = range_size *
|
|
802
|
-
io.write_tensor(
|
|
789
|
+
const size_t buf_size = range_size * s_size_row;
|
|
790
|
+
io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
|
|
803
791
|
}
|
|
804
792
|
}
|
|
805
793
|
} else {
|
|
806
794
|
// When v is transposed, we also need the element size and get the element ranges from each row
|
|
807
|
-
const uint32_t
|
|
795
|
+
const uint32_t mem_size = size;
|
|
808
796
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
809
|
-
const uint32_t
|
|
797
|
+
const uint32_t n_embd_s = hparams.n_embd_s();
|
|
810
798
|
|
|
811
799
|
// Write value type
|
|
812
|
-
const int32_t
|
|
813
|
-
io.write(&
|
|
800
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
801
|
+
io.write(&s_type_i, sizeof(s_type_i));
|
|
814
802
|
|
|
815
803
|
// Write element size
|
|
816
|
-
const uint32_t
|
|
817
|
-
io.write(&
|
|
804
|
+
const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
|
|
805
|
+
io.write(&s_size_el, sizeof(s_size_el));
|
|
818
806
|
|
|
819
807
|
// Write GQA embedding size
|
|
820
|
-
io.write(&
|
|
808
|
+
io.write(&n_embd_s, sizeof(n_embd_s));
|
|
821
809
|
|
|
822
810
|
// For each row, we get the element values of each cell
|
|
823
|
-
for (uint32_t j = 0; j <
|
|
811
|
+
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
|
824
812
|
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
|
825
813
|
for (const auto & range : cell_ranges) {
|
|
826
814
|
const size_t range_size = range.second - range.first;
|
|
827
|
-
const size_t src_offset = (range.first + j *
|
|
828
|
-
const size_t buf_size = range_size *
|
|
829
|
-
io.write_tensor(
|
|
815
|
+
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
|
|
816
|
+
const size_t buf_size = range_size * s_size_el;
|
|
817
|
+
io.write_tensor(s_l[il], src_offset, buf_size);
|
|
830
818
|
}
|
|
831
819
|
}
|
|
832
820
|
}
|
|
833
821
|
}
|
|
834
822
|
}
|
|
835
823
|
|
|
836
|
-
bool
|
|
824
|
+
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
|
837
825
|
if (dest_seq_id != -1) {
|
|
838
826
|
// single sequence
|
|
839
827
|
|
|
@@ -883,10 +871,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
883
871
|
return false;
|
|
884
872
|
}
|
|
885
873
|
|
|
886
|
-
clear();
|
|
874
|
+
clear(true);
|
|
887
875
|
|
|
888
876
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
889
|
-
|
|
877
|
+
auto & cell = cells[i];
|
|
890
878
|
|
|
891
879
|
llama_pos pos;
|
|
892
880
|
uint32_t n_seq_id;
|
|
@@ -900,7 +888,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
900
888
|
llama_seq_id seq_id;
|
|
901
889
|
io.read_to(&seq_id, sizeof(seq_id));
|
|
902
890
|
|
|
903
|
-
// TODO:
|
|
891
|
+
// TODO: llama_memory_recurrent should have a notion of max sequences
|
|
904
892
|
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
|
905
893
|
if (seq_id < 0) {
|
|
906
894
|
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
|
@@ -932,10 +920,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
932
920
|
return true;
|
|
933
921
|
}
|
|
934
922
|
|
|
935
|
-
bool
|
|
936
|
-
uint32_t
|
|
923
|
+
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
|
924
|
+
uint32_t s_trans;
|
|
937
925
|
uint32_t n_layer;
|
|
938
|
-
io.read_to(&
|
|
926
|
+
io.read_to(&s_trans, sizeof(s_trans));
|
|
939
927
|
io.read_to(&n_layer, sizeof(n_layer));
|
|
940
928
|
|
|
941
929
|
if (n_layer != hparams.n_layer) {
|
|
@@ -946,102 +934,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
|
946
934
|
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
|
947
935
|
return false;
|
|
948
936
|
}
|
|
949
|
-
if (false != (bool)
|
|
950
|
-
LLAMA_LOG_ERROR("%s: incompatible
|
|
937
|
+
if (false != (bool) s_trans) {
|
|
938
|
+
LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
|
|
951
939
|
return false;
|
|
952
940
|
}
|
|
953
941
|
|
|
954
942
|
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
|
955
943
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
956
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
957
944
|
|
|
958
945
|
// Read type of key
|
|
959
|
-
int32_t
|
|
960
|
-
io.read_to(&
|
|
961
|
-
const int32_t
|
|
962
|
-
if (
|
|
963
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
946
|
+
int32_t r_type_i_ref;
|
|
947
|
+
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
|
948
|
+
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
|
949
|
+
if (r_type_i != r_type_i_ref) {
|
|
950
|
+
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
|
964
951
|
return false;
|
|
965
952
|
}
|
|
966
953
|
|
|
967
954
|
// Read row size of key
|
|
968
|
-
uint64_t
|
|
969
|
-
io.read_to(&
|
|
970
|
-
const size_t
|
|
971
|
-
if (
|
|
972
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
955
|
+
uint64_t r_size_row_ref;
|
|
956
|
+
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
|
957
|
+
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
|
958
|
+
if (r_size_row != r_size_row_ref) {
|
|
959
|
+
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
|
973
960
|
return false;
|
|
974
961
|
}
|
|
975
962
|
|
|
976
963
|
if (cell_count) {
|
|
977
964
|
// Read and set the keys for the whole cell range
|
|
978
|
-
ggml_backend_tensor_set(
|
|
965
|
+
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
|
979
966
|
}
|
|
980
967
|
}
|
|
981
968
|
|
|
982
|
-
if (!
|
|
969
|
+
if (!s_trans) {
|
|
983
970
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
984
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
985
971
|
|
|
986
972
|
// Read type of value
|
|
987
|
-
int32_t
|
|
988
|
-
io.read_to(&
|
|
989
|
-
const int32_t
|
|
990
|
-
if (
|
|
991
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
973
|
+
int32_t s_type_i_ref;
|
|
974
|
+
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
|
975
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
976
|
+
if (s_type_i != s_type_i_ref) {
|
|
977
|
+
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
|
992
978
|
return false;
|
|
993
979
|
}
|
|
994
980
|
|
|
995
981
|
// Read row size of value
|
|
996
|
-
uint64_t
|
|
997
|
-
io.read_to(&
|
|
998
|
-
const size_t
|
|
999
|
-
if (
|
|
1000
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
982
|
+
uint64_t s_size_row_ref;
|
|
983
|
+
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
|
984
|
+
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
|
985
|
+
if (s_size_row != s_size_row_ref) {
|
|
986
|
+
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
|
1001
987
|
return false;
|
|
1002
988
|
}
|
|
1003
989
|
|
|
1004
990
|
if (cell_count) {
|
|
1005
991
|
// Read and set the values for the whole cell range
|
|
1006
|
-
ggml_backend_tensor_set(
|
|
992
|
+
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
|
1007
993
|
}
|
|
1008
994
|
}
|
|
1009
995
|
} else {
|
|
1010
996
|
// For each layer, read the values for each cell (transposed)
|
|
1011
997
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
1012
|
-
const uint32_t
|
|
998
|
+
const uint32_t n_embd_s = hparams.n_embd_s();
|
|
1013
999
|
|
|
1014
1000
|
// Read type of value
|
|
1015
|
-
int32_t
|
|
1016
|
-
io.read_to(&
|
|
1017
|
-
const int32_t
|
|
1018
|
-
if (
|
|
1019
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
1001
|
+
int32_t s_type_i_ref;
|
|
1002
|
+
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
|
1003
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
1004
|
+
if (s_type_i != s_type_i_ref) {
|
|
1005
|
+
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
|
1020
1006
|
return false;
|
|
1021
1007
|
}
|
|
1022
1008
|
|
|
1023
1009
|
// Read element size of value
|
|
1024
|
-
uint32_t
|
|
1025
|
-
io.read_to(&
|
|
1026
|
-
const size_t
|
|
1027
|
-
if (
|
|
1028
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
1010
|
+
uint32_t s_size_el_ref;
|
|
1011
|
+
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
|
1012
|
+
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
|
1013
|
+
if (s_size_el != s_size_el_ref) {
|
|
1014
|
+
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
|
1029
1015
|
return false;
|
|
1030
1016
|
}
|
|
1031
1017
|
|
|
1032
|
-
// Read
|
|
1033
|
-
uint32_t
|
|
1034
|
-
io.read_to(&
|
|
1035
|
-
if (
|
|
1036
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
1018
|
+
// Read state embedding size
|
|
1019
|
+
uint32_t n_embd_s_ref;
|
|
1020
|
+
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
|
1021
|
+
if (n_embd_s != n_embd_s_ref) {
|
|
1022
|
+
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
|
1037
1023
|
return false;
|
|
1038
1024
|
}
|
|
1039
1025
|
|
|
1040
1026
|
if (cell_count) {
|
|
1041
1027
|
// For each row in the transposed matrix, read the values for the whole cell range
|
|
1042
|
-
for (uint32_t j = 0; j <
|
|
1043
|
-
const size_t dst_offset = (head + j * size) *
|
|
1044
|
-
ggml_backend_tensor_set(
|
|
1028
|
+
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
|
1029
|
+
const size_t dst_offset = (head + j * size) * s_size_el;
|
|
1030
|
+
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
|
1045
1031
|
}
|
|
1046
1032
|
}
|
|
1047
1033
|
}
|
|
@@ -1051,25 +1037,23 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
|
1051
1037
|
}
|
|
1052
1038
|
|
|
1053
1039
|
//
|
|
1054
|
-
//
|
|
1040
|
+
// llama_memory_recurrent_state
|
|
1055
1041
|
//
|
|
1056
1042
|
|
|
1057
|
-
|
|
1043
|
+
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
|
|
1058
1044
|
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
|
1045
|
+
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
|
1046
|
+
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
|
1062
1047
|
}
|
|
1063
1048
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
llama_kv_cache_recurrent * kv,
|
|
1049
|
+
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
|
1050
|
+
llama_memory_recurrent * mem,
|
|
1067
1051
|
llama_sbatch sbatch,
|
|
1068
|
-
std::vector<llama_ubatch> ubatches) : status(
|
|
1052
|
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
|
1069
1053
|
|
|
1070
|
-
|
|
1054
|
+
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
|
|
1071
1055
|
|
|
1072
|
-
bool
|
|
1056
|
+
bool llama_memory_recurrent_state::next() {
|
|
1073
1057
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1074
1058
|
|
|
1075
1059
|
if (++i_next >= ubatches.size()) {
|
|
@@ -1079,54 +1063,54 @@ bool llama_kv_cache_recurrent_state::next() {
|
|
|
1079
1063
|
return true;
|
|
1080
1064
|
}
|
|
1081
1065
|
|
|
1082
|
-
bool
|
|
1066
|
+
bool llama_memory_recurrent_state::apply() {
|
|
1083
1067
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1084
1068
|
|
|
1085
|
-
|
|
1069
|
+
mem->find_slot(ubatches[i_next]);
|
|
1086
1070
|
|
|
1087
1071
|
return true;
|
|
1088
1072
|
}
|
|
1089
1073
|
|
|
1090
|
-
std::vector<int64_t> &
|
|
1074
|
+
std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
|
|
1091
1075
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1092
1076
|
|
|
1093
1077
|
return sbatch.out_ids;
|
|
1094
1078
|
}
|
|
1095
1079
|
|
|
1096
|
-
llama_memory_status
|
|
1080
|
+
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
|
1097
1081
|
return status;
|
|
1098
1082
|
}
|
|
1099
1083
|
|
|
1100
|
-
const llama_ubatch &
|
|
1084
|
+
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
|
|
1101
1085
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1102
1086
|
|
|
1103
1087
|
return ubatches[i_next];
|
|
1104
1088
|
}
|
|
1105
1089
|
|
|
1106
|
-
uint32_t
|
|
1107
|
-
return is_full ?
|
|
1090
|
+
uint32_t llama_memory_recurrent_state::get_n_rs() const {
|
|
1091
|
+
return is_full ? mem->size : mem->n;
|
|
1108
1092
|
}
|
|
1109
1093
|
|
|
1110
|
-
uint32_t
|
|
1111
|
-
return is_full ? 0 :
|
|
1094
|
+
uint32_t llama_memory_recurrent_state::get_head() const {
|
|
1095
|
+
return is_full ? 0 : mem->head;
|
|
1112
1096
|
}
|
|
1113
1097
|
|
|
1114
|
-
|
|
1115
|
-
return
|
|
1098
|
+
int32_t llama_memory_recurrent_state::get_rs_z() const {
|
|
1099
|
+
return is_full ? 0 : mem->rs_z;
|
|
1116
1100
|
}
|
|
1117
1101
|
|
|
1118
|
-
|
|
1119
|
-
return
|
|
1102
|
+
uint32_t llama_memory_recurrent_state::get_size() const {
|
|
1103
|
+
return mem->size;
|
|
1120
1104
|
}
|
|
1121
1105
|
|
|
1122
|
-
ggml_tensor *
|
|
1123
|
-
return
|
|
1106
|
+
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
|
|
1107
|
+
return mem->r_l[il];
|
|
1124
1108
|
}
|
|
1125
1109
|
|
|
1126
|
-
|
|
1127
|
-
return
|
|
1110
|
+
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
|
|
1111
|
+
return mem->s_l[il];
|
|
1128
1112
|
}
|
|
1129
1113
|
|
|
1130
|
-
|
|
1131
|
-
return
|
|
1114
|
+
int32_t llama_memory_recurrent_state::s_copy(int i) const {
|
|
1115
|
+
return mem->cells[i + mem->head].src0;
|
|
1132
1116
|
}
|