llama-cpp-capacitor 0.0.5 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/LICENSE +21 -0
- package/cpp/README.md +4 -0
- package/cpp/anyascii.c +22223 -0
- package/cpp/anyascii.h +42 -0
- package/cpp/chat-parser.cpp +393 -0
- package/cpp/chat-parser.h +120 -0
- package/cpp/chat.cpp +2315 -0
- package/cpp/chat.h +221 -0
- package/cpp/common.cpp +1619 -0
- package/cpp/common.h +744 -0
- package/cpp/ggml-alloc.c +1028 -0
- package/cpp/ggml-alloc.h +76 -0
- package/cpp/ggml-backend-impl.h +255 -0
- package/cpp/ggml-backend-reg.cpp +600 -0
- package/cpp/ggml-backend.cpp +2118 -0
- package/cpp/ggml-backend.h +354 -0
- package/cpp/ggml-common.h +1878 -0
- package/cpp/ggml-cpp.h +39 -0
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2512 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +3650 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +1891 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +3820 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +6307 -0
- package/cpp/ggml-cpu/arch-fallback.h +215 -0
- package/cpp/ggml-cpu/binary-ops.cpp +158 -0
- package/cpp/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml-cpu/common.h +73 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +525 -0
- package/cpp/ggml-cpu/ggml-cpu.c +3578 -0
- package/cpp/ggml-cpu/ggml-cpu.cpp +672 -0
- package/cpp/ggml-cpu/ops.cpp +10587 -0
- package/cpp/ggml-cpu/ops.h +114 -0
- package/cpp/ggml-cpu/quants.c +1193 -0
- package/cpp/ggml-cpu/quants.h +97 -0
- package/cpp/ggml-cpu/repack.cpp +1982 -0
- package/cpp/ggml-cpu/repack.h +120 -0
- package/cpp/ggml-cpu/simd-mappings.h +1184 -0
- package/cpp/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml-cpu/traits.h +38 -0
- package/cpp/ggml-cpu/unary-ops.cpp +186 -0
- package/cpp/ggml-cpu/unary-ops.h +28 -0
- package/cpp/ggml-cpu/vec.cpp +348 -0
- package/cpp/ggml-cpu/vec.h +1121 -0
- package/cpp/ggml-cpu.h +145 -0
- package/cpp/ggml-impl.h +622 -0
- package/cpp/ggml-metal-impl.h +688 -0
- package/cpp/ggml-metal.h +66 -0
- package/cpp/ggml-metal.m +6833 -0
- package/cpp/ggml-opt.cpp +1093 -0
- package/cpp/ggml-opt.h +256 -0
- package/cpp/ggml-quants.c +5324 -0
- package/cpp/ggml-quants.h +106 -0
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +14 -0
- package/cpp/ggml.c +7108 -0
- package/cpp/ggml.h +2492 -0
- package/cpp/gguf.cpp +1358 -0
- package/cpp/gguf.h +202 -0
- package/cpp/json-partial.cpp +256 -0
- package/cpp/json-partial.h +38 -0
- package/cpp/json-schema-to-grammar.cpp +985 -0
- package/cpp/json-schema-to-grammar.h +21 -0
- package/cpp/llama-adapter.cpp +388 -0
- package/cpp/llama-adapter.h +76 -0
- package/cpp/llama-arch.cpp +2355 -0
- package/cpp/llama-arch.h +499 -0
- package/cpp/llama-batch.cpp +875 -0
- package/cpp/llama-batch.h +160 -0
- package/cpp/llama-chat.cpp +783 -0
- package/cpp/llama-chat.h +65 -0
- package/cpp/llama-context.cpp +2748 -0
- package/cpp/llama-context.h +306 -0
- package/cpp/llama-cparams.cpp +5 -0
- package/cpp/llama-cparams.h +41 -0
- package/cpp/llama-cpp.h +30 -0
- package/cpp/llama-grammar.cpp +1229 -0
- package/cpp/llama-grammar.h +173 -0
- package/cpp/llama-graph.cpp +1891 -0
- package/cpp/llama-graph.h +810 -0
- package/cpp/llama-hparams.cpp +180 -0
- package/cpp/llama-hparams.h +233 -0
- package/cpp/llama-impl.cpp +167 -0
- package/cpp/llama-impl.h +61 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache-iswa.cpp +318 -0
- package/cpp/llama-kv-cache-iswa.h +135 -0
- package/cpp/llama-kv-cache.cpp +2059 -0
- package/cpp/llama-kv-cache.h +374 -0
- package/cpp/llama-kv-cells.h +491 -0
- package/cpp/llama-memory-hybrid.cpp +258 -0
- package/cpp/llama-memory-hybrid.h +137 -0
- package/cpp/llama-memory-recurrent.cpp +1146 -0
- package/cpp/llama-memory-recurrent.h +179 -0
- package/cpp/llama-memory.cpp +59 -0
- package/cpp/llama-memory.h +119 -0
- package/cpp/llama-mmap.cpp +600 -0
- package/cpp/llama-mmap.h +68 -0
- package/cpp/llama-model-loader.cpp +1164 -0
- package/cpp/llama-model-loader.h +170 -0
- package/cpp/llama-model-saver.cpp +282 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +19042 -0
- package/cpp/llama-model.h +491 -0
- package/cpp/llama-sampling.cpp +2575 -0
- package/cpp/llama-sampling.h +32 -0
- package/cpp/llama-vocab.cpp +3792 -0
- package/cpp/llama-vocab.h +176 -0
- package/cpp/llama.cpp +358 -0
- package/cpp/llama.h +1373 -0
- package/cpp/log.cpp +427 -0
- package/cpp/log.h +103 -0
- package/cpp/minja/chat-template.hpp +550 -0
- package/cpp/minja/minja.hpp +3009 -0
- package/cpp/nlohmann/json.hpp +25526 -0
- package/cpp/nlohmann/json_fwd.hpp +187 -0
- package/cpp/regex-partial.cpp +204 -0
- package/cpp/regex-partial.h +56 -0
- package/cpp/rn-completion.cpp +681 -0
- package/cpp/rn-completion.h +116 -0
- package/cpp/rn-llama.cpp +345 -0
- package/cpp/rn-llama.h +149 -0
- package/cpp/rn-mtmd.hpp +602 -0
- package/cpp/rn-tts.cpp +591 -0
- package/cpp/rn-tts.h +59 -0
- package/cpp/sampling.cpp +579 -0
- package/cpp/sampling.h +107 -0
- package/cpp/tools/mtmd/clip-impl.h +473 -0
- package/cpp/tools/mtmd/clip.cpp +4322 -0
- package/cpp/tools/mtmd/clip.h +106 -0
- package/cpp/tools/mtmd/miniaudio/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +769 -0
- package/cpp/tools/mtmd/mtmd-audio.h +47 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +460 -0
- package/cpp/tools/mtmd/mtmd-helper.h +91 -0
- package/cpp/tools/mtmd/mtmd.cpp +1066 -0
- package/cpp/tools/mtmd/mtmd.h +298 -0
- package/cpp/tools/mtmd/stb/stb_image.h +7988 -0
- package/cpp/unicode-data.cpp +7034 -0
- package/cpp/unicode-data.h +20 -0
- package/cpp/unicode.cpp +1061 -0
- package/cpp/unicode.h +68 -0
- package/package.json +2 -1
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
#include "llama-kv-cache-iswa.h"
|
|
2
|
+
|
|
3
|
+
#include "llama-impl.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
5
|
+
#include "llama-model.h"
|
|
6
|
+
|
|
7
|
+
#include <algorithm>
|
|
8
|
+
#include <cassert>
|
|
9
|
+
|
|
10
|
+
//
|
|
11
|
+
// llama_kv_cache_iswa
|
|
12
|
+
//
|
|
13
|
+
|
|
14
|
+
llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|
15
|
+
const llama_model & model,
|
|
16
|
+
lm_ggml_type type_k,
|
|
17
|
+
lm_ggml_type type_v,
|
|
18
|
+
bool v_trans,
|
|
19
|
+
bool offload,
|
|
20
|
+
bool swa_full,
|
|
21
|
+
bool unified,
|
|
22
|
+
uint32_t kv_size,
|
|
23
|
+
uint32_t n_seq_max,
|
|
24
|
+
uint32_t n_ubatch,
|
|
25
|
+
uint32_t n_pad,
|
|
26
|
+
const layer_filter_cb & filter,
|
|
27
|
+
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
|
|
28
|
+
|
|
29
|
+
// chain filters
|
|
30
|
+
const layer_filter_cb filter_base = [&](int32_t il) {
|
|
31
|
+
if (filter && !filter(il)) {
|
|
32
|
+
return false;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
return !model.hparams.is_swa(il);
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
const layer_filter_cb filter_swa = [&](int32_t il) {
|
|
39
|
+
if (filter && !filter(il)) {
|
|
40
|
+
return false;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return model.hparams.is_swa(il);
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
const uint32_t size_base = kv_size;
|
|
47
|
+
|
|
48
|
+
uint32_t size_swa = std::min(size_base, LM_GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
|
49
|
+
|
|
50
|
+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
51
|
+
if (swa_full) {
|
|
52
|
+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
|
53
|
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
54
|
+
|
|
55
|
+
size_swa = size_base;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
|
59
|
+
|
|
60
|
+
kv_base = std::make_unique<llama_kv_cache>(
|
|
61
|
+
model, type_k, type_v,
|
|
62
|
+
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
|
63
|
+
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
|
|
64
|
+
|
|
65
|
+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
66
|
+
|
|
67
|
+
kv_swa = std::make_unique<llama_kv_cache>(
|
|
68
|
+
model, type_k, type_v,
|
|
69
|
+
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
|
70
|
+
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
void llama_kv_cache_iswa::clear(bool data) {
|
|
74
|
+
kv_base->clear(data);
|
|
75
|
+
kv_swa ->clear(data);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
79
|
+
bool res = true;
|
|
80
|
+
|
|
81
|
+
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
|
82
|
+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
|
83
|
+
|
|
84
|
+
return res;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
88
|
+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
89
|
+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
|
|
93
|
+
kv_base->seq_keep(seq_id);
|
|
94
|
+
kv_swa ->seq_keep(seq_id);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
98
|
+
kv_base->seq_add(seq_id, p0, p1, shift);
|
|
99
|
+
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
103
|
+
kv_base->seq_div(seq_id, p0, p1, d);
|
|
104
|
+
kv_swa ->seq_div(seq_id, p0, p1, d);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
|
108
|
+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
|
109
|
+
return kv_swa->seq_pos_min(seq_id);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
113
|
+
return kv_swa->seq_pos_max(seq_id);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
117
|
+
LM_GGML_UNUSED(embd_all);
|
|
118
|
+
|
|
119
|
+
// first try simple split
|
|
120
|
+
do {
|
|
121
|
+
if (!unified) {
|
|
122
|
+
// requires equal splits, so we skip the simple split
|
|
123
|
+
break;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
balloc.split_reset();
|
|
127
|
+
|
|
128
|
+
std::vector<llama_ubatch> ubatches;
|
|
129
|
+
while (true) {
|
|
130
|
+
auto ubatch = balloc.split_simple(n_ubatch);
|
|
131
|
+
|
|
132
|
+
if (ubatch.n_tokens == 0) {
|
|
133
|
+
break;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
140
|
+
// failed to find a suitable split
|
|
141
|
+
break;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
auto sinfos_base = kv_base->prepare(ubatches);
|
|
145
|
+
if (sinfos_base.empty()) {
|
|
146
|
+
break;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
150
|
+
if (sinfos_swa.empty()) {
|
|
151
|
+
break;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
assert(sinfos_base.size() == sinfos_swa.size());
|
|
155
|
+
|
|
156
|
+
return std::make_unique<llama_kv_cache_iswa_context>(
|
|
157
|
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
158
|
+
} while (false);
|
|
159
|
+
|
|
160
|
+
// if it fails, try equal split
|
|
161
|
+
do {
|
|
162
|
+
balloc.split_reset();
|
|
163
|
+
|
|
164
|
+
std::vector<llama_ubatch> ubatches;
|
|
165
|
+
while (true) {
|
|
166
|
+
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
|
167
|
+
|
|
168
|
+
if (ubatch.n_tokens == 0) {
|
|
169
|
+
break;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
176
|
+
// failed to find a suitable split
|
|
177
|
+
break;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
auto sinfos_base = kv_base->prepare(ubatches);
|
|
181
|
+
if (sinfos_base.empty()) {
|
|
182
|
+
break;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
186
|
+
if (sinfos_swa.empty()) {
|
|
187
|
+
break;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
assert(sinfos_base.size() == sinfos_swa.size());
|
|
191
|
+
|
|
192
|
+
return std::make_unique<llama_kv_cache_iswa_context>(
|
|
193
|
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
194
|
+
} while (false);
|
|
195
|
+
|
|
196
|
+
// TODO: if we fail again, we should attempt different splitting strategies
|
|
197
|
+
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
198
|
+
|
|
199
|
+
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
|
|
203
|
+
return std::make_unique<llama_kv_cache_iswa_context>(this);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
|
|
207
|
+
return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
bool llama_kv_cache_iswa::get_can_shift() const {
|
|
211
|
+
return kv_base->get_size() == kv_swa->get_size();
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
|
215
|
+
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
|
216
|
+
kv_base->state_write(io, seq_id, flags);
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
kv_swa->state_write(io, seq_id, flags);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
223
|
+
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
|
224
|
+
kv_base->state_read(io, seq_id, flags);
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
kv_swa->state_read(io, seq_id, flags);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
|
|
231
|
+
return kv_base.get();
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
|
|
235
|
+
return kv_swa.get();
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
//
|
|
239
|
+
// llama_kv_cache_iswa_context
|
|
240
|
+
//
|
|
241
|
+
|
|
242
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
|
|
243
|
+
|
|
244
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
245
|
+
llama_kv_cache_iswa * kv) :
|
|
246
|
+
ctx_base(kv->get_base()->init_full()),
|
|
247
|
+
ctx_swa (kv->get_swa ()->init_full()),
|
|
248
|
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
252
|
+
llama_kv_cache_iswa * kv,
|
|
253
|
+
llama_context * lctx,
|
|
254
|
+
bool optimize) :
|
|
255
|
+
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
|
256
|
+
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
|
257
|
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
261
|
+
llama_kv_cache_iswa * kv,
|
|
262
|
+
slot_info_vec_t sinfos_base,
|
|
263
|
+
slot_info_vec_t sinfos_swa,
|
|
264
|
+
std::vector<llama_ubatch> ubatches) :
|
|
265
|
+
ubatches(std::move(ubatches)),
|
|
266
|
+
// note: here we copy the ubatches. not sure if this is ideal
|
|
267
|
+
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
|
268
|
+
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
|
269
|
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
|
|
273
|
+
|
|
274
|
+
bool llama_kv_cache_iswa_context::next() {
|
|
275
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
276
|
+
|
|
277
|
+
ctx_base->next();
|
|
278
|
+
ctx_swa ->next();
|
|
279
|
+
|
|
280
|
+
if (++i_next >= ubatches.size()) {
|
|
281
|
+
return false;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
return true;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
bool llama_kv_cache_iswa_context::apply() {
|
|
288
|
+
assert(!llama_memory_status_is_fail(status));
|
|
289
|
+
|
|
290
|
+
bool res = true;
|
|
291
|
+
|
|
292
|
+
res = res & ctx_base->apply();
|
|
293
|
+
res = res & ctx_swa ->apply();
|
|
294
|
+
|
|
295
|
+
return res;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
|
|
299
|
+
return status;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
|
|
303
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
304
|
+
|
|
305
|
+
return ubatches[i_next];
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
|
|
309
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
310
|
+
|
|
311
|
+
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
|
|
315
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
316
|
+
|
|
317
|
+
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
|
|
318
|
+
}
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "llama-kv-cache.h"
|
|
4
|
+
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// llama_kv_cache_iswa
|
|
9
|
+
//
|
|
10
|
+
|
|
11
|
+
// utilizes two instances of llama_kv_cache
|
|
12
|
+
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
|
13
|
+
|
|
14
|
+
class llama_kv_cache_iswa : public llama_memory_i {
|
|
15
|
+
public:
|
|
16
|
+
llama_kv_cache_iswa(
|
|
17
|
+
const llama_model & model,
|
|
18
|
+
lm_ggml_type type_k,
|
|
19
|
+
lm_ggml_type type_v,
|
|
20
|
+
bool v_trans,
|
|
21
|
+
bool offload,
|
|
22
|
+
bool swa_full,
|
|
23
|
+
bool unified,
|
|
24
|
+
uint32_t kv_size,
|
|
25
|
+
uint32_t n_seq_max,
|
|
26
|
+
uint32_t n_ubatch,
|
|
27
|
+
uint32_t n_pad,
|
|
28
|
+
const layer_filter_cb & filter,
|
|
29
|
+
const layer_reuse_cb & reuse);
|
|
30
|
+
|
|
31
|
+
~llama_kv_cache_iswa() = default;
|
|
32
|
+
|
|
33
|
+
//
|
|
34
|
+
// llama_memory_i
|
|
35
|
+
//
|
|
36
|
+
|
|
37
|
+
llama_memory_context_ptr init_batch(
|
|
38
|
+
llama_batch_allocr & balloc,
|
|
39
|
+
uint32_t n_ubatch,
|
|
40
|
+
bool embd_all) override;
|
|
41
|
+
|
|
42
|
+
llama_memory_context_ptr init_full() override;
|
|
43
|
+
|
|
44
|
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
45
|
+
|
|
46
|
+
bool get_can_shift() const override;
|
|
47
|
+
|
|
48
|
+
void clear(bool data) override;
|
|
49
|
+
|
|
50
|
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
51
|
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
52
|
+
void seq_keep(llama_seq_id seq_id) override;
|
|
53
|
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
|
54
|
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
55
|
+
|
|
56
|
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
57
|
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
58
|
+
|
|
59
|
+
// state write/load
|
|
60
|
+
|
|
61
|
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
|
62
|
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
|
63
|
+
|
|
64
|
+
//
|
|
65
|
+
// llama_kv_cache_iswa specific API
|
|
66
|
+
//
|
|
67
|
+
|
|
68
|
+
llama_kv_cache * get_base() const;
|
|
69
|
+
llama_kv_cache * get_swa () const;
|
|
70
|
+
|
|
71
|
+
private:
|
|
72
|
+
const llama_hparams & hparams;
|
|
73
|
+
|
|
74
|
+
const bool unified;
|
|
75
|
+
|
|
76
|
+
std::unique_ptr<llama_kv_cache> kv_base;
|
|
77
|
+
std::unique_ptr<llama_kv_cache> kv_swa;
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
class llama_kv_cache_iswa_context : public llama_memory_context_i {
|
|
81
|
+
public:
|
|
82
|
+
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
|
83
|
+
|
|
84
|
+
// used for errors
|
|
85
|
+
llama_kv_cache_iswa_context(llama_memory_status status);
|
|
86
|
+
|
|
87
|
+
// used to create a full-cache context
|
|
88
|
+
llama_kv_cache_iswa_context(
|
|
89
|
+
llama_kv_cache_iswa * kv);
|
|
90
|
+
|
|
91
|
+
// used to create an update context
|
|
92
|
+
llama_kv_cache_iswa_context(
|
|
93
|
+
llama_kv_cache_iswa * kv,
|
|
94
|
+
llama_context * lctx,
|
|
95
|
+
bool optimize);
|
|
96
|
+
|
|
97
|
+
// used to create a batch processing context from a batch
|
|
98
|
+
llama_kv_cache_iswa_context(
|
|
99
|
+
llama_kv_cache_iswa * kv,
|
|
100
|
+
slot_info_vec_t sinfos_base,
|
|
101
|
+
slot_info_vec_t sinfos_swa,
|
|
102
|
+
std::vector<llama_ubatch> ubatches);
|
|
103
|
+
|
|
104
|
+
virtual ~llama_kv_cache_iswa_context();
|
|
105
|
+
|
|
106
|
+
//
|
|
107
|
+
// llama_memory_context_i
|
|
108
|
+
//
|
|
109
|
+
|
|
110
|
+
bool next() override;
|
|
111
|
+
bool apply() override;
|
|
112
|
+
|
|
113
|
+
llama_memory_status get_status() const override;
|
|
114
|
+
const llama_ubatch & get_ubatch() const override;
|
|
115
|
+
|
|
116
|
+
//
|
|
117
|
+
// llama_kv_cache_iswa_context specific API
|
|
118
|
+
//
|
|
119
|
+
|
|
120
|
+
const llama_kv_cache_context * get_base() const;
|
|
121
|
+
const llama_kv_cache_context * get_swa() const;
|
|
122
|
+
|
|
123
|
+
private:
|
|
124
|
+
//llama_kv_cache_iswa * kv;
|
|
125
|
+
|
|
126
|
+
// the index of the next ubatch to process
|
|
127
|
+
size_t i_next = 0;
|
|
128
|
+
|
|
129
|
+
std::vector<llama_ubatch> ubatches;
|
|
130
|
+
|
|
131
|
+
const llama_memory_context_ptr ctx_base;
|
|
132
|
+
const llama_memory_context_ptr ctx_swa;
|
|
133
|
+
|
|
134
|
+
const llama_memory_status status;
|
|
135
|
+
};
|