@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
#include "llama-batch.h"
|
|
4
4
|
#include "llama-graph.h"
|
|
5
|
-
#include "llama-kv-cache.h"
|
|
6
5
|
#include "llama-kv-cells.h"
|
|
6
|
+
#include "llama-memory.h"
|
|
7
7
|
|
|
8
8
|
#include <unordered_map>
|
|
9
9
|
#include <vector>
|
|
@@ -17,13 +17,26 @@ struct llama_context;
|
|
|
17
17
|
// llama_kv_cache_unified
|
|
18
18
|
//
|
|
19
19
|
|
|
20
|
-
class llama_kv_cache_unified : public
|
|
20
|
+
class llama_kv_cache_unified : public llama_memory_i {
|
|
21
21
|
public:
|
|
22
22
|
static uint32_t get_padding(const llama_cparams & cparams);
|
|
23
23
|
|
|
24
24
|
// this callback is used to filter out layers that should not be included in the cache
|
|
25
25
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
|
26
26
|
|
|
27
|
+
using ubatch_heads = std::vector<uint32_t>;
|
|
28
|
+
|
|
29
|
+
struct defrag_info {
|
|
30
|
+
bool empty() const {
|
|
31
|
+
return ids.empty();
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
// contains information about which cell moves where:
|
|
35
|
+
// - cell i moves to ids[i]
|
|
36
|
+
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
|
37
|
+
std::vector<uint32_t> ids;
|
|
38
|
+
};
|
|
39
|
+
|
|
27
40
|
llama_kv_cache_unified(
|
|
28
41
|
const llama_model & model,
|
|
29
42
|
layer_filter_cb && filter,
|
|
@@ -43,7 +56,18 @@ public:
|
|
|
43
56
|
// llama_memory_i
|
|
44
57
|
//
|
|
45
58
|
|
|
46
|
-
|
|
59
|
+
llama_memory_state_ptr init_batch(
|
|
60
|
+
const llama_batch & batch,
|
|
61
|
+
uint32_t n_ubatch,
|
|
62
|
+
bool embd_all) override;
|
|
63
|
+
|
|
64
|
+
llama_memory_state_ptr init_full() override;
|
|
65
|
+
|
|
66
|
+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
67
|
+
|
|
68
|
+
bool get_can_shift() const override;
|
|
69
|
+
|
|
70
|
+
void clear(bool data) override;
|
|
47
71
|
|
|
48
72
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
49
73
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
@@ -54,24 +78,6 @@ public:
|
|
|
54
78
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
55
79
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
56
80
|
|
|
57
|
-
//
|
|
58
|
-
// llama_kv_cache
|
|
59
|
-
//
|
|
60
|
-
|
|
61
|
-
llama_memory_state_ptr init_batch(
|
|
62
|
-
const llama_batch & batch,
|
|
63
|
-
uint32_t n_ubatch,
|
|
64
|
-
bool embd_pooled,
|
|
65
|
-
bool logits_all) override;
|
|
66
|
-
|
|
67
|
-
llama_memory_state_ptr init_full() override;
|
|
68
|
-
|
|
69
|
-
bool update(llama_context & lctx) override;
|
|
70
|
-
|
|
71
|
-
void defrag_sched(float thold) override;
|
|
72
|
-
|
|
73
|
-
bool get_can_shift() const override;
|
|
74
|
-
|
|
75
81
|
// state write/load
|
|
76
82
|
|
|
77
83
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
@@ -83,6 +89,8 @@ public:
|
|
|
83
89
|
|
|
84
90
|
uint32_t get_size() const;
|
|
85
91
|
|
|
92
|
+
bool get_has_shift() const;
|
|
93
|
+
|
|
86
94
|
//
|
|
87
95
|
// graph_build API
|
|
88
96
|
//
|
|
@@ -103,7 +111,9 @@ public:
|
|
|
103
111
|
|
|
104
112
|
// find places for the provided ubatches in the cache, returns the head locations
|
|
105
113
|
// return empty vector on failure
|
|
106
|
-
|
|
114
|
+
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
|
115
|
+
|
|
116
|
+
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
|
107
117
|
|
|
108
118
|
// return the cell position where we can insert the ubatch
|
|
109
119
|
// return -1 on failure to find a contiguous slot of kv cells
|
|
@@ -133,8 +143,7 @@ private:
|
|
|
133
143
|
ggml_tensor * v;
|
|
134
144
|
};
|
|
135
145
|
|
|
136
|
-
bool
|
|
137
|
-
bool v_trans = true; // the value tensor is transposed
|
|
146
|
+
bool v_trans = true; // the value tensor is transposed
|
|
138
147
|
|
|
139
148
|
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
|
140
149
|
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
|
@@ -148,6 +157,8 @@ private:
|
|
|
148
157
|
// SWA
|
|
149
158
|
const uint32_t n_swa = 0;
|
|
150
159
|
|
|
160
|
+
int debug = 0;
|
|
161
|
+
|
|
151
162
|
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
|
152
163
|
|
|
153
164
|
std::vector<ggml_context_ptr> ctxs;
|
|
@@ -160,13 +171,8 @@ private:
|
|
|
160
171
|
// model layer id -> KV cache layer id
|
|
161
172
|
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
|
162
173
|
|
|
163
|
-
//
|
|
164
|
-
|
|
165
|
-
std::vector<uint32_t> ids;
|
|
166
|
-
} defrag_info;
|
|
167
|
-
|
|
168
|
-
// return true if cells have been moved
|
|
169
|
-
bool defrag_prepare(int32_t n_max_nodes);
|
|
174
|
+
// return non-empty vector if cells have been moved
|
|
175
|
+
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
|
170
176
|
|
|
171
177
|
size_t total_size() const;
|
|
172
178
|
|
|
@@ -192,7 +198,8 @@ private:
|
|
|
192
198
|
llm_graph_result_ptr build_graph_defrag(
|
|
193
199
|
const llama_cparams & cparams,
|
|
194
200
|
ggml_context * ctx,
|
|
195
|
-
ggml_cgraph * gf
|
|
201
|
+
ggml_cgraph * gf,
|
|
202
|
+
const defrag_info & dinfo) const;
|
|
196
203
|
|
|
197
204
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
|
198
205
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
@@ -203,20 +210,29 @@ private:
|
|
|
203
210
|
|
|
204
211
|
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
|
205
212
|
public:
|
|
213
|
+
// some shorthands
|
|
214
|
+
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
|
215
|
+
using defrag_info = llama_kv_cache_unified::defrag_info;
|
|
216
|
+
|
|
206
217
|
// used for errors
|
|
207
218
|
llama_kv_cache_unified_state(llama_memory_status status);
|
|
208
219
|
|
|
209
220
|
// used to create a full-cache state
|
|
210
221
|
llama_kv_cache_unified_state(
|
|
211
|
-
llama_memory_status status,
|
|
212
222
|
llama_kv_cache_unified * kv);
|
|
213
223
|
|
|
214
|
-
// used to create
|
|
224
|
+
// used to create an update state
|
|
225
|
+
llama_kv_cache_unified_state(
|
|
226
|
+
llama_kv_cache_unified * kv,
|
|
227
|
+
llama_context * lctx,
|
|
228
|
+
bool do_shift,
|
|
229
|
+
defrag_info dinfo);
|
|
230
|
+
|
|
231
|
+
// used to create a decode state from a batch
|
|
215
232
|
llama_kv_cache_unified_state(
|
|
216
|
-
llama_memory_status status,
|
|
217
233
|
llama_kv_cache_unified * kv,
|
|
218
234
|
llama_sbatch sbatch,
|
|
219
|
-
|
|
235
|
+
ubatch_heads heads,
|
|
220
236
|
std::vector<llama_ubatch> ubatches);
|
|
221
237
|
|
|
222
238
|
virtual ~llama_kv_cache_unified_state();
|
|
@@ -253,16 +269,30 @@ public:
|
|
|
253
269
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
|
254
270
|
|
|
255
271
|
private:
|
|
256
|
-
|
|
272
|
+
llama_memory_status status;
|
|
257
273
|
|
|
258
274
|
llama_kv_cache_unified * kv;
|
|
275
|
+
llama_context * lctx;
|
|
276
|
+
|
|
277
|
+
//
|
|
278
|
+
// update state
|
|
279
|
+
//
|
|
280
|
+
|
|
281
|
+
bool do_shift = false;
|
|
282
|
+
|
|
283
|
+
defrag_info dinfo;
|
|
284
|
+
|
|
285
|
+
//
|
|
286
|
+
// batch processing state
|
|
287
|
+
//
|
|
259
288
|
|
|
260
289
|
llama_sbatch sbatch;
|
|
261
290
|
|
|
262
291
|
// the index of the next ubatch to process
|
|
263
292
|
size_t i_next = 0;
|
|
264
293
|
|
|
265
|
-
|
|
294
|
+
ubatch_heads heads;
|
|
295
|
+
|
|
266
296
|
std::vector<llama_ubatch> ubatches;
|
|
267
297
|
|
|
268
298
|
//
|
|
@@ -23,7 +23,7 @@ public:
|
|
|
23
23
|
|
|
24
24
|
used.clear();
|
|
25
25
|
|
|
26
|
-
for (uint32_t s = 0; s <
|
|
26
|
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
27
27
|
seq_pos[s].clear();
|
|
28
28
|
}
|
|
29
29
|
}
|
|
@@ -80,6 +80,9 @@ public:
|
|
|
80
80
|
assert(isrc < pos.size());
|
|
81
81
|
assert(idst < pos.size());
|
|
82
82
|
|
|
83
|
+
assert(pos[idst] == -1);
|
|
84
|
+
assert(pos[isrc] != -1);
|
|
85
|
+
|
|
83
86
|
pos [idst] = pos [isrc];
|
|
84
87
|
shift[idst] = shift[isrc];
|
|
85
88
|
seq [idst] = seq [isrc];
|
|
@@ -144,9 +147,10 @@ public:
|
|
|
144
147
|
assert(pos[i] != -1);
|
|
145
148
|
|
|
146
149
|
seq_pos_rm(i);
|
|
150
|
+
seq[i].reset();
|
|
147
151
|
|
|
148
152
|
pos[i] = -1;
|
|
149
|
-
|
|
153
|
+
shift[i] = 0;
|
|
150
154
|
|
|
151
155
|
used.erase(i);
|
|
152
156
|
}
|
|
@@ -164,6 +168,7 @@ public:
|
|
|
164
168
|
|
|
165
169
|
if (seq[i].none()) {
|
|
166
170
|
pos[i] = -1;
|
|
171
|
+
shift[i] = 0;
|
|
167
172
|
|
|
168
173
|
used.erase(i);
|
|
169
174
|
|
|
@@ -192,6 +197,7 @@ public:
|
|
|
192
197
|
seq[i].reset();
|
|
193
198
|
|
|
194
199
|
pos[i] = -1;
|
|
200
|
+
shift[i] = 0;
|
|
195
201
|
|
|
196
202
|
used.erase(i);
|
|
197
203
|
|
|
@@ -234,7 +240,7 @@ public:
|
|
|
234
240
|
llama_seq_id seq_get(uint32_t i) const {
|
|
235
241
|
assert(seq[i].count() == 1);
|
|
236
242
|
|
|
237
|
-
for (int s = 0; s <
|
|
243
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
238
244
|
if (seq[i].test(s)) {
|
|
239
245
|
return s;
|
|
240
246
|
}
|
|
@@ -247,7 +253,7 @@ public:
|
|
|
247
253
|
// return -1 if the sequence is not present
|
|
248
254
|
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
|
249
255
|
assert(seq_id >= 0);
|
|
250
|
-
assert(seq_id <
|
|
256
|
+
assert(seq_id < LLAMA_MAX_SEQ);
|
|
251
257
|
|
|
252
258
|
if (seq_pos[seq_id].empty()) {
|
|
253
259
|
return -1;
|
|
@@ -260,7 +266,7 @@ public:
|
|
|
260
266
|
// return -1 if the sequence is not present
|
|
261
267
|
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
|
262
268
|
assert(seq_id >= 0);
|
|
263
|
-
assert(seq_id <
|
|
269
|
+
assert(seq_id < LLAMA_MAX_SEQ);
|
|
264
270
|
|
|
265
271
|
if (seq_pos[seq_id].empty()) {
|
|
266
272
|
return -1;
|
|
@@ -317,21 +323,20 @@ public:
|
|
|
317
323
|
pos[i] += d;
|
|
318
324
|
shift[i] += d;
|
|
319
325
|
|
|
320
|
-
seq_pos_add(i);
|
|
321
|
-
|
|
322
326
|
has_shift = true;
|
|
323
327
|
|
|
324
328
|
if (pos[i] < 0) {
|
|
325
|
-
seq_pos_rm(i);
|
|
326
|
-
|
|
327
329
|
seq[i].reset();
|
|
328
330
|
pos[i] = -1;
|
|
331
|
+
shift[i] = 0;
|
|
329
332
|
|
|
330
333
|
used.erase(i);
|
|
331
334
|
|
|
332
335
|
return true;
|
|
333
336
|
}
|
|
334
337
|
|
|
338
|
+
seq_pos_add(i);
|
|
339
|
+
|
|
335
340
|
return false;
|
|
336
341
|
}
|
|
337
342
|
|
|
@@ -379,20 +384,20 @@ private:
|
|
|
379
384
|
//
|
|
380
385
|
std::vector<llama_pos> shift;
|
|
381
386
|
|
|
382
|
-
using bits_t = std::bitset<
|
|
387
|
+
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
|
|
383
388
|
|
|
384
389
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
|
385
390
|
std::vector<bits_t> seq;
|
|
386
391
|
|
|
387
392
|
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
|
388
393
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
|
389
|
-
std::set<llama_pos> seq_pos[
|
|
394
|
+
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
|
390
395
|
|
|
391
396
|
// helper functions for updating `seq_pos`, once cell at a time:
|
|
392
397
|
|
|
393
398
|
// remove cell i
|
|
394
399
|
void seq_pos_rm(uint32_t i) {
|
|
395
|
-
for (int s = 0; s <
|
|
400
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
396
401
|
if (seq[i].test(s)) {
|
|
397
402
|
seq_pos[s].erase(pos[i]);
|
|
398
403
|
}
|
|
@@ -401,7 +406,7 @@ private:
|
|
|
401
406
|
|
|
402
407
|
// add cell i
|
|
403
408
|
void seq_pos_add(uint32_t i) {
|
|
404
|
-
for (int s = 0; s <
|
|
409
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
405
410
|
if (seq[i].test(s)) {
|
|
406
411
|
seq_pos[s].insert(pos[i]);
|
|
407
412
|
}
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
#include "llama-memory-hybrid.h"
|
|
2
|
+
|
|
3
|
+
#include "llama-impl.h"
|
|
4
|
+
#include "llama-model.h"
|
|
5
|
+
#include "llama-context.h"
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// llama_memory_hybrid
|
|
9
|
+
//
|
|
10
|
+
|
|
11
|
+
llama_memory_hybrid::llama_memory_hybrid(
|
|
12
|
+
const llama_model & model,
|
|
13
|
+
/* attn */
|
|
14
|
+
ggml_type type_k,
|
|
15
|
+
ggml_type type_v,
|
|
16
|
+
bool v_trans,
|
|
17
|
+
uint32_t kv_size,
|
|
18
|
+
uint32_t n_pad,
|
|
19
|
+
uint32_t n_swa,
|
|
20
|
+
llama_swa_type swa_type,
|
|
21
|
+
/* recurrent */
|
|
22
|
+
ggml_type type_r,
|
|
23
|
+
ggml_type type_s,
|
|
24
|
+
uint32_t rs_size,
|
|
25
|
+
/* common */
|
|
26
|
+
uint32_t n_seq_max,
|
|
27
|
+
bool offload,
|
|
28
|
+
/* layer filters */
|
|
29
|
+
layer_filter_cb && filter_attn,
|
|
30
|
+
layer_filter_cb && filter_recr) :
|
|
31
|
+
hparams(model.hparams),
|
|
32
|
+
mem_attn(new llama_kv_cache_unified(
|
|
33
|
+
model,
|
|
34
|
+
filter_attn == nullptr ?
|
|
35
|
+
[&](int32_t il) { return !model.hparams.is_recurrent(il); }
|
|
36
|
+
: filter_attn,
|
|
37
|
+
type_k,
|
|
38
|
+
type_v,
|
|
39
|
+
v_trans,
|
|
40
|
+
offload,
|
|
41
|
+
kv_size,
|
|
42
|
+
n_seq_max,
|
|
43
|
+
n_pad,
|
|
44
|
+
n_swa,
|
|
45
|
+
swa_type
|
|
46
|
+
)),
|
|
47
|
+
mem_recr(new llama_memory_recurrent(
|
|
48
|
+
model,
|
|
49
|
+
filter_recr == nullptr ?
|
|
50
|
+
[&](int32_t il) { return model.hparams.is_recurrent(il); }
|
|
51
|
+
: filter_recr,
|
|
52
|
+
type_r,
|
|
53
|
+
type_s,
|
|
54
|
+
offload,
|
|
55
|
+
rs_size,
|
|
56
|
+
n_seq_max
|
|
57
|
+
)) {}
|
|
58
|
+
|
|
59
|
+
llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
|
60
|
+
|
|
61
|
+
// since this includes a recurrent cache, we cannot use split_simple
|
|
62
|
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
|
63
|
+
|
|
64
|
+
// follow the recurrent pattern for creating the ubatch splits
|
|
65
|
+
std::vector<llama_ubatch> ubatches;
|
|
66
|
+
while (sbatch.n_tokens > 0) {
|
|
67
|
+
llama_ubatch ubatch;
|
|
68
|
+
|
|
69
|
+
if (embd_pooled) {
|
|
70
|
+
// Pooled embeddings cannot be split across ubatches (yet)
|
|
71
|
+
ubatch = sbatch.split_seq(n_ubatch);
|
|
72
|
+
} else {
|
|
73
|
+
ubatch = sbatch.split_equal(n_ubatch);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
ubatches.push_back(ubatch);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// prepare the recurrent batches first
|
|
80
|
+
if (!mem_recr->prepare(ubatches)) {
|
|
81
|
+
// TODO: will the recurrent cache be in an undefined state at this point?
|
|
82
|
+
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
|
83
|
+
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// prepare the attention cache
|
|
87
|
+
auto heads_attn = mem_attn->prepare(ubatches);
|
|
88
|
+
if (heads_attn.empty()) {
|
|
89
|
+
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
|
90
|
+
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
return std::make_unique<llama_memory_hybrid_state>(
|
|
94
|
+
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
llama_memory_state_ptr llama_memory_hybrid::init_full() {
|
|
98
|
+
return std::make_unique<llama_memory_hybrid_state>(this);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
|
102
|
+
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
bool llama_memory_hybrid::get_can_shift() const {
|
|
106
|
+
// Shifting is trivially supported for recurrent
|
|
107
|
+
return mem_attn->get_can_shift();
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
void llama_memory_hybrid::clear(bool data) {
|
|
111
|
+
mem_attn->clear(data);
|
|
112
|
+
mem_recr->clear(data);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
116
|
+
// Try removing from the recurrent cache first since it may fail. If it does
|
|
117
|
+
// fail, the cache will not have been mutated.
|
|
118
|
+
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
|
119
|
+
return false;
|
|
120
|
+
}
|
|
121
|
+
return mem_attn->seq_rm(seq_id, p0, p1);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
125
|
+
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
126
|
+
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
|
|
130
|
+
mem_attn->seq_keep(seq_id);
|
|
131
|
+
mem_recr->seq_keep(seq_id);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
135
|
+
mem_attn->seq_add(seq_id, p0, p1, shift);
|
|
136
|
+
mem_recr->seq_add(seq_id, p0, p1, shift);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
140
|
+
mem_attn->seq_div(seq_id, p0, p1, d);
|
|
141
|
+
mem_recr->seq_div(seq_id, p0, p1, d);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
|
|
145
|
+
// the min of the total cache is the max of the two caches' min values
|
|
146
|
+
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
|
150
|
+
// the max of the total cache is the min of the two caches' max values
|
|
151
|
+
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
155
|
+
mem_attn->state_write(io, seq_id);
|
|
156
|
+
mem_recr->state_write(io, seq_id);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
160
|
+
mem_attn->state_read(io, seq_id);
|
|
161
|
+
mem_recr->state_read(io, seq_id);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
|
165
|
+
return mem_attn.get();
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
|
169
|
+
return mem_recr.get();
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
|
|
173
|
+
|
|
174
|
+
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
|
|
175
|
+
state_attn(mem->get_mem_attn()->init_full()),
|
|
176
|
+
state_recr(mem->get_mem_recr()->init_full()),
|
|
177
|
+
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
|
181
|
+
llama_memory_hybrid * mem,
|
|
182
|
+
llama_context * lctx,
|
|
183
|
+
bool optimize) :
|
|
184
|
+
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
|
185
|
+
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
|
186
|
+
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
|
190
|
+
llama_memory_hybrid * mem,
|
|
191
|
+
llama_sbatch sbatch,
|
|
192
|
+
std::vector<uint32_t> heads_attn,
|
|
193
|
+
std::vector<llama_ubatch> ubatches) :
|
|
194
|
+
sbatch(std::move(sbatch)),
|
|
195
|
+
ubatches(std::move(ubatches)),
|
|
196
|
+
// note: here we copy the ubatches. not sure if this is ideal
|
|
197
|
+
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
|
|
198
|
+
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)),
|
|
199
|
+
status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
bool llama_memory_hybrid_state::next() {
|
|
203
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
204
|
+
|
|
205
|
+
state_attn->next();
|
|
206
|
+
state_recr->next();
|
|
207
|
+
|
|
208
|
+
if (++i_next >= ubatches.size()) {
|
|
209
|
+
return false;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
return true;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
bool llama_memory_hybrid_state::apply() {
|
|
216
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
217
|
+
|
|
218
|
+
bool res = true;
|
|
219
|
+
|
|
220
|
+
res = res & state_attn->apply();
|
|
221
|
+
res = res & state_recr->apply();
|
|
222
|
+
|
|
223
|
+
return res;
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
|
|
227
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
228
|
+
|
|
229
|
+
return sbatch.out_ids;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
llama_memory_status llama_memory_hybrid_state::get_status() const {
|
|
233
|
+
return status;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
|
|
237
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
238
|
+
return ubatches[i_next];
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
|
|
242
|
+
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
|
|
246
|
+
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
|
|
247
|
+
}
|