@fugood/llama.node 1.1.8 → 1.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. package/lib/binding.ts +9 -0
  2. package/lib/index.js +9 -2
  3. package/lib/index.ts +57 -30
  4. package/lib/version.js +2 -2
  5. package/lib/version.ts +2 -2
  6. package/package.json +14 -14
  7. package/scripts/llama.cpp.patch +15 -5
  8. package/src/LlamaCompletionWorker.cpp +12 -3
  9. package/src/LlamaCompletionWorker.h +3 -1
  10. package/src/LlamaContext.cpp +14 -1
  11. package/src/llama.cpp/common/arg.cpp +6 -4
  12. package/src/llama.cpp/common/chat.cpp +34 -3
  13. package/src/llama.cpp/common/common.cpp +0 -15
  14. package/src/llama.cpp/common/common.h +1 -2
  15. package/src/llama.cpp/ggml/CMakeLists.txt +0 -1
  16. package/src/llama.cpp/ggml/include/ggml.h +25 -0
  17. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +316 -0
  18. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -2
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  20. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
  21. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +142 -0
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  23. package/src/llama.cpp/include/llama.h +1 -110
  24. package/src/llama.cpp/src/CMakeLists.txt +2 -2
  25. package/src/llama.cpp/src/llama-arch.cpp +19 -0
  26. package/src/llama.cpp/src/llama-arch.h +1 -0
  27. package/src/llama.cpp/src/llama-chat.cpp +13 -2
  28. package/src/llama.cpp/src/llama-chat.h +1 -0
  29. package/src/llama.cpp/src/llama-context.cpp +5 -192
  30. package/src/llama.cpp/src/llama-context.h +2 -7
  31. package/src/llama.cpp/src/llama-cparams.h +0 -1
  32. package/src/llama.cpp/src/llama-graph.cpp +35 -57
  33. package/src/llama.cpp/src/llama-graph.h +36 -46
  34. package/src/llama.cpp/src/llama-hparams.cpp +25 -0
  35. package/src/llama.cpp/src/llama-hparams.h +6 -0
  36. package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +69 -52
  37. package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +28 -26
  38. package/src/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +123 -474
  39. package/src/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +34 -59
  40. package/src/llama.cpp/src/llama-kv-cells.h +21 -21
  41. package/src/llama.cpp/src/llama-memory-hybrid.cpp +34 -33
  42. package/src/llama.cpp/src/llama-memory-hybrid.h +24 -28
  43. package/src/llama.cpp/src/llama-memory-recurrent.cpp +7 -7
  44. package/src/llama.cpp/src/llama-memory-recurrent.h +8 -12
  45. package/src/llama.cpp/src/llama-memory.h +11 -8
  46. package/src/llama.cpp/src/llama-model.cpp +396 -187
  47. package/src/llama.cpp/src/llama-model.h +1 -0
@@ -14,27 +14,13 @@ struct llama_model;
14
14
  struct llama_context;
15
15
 
16
16
  //
17
- // llama_kv_cache_unified
17
+ // llama_kv_cache
18
18
  //
19
19
 
