@fugood/llama.node 1.1.8 → 1.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +9 -0
- package/lib/index.js +9 -2
- package/lib/index.ts +57 -30
- package/lib/version.js +2 -2
- package/lib/version.ts +2 -2
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +15 -5
- package/src/LlamaCompletionWorker.cpp +12 -3
- package/src/LlamaCompletionWorker.h +3 -1
- package/src/LlamaContext.cpp +14 -1
- package/src/llama.cpp/common/arg.cpp +6 -4
- package/src/llama.cpp/common/chat.cpp +34 -3
- package/src/llama.cpp/common/common.cpp +0 -15
- package/src/llama.cpp/common/common.h +1 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +0 -1
- package/src/llama.cpp/ggml/include/ggml.h +25 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +316 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +142 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
- package/src/llama.cpp/include/llama.h +1 -110
- package/src/llama.cpp/src/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/llama-arch.cpp +19 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +13 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +5 -192
- package/src/llama.cpp/src/llama-context.h +2 -7
- package/src/llama.cpp/src/llama-cparams.h +0 -1
- package/src/llama.cpp/src/llama-graph.cpp +35 -57
- package/src/llama.cpp/src/llama-graph.h +36 -46
- package/src/llama.cpp/src/llama-hparams.cpp +25 -0
- package/src/llama.cpp/src/llama-hparams.h +6 -0
- package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +69 -52
- package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +28 -26
- package/src/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +123 -474
- package/src/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +34 -59
- package/src/llama.cpp/src/llama-kv-cells.h +21 -21
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +34 -33
- package/src/llama.cpp/src/llama-memory-hybrid.h +24 -28
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +7 -7
- package/src/llama.cpp/src/llama-memory-recurrent.h +8 -12
- package/src/llama.cpp/src/llama-memory.h +11 -8
- package/src/llama.cpp/src/llama-model.cpp +396 -187
- package/src/llama.cpp/src/llama-model.h +1 -0
|
@@ -19,8 +19,8 @@ struct llama_cparams;
|
|
|
19
19
|
|
|
20
20
|
struct llama_memory_context_i;
|
|
21
21
|
|
|
22
|
-
class
|
|
23
|
-
class
|
|
22
|
+
class llama_kv_cache_context;
|
|
23
|
+
class llama_kv_cache_iswa_context;
|
|
24
24
|
class llama_memory_recurrent_context;
|
|
25
25
|
class llama_memory_hybrid_context;
|
|
26
26
|
|
|
@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|
|
152
152
|
public:
|
|
153
153
|
llm_graph_input_pos_bucket_kv(
|
|
154
154
|
const llama_hparams & hparams,
|
|
155
|
-
const
|
|
155
|
+
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
|
|
156
156
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
|
157
157
|
|
|
158
158
|
void set_input(const llama_ubatch * ubatch) override;
|
|
@@ -161,7 +161,7 @@ public:
|
|
|
161
161
|
|
|
162
162
|
const llama_hparams hparams;
|
|
163
163
|
|
|
164
|
-
const
|
|
164
|
+
const llama_kv_cache_context * mctx;
|
|
165
165
|
};
|
|
166
166
|
|
|
167
167
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
@@ -257,17 +257,17 @@ public:
|
|
|
257
257
|
const llama_cparams cparams;
|
|
258
258
|
};
|
|
259
259
|
|
|
260
|
-
class
|
|
260
|
+
class llm_graph_input_attn_kv : public llm_graph_input_i {
|
|
261
261
|
public:
|
|
262
|
-
|
|
262
|
+
llm_graph_input_attn_kv(
|
|
263
263
|
const llama_hparams & hparams,
|
|
264
264
|
const llama_cparams & cparams,
|
|
265
|
-
const
|
|
265
|
+
const llama_kv_cache_context * mctx) :
|
|
266
266
|
hparams(hparams),
|
|
267
267
|
cparams(cparams),
|
|
268
268
|
mctx(mctx) {
|
|
269
269
|
}
|
|
270
|
-
~
|
|
270
|
+
~llm_graph_input_attn_kv() = default;
|
|
271
271
|
|
|
272
272
|
void set_input(const llama_ubatch * ubatch) override;
|
|
273
273
|
|
|
@@ -290,20 +290,20 @@ public:
|
|
|
290
290
|
const llama_hparams hparams;
|
|
291
291
|
const llama_cparams cparams;
|
|
292
292
|
|
|
293
|
-
const
|
|
293
|
+
const llama_kv_cache_context * mctx;
|
|
294
294
|
};
|
|
295
295
|
|
|
296
|
-
class
|
|
296
|
+
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
|
|
297
297
|
public:
|
|
298
|
-
|
|
298
|
+
llm_graph_input_attn_kv_iswa(
|
|
299
299
|
const llama_hparams & hparams,
|
|
300
300
|
const llama_cparams & cparams,
|
|
301
|
-
const
|
|
301
|
+
const llama_kv_cache_iswa_context * mctx) :
|
|
302
302
|
hparams(hparams),
|
|
303
303
|
cparams(cparams),
|
|
304
304
|
mctx(mctx) {
|
|
305
305
|
}
|
|
306
|
-
~
|
|
306
|
+
~llm_graph_input_attn_kv_iswa() = default;
|
|
307
307
|
|
|
308
308
|
void set_input(const llama_ubatch * ubatch) override;
|
|
309
309
|
|
|
@@ -330,7 +330,7 @@ public:
|
|
|
330
330
|
const llama_hparams hparams;
|
|
331
331
|
const llama_cparams cparams;
|
|
332
332
|
|
|
333
|
-
const
|
|
333
|
+
const llama_kv_cache_iswa_context * mctx;
|
|
334
334
|
};
|
|
335
335
|
|
|
336
336
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -351,7 +351,7 @@ public:
|
|
|
351
351
|
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
|
352
352
|
public:
|
|
353
353
|
llm_graph_input_mem_hybrid(
|
|
354
|
-
std::unique_ptr<
|
|
354
|
+
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
|
355
355
|
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
|
356
356
|
const llama_memory_hybrid_context * mctx) :
|
|
357
357
|
inp_attn(std::move(inp_attn)),
|
|
@@ -361,11 +361,11 @@ public:
|
|
|
361
361
|
|
|
362
362
|
void set_input(const llama_ubatch * ubatch) override;
|
|
363
363
|
|
|
364
|
-
std::unique_ptr<
|
|
365
|
-
std::unique_ptr<llm_graph_input_rs>
|
|
364
|
+
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
|
365
|
+
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
|
366
366
|
|
|
367
|
-
|
|
368
|
-
llm_graph_input_rs
|
|
367
|
+
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
|
368
|
+
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
|
369
369
|
|
|
370
370
|
const llama_memory_hybrid_context * mctx;
|
|
371
371
|
};
|
|
@@ -680,14 +680,14 @@ struct llm_graph_context {
|
|
|
680
680
|
//
|
|
681
681
|
|
|
682
682
|
ggml_tensor * build_attn_mha(
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
683
|
+
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
|
684
|
+
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
|
685
|
+
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
|
686
|
+
ggml_tensor * kq_b,
|
|
687
|
+
ggml_tensor * kq_mask,
|
|
688
|
+
ggml_tensor * sinks, // [n_head_q]
|
|
689
|
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
690
|
+
float kq_scale) const;
|
|
691
691
|
|
|
692
692
|
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
|
693
693
|
|
|
@@ -699,50 +699,39 @@ struct llm_graph_context {
|
|
|
699
699
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
700
700
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
701
701
|
ggml_tensor * kq_b,
|
|
702
|
+
ggml_tensor * sinks, // [n_head_q]
|
|
702
703
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
703
704
|
float kq_scale,
|
|
704
705
|
int il) const;
|
|
705
706
|
|
|
706
|
-
|
|
707
|
+
llm_graph_input_attn_kv * build_attn_inp_kv() const;
|
|
707
708
|
|
|
708
709
|
ggml_tensor * build_attn(
|
|
709
|
-
|
|
710
|
+
llm_graph_input_attn_kv * inp,
|
|
710
711
|
ggml_tensor * wo,
|
|
711
712
|
ggml_tensor * wo_b,
|
|
712
713
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
713
714
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
714
715
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
715
716
|
ggml_tensor * kq_b,
|
|
717
|
+
ggml_tensor * sinks, // [n_head_q]
|
|
716
718
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
717
719
|
float kq_scale,
|
|
718
720
|
int il) const;
|
|
719
721
|
|
|
720
|
-
|
|
722
|
+
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
|
|
721
723
|
|
|
722
724
|
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
|
723
725
|
ggml_tensor * build_attn(
|
|
724
|
-
|
|
725
|
-
ggml_tensor * wo,
|
|
726
|
-
ggml_tensor * wo_b,
|
|
727
|
-
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
728
|
-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
|
729
|
-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
|
730
|
-
ggml_tensor * kq_b,
|
|
731
|
-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
732
|
-
float kq_scale,
|
|
733
|
-
int il) const;
|
|
734
|
-
|
|
735
|
-
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
|
|
736
|
-
ggml_tensor * build_attn_with_sinks(
|
|
737
|
-
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
726
|
+
llm_graph_input_attn_kv_iswa * inp,
|
|
738
727
|
ggml_tensor * wo,
|
|
739
728
|
ggml_tensor * wo_b,
|
|
740
729
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
741
730
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
|
742
731
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
|
743
732
|
ggml_tensor * kq_b,
|
|
744
|
-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
745
733
|
ggml_tensor * sinks, // [n_head_q]
|
|
734
|
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
746
735
|
float kq_scale,
|
|
747
736
|
int il) const;
|
|
748
737
|
|
|
@@ -756,6 +745,7 @@ struct llm_graph_context {
|
|
|
756
745
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
757
746
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
758
747
|
ggml_tensor * kq_b,
|
|
748
|
+
ggml_tensor * sinks, // [n_head_q]
|
|
759
749
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
760
750
|
float kq_scale,
|
|
761
751
|
int il) const;
|
|
@@ -765,7 +755,7 @@ struct llm_graph_context {
|
|
|
765
755
|
//
|
|
766
756
|
|
|
767
757
|
// TODO: move this implementation to llama_memory_recurrent.
|
|
768
|
-
// this is analogous to
|
|
758
|
+
// this is analogous to llama_kv_cache::cpy_k / cpy_v
|
|
769
759
|
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
|
770
760
|
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
|
771
761
|
// `llama_memory_recurrent`
|
|
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
|
|
|
153
153
|
|
|
154
154
|
GGML_ABORT("fatal error");
|
|
155
155
|
}
|
|
156
|
+
|
|
157
|
+
bool llama_hparams::has_kv(uint32_t il) const {
|
|
158
|
+
if (n_layer_kv_from_start >= 0) {
|
|
159
|
+
if (il < (uint32_t) n_layer_kv_from_start) {
|
|
160
|
+
return true;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
return false;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
// by default, all layers have kv
|
|
167
|
+
return true;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
uint32_t llama_hparams::n_layer_kv() const {
|
|
171
|
+
uint32_t res = 0;
|
|
172
|
+
|
|
173
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
174
|
+
if (has_kv(il)) {
|
|
175
|
+
res++;
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
return res;
|
|
180
|
+
}
|
|
@@ -41,6 +41,7 @@ struct llama_hparams {
|
|
|
41
41
|
uint32_t n_embd;
|
|
42
42
|
uint32_t n_embd_features = 0;
|
|
43
43
|
uint32_t n_layer;
|
|
44
|
+
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
|
|
44
45
|
uint32_t n_rot;
|
|
45
46
|
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
|
46
47
|
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
|
@@ -221,6 +222,11 @@ struct llama_hparams {
|
|
|
221
222
|
uint32_t n_pos_per_embd() const;
|
|
222
223
|
|
|
223
224
|
bool is_swa(uint32_t il) const;
|
|
225
|
+
|
|
226
|
+
bool has_kv(uint32_t il) const;
|
|
227
|
+
|
|
228
|
+
// number of layers for which has_kv() returns true
|
|
229
|
+
uint32_t n_layer_kv() const;
|
|
224
230
|
};
|
|
225
231
|
|
|
226
232
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
#include "llama-kv-cache-
|
|
1
|
+
#include "llama-kv-cache-iswa.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
4
|
#include "llama-batch.h"
|
|
@@ -8,10 +8,10 @@
|
|
|
8
8
|
#include <cassert>
|
|
9
9
|
|
|
10
10
|
//
|
|
11
|
-
//
|
|
11
|
+
// llama_kv_cache_iswa
|
|
12
12
|
//
|
|
13
13
|
|
|
14
|
-
|
|
14
|
+
llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|
15
15
|
const llama_model & model,
|
|
16
16
|
ggml_type type_k,
|
|
17
17
|
ggml_type type_v,
|
|
@@ -22,9 +22,26 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
|
22
22
|
uint32_t kv_size,
|
|
23
23
|
uint32_t n_seq_max,
|
|
24
24
|
uint32_t n_ubatch,
|
|
25
|
-
uint32_t n_pad
|
|
26
|
-
|
|
27
|
-
|
|
25
|
+
uint32_t n_pad,
|
|
26
|
+
const layer_filter_cb & filter,
|
|
27
|
+
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
|
|
28
|
+
|
|
29
|
+
// chain filters
|
|
30
|
+
const layer_filter_cb filter_base = [&](int32_t il) {
|
|
31
|
+
if (filter && !filter(il)) {
|
|
32
|
+
return false;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
return !model.hparams.is_swa(il);
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
const layer_filter_cb filter_swa = [&](int32_t il) {
|
|
39
|
+
if (filter && !filter(il)) {
|
|
40
|
+
return false;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return model.hparams.is_swa(il);
|
|
44
|
+
};
|
|
28
45
|
|
|
29
46
|
const uint32_t size_base = kv_size;
|
|
30
47
|
|
|
@@ -40,25 +57,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
|
40
57
|
|
|
41
58
|
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
|
42
59
|
|
|
43
|
-
kv_base = std::make_unique<
|
|
44
|
-
model,
|
|
60
|
+
kv_base = std::make_unique<llama_kv_cache>(
|
|
61
|
+
model, type_k, type_v,
|
|
45
62
|
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
|
46
|
-
0, LLAMA_SWA_TYPE_NONE);
|
|
63
|
+
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
|
|
47
64
|
|
|
48
65
|
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
49
66
|
|
|
50
|
-
kv_swa = std::make_unique<
|
|
51
|
-
model,
|
|
67
|
+
kv_swa = std::make_unique<llama_kv_cache>(
|
|
68
|
+
model, type_k, type_v,
|
|
52
69
|
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
|
53
|
-
hparams.n_swa, hparams.swa_type);
|
|
70
|
+
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
|
|
54
71
|
}
|
|
55
72
|
|
|
56
|
-
void
|
|
73
|
+
void llama_kv_cache_iswa::clear(bool data) {
|
|
57
74
|
kv_base->clear(data);
|
|
58
75
|
kv_swa ->clear(data);
|
|
59
76
|
}
|
|
60
77
|
|
|
61
|
-
bool
|
|
78
|
+
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
62
79
|
bool res = true;
|
|
63
80
|
|
|
64
81
|
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
|
@@ -67,36 +84,36 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
|
|
|
67
84
|
return res;
|
|
68
85
|
}
|
|
69
86
|
|
|
70
|
-
void
|
|
87
|
+
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
71
88
|
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
72
89
|
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
73
90
|
}
|
|
74
91
|
|
|
75
|
-
void
|
|
92
|
+
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
|
|
76
93
|
kv_base->seq_keep(seq_id);
|
|
77
94
|
kv_swa ->seq_keep(seq_id);
|
|
78
95
|
}
|
|
79
96
|
|
|
80
|
-
void
|
|
97
|
+
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
81
98
|
kv_base->seq_add(seq_id, p0, p1, shift);
|
|
82
99
|
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
|
83
100
|
}
|
|
84
101
|
|
|
85
|
-
void
|
|
102
|
+
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
86
103
|
kv_base->seq_div(seq_id, p0, p1, d);
|
|
87
104
|
kv_swa ->seq_div(seq_id, p0, p1, d);
|
|
88
105
|
}
|
|
89
106
|
|
|
90
|
-
llama_pos
|
|
107
|
+
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
|
91
108
|
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
|
92
109
|
return kv_swa->seq_pos_min(seq_id);
|
|
93
110
|
}
|
|
94
111
|
|
|
95
|
-
llama_pos
|
|
112
|
+
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
96
113
|
return kv_swa->seq_pos_max(seq_id);
|
|
97
114
|
}
|
|
98
115
|
|
|
99
|
-
llama_memory_context_ptr
|
|
116
|
+
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
100
117
|
GGML_UNUSED(embd_all);
|
|
101
118
|
|
|
102
119
|
// first try simple split
|
|
@@ -136,7 +153,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
136
153
|
|
|
137
154
|
assert(sinfos_base.size() == sinfos_swa.size());
|
|
138
155
|
|
|
139
|
-
return std::make_unique<
|
|
156
|
+
return std::make_unique<llama_kv_cache_iswa_context>(
|
|
140
157
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
141
158
|
} while (false);
|
|
142
159
|
|
|
@@ -172,29 +189,29 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
172
189
|
|
|
173
190
|
assert(sinfos_base.size() == sinfos_swa.size());
|
|
174
191
|
|
|
175
|
-
return std::make_unique<
|
|
192
|
+
return std::make_unique<llama_kv_cache_iswa_context>(
|
|
176
193
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
177
194
|
} while (false);
|
|
178
195
|
|
|
179
196
|
// TODO: if we fail again, we should attempt different splitting strategies
|
|
180
197
|
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
181
198
|
|
|
182
|
-
return std::make_unique<
|
|
199
|
+
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
183
200
|
}
|
|
184
201
|
|
|
185
|
-
llama_memory_context_ptr
|
|
186
|
-
return std::make_unique<
|
|
202
|
+
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
|
|
203
|
+
return std::make_unique<llama_kv_cache_iswa_context>(this);
|
|
187
204
|
}
|
|
188
205
|
|
|
189
|
-
llama_memory_context_ptr
|
|
190
|
-
return std::make_unique<
|
|
206
|
+
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
|
|
207
|
+
return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
|
|
191
208
|
}
|
|
192
209
|
|
|
193
|
-
bool
|
|
210
|
+
bool llama_kv_cache_iswa::get_can_shift() const {
|
|
194
211
|
return kv_base->get_size() == kv_swa->get_size();
|
|
195
212
|
}
|
|
196
213
|
|
|
197
|
-
void
|
|
214
|
+
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
|
198
215
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
|
199
216
|
kv_base->state_write(io, seq_id, flags);
|
|
200
217
|
}
|
|
@@ -202,7 +219,7 @@ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_i
|
|
|
202
219
|
kv_swa->state_write(io, seq_id, flags);
|
|
203
220
|
}
|
|
204
221
|
|
|
205
|
-
void
|
|
222
|
+
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
206
223
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
|
207
224
|
kv_base->state_read(io, seq_id, flags);
|
|
208
225
|
}
|
|
@@ -210,29 +227,29 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id
|
|
|
210
227
|
kv_swa->state_read(io, seq_id, flags);
|
|
211
228
|
}
|
|
212
229
|
|
|
213
|
-
|
|
230
|
+
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
|
|
214
231
|
return kv_base.get();
|
|
215
232
|
}
|
|
216
233
|
|
|
217
|
-
|
|
234
|
+
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
|
|
218
235
|
return kv_swa.get();
|
|
219
236
|
}
|
|
220
237
|
|
|
221
238
|
//
|
|
222
|
-
//
|
|
239
|
+
// llama_kv_cache_iswa_context
|
|
223
240
|
//
|
|
224
241
|
|
|
225
|
-
|
|
242
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
|
|
226
243
|
|
|
227
|
-
|
|
228
|
-
|
|
244
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
245
|
+
llama_kv_cache_iswa * kv) :
|
|
229
246
|
ctx_base(kv->get_base()->init_full()),
|
|
230
247
|
ctx_swa (kv->get_swa ()->init_full()),
|
|
231
248
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
232
249
|
}
|
|
233
250
|
|
|
234
|
-
|
|
235
|
-
|
|
251
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
252
|
+
llama_kv_cache_iswa * kv,
|
|
236
253
|
llama_context * lctx,
|
|
237
254
|
bool optimize) :
|
|
238
255
|
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
|
@@ -240,21 +257,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
|
240
257
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
241
258
|
}
|
|
242
259
|
|
|
243
|
-
|
|
244
|
-
|
|
260
|
+
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
261
|
+
llama_kv_cache_iswa * kv,
|
|
245
262
|
slot_info_vec_t sinfos_base,
|
|
246
263
|
slot_info_vec_t sinfos_swa,
|
|
247
264
|
std::vector<llama_ubatch> ubatches) :
|
|
248
265
|
ubatches(std::move(ubatches)),
|
|
249
266
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
250
|
-
ctx_base(new
|
|
251
|
-
ctx_swa (new
|
|
267
|
+
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
|
268
|
+
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
|
252
269
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
253
270
|
}
|
|
254
271
|
|
|
255
|
-
|
|
272
|
+
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
|
|
256
273
|
|
|
257
|
-
bool
|
|
274
|
+
bool llama_kv_cache_iswa_context::next() {
|
|
258
275
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
259
276
|
|
|
260
277
|
ctx_base->next();
|
|
@@ -267,7 +284,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
|
|
|
267
284
|
return true;
|
|
268
285
|
}
|
|
269
286
|
|
|
270
|
-
bool
|
|
287
|
+
bool llama_kv_cache_iswa_context::apply() {
|
|
271
288
|
assert(!llama_memory_status_is_fail(status));
|
|
272
289
|
|
|
273
290
|
bool res = true;
|
|
@@ -278,24 +295,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
|
|
|
278
295
|
return res;
|
|
279
296
|
}
|
|
280
297
|
|
|
281
|
-
llama_memory_status
|
|
298
|
+
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
|
|
282
299
|
return status;
|
|
283
300
|
}
|
|
284
301
|
|
|
285
|
-
const llama_ubatch &
|
|
302
|
+
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
|
|
286
303
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
287
304
|
|
|
288
305
|
return ubatches[i_next];
|
|
289
306
|
}
|
|
290
307
|
|
|
291
|
-
const
|
|
308
|
+
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
|
|
292
309
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
293
310
|
|
|
294
|
-
return static_cast<const
|
|
311
|
+
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
|
|
295
312
|
}
|
|
296
313
|
|
|
297
|
-
const
|
|
314
|
+
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
|
|
298
315
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
299
316
|
|
|
300
|
-
return static_cast<const
|
|
317
|
+
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
|
|
301
318
|
}
|
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
|
-
#include "llama-kv-cache
|
|
3
|
+
#include "llama-kv-cache.h"
|
|
4
4
|
|
|
5
5
|
#include <vector>
|
|
6
6
|
|
|
7
7
|
//
|
|
8
|
-
//
|
|
8
|
+
// llama_kv_cache_iswa
|
|
9
9
|
//
|
|
10
10
|
|
|
11
|
-
// utilizes two instances of
|
|
11
|
+
// utilizes two instances of llama_kv_cache
|
|
12
12
|
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
|
13
13
|
|
|
14
|
-
class
|
|
14
|
+
class llama_kv_cache_iswa : public llama_memory_i {
|
|
15
15
|
public:
|
|
16
|
-
|
|
16
|
+
llama_kv_cache_iswa(
|
|
17
17
|
const llama_model & model,
|
|
18
18
|
ggml_type type_k,
|
|
19
19
|
ggml_type type_v,
|
|
@@ -24,9 +24,11 @@ public:
|
|
|
24
24
|
uint32_t kv_size,
|
|
25
25
|
uint32_t n_seq_max,
|
|
26
26
|
uint32_t n_ubatch,
|
|
27
|
-
uint32_t n_pad
|
|
27
|
+
uint32_t n_pad,
|
|
28
|
+
const layer_filter_cb & filter,
|
|
29
|
+
const layer_reuse_cb & reuse);
|
|
28
30
|
|
|
29
|
-
~
|
|
31
|
+
~llama_kv_cache_iswa() = default;
|
|
30
32
|
|
|
31
33
|
//
|
|
32
34
|
// llama_memory_i
|
|
@@ -60,46 +62,46 @@ public:
|
|
|
60
62
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
|
61
63
|
|
|
62
64
|
//
|
|
63
|
-
//
|
|
65
|
+
// llama_kv_cache_iswa specific API
|
|
64
66
|
//
|
|
65
67
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
+
llama_kv_cache * get_base() const;
|
|
69
|
+
llama_kv_cache * get_swa () const;
|
|
68
70
|
|
|
69
71
|
private:
|
|
70
72
|
const llama_hparams & hparams;
|
|
71
73
|
|
|
72
74
|
const bool unified;
|
|
73
75
|
|
|
74
|
-
std::unique_ptr<
|
|
75
|
-
std::unique_ptr<
|
|
76
|
+
std::unique_ptr<llama_kv_cache> kv_base;
|
|
77
|
+
std::unique_ptr<llama_kv_cache> kv_swa;
|
|
76
78
|
};
|
|
77
79
|
|
|
78
|
-
class
|
|
80
|
+
class llama_kv_cache_iswa_context : public llama_memory_context_i {
|
|
79
81
|
public:
|
|
80
|
-
using slot_info_vec_t =
|
|
82
|
+
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
|
81
83
|
|
|
82
84
|
// used for errors
|
|
83
|
-
|
|
85
|
+
llama_kv_cache_iswa_context(llama_memory_status status);
|
|
84
86
|
|
|
85
87
|
// used to create a full-cache context
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
llama_kv_cache_iswa_context(
|
|
89
|
+
llama_kv_cache_iswa * kv);
|
|
88
90
|
|
|
89
91
|
// used to create an update context
|
|
90
|
-
|
|
91
|
-
|
|
92
|
+
llama_kv_cache_iswa_context(
|
|
93
|
+
llama_kv_cache_iswa * kv,
|
|
92
94
|
llama_context * lctx,
|
|
93
95
|
bool optimize);
|
|
94
96
|
|
|
95
97
|
// used to create a batch processing context from a batch
|
|
96
|
-
|
|
97
|
-
|
|
98
|
+
llama_kv_cache_iswa_context(
|
|
99
|
+
llama_kv_cache_iswa * kv,
|
|
98
100
|
slot_info_vec_t sinfos_base,
|
|
99
101
|
slot_info_vec_t sinfos_swa,
|
|
100
102
|
std::vector<llama_ubatch> ubatches);
|
|
101
103
|
|
|
102
|
-
virtual ~
|
|
104
|
+
virtual ~llama_kv_cache_iswa_context();
|
|
103
105
|
|
|
104
106
|
//
|
|
105
107
|
// llama_memory_context_i
|
|
@@ -112,14 +114,14 @@ public:
|
|
|
112
114
|
const llama_ubatch & get_ubatch() const override;
|
|
113
115
|
|
|
114
116
|
//
|
|
115
|
-
//
|
|
117
|
+
// llama_kv_cache_iswa_context specific API
|
|
116
118
|
//
|
|
117
119
|
|
|
118
|
-
const
|
|
119
|
-
const
|
|
120
|
+
const llama_kv_cache_context * get_base() const;
|
|
121
|
+
const llama_kv_cache_context * get_swa() const;
|
|
120
122
|
|
|
121
123
|
private:
|
|
122
|
-
//
|
|
124
|
+
//llama_kv_cache_iswa * kv;
|
|
123
125
|
|
|
124
126
|
// the index of the next ubatch to process
|
|
125
127
|
size_t i_next = 0;
|