@fugood/llama.node 1.1.8 → 1.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +9 -0
- package/lib/index.js +9 -2
- package/lib/index.ts +57 -30
- package/lib/version.js +2 -2
- package/lib/version.ts +2 -2
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +15 -5
- package/src/LlamaCompletionWorker.cpp +12 -3
- package/src/LlamaCompletionWorker.h +3 -1
- package/src/LlamaContext.cpp +14 -1
- package/src/llama.cpp/common/arg.cpp +6 -4
- package/src/llama.cpp/common/chat.cpp +34 -3
- package/src/llama.cpp/common/common.cpp +0 -15
- package/src/llama.cpp/common/common.h +1 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +0 -1
- package/src/llama.cpp/ggml/include/ggml.h +25 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +316 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +142 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
- package/src/llama.cpp/include/llama.h +1 -110
- package/src/llama.cpp/src/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/llama-arch.cpp +19 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +13 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +5 -192
- package/src/llama.cpp/src/llama-context.h +2 -7
- package/src/llama.cpp/src/llama-cparams.h +0 -1
- package/src/llama.cpp/src/llama-graph.cpp +35 -57
- package/src/llama.cpp/src/llama-graph.h +36 -46
- package/src/llama.cpp/src/llama-hparams.cpp +25 -0
- package/src/llama.cpp/src/llama-hparams.h +6 -0
- package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +69 -52
- package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +28 -26
- package/src/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +123 -474
- package/src/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +34 -59
- package/src/llama.cpp/src/llama-kv-cells.h +21 -21
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +34 -33
- package/src/llama.cpp/src/llama-memory-hybrid.h +24 -28
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +7 -7
- package/src/llama.cpp/src/llama-memory-recurrent.h +8 -12
- package/src/llama.cpp/src/llama-memory.h +11 -8
- package/src/llama.cpp/src/llama-model.cpp +396 -187
- package/src/llama.cpp/src/llama-model.h +1 -0
|
@@ -14,27 +14,13 @@ struct llama_model;
|
|
|
14
14
|
struct llama_context;
|
|
15
15
|
|
|
16
16
|
//
|
|
17
|
-
//
|
|
17
|
+
// llama_kv_cache
|
|
18
18
|
//
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class llama_kv_cache : public llama_memory_i {
|
|
21
21
|
public:
|
|
22
22
|
static uint32_t get_padding(const llama_cparams & cparams);
|
|
23
23
|
|
|
24
|
-
// this callback is used to filter out layers that should not be included in the cache
|
|
25
|
-
using layer_filter_cb = std::function<bool(int32_t il)>;
|
|
26
|
-
|
|
27
|
-
struct defrag_info {
|
|
28
|
-
bool empty() const {
|
|
29
|
-
return ids.empty();
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
// contains information about which cell moves where:
|
|
33
|
-
// - cell i moves to ids[i]
|
|
34
|
-
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
|
35
|
-
std::vector<uint32_t> ids;
|
|
36
|
-
};
|
|
37
|
-
|
|
38
24
|
struct stream_copy_info {
|
|
39
25
|
bool empty() const {
|
|
40
26
|
assert(ssrc.size() == sdst.size());
|
|
@@ -92,21 +78,22 @@ public:
|
|
|
92
78
|
|
|
93
79
|
using slot_info_vec_t = std::vector<slot_info>;
|
|
94
80
|
|
|
95
|
-
|
|
96
|
-
const llama_model &
|
|
97
|
-
|
|
98
|
-
ggml_type
|
|
99
|
-
|
|
100
|
-
bool
|
|
101
|
-
bool
|
|
102
|
-
|
|
103
|
-
uint32_t
|
|
104
|
-
uint32_t
|
|
105
|
-
uint32_t
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
81
|
+
llama_kv_cache(
|
|
82
|
+
const llama_model & model,
|
|
83
|
+
ggml_type type_k,
|
|
84
|
+
ggml_type type_v,
|
|
85
|
+
bool v_trans,
|
|
86
|
+
bool offload,
|
|
87
|
+
bool unified,
|
|
88
|
+
uint32_t kv_size,
|
|
89
|
+
uint32_t n_seq_max,
|
|
90
|
+
uint32_t n_pad,
|
|
91
|
+
uint32_t n_swa,
|
|
92
|
+
llama_swa_type swa_type,
|
|
93
|
+
const layer_filter_cb & filter,
|
|
94
|
+
const layer_reuse_cb & reuse);
|
|
95
|
+
|
|
96
|
+
~llama_kv_cache() = default;
|
|
110
97
|
|
|
111
98
|
//
|
|
112
99
|
// llama_memory_i
|
|
@@ -140,7 +127,7 @@ public:
|
|
|
140
127
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
|
141
128
|
|
|
142
129
|
//
|
|
143
|
-
//
|
|
130
|
+
// llama_kv_cache specific API
|
|
144
131
|
//
|
|
145
132
|
|
|
146
133
|
uint32_t get_size() const;
|
|
@@ -173,7 +160,7 @@ public:
|
|
|
173
160
|
// return empty vector on failure
|
|
174
161
|
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
|
175
162
|
|
|
176
|
-
bool update(llama_context * lctx, bool do_shift, const
|
|
163
|
+
bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
|
|
177
164
|
|
|
178
165
|
// find a slot of kv cells that can hold the ubatch
|
|
179
166
|
// if cont == true, then the slot must be continuous
|
|
@@ -241,7 +228,7 @@ private:
|
|
|
241
228
|
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
|
242
229
|
std::vector<uint32_t> v_heads;
|
|
243
230
|
|
|
244
|
-
std::vector<
|
|
231
|
+
std::vector<llama_kv_cells> v_cells;
|
|
245
232
|
|
|
246
233
|
// maps from a sequence id to a stream id
|
|
247
234
|
std::vector<uint32_t> seq_to_stream;
|
|
@@ -254,9 +241,6 @@ private:
|
|
|
254
241
|
// model layer id -> KV cache layer id
|
|
255
242
|
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
|
256
243
|
|
|
257
|
-
// return non-empty vector if cells have been moved
|
|
258
|
-
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
|
259
|
-
|
|
260
244
|
size_t total_size() const;
|
|
261
245
|
|
|
262
246
|
size_t size_k_bytes() const;
|
|
@@ -277,11 +261,6 @@ private:
|
|
|
277
261
|
llm_graph_result * res,
|
|
278
262
|
llama_context * lctx) const;
|
|
279
263
|
|
|
280
|
-
ggml_cgraph * build_graph_defrag(
|
|
281
|
-
llm_graph_result * res,
|
|
282
|
-
llama_context * lctx,
|
|
283
|
-
const defrag_info & dinfo) const;
|
|
284
|
-
|
|
285
264
|
struct cell_ranges_t {
|
|
286
265
|
uint32_t strm;
|
|
287
266
|
|
|
@@ -295,35 +274,33 @@ private:
|
|
|
295
274
|
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
|
296
275
|
};
|
|
297
276
|
|
|
298
|
-
class
|
|
277
|
+
class llama_kv_cache_context : public llama_memory_context_i {
|
|
299
278
|
public:
|
|
300
279
|
// some shorthands
|
|
301
|
-
using slot_info_vec_t =
|
|
302
|
-
using
|
|
303
|
-
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
|
|
280
|
+
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
|
281
|
+
using stream_copy_info = llama_kv_cache::stream_copy_info;
|
|
304
282
|
|
|
305
283
|
// used for errors
|
|
306
|
-
|
|
284
|
+
llama_kv_cache_context(llama_memory_status status);
|
|
307
285
|
|
|
308
286
|
// used to create a full-cache context
|
|
309
|
-
|
|
310
|
-
|
|
287
|
+
llama_kv_cache_context(
|
|
288
|
+
llama_kv_cache * kv);
|
|
311
289
|
|
|
312
290
|
// used to create an update context
|
|
313
|
-
|
|
314
|
-
|
|
291
|
+
llama_kv_cache_context(
|
|
292
|
+
llama_kv_cache * kv,
|
|
315
293
|
llama_context * lctx,
|
|
316
294
|
bool do_shift,
|
|
317
|
-
defrag_info dinfo,
|
|
318
295
|
stream_copy_info sc_info);
|
|
319
296
|
|
|
320
297
|
// used to create a batch procesing context from a batch
|
|
321
|
-
|
|
322
|
-
|
|
298
|
+
llama_kv_cache_context(
|
|
299
|
+
llama_kv_cache * kv,
|
|
323
300
|
slot_info_vec_t sinfos,
|
|
324
301
|
std::vector<llama_ubatch> ubatches);
|
|
325
302
|
|
|
326
|
-
virtual ~
|
|
303
|
+
virtual ~llama_kv_cache_context();
|
|
327
304
|
|
|
328
305
|
//
|
|
329
306
|
// llama_memory_context_i
|
|
@@ -336,7 +313,7 @@ public:
|
|
|
336
313
|
const llama_ubatch & get_ubatch() const override;
|
|
337
314
|
|
|
338
315
|
//
|
|
339
|
-
//
|
|
316
|
+
// llama_kv_cache_context specific API
|
|
340
317
|
//
|
|
341
318
|
|
|
342
319
|
uint32_t get_n_kv() const;
|
|
@@ -365,7 +342,7 @@ public:
|
|
|
365
342
|
private:
|
|
366
343
|
llama_memory_status status;
|
|
367
344
|
|
|
368
|
-
|
|
345
|
+
llama_kv_cache * kv;
|
|
369
346
|
llama_context * lctx;
|
|
370
347
|
|
|
371
348
|
//
|
|
@@ -374,8 +351,6 @@ private:
|
|
|
374
351
|
|
|
375
352
|
bool do_shift = false;
|
|
376
353
|
|
|
377
|
-
defrag_info dinfo;
|
|
378
|
-
|
|
379
354
|
stream_copy_info sc_info;
|
|
380
355
|
|
|
381
356
|
//
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
|
|
12
12
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
|
13
13
|
// TODO: add unit tests
|
|
14
|
-
class
|
|
14
|
+
class llama_kv_cells {
|
|
15
15
|
public:
|
|
16
16
|
void reset() {
|
|
17
17
|
for (uint32_t i = 0; i < pos.size(); ++i) {
|
|
@@ -77,30 +77,30 @@ public:
|
|
|
77
77
|
}
|
|
78
78
|
|
|
79
79
|
// move cell isrc to idst (used during defrag)
|
|
80
|
-
void mv(uint32_t isrc, uint32_t idst) {
|
|
81
|
-
|
|
82
|
-
|
|
80
|
+
//void mv(uint32_t isrc, uint32_t idst) {
|
|
81
|
+
// assert(isrc < pos.size());
|
|
82
|
+
// assert(idst < pos.size());
|
|
83
83
|
|
|
84
|
-
|
|
85
|
-
|
|
84
|
+
// assert(pos[idst] == -1);
|
|
85
|
+
// assert(pos[isrc] != -1);
|
|
86
86
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
// pos [idst] = pos [isrc];
|
|
88
|
+
// shift[idst] = shift[isrc];
|
|
89
|
+
// seq [idst] = seq [isrc];
|
|
90
90
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
91
|
+
// pos [isrc] = -1;
|
|
92
|
+
// shift[isrc] = 0;
|
|
93
|
+
// seq [isrc].reset();
|
|
94
94
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
}
|
|
95
|
+
// used.erase (isrc);
|
|
96
|
+
// used.insert(idst);
|
|
97
|
+
//}
|
|
98
98
|
|
|
99
99
|
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
|
100
|
-
|
|
100
|
+
llama_kv_cells cp(uint32_t i, uint32_t n) const {
|
|
101
101
|
assert(i + n <= pos.size());
|
|
102
102
|
|
|
103
|
-
|
|
103
|
+
llama_kv_cells res;
|
|
104
104
|
|
|
105
105
|
res.resize(n);
|
|
106
106
|
|
|
@@ -117,8 +117,8 @@ public:
|
|
|
117
117
|
}
|
|
118
118
|
|
|
119
119
|
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
|
120
|
-
|
|
121
|
-
|
|
120
|
+
llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
|
|
121
|
+
llama_kv_cells res;
|
|
122
122
|
|
|
123
123
|
res.resize(idxs.size());
|
|
124
124
|
|
|
@@ -135,7 +135,7 @@ public:
|
|
|
135
135
|
}
|
|
136
136
|
|
|
137
137
|
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
|
138
|
-
void set(uint32_t i, const
|
|
138
|
+
void set(uint32_t i, const llama_kv_cells & other) {
|
|
139
139
|
assert(i + other.pos.size() <= pos.size());
|
|
140
140
|
|
|
141
141
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
|
@@ -165,7 +165,7 @@ public:
|
|
|
165
165
|
}
|
|
166
166
|
|
|
167
167
|
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
|
168
|
-
void set(const std::vector<uint32_t> & idxs, const
|
|
168
|
+
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
|
|
169
169
|
assert(idxs.size() == other.pos.size());
|
|
170
170
|
|
|
171
171
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
|
@@ -9,32 +9,29 @@
|
|
|
9
9
|
//
|
|
10
10
|
|
|
11
11
|
llama_memory_hybrid::llama_memory_hybrid(
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
12
|
+
const llama_model & model,
|
|
13
|
+
/* attn */
|
|
14
|
+
ggml_type type_k,
|
|
15
|
+
ggml_type type_v,
|
|
16
|
+
bool v_trans,
|
|
17
|
+
uint32_t kv_size,
|
|
18
|
+
uint32_t n_pad,
|
|
19
|
+
uint32_t n_swa,
|
|
20
|
+
llama_swa_type swa_type,
|
|
21
|
+
/* recurrent */
|
|
22
|
+
ggml_type type_r,
|
|
23
|
+
ggml_type type_s,
|
|
24
|
+
uint32_t rs_size,
|
|
25
|
+
/* common */
|
|
26
|
+
uint32_t n_seq_max,
|
|
27
|
+
bool offload,
|
|
28
|
+
bool unified,
|
|
29
|
+
/* layer filters */
|
|
30
|
+
const layer_filter_cb & filter_attn,
|
|
31
|
+
const layer_filter_cb & filter_recr) :
|
|
32
32
|
hparams(model.hparams),
|
|
33
|
-
mem_attn(new
|
|
33
|
+
mem_attn(new llama_kv_cache(
|
|
34
34
|
model,
|
|
35
|
-
filter_attn == nullptr ?
|
|
36
|
-
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
|
37
|
-
: filter_attn,
|
|
38
35
|
type_k,
|
|
39
36
|
type_v,
|
|
40
37
|
v_trans,
|
|
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|
|
44
41
|
n_seq_max,
|
|
45
42
|
n_pad,
|
|
46
43
|
n_swa,
|
|
47
|
-
swa_type
|
|
44
|
+
swa_type,
|
|
45
|
+
filter_attn == nullptr ?
|
|
46
|
+
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
|
47
|
+
: filter_attn,
|
|
48
|
+
nullptr
|
|
48
49
|
)),
|
|
49
50
|
mem_recr(new llama_memory_recurrent(
|
|
50
51
|
model,
|
|
51
|
-
filter_recr == nullptr ?
|
|
52
|
-
[&](int32_t il) { return hparams.is_recurrent(il); }
|
|
53
|
-
: filter_recr,
|
|
54
52
|
type_r,
|
|
55
53
|
type_s,
|
|
56
54
|
offload,
|
|
57
55
|
rs_size,
|
|
58
|
-
n_seq_max
|
|
56
|
+
n_seq_max,
|
|
57
|
+
filter_recr == nullptr ?
|
|
58
|
+
[&](int32_t il) { return hparams.is_recurrent(il); }
|
|
59
|
+
: filter_recr
|
|
59
60
|
)) {}
|
|
60
61
|
|
|
61
62
|
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
@@ -179,7 +180,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
|
|
|
179
180
|
mem_recr->state_read(io, seq_id);
|
|
180
181
|
}
|
|
181
182
|
|
|
182
|
-
|
|
183
|
+
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
|
|
183
184
|
return mem_attn.get();
|
|
184
185
|
}
|
|
185
186
|
|
|
@@ -210,7 +211,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|
|
210
211
|
std::vector<llama_ubatch> ubatches) :
|
|
211
212
|
ubatches(std::move(ubatches)),
|
|
212
213
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
213
|
-
ctx_attn(new
|
|
214
|
+
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
|
214
215
|
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
|
215
216
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
216
217
|
}
|
|
@@ -248,8 +249,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
|
|
248
249
|
return ubatches[i_next];
|
|
249
250
|
}
|
|
250
251
|
|
|
251
|
-
const
|
|
252
|
-
return static_cast<const
|
|
252
|
+
const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
|
|
253
|
+
return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
|
|
253
254
|
}
|
|
254
255
|
|
|
255
256
|
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
#include "llama-batch.h"
|
|
4
4
|
#include "llama-graph.h"
|
|
5
|
-
#include "llama-kv-cache
|
|
5
|
+
#include "llama-kv-cache.h"
|
|
6
6
|
#include "llama-memory.h"
|
|
7
7
|
#include "llama-memory-recurrent.h"
|
|
8
8
|
|
|
@@ -13,36 +13,32 @@
|
|
|
13
13
|
// llama_memory_hybrid
|
|
14
14
|
//
|
|
15
15
|
|
|
16
|
-
// utilizes instances of llama_memory_recurrent and
|
|
16
|
+
// utilizes instances of llama_memory_recurrent and llama_kv_cache to
|
|
17
17
|
// support models where each layer may be either attention-based or recurrent
|
|
18
18
|
|
|
19
19
|
class llama_memory_hybrid : public llama_memory_i {
|
|
20
20
|
public:
|
|
21
|
-
|
|
22
|
-
// this callback is used to filter out layers that should not be included in the cache
|
|
23
|
-
using layer_filter_cb = std::function<bool(int32_t il)>;
|
|
24
|
-
|
|
25
21
|
llama_memory_hybrid(
|
|
26
22
|
const llama_model & model,
|
|
27
23
|
/* attn */
|
|
28
|
-
ggml_type
|
|
29
|
-
ggml_type
|
|
30
|
-
bool
|
|
31
|
-
uint32_t
|
|
32
|
-
uint32_t
|
|
33
|
-
uint32_t
|
|
34
|
-
llama_swa_type
|
|
35
|
-
|
|
36
|
-
ggml_type
|
|
37
|
-
ggml_type
|
|
38
|
-
uint32_t
|
|
39
|
-
|
|
40
|
-
uint32_t
|
|
41
|
-
bool
|
|
42
|
-
bool
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
24
|
+
ggml_type type_k,
|
|
25
|
+
ggml_type type_v,
|
|
26
|
+
bool v_trans,
|
|
27
|
+
uint32_t kv_size,
|
|
28
|
+
uint32_t n_pad,
|
|
29
|
+
uint32_t n_swa,
|
|
30
|
+
llama_swa_type swa_type,
|
|
31
|
+
/* recurrent */
|
|
32
|
+
ggml_type type_r,
|
|
33
|
+
ggml_type type_s,
|
|
34
|
+
uint32_t rs_size,
|
|
35
|
+
/* common */
|
|
36
|
+
uint32_t n_seq_max,
|
|
37
|
+
bool offload,
|
|
38
|
+
bool unified,
|
|
39
|
+
/* layer filters */
|
|
40
|
+
const layer_filter_cb & filter_attn = nullptr,
|
|
41
|
+
const layer_filter_cb & filter_recr = nullptr);
|
|
46
42
|
|
|
47
43
|
~llama_memory_hybrid() = default;
|
|
48
44
|
|
|
@@ -81,19 +77,19 @@ public:
|
|
|
81
77
|
// llama_memory_hybrid specific API
|
|
82
78
|
//
|
|
83
79
|
|
|
84
|
-
|
|
80
|
+
llama_kv_cache * get_mem_attn() const;
|
|
85
81
|
llama_memory_recurrent * get_mem_recr() const;
|
|
86
82
|
|
|
87
83
|
private:
|
|
88
84
|
const llama_hparams & hparams;
|
|
89
85
|
|
|
90
|
-
const std::unique_ptr<
|
|
86
|
+
const std::unique_ptr<llama_kv_cache> mem_attn;
|
|
91
87
|
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
|
92
88
|
};
|
|
93
89
|
|
|
94
90
|
class llama_memory_hybrid_context : public llama_memory_context_i {
|
|
95
91
|
public:
|
|
96
|
-
using slot_info_vec_t =
|
|
92
|
+
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
|
97
93
|
|
|
98
94
|
// init failure
|
|
99
95
|
explicit llama_memory_hybrid_context(llama_memory_status status);
|
|
@@ -125,7 +121,7 @@ public:
|
|
|
125
121
|
// llama_memory_hybrid_context
|
|
126
122
|
//
|
|
127
123
|
|
|
128
|
-
const
|
|
124
|
+
const llama_kv_cache_context * get_attn() const;
|
|
129
125
|
const llama_memory_recurrent_context * get_recr() const;
|
|
130
126
|
|
|
131
127
|
private:
|
|
@@ -16,13 +16,13 @@
|
|
|
16
16
|
//
|
|
17
17
|
|
|
18
18
|
llama_memory_recurrent::llama_memory_recurrent(
|
|
19
|
-
const llama_model &
|
|
20
|
-
|
|
21
|
-
ggml_type
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
uint32_t
|
|
25
|
-
|
|
19
|
+
const llama_model & model,
|
|
20
|
+
ggml_type type_r,
|
|
21
|
+
ggml_type type_s,
|
|
22
|
+
bool offload,
|
|
23
|
+
uint32_t mem_size,
|
|
24
|
+
uint32_t n_seq_max,
|
|
25
|
+
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
|
26
26
|
const int32_t n_layer = hparams.n_layer;
|
|
27
27
|
|
|
28
28
|
head = 0;
|
|
@@ -12,21 +12,17 @@
|
|
|
12
12
|
//
|
|
13
13
|
|
|
14
14
|
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
|
|
15
|
-
// see the implementation of
|
|
15
|
+
// see the implementation of llama_kv_cache_context_i for an example how to do it
|
|
16
16
|
class llama_memory_recurrent : public llama_memory_i {
|
|
17
17
|
public:
|
|
18
|
-
|
|
19
|
-
// this callback is used to filter out layers that should not be included in the cache
|
|
20
|
-
using layer_filter_cb = std::function<bool(int32_t il)>;
|
|
21
|
-
|
|
22
18
|
llama_memory_recurrent(
|
|
23
|
-
const llama_model &
|
|
24
|
-
|
|
25
|
-
ggml_type
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
uint32_t
|
|
29
|
-
|
|
19
|
+
const llama_model & model,
|
|
20
|
+
ggml_type type_r,
|
|
21
|
+
ggml_type type_s,
|
|
22
|
+
bool offload,
|
|
23
|
+
uint32_t mem_size,
|
|
24
|
+
uint32_t n_seq_max,
|
|
25
|
+
const layer_filter_cb & filter);
|
|
30
26
|
|
|
31
27
|
~llama_memory_recurrent() = default;
|
|
32
28
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#include "llama.h"
|
|
4
4
|
|
|
5
5
|
#include <memory>
|
|
6
|
+
#include <functional>
|
|
6
7
|
|
|
7
8
|
struct llama_ubatch;
|
|
8
9
|
|
|
@@ -36,8 +37,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
|
|
|
36
37
|
|
|
37
38
|
// the interface for managing the memory context during batch processing
|
|
38
39
|
// this interface is implemented per memory type. see:
|
|
39
|
-
// -
|
|
40
|
-
// -
|
|
40
|
+
// - llama_kv_cache_context
|
|
41
|
+
// - llama_kv_cache_iswa_context
|
|
41
42
|
// ...
|
|
42
43
|
//
|
|
43
44
|
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
|
|
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
|
|
|
64
65
|
// general concept of LLM memory
|
|
65
66
|
// the KV cache is a type of LLM memory, but there can be other types
|
|
66
67
|
struct llama_memory_i {
|
|
68
|
+
// this callback is used to filter out layers that should not be included in the cache
|
|
69
|
+
using layer_filter_cb = std::function<bool(int32_t il)>;
|
|
70
|
+
|
|
71
|
+
// this callback is used to specify which layers should reuse memory from other layers
|
|
72
|
+
// return negative value to indicate that the layer il should not reuse memory
|
|
73
|
+
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
|
|
74
|
+
|
|
67
75
|
virtual ~llama_memory_i() = default;
|
|
68
76
|
|
|
69
77
|
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
|
@@ -77,7 +85,7 @@ struct llama_memory_i {
|
|
|
77
85
|
// simulate full cache, used for allocating worst-case compute buffers
|
|
78
86
|
virtual llama_memory_context_ptr init_full() = 0;
|
|
79
87
|
|
|
80
|
-
// prepare for any pending memory updates, such as shifts,
|
|
88
|
+
// prepare for any pending memory updates, such as shifts, copies, etc.
|
|
81
89
|
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
|
82
90
|
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
|
83
91
|
|
|
@@ -109,8 +117,3 @@ struct llama_memory_i {
|
|
|
109
117
|
};
|
|
110
118
|
|
|
111
119
|
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
|
112
|
-
|
|
113
|
-
// TODO: temporary until the llama_kv_cache is removed from the public API
|
|
114
|
-
struct llama_kv_cache : public llama_memory_i {
|
|
115
|
-
virtual ~llama_kv_cache() = default;
|
|
116
|
-
};
|