20
- class llama_kv_cache_unified : public llama_memory_i {
20
+ class llama_kv_cache : public llama_memory_i {
21
21
  public:
22
22
  static uint32_t get_padding(const llama_cparams & cparams);
23
23
 
24
- // this callback is used to filter out layers that should not be included in the cache
25
- using layer_filter_cb = std::function<bool(int32_t il)>;
26
-
27
- struct defrag_info {
28
- bool empty() const {
29
- return ids.empty();
30
- }
31
-
32
- // contains information about which cell moves where:
33
- // - cell i moves to ids[i]
34
- // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
35
- std::vector<uint32_t> ids;
36
- };
37
-
38
24
  struct stream_copy_info {
39
25
  bool empty() const {
40
26
  assert(ssrc.size() == sdst.size());
@@ -92,21 +78,22 @@ public:
92
78
 
93
79
  using slot_info_vec_t = std::vector<slot_info>;
94
80
 
95
- llama_kv_cache_unified(
96
- const llama_model & model,
97
- layer_filter_cb && filter,
98
- ggml_type type_k,
99
- ggml_type type_v,
100
- bool v_trans,
101
- bool offload,
102
- bool unified,
103
- uint32_t kv_size,
104
- uint32_t n_seq_max,
105
- uint32_t n_pad,
106
- uint32_t n_swa,
107
- llama_swa_type swa_type);
108
-
109
- ~llama_kv_cache_unified() = default;
81
+ llama_kv_cache(
82
+ const llama_model & model,
83
+ ggml_type type_k,
84
+ ggml_type type_v,
85
+ bool v_trans,
86
+ bool offload,
87
+ bool unified,
88
+ uint32_t kv_size,
89
+ uint32_t n_seq_max,
90
+ uint32_t n_pad,
91
+ uint32_t n_swa,
92
+ llama_swa_type swa_type,
93
+ const layer_filter_cb & filter,
94
+ const layer_reuse_cb & reuse);
95
+
96
+ ~llama_kv_cache() = default;
110
97
 
111
98
  //
112
99
  // llama_memory_i
@@ -140,7 +127,7 @@ public:
140
127
  void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
141
128
 
142
129
  //
143
- // llama_kv_cache_unified specific API
130
+ // llama_kv_cache specific API
144
131
  //
145
132
 
146
133
  uint32_t get_size() const;
@@ -173,7 +160,7 @@ public:
173
160
  // return empty vector on failure
174
161
  slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
175
162
 
176
- bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
163
+ bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
177
164
 
178
165
  // find a slot of kv cells that can hold the ubatch
179
166
  // if cont == true, then the slot must be continuous
@@ -241,7 +228,7 @@ private:
241
228
  // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
242
229
  std::vector<uint32_t> v_heads;
243
230
 
244
- std::vector<llama_kv_cells_unified> v_cells;
231
+ std::vector<llama_kv_cells> v_cells;
245
232
 
246
233
  // maps from a sequence id to a stream id
247
234
  std::vector<uint32_t> seq_to_stream;
@@ -254,9 +241,6 @@ private:
254
241
  // model layer id -> KV cache layer id
255
242
  std::unordered_map<int32_t, int32_t> map_layer_ids;
256
243
 
257
- // return non-empty vector if cells have been moved
258
- defrag_info defrag_prepare(int32_t n_max_nodes) const;
259
-
260
244
  size_t total_size() const;
261
245
 
262
246
  size_t size_k_bytes() const;
@@ -277,11 +261,6 @@ private:
277
261
  llm_graph_result * res,
278
262
  llama_context * lctx) const;
279
263
 
280
- ggml_cgraph * build_graph_defrag(
281
- llm_graph_result * res,
282
- llama_context * lctx,
283
- const defrag_info & dinfo) const;
284
-
285
264
  struct cell_ranges_t {
286
265
  uint32_t strm;
287
266
 
@@ -295,35 +274,33 @@ private:
295
274
  bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
296
275
  };
297
276
 
298
- class llama_kv_cache_unified_context : public llama_memory_context_i {
277
+ class llama_kv_cache_context : public llama_memory_context_i {
299
278
  public:
300
279
  // some shorthands
301
- using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
302
- using defrag_info = llama_kv_cache_unified::defrag_info;
303
- using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
280
+ using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
281
+ using stream_copy_info = llama_kv_cache::stream_copy_info;
304
282
 
305
283
  // used for errors
306
- llama_kv_cache_unified_context(llama_memory_status status);
284
+ llama_kv_cache_context(llama_memory_status status);
307
285
 
308
286
  // used to create a full-cache context
309
- llama_kv_cache_unified_context(
310
- llama_kv_cache_unified * kv);
287
+ llama_kv_cache_context(
288
+ llama_kv_cache * kv);
311
289
 
312
290
  // used to create an update context
313
- llama_kv_cache_unified_context(
314
- llama_kv_cache_unified * kv,
291
+ llama_kv_cache_context(
292
+ llama_kv_cache * kv,
315
293
  llama_context * lctx,
316
294
  bool do_shift,
317
- defrag_info dinfo,
318
295
  stream_copy_info sc_info);
319
296
 
320
297
  // used to create a batch procesing context from a batch
321
- llama_kv_cache_unified_context(
322
- llama_kv_cache_unified * kv,
298
+ llama_kv_cache_context(
299
+ llama_kv_cache * kv,
323
300
  slot_info_vec_t sinfos,
324
301
  std::vector<llama_ubatch> ubatches);
325
302
 
326
- virtual ~llama_kv_cache_unified_context();
303
+ virtual ~llama_kv_cache_context();
327
304
 
328
305
  //
329
306
  // llama_memory_context_i
@@ -336,7 +313,7 @@ public:
336
313
  const llama_ubatch & get_ubatch() const override;
337
314
 
338
315
  //
339
- // llama_kv_cache_unified_context specific API
316
+ // llama_kv_cache_context specific API
340
317
  //
341
318
 
342
319
  uint32_t get_n_kv() const;
@@ -365,7 +342,7 @@ public:
365
342
  private:
366
343
  llama_memory_status status;
367
344
 
368
- llama_kv_cache_unified * kv;
345
+ llama_kv_cache * kv;
369
346
  llama_context * lctx;
370
347
 
371
348
  //
@@ -374,8 +351,6 @@ private:
374
351
 
375
352
  bool do_shift = false;
376
353
 
377
- defrag_info dinfo;
378
-
379
354
  stream_copy_info sc_info;
380
355
 
381
356
  //
@@ -11,7 +11,7 @@
11
11
 
12
12
  // meta information about KV cells that can be part of multiple sequences at the same time
13
13
  // TODO: add unit tests
14
- class llama_kv_cells_unified {
14
+ class llama_kv_cells {
15
15
  public:
16
16
  void reset() {
17
17
  for (uint32_t i = 0; i < pos.size(); ++i) {
@@ -77,30 +77,30 @@ public:
77
77
  }
78
78
 
79
79
  // move cell isrc to idst (used during defrag)
80
- void mv(uint32_t isrc, uint32_t idst) {
81
- assert(isrc < pos.size());
82
- assert(idst < pos.size());
80
+ //void mv(uint32_t isrc, uint32_t idst) {
81
+ // assert(isrc < pos.size());
82
+ // assert(idst < pos.size());
83
83
 
84
- assert(pos[idst] == -1);
85
- assert(pos[isrc] != -1);
84
+ // assert(pos[idst] == -1);
85
+ // assert(pos[isrc] != -1);
86
86
 
87
- pos [idst] = pos [isrc];
88
- shift[idst] = shift[isrc];
89
- seq [idst] = seq [isrc];
87
+ // pos [idst] = pos [isrc];
88
+ // shift[idst] = shift[isrc];
89
+ // seq [idst] = seq [isrc];
90
90
 
91
- pos [isrc] = -1;
92
- shift[isrc] = 0;
93
- seq [isrc].reset();
91
+ // pos [isrc] = -1;
92
+ // shift[isrc] = 0;
93
+ // seq [isrc].reset();
94
94
 
95
- used.erase (isrc);
96
- used.insert(idst);
97
- }
95
+ // used.erase (isrc);
96
+ // used.insert(idst);
97
+ //}
98
98
 
99
99
  // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
100
- llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
100
+ llama_kv_cells cp(uint32_t i, uint32_t n) const {
101
101
  assert(i + n <= pos.size());
102
102
 
103
- llama_kv_cells_unified res;
103
+ llama_kv_cells res;
104
104
 
105
105
  res.resize(n);
106
106
 
@@ -117,8 +117,8 @@ public:
117
117
  }
118
118
 
119
119
  // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
120
- llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
121
- llama_kv_cells_unified res;
120
+ llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
121
+ llama_kv_cells res;
122
122
 
123
123
  res.resize(idxs.size());
124
124
 
@@ -135,7 +135,7 @@ public:
135
135
  }
136
136
 
137
137
  // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
138
- void set(uint32_t i, const llama_kv_cells_unified & other) {
138
+ void set(uint32_t i, const llama_kv_cells & other) {
139
139
  assert(i + other.pos.size() <= pos.size());
140
140
 
141
141
  for (uint32_t j = 0; j < other.pos.size(); ++j) {
@@ -165,7 +165,7 @@ public:
165
165
  }
166
166
 
167
167
  // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
168
- void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
168
+ void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
169
169
  assert(idxs.size() == other.pos.size());
170
170
 
171
171
  for (uint32_t j = 0; j < other.pos.size(); ++j) {
@@ -9,32 +9,29 @@
9
9
  //
10
10
 
11
11
  llama_memory_hybrid::llama_memory_hybrid(
12
- const llama_model & model,
13
- /* attn */
14
- ggml_type type_k,
15
- ggml_type type_v,
16
- bool v_trans,
17
- uint32_t kv_size,
18
- uint32_t n_pad,
19
- uint32_t n_swa,
20
- llama_swa_type swa_type,
21
- /* recurrent */
22
- ggml_type type_r,
23
- ggml_type type_s,
24
- uint32_t rs_size,
25
- /* common */
26
- uint32_t n_seq_max,
27
- bool offload,
28
- bool unified,
29
- /* layer filters */
30
- layer_filter_cb && filter_attn,
31
- layer_filter_cb && filter_recr) :
12
+ const llama_model & model,
13
+ /* attn */
14
+ ggml_type type_k,
15
+ ggml_type type_v,
16
+ bool v_trans,
17
+ uint32_t kv_size,
18
+ uint32_t n_pad,
19
+ uint32_t n_swa,
20
+ llama_swa_type swa_type,
21
+ /* recurrent */
22
+ ggml_type type_r,
23
+ ggml_type type_s,
24
+ uint32_t rs_size,
25
+ /* common */
26
+ uint32_t n_seq_max,
27
+ bool offload,
28
+ bool unified,
29
+ /* layer filters */
30
+ const layer_filter_cb & filter_attn,
31
+ const layer_filter_cb & filter_recr) :
32
32
  hparams(model.hparams),
33
- mem_attn(new llama_kv_cache_unified(
33
+ mem_attn(new llama_kv_cache(
34
34
  model,
35
- filter_attn == nullptr ?
36
- [&](int32_t il) { return !hparams.is_recurrent(il); }
37
- : filter_attn,
38
35
  type_k,
39
36
  type_v,
40
37
  v_trans,
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
44
41
  n_seq_max,
45
42
  n_pad,
46
43
  n_swa,
47
- swa_type
44
+ swa_type,
45
+ filter_attn == nullptr ?
46
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
47
+ : filter_attn,
48
+ nullptr
48
49
  )),
49
50
  mem_recr(new llama_memory_recurrent(
50
51
  model,
51
- filter_recr == nullptr ?
52
- [&](int32_t il) { return hparams.is_recurrent(il); }
53
- : filter_recr,
54
52
  type_r,
55
53
  type_s,
56
54
  offload,
57
55
  rs_size,
58
- n_seq_max
56
+ n_seq_max,
57
+ filter_recr == nullptr ?
58
+ [&](int32_t il) { return hparams.is_recurrent(il); }
59
+ : filter_recr
59
60
  )) {}
60
61
 
61
62
  llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
@@ -179,7 +180,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
179
180
  mem_recr->state_read(io, seq_id);
180
181
  }
181
182
 
182
- llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
183
+ llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
183
184
  return mem_attn.get();
184
185
  }
185
186
 
@@ -210,7 +211,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
210
211
  std::vector<llama_ubatch> ubatches) :
211
212
  ubatches(std::move(ubatches)),
212
213
  // note: here we copy the ubatches. not sure if this is ideal
213
- ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
214
+ ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
214
215
  ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
215
216
  status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
216
217
  }
@@ -248,8 +249,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
248
249
  return ubatches[i_next];
249
250
  }
250
251
 
251
- const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
252
- return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
252
+ const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
253
+ return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
253
254
  }
254
255
 
255
256
  const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
@@ -2,7 +2,7 @@
2
2
 
3
3
  #include "llama-batch.h"
4
4
  #include "llama-graph.h"
5
- #include "llama-kv-cache-unified.h"
5
+ #include "llama-kv-cache.h"
6
6
  #include "llama-memory.h"
7
7
  #include "llama-memory-recurrent.h"
8
8
 
@@ -13,36 +13,32 @@
13
13
  // llama_memory_hybrid
14
14
  //
15
15
 
16
- // utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
16
+ // utilizes instances of llama_memory_recurrent and llama_kv_cache to
17
17
  // support models where each layer may be either attention-based or recurrent
18
18
 
19
19
  class llama_memory_hybrid : public llama_memory_i {
20
20
  public:
21
-
22
- // this callback is used to filter out layers that should not be included in the cache
23
- using layer_filter_cb = std::function<bool(int32_t il)>;
24
-
25
21
  llama_memory_hybrid(
26
22
  const llama_model & model,
27
23
  /* attn */
28
- ggml_type type_k,
29
- ggml_type type_v,
30
- bool v_trans,
31
- uint32_t kv_size,
32
- uint32_t n_pad,
33
- uint32_t n_swa,
34
- llama_swa_type swa_type,
35
- /* recurrent */
36
- ggml_type type_r,
37
- ggml_type type_s,
38
- uint32_t rs_size,
39
- /* common */
40
- uint32_t n_seq_max,
41
- bool offload,
42
- bool unified,
43
- /* layer filters */
44
- layer_filter_cb && filter_attn = nullptr,
45
- layer_filter_cb && filter_recr = nullptr);
24
+ ggml_type type_k,
25
+ ggml_type type_v,
26
+ bool v_trans,
27
+ uint32_t kv_size,
28
+ uint32_t n_pad,
29
+ uint32_t n_swa,
30
+ llama_swa_type swa_type,
31
+ /* recurrent */
32
+ ggml_type type_r,
33
+ ggml_type type_s,
34
+ uint32_t rs_size,
35
+ /* common */
36
+ uint32_t n_seq_max,
37
+ bool offload,
38
+ bool unified,
39
+ /* layer filters */
40
+ const layer_filter_cb & filter_attn = nullptr,
41
+ const layer_filter_cb & filter_recr = nullptr);
46
42
 
47
43
  ~llama_memory_hybrid() = default;
48
44
 
@@ -81,19 +77,19 @@ public:
81
77
  // llama_memory_hybrid specific API
82
78
  //
83
79
 
84
- llama_kv_cache_unified * get_mem_attn() const;
80
+ llama_kv_cache * get_mem_attn() const;
85
81
  llama_memory_recurrent * get_mem_recr() const;
86
82
 
87
83
  private:
88
84
  const llama_hparams & hparams;
89
85
 
90
- const std::unique_ptr<llama_kv_cache_unified> mem_attn;
86
+ const std::unique_ptr<llama_kv_cache> mem_attn;
91
87
  const std::unique_ptr<llama_memory_recurrent> mem_recr;
92
88
  };
93
89
 
94
90
  class llama_memory_hybrid_context : public llama_memory_context_i {
95
91
  public:
96
- using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
92
+ using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
97
93
 
98
94
  // init failure
99
95
  explicit llama_memory_hybrid_context(llama_memory_status status);
@@ -125,7 +121,7 @@ public:
125
121
  // llama_memory_hybrid_context
126
122
  //
127
123
 
128
- const llama_kv_cache_unified_context * get_attn() const;
124
+ const llama_kv_cache_context * get_attn() const;
129
125
  const llama_memory_recurrent_context * get_recr() const;
130
126
 
131
127
  private:
@@ -16,13 +16,13 @@
16
16
  //
17
17
 
18
18
  llama_memory_recurrent::llama_memory_recurrent(
19
- const llama_model & model,
20
- layer_filter_cb && filter,
21
- ggml_type type_r,
22
- ggml_type type_s,
23
- bool offload,
24
- uint32_t mem_size,
25
- uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
19
+ const llama_model & model,
20
+ ggml_type type_r,
21
+ ggml_type type_s,
22
+ bool offload,
23
+ uint32_t mem_size,
24
+ uint32_t n_seq_max,
25
+ const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
26
26
  const int32_t n_layer = hparams.n_layer;
27
27
 
28
28
  head = 0;
@@ -12,21 +12,17 @@
12
12
  //
13
13
 
14
14
  // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
15
- // see the implementation of llama_kv_cache_unified_context_i for an example how to do it
15
+ // see the implementation of llama_kv_cache_context_i for an example how to do it
16
16
  class llama_memory_recurrent : public llama_memory_i {
17
17
  public:
18
-
19
- // this callback is used to filter out layers that should not be included in the cache
20
- using layer_filter_cb = std::function<bool(int32_t il)>;
21
-
22
18
  llama_memory_recurrent(
23
- const llama_model & model,
24
- layer_filter_cb && filter,
25
- ggml_type type_r,
26
- ggml_type type_s,
27
- bool offload,
28
- uint32_t mem_size,
29
- uint32_t n_seq_max);
19
+ const llama_model & model,
20
+ ggml_type type_r,
21
+ ggml_type type_s,
22
+ bool offload,
23
+ uint32_t mem_size,
24
+ uint32_t n_seq_max,
25
+ const layer_filter_cb & filter);
30
26
 
31
27
  ~llama_memory_recurrent() = default;
32
28
 
@@ -3,6 +3,7 @@
3
3
  #include "llama.h"
4
4
 
5
5
  #include <memory>
6
+ #include <functional>
6
7
 
7
8
  struct llama_ubatch;
8
9
 
@@ -36,8 +37,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
36
37
 
37
38
  // the interface for managing the memory context during batch processing
38
39
  // this interface is implemented per memory type. see:
39
- // - llama_kv_cache_unified_context
40
- // - llama_kv_cache_unified_iswa_context
40
+ // - llama_kv_cache_context
41
+ // - llama_kv_cache_iswa_context
41
42
  // ...
42
43
  //
43
44
  // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
64
65
  // general concept of LLM memory
65
66
  // the KV cache is a type of LLM memory, but there can be other types
66
67
  struct llama_memory_i {
68
+ // this callback is used to filter out layers that should not be included in the cache
69
+ using layer_filter_cb = std::function<bool(int32_t il)>;
70
+
71
+ // this callback is used to specify which layers should reuse memory from other layers
72
+ // return negative value to indicate that the layer il should not reuse memory
73
+ using layer_reuse_cb = std::function<int32_t(int32_t il)>;
74
+
67
75
  virtual ~llama_memory_i() = default;
68
76
 
69
77
  // split the input batch into a set of ubatches and verify that they can fit into the cache
@@ -77,7 +85,7 @@ struct llama_memory_i {
77
85
  // simulate full cache, used for allocating worst-case compute buffers
78
86
  virtual llama_memory_context_ptr init_full() = 0;
79
87
 
80
- // prepare for any pending memory updates, such as shifts, defrags, etc.
88
+ // prepare for any pending memory updates, such as shifts, copies, etc.
81
89
  // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
82
90
  virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
83
91
 
@@ -109,8 +117,3 @@ struct llama_memory_i {
109
117
  };
110
118
 
111
119
  using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
112
-
113
- // TODO: temporary until the llama_kv_cache is removed from the public API
114
- struct llama_kv_cache : public llama_memory_i {
115
- virtual ~llama_kv_cache() = default;
116
- };