@novastera-oss/llamarn 0.2.4 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +3 -2
- package/android/CMakeLists.txt +6 -3
- package/android/src/main/cpp/include/llama.h +12 -8
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +46 -65
- package/cpp/LlamaCppModel.h +5 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/README.md +1 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +5 -8
- package/cpp/llama.cpp/common/arg.cpp +8 -6
- package/cpp/llama.cpp/common/chat-parser.cpp +4 -3
- package/cpp/llama.cpp/common/chat-parser.h +2 -1
- package/cpp/llama.cpp/common/chat.cpp +4 -4
- package/cpp/llama.cpp/common/common.cpp +2 -0
- package/cpp/llama.cpp/common/json-partial.cpp +5 -4
- package/cpp/llama.cpp/common/json-partial.h +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +31 -28
- package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +23 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +19 -8
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml.c +9 -2
- package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
- package/cpp/llama.cpp/include/llama.h +12 -8
- package/cpp/llama.cpp/src/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +19 -12
- package/cpp/llama.cpp/src/llama-batch.h +15 -10
- package/cpp/llama.cpp/src/llama-context.cpp +226 -151
- package/cpp/llama.cpp/src/llama-context.h +25 -8
- package/cpp/llama.cpp/src/llama-graph.cpp +50 -47
- package/cpp/llama.cpp/src/llama-graph.h +25 -24
- package/cpp/llama.cpp/src/llama-kv-cache-recurrent.cpp +1132 -0
- package/cpp/llama.cpp/src/llama-kv-cache-recurrent.h +191 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +249 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +136 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1717 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +278 -0
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2746
- package/cpp/llama.cpp/src/llama-kv-cache.h +14 -472
- package/cpp/llama.cpp/src/llama-kv-cells.h +37 -6
- package/cpp/llama.cpp/src/llama-memory.h +44 -0
- package/cpp/llama.cpp/src/llama-model.cpp +23 -16
- package/cpp/llama.cpp/src/llama-vocab.cpp +7 -2
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
- package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
- package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
- package/cpp/rn-completion.cpp +101 -52
- package/cpp/rn-utils.hpp +8 -1
- package/ios/include/common/minja/chat-template.hpp +1 -1
- package/ios/include/common/minja/minja.hpp +1 -1
- package/ios/include/json-schema-to-grammar.h +4 -4
- package/ios/include/llama.h +12 -8
- package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
- package/ios/libs/llama.xcframework/Info.plist +22 -22
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4617
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3557
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3624 -3559
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4616
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4637
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3556
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4725 -4653
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4746 -4674
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3652 -3587
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "llama-batch.h"
|
|
4
|
+
#include "llama-graph.h"
|
|
5
|
+
#include "llama-kv-cache.h"
|
|
6
|
+
|
|
7
|
+
#include <set>
|
|
8
|
+
#include <vector>
|
|
9
|
+
|
|
10
|
+
//
|
|
11
|
+
// llama_kv_cache_recurrent
|
|
12
|
+
//
|
|
13
|
+
|
|
14
|
+
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
|
15
|
+
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
|
16
|
+
class llama_kv_cache_recurrent : public llama_kv_cache {
|
|
17
|
+
public:
|
|
18
|
+
llama_kv_cache_recurrent(
|
|
19
|
+
const llama_model & model,
|
|
20
|
+
ggml_type type_k,
|
|
21
|
+
ggml_type type_v,
|
|
22
|
+
bool offload,
|
|
23
|
+
uint32_t kv_size,
|
|
24
|
+
uint32_t n_seq_max);
|
|
25
|
+
|
|
26
|
+
~llama_kv_cache_recurrent() = default;
|
|
27
|
+
|
|
28
|
+
//
|
|
29
|
+
// llama_memory_i
|
|
30
|
+
//
|
|
31
|
+
|
|
32
|
+
void clear() override;
|
|
33
|
+
|
|
34
|
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
35
|
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
36
|
+
void seq_keep(llama_seq_id seq_id) override;
|
|
37
|
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
|
38
|
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
39
|
+
|
|
40
|
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
41
|
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
42
|
+
|
|
43
|
+
//
|
|
44
|
+
// llama_kv_cache
|
|
45
|
+
//
|
|
46
|
+
|
|
47
|
+
llama_memory_state_ptr init_batch(
|
|
48
|
+
const llama_batch & batch,
|
|
49
|
+
uint32_t n_ubatch,
|
|
50
|
+
bool embd_pooled,
|
|
51
|
+
bool logits_all) override;
|
|
52
|
+
|
|
53
|
+
llama_memory_state_ptr init_full() override;
|
|
54
|
+
|
|
55
|
+
bool update(llama_context & lctx) override;
|
|
56
|
+
|
|
57
|
+
void defrag_sched(float thold) override;
|
|
58
|
+
|
|
59
|
+
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
|
60
|
+
|
|
61
|
+
// find a contiguous slot of kv cells and emplace the ubatch there
|
|
62
|
+
bool find_slot(const llama_ubatch & ubatch);
|
|
63
|
+
|
|
64
|
+
bool get_can_shift() const override;
|
|
65
|
+
|
|
66
|
+
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
|
67
|
+
int32_t s_copy(int i) const;
|
|
68
|
+
float s_mask(int i) const;
|
|
69
|
+
|
|
70
|
+
// state write/load
|
|
71
|
+
|
|
72
|
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
73
|
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
|
74
|
+
|
|
75
|
+
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
|
76
|
+
uint32_t size = 0; // total number of cells, shared across all sequences
|
|
77
|
+
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
|
78
|
+
|
|
79
|
+
// computed before each graph build
|
|
80
|
+
uint32_t n = 0;
|
|
81
|
+
|
|
82
|
+
// TODO: optimize for recurrent state needs
|
|
83
|
+
struct kv_cell {
|
|
84
|
+
llama_pos pos = -1;
|
|
85
|
+
int32_t src = -1; // used to copy states
|
|
86
|
+
int32_t tail = -1;
|
|
87
|
+
|
|
88
|
+
std::set<llama_seq_id> seq_id;
|
|
89
|
+
|
|
90
|
+
bool has_seq_id(const llama_seq_id & id) const {
|
|
91
|
+
return seq_id.find(id) != seq_id.end();
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
bool is_empty() const {
|
|
95
|
+
return seq_id.empty();
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
bool is_same_seq(const kv_cell & other) const {
|
|
99
|
+
return seq_id == other.seq_id;
|
|
100
|
+
}
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
std::vector<kv_cell> cells;
|
|
104
|
+
|
|
105
|
+
std::vector<ggml_tensor *> k_l; // per layer
|
|
106
|
+
std::vector<ggml_tensor *> v_l;
|
|
107
|
+
|
|
108
|
+
private:
|
|
109
|
+
//const llama_model & model;
|
|
110
|
+
const llama_hparams & hparams;
|
|
111
|
+
|
|
112
|
+
const uint32_t n_seq_max = 1;
|
|
113
|
+
|
|
114
|
+
std::vector<ggml_context_ptr> ctxs;
|
|
115
|
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
116
|
+
|
|
117
|
+
size_t total_size() const;
|
|
118
|
+
|
|
119
|
+
size_t size_k_bytes() const;
|
|
120
|
+
size_t size_v_bytes() const;
|
|
121
|
+
|
|
122
|
+
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
|
123
|
+
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
124
|
+
|
|
125
|
+
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
|
126
|
+
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
|
127
|
+
};
|
|
128
|
+
|
|
129
|
+
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
|
130
|
+
public:
|
|
131
|
+
// used for errors
|
|
132
|
+
llama_kv_cache_recurrent_state(llama_memory_status status);
|
|
133
|
+
|
|
134
|
+
// used to create a full-cache state
|
|
135
|
+
llama_kv_cache_recurrent_state(
|
|
136
|
+
llama_memory_status status,
|
|
137
|
+
llama_kv_cache_recurrent * kv);
|
|
138
|
+
|
|
139
|
+
// used to create a state from a batch
|
|
140
|
+
llama_kv_cache_recurrent_state(
|
|
141
|
+
llama_memory_status status,
|
|
142
|
+
llama_kv_cache_recurrent * kv,
|
|
143
|
+
llama_sbatch sbatch,
|
|
144
|
+
std::vector<llama_ubatch> ubatches);
|
|
145
|
+
|
|
146
|
+
virtual ~llama_kv_cache_recurrent_state();
|
|
147
|
+
|
|
148
|
+
//
|
|
149
|
+
// llama_memory_state_i
|
|
150
|
+
//
|
|
151
|
+
|
|
152
|
+
bool next() override;
|
|
153
|
+
bool apply() override;
|
|
154
|
+
|
|
155
|
+
std::vector<int64_t> & out_ids() override;
|
|
156
|
+
|
|
157
|
+
llama_memory_status get_status() const override;
|
|
158
|
+
const llama_ubatch & get_ubatch() const override;
|
|
159
|
+
|
|
160
|
+
//
|
|
161
|
+
// llama_kv_cache_recurrent_state specific API
|
|
162
|
+
//
|
|
163
|
+
|
|
164
|
+
uint32_t get_n_kv() const;
|
|
165
|
+
uint32_t get_head() const;
|
|
166
|
+
uint32_t get_size() const;
|
|
167
|
+
|
|
168
|
+
ggml_tensor * get_k_l(int32_t il) const;
|
|
169
|
+
ggml_tensor * get_v_l(int32_t il) const;
|
|
170
|
+
|
|
171
|
+
int32_t s_copy(int i) const;
|
|
172
|
+
float s_mask(int i) const;
|
|
173
|
+
|
|
174
|
+
private:
|
|
175
|
+
const llama_memory_status status;
|
|
176
|
+
|
|
177
|
+
llama_kv_cache_recurrent * kv;
|
|
178
|
+
|
|
179
|
+
llama_sbatch sbatch;
|
|
180
|
+
|
|
181
|
+
size_t i_next = 0;
|
|
182
|
+
|
|
183
|
+
std::vector<llama_ubatch> ubatches;
|
|
184
|
+
|
|
185
|
+
//
|
|
186
|
+
// data needed for building the compute graph for the current ubatch:
|
|
187
|
+
// TODO: extract all the state like `head` and `n` here
|
|
188
|
+
//
|
|
189
|
+
|
|
190
|
+
const bool is_full = false;
|
|
191
|
+
};
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
#include "llama-kv-cache-unified-iswa.h"
|
|
2
|
+
|
|
3
|
+
#include "llama-impl.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
5
|
+
#include "llama-model.h"
|
|
6
|
+
|
|
7
|
+
#include <algorithm>
|
|
8
|
+
#include <cassert>
|
|
9
|
+
|
|
10
|
+
//
|
|
11
|
+
// llama_kv_cache_unified_iswa
|
|
12
|
+
//
|
|
13
|
+
|
|
14
|
+
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
15
|
+
const llama_model & model,
|
|
16
|
+
ggml_type type_k,
|
|
17
|
+
ggml_type type_v,
|
|
18
|
+
bool v_trans,
|
|
19
|
+
bool offload,
|
|
20
|
+
bool swa_full,
|
|
21
|
+
uint32_t kv_size,
|
|
22
|
+
uint32_t n_seq_max,
|
|
23
|
+
uint32_t n_ubatch,
|
|
24
|
+
uint32_t n_pad) : hparams(model.hparams) {
|
|
25
|
+
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
|
26
|
+
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
|
27
|
+
|
|
28
|
+
const uint32_t size_base = kv_size;
|
|
29
|
+
|
|
30
|
+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
|
31
|
+
|
|
32
|
+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
33
|
+
if (swa_full) {
|
|
34
|
+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
|
35
|
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
36
|
+
|
|
37
|
+
size_swa = size_base;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
|
41
|
+
|
|
42
|
+
kv_base = std::make_unique<llama_kv_cache_unified>(
|
|
43
|
+
model, std::move(filter_base), type_k, type_v,
|
|
44
|
+
v_trans, offload, size_base, n_seq_max, n_pad,
|
|
45
|
+
0, LLAMA_SWA_TYPE_NONE);
|
|
46
|
+
|
|
47
|
+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
48
|
+
|
|
49
|
+
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
|
50
|
+
model, std::move(filter_swa), type_k, type_v,
|
|
51
|
+
v_trans, offload, size_swa, n_seq_max, n_pad,
|
|
52
|
+
hparams.n_swa, hparams.swa_type);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
void llama_kv_cache_unified_iswa::clear() {
|
|
56
|
+
kv_base->clear();
|
|
57
|
+
kv_swa ->clear();
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
61
|
+
bool res = true;
|
|
62
|
+
|
|
63
|
+
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
|
64
|
+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
|
65
|
+
|
|
66
|
+
return res;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
70
|
+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
71
|
+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
|
75
|
+
kv_base->seq_keep(seq_id);
|
|
76
|
+
kv_swa ->seq_keep(seq_id);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
80
|
+
kv_base->seq_add(seq_id, p0, p1, shift);
|
|
81
|
+
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
85
|
+
kv_base->seq_div(seq_id, p0, p1, d);
|
|
86
|
+
kv_swa ->seq_div(seq_id, p0, p1, d);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
|
90
|
+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
|
91
|
+
return kv_swa->seq_pos_min(seq_id);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
95
|
+
return kv_swa->seq_pos_max(seq_id);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
|
99
|
+
GGML_UNUSED(embd_pooled);
|
|
100
|
+
|
|
101
|
+
// TODO: if we fail with split_simple, we should attempt different splitting strategies
|
|
102
|
+
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
103
|
+
|
|
104
|
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
|
105
|
+
|
|
106
|
+
std::vector<llama_ubatch> ubatches;
|
|
107
|
+
|
|
108
|
+
while (sbatch.n_tokens > 0) {
|
|
109
|
+
auto ubatch = sbatch.split_simple(n_ubatch);
|
|
110
|
+
|
|
111
|
+
ubatches.push_back(ubatch);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
auto heads_base = kv_base->prepare(ubatches);
|
|
115
|
+
if (heads_base.empty()) {
|
|
116
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
auto heads_swa = kv_swa->prepare(ubatches);
|
|
120
|
+
if (heads_swa.empty()) {
|
|
121
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
assert(heads_base.size() == heads_swa.size());
|
|
125
|
+
|
|
126
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
|
127
|
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
|
131
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
|
135
|
+
bool res = false;
|
|
136
|
+
|
|
137
|
+
res = res | kv_base->update(lctx);
|
|
138
|
+
res = res | kv_swa ->update(lctx);
|
|
139
|
+
|
|
140
|
+
return res;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
|
144
|
+
kv_base->defrag_sched(thold);
|
|
145
|
+
kv_swa ->defrag_sched(thold);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
149
|
+
return kv_base->get_size() == kv_swa->get_size();
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
153
|
+
kv_base->state_write(io, seq_id);
|
|
154
|
+
kv_swa ->state_write(io, seq_id);
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
158
|
+
kv_base->state_read(io, seq_id);
|
|
159
|
+
kv_swa ->state_read(io, seq_id);
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
|
163
|
+
return kv_base.get();
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
167
|
+
return kv_swa.get();
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
//
|
|
171
|
+
// llama_kv_cache_unified_iswa_state
|
|
172
|
+
//
|
|
173
|
+
|
|
174
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
|
175
|
+
|
|
176
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
177
|
+
llama_memory_status status,
|
|
178
|
+
llama_kv_cache_unified_iswa * kv) : status(status) {
|
|
179
|
+
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
|
|
180
|
+
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
184
|
+
llama_memory_status status,
|
|
185
|
+
llama_kv_cache_unified_iswa * kv,
|
|
186
|
+
llama_sbatch sbatch,
|
|
187
|
+
std::vector<uint32_t> heads_base,
|
|
188
|
+
std::vector<uint32_t> heads_swa,
|
|
189
|
+
std::vector<llama_ubatch> ubatches)
|
|
190
|
+
: status(status),
|
|
191
|
+
sbatch(std::move(sbatch)),
|
|
192
|
+
ubatches(std::move(ubatches)) {
|
|
193
|
+
// note: here we copy the ubatches. not sure if this is ideal
|
|
194
|
+
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
|
195
|
+
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
|
199
|
+
|
|
200
|
+
bool llama_kv_cache_unified_iswa_state::next() {
|
|
201
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
202
|
+
|
|
203
|
+
state_base->next();
|
|
204
|
+
state_swa ->next();
|
|
205
|
+
|
|
206
|
+
if (++i_next >= ubatches.size()) {
|
|
207
|
+
return false;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
return true;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
bool llama_kv_cache_unified_iswa_state::apply() {
|
|
214
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
215
|
+
|
|
216
|
+
bool res = true;
|
|
217
|
+
|
|
218
|
+
res = res & state_base->apply();
|
|
219
|
+
res = res & state_swa ->apply();
|
|
220
|
+
|
|
221
|
+
return res;
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
|
225
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
226
|
+
|
|
227
|
+
return sbatch.out_ids;
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
|
231
|
+
return status;
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
|
235
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
236
|
+
return ubatches[i_next];
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
|
240
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
241
|
+
|
|
242
|
+
return state_base.get();
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
|
246
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
247
|
+
|
|
248
|
+
return state_swa.get();
|
|
249
|
+
}
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "llama-kv-cache-unified.h"
|
|
4
|
+
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// llama_kv_cache_unified_iswa
|
|
9
|
+
//
|
|
10
|
+
|
|
11
|
+
// utilizes two instances of llama_kv_cache_unified
|
|
12
|
+
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
|
13
|
+
|
|
14
|
+
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
|
15
|
+
public:
|
|
16
|
+
llama_kv_cache_unified_iswa(
|
|
17
|
+
const llama_model & model,
|
|
18
|
+
ggml_type type_k,
|
|
19
|
+
ggml_type type_v,
|
|
20
|
+
bool v_trans,
|
|
21
|
+
bool offload,
|
|
22
|
+
bool swa_full,
|
|
23
|
+
uint32_t kv_size,
|
|
24
|
+
uint32_t n_seq_max,
|
|
25
|
+
uint32_t n_ubatch,
|
|
26
|
+
uint32_t n_pad);
|
|
27
|
+
|
|
28
|
+
~llama_kv_cache_unified_iswa() = default;
|
|
29
|
+
|
|
30
|
+
//
|
|
31
|
+
// llama_memory_i
|
|
32
|
+
//
|
|
33
|
+
|
|
34
|
+
void clear() override;
|
|
35
|
+
|
|
36
|
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
37
|
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
38
|
+
void seq_keep(llama_seq_id seq_id) override;
|
|
39
|
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
|
40
|
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
41
|
+
|
|
42
|
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
43
|
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
44
|
+
|
|
45
|
+
//
|
|
46
|
+
// llama_kv_cache
|
|
47
|
+
//
|
|
48
|
+
|
|
49
|
+
llama_memory_state_ptr init_batch(
|
|
50
|
+
const llama_batch & batch,
|
|
51
|
+
uint32_t n_ubatch,
|
|
52
|
+
bool embd_pooled,
|
|
53
|
+
bool logits_all) override;
|
|
54
|
+
|
|
55
|
+
llama_memory_state_ptr init_full() override;
|
|
56
|
+
|
|
57
|
+
bool update(llama_context & lctx) override;
|
|
58
|
+
|
|
59
|
+
void defrag_sched(float thold) override;
|
|
60
|
+
|
|
61
|
+
bool get_can_shift() const override;
|
|
62
|
+
|
|
63
|
+
// state write/load
|
|
64
|
+
|
|
65
|
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
66
|
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
|
67
|
+
|
|
68
|
+
//
|
|
69
|
+
// llama_kv_cache_unified_iswa specific API
|
|
70
|
+
//
|
|
71
|
+
|
|
72
|
+
llama_kv_cache_unified * get_base() const;
|
|
73
|
+
llama_kv_cache_unified * get_swa () const;
|
|
74
|
+
|
|
75
|
+
private:
|
|
76
|
+
const llama_hparams & hparams;
|
|
77
|
+
|
|
78
|
+
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
|
79
|
+
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
|
83
|
+
public:
|
|
84
|
+
// used for errors
|
|
85
|
+
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
|
86
|
+
|
|
87
|
+
// used to create a full-cache state
|
|
88
|
+
llama_kv_cache_unified_iswa_state(
|
|
89
|
+
llama_memory_status status,
|
|
90
|
+
llama_kv_cache_unified_iswa * kv);
|
|
91
|
+
|
|
92
|
+
// used to create a state from a batch
|
|
93
|
+
llama_kv_cache_unified_iswa_state(
|
|
94
|
+
llama_memory_status status,
|
|
95
|
+
llama_kv_cache_unified_iswa * kv,
|
|
96
|
+
llama_sbatch sbatch,
|
|
97
|
+
std::vector<uint32_t> heads_base,
|
|
98
|
+
std::vector<uint32_t> heads_swa,
|
|
99
|
+
std::vector<llama_ubatch> ubatches);
|
|
100
|
+
|
|
101
|
+
virtual ~llama_kv_cache_unified_iswa_state();
|
|
102
|
+
|
|
103
|
+
//
|
|
104
|
+
// llama_memory_state_i
|
|
105
|
+
//
|
|
106
|
+
|
|
107
|
+
bool next() override;
|
|
108
|
+
bool apply() override;
|
|
109
|
+
|
|
110
|
+
std::vector<int64_t> & out_ids() override;
|
|
111
|
+
|
|
112
|
+
llama_memory_status get_status() const override;
|
|
113
|
+
const llama_ubatch & get_ubatch() const override;
|
|
114
|
+
|
|
115
|
+
//
|
|
116
|
+
// llama_kv_cache_unified_iswa_state specific API
|
|
117
|
+
//
|
|
118
|
+
|
|
119
|
+
const llama_kv_cache_unified_state * get_base() const;
|
|
120
|
+
const llama_kv_cache_unified_state * get_swa() const;
|
|
121
|
+
|
|
122
|
+
private:
|
|
123
|
+
const llama_memory_status status;
|
|
124
|
+
|
|
125
|
+
//llama_kv_cache_unified_iswa * kv;
|
|
126
|
+
|
|
127
|
+
llama_sbatch sbatch;
|
|
128
|
+
|
|
129
|
+
// the index of the next ubatch to process
|
|
130
|
+
size_t i_next = 0;
|
|
131
|
+
|
|
132
|
+
std::vector<llama_ubatch> ubatches;
|
|
133
|
+
|
|
134
|
+
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
|
135
|
+
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
|
136
|
+
};
|