@fugood/llama.node 1.1.8 → 1.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. package/lib/binding.ts +9 -0
  2. package/lib/index.js +9 -2
  3. package/lib/index.ts +57 -30
  4. package/lib/version.js +2 -2
  5. package/lib/version.ts +2 -2
  6. package/package.json +14 -14
  7. package/scripts/llama.cpp.patch +15 -5
  8. package/src/LlamaCompletionWorker.cpp +12 -3
  9. package/src/LlamaCompletionWorker.h +3 -1
  10. package/src/LlamaContext.cpp +14 -1
  11. package/src/llama.cpp/common/arg.cpp +6 -4
  12. package/src/llama.cpp/common/chat.cpp +34 -3
  13. package/src/llama.cpp/common/common.cpp +0 -15
  14. package/src/llama.cpp/common/common.h +1 -2
  15. package/src/llama.cpp/ggml/CMakeLists.txt +0 -1
  16. package/src/llama.cpp/ggml/include/ggml.h +25 -0
  17. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +316 -0
  18. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -2
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  20. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
  21. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +142 -0
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  23. package/src/llama.cpp/include/llama.h +1 -110
  24. package/src/llama.cpp/src/CMakeLists.txt +2 -2
  25. package/src/llama.cpp/src/llama-arch.cpp +19 -0
  26. package/src/llama.cpp/src/llama-arch.h +1 -0
  27. package/src/llama.cpp/src/llama-chat.cpp +13 -2
  28. package/src/llama.cpp/src/llama-chat.h +1 -0
  29. package/src/llama.cpp/src/llama-context.cpp +5 -192
  30. package/src/llama.cpp/src/llama-context.h +2 -7
  31. package/src/llama.cpp/src/llama-cparams.h +0 -1
  32. package/src/llama.cpp/src/llama-graph.cpp +35 -57
  33. package/src/llama.cpp/src/llama-graph.h +36 -46
  34. package/src/llama.cpp/src/llama-hparams.cpp +25 -0
  35. package/src/llama.cpp/src/llama-hparams.h +6 -0
  36. package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +69 -52
  37. package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +28 -26
  38. package/src/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +123 -474
  39. package/src/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +34 -59
  40. package/src/llama.cpp/src/llama-kv-cells.h +21 -21
  41. package/src/llama.cpp/src/llama-memory-hybrid.cpp +34 -33
  42. package/src/llama.cpp/src/llama-memory-hybrid.h +24 -28
  43. package/src/llama.cpp/src/llama-memory-recurrent.cpp +7 -7
  44. package/src/llama.cpp/src/llama-memory-recurrent.h +8 -12
  45. package/src/llama.cpp/src/llama-memory.h +11 -8
  46. package/src/llama.cpp/src/llama-model.cpp +396 -187
  47. package/src/llama.cpp/src/llama-model.h +1 -0
@@ -19,8 +19,8 @@ struct llama_cparams;
19
19
 
20
20
  struct llama_memory_context_i;
21
21
 
22
- class llama_kv_cache_unified_context;
23
- class llama_kv_cache_unified_iswa_context;
22
+ class llama_kv_cache_context;
23
+ class llama_kv_cache_iswa_context;
24
24
  class llama_memory_recurrent_context;
25
25
  class llama_memory_hybrid_context;
26
26
 
@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
152
152
  public:
153
153
  llm_graph_input_pos_bucket_kv(
154
154
  const llama_hparams & hparams,
155
- const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
155
+ const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
156
156
  virtual ~llm_graph_input_pos_bucket_kv() = default;
157
157
 
158
158
  void set_input(const llama_ubatch * ubatch) override;
@@ -161,7 +161,7 @@ public:
161
161
 
162
162
  const llama_hparams hparams;
163
163
 
164
- const llama_kv_cache_unified_context * mctx;
164
+ const llama_kv_cache_context * mctx;
165
165
  };
166
166
 
167
167
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -257,17 +257,17 @@ public:
257
257
  const llama_cparams cparams;
258
258
  };
259
259
 
260
- class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
260
+ class llm_graph_input_attn_kv : public llm_graph_input_i {
261
261
  public:
262
- llm_graph_input_attn_kv_unified(
262
+ llm_graph_input_attn_kv(
263
263
  const llama_hparams & hparams,
264
264
  const llama_cparams & cparams,
265
- const llama_kv_cache_unified_context * mctx) :
265
+ const llama_kv_cache_context * mctx) :
266
266
  hparams(hparams),
267
267
  cparams(cparams),
268
268
  mctx(mctx) {
269
269
  }
270
- ~llm_graph_input_attn_kv_unified() = default;
270
+ ~llm_graph_input_attn_kv() = default;
271
271
 
272
272
  void set_input(const llama_ubatch * ubatch) override;
273
273
 
@@ -290,20 +290,20 @@ public:
290
290
  const llama_hparams hparams;
291
291
  const llama_cparams cparams;
292
292
 
293
- const llama_kv_cache_unified_context * mctx;
293
+ const llama_kv_cache_context * mctx;
294
294
  };
295
295
 
296
- class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
296
+ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
297
297
  public:
298
- llm_graph_input_attn_kv_unified_iswa(
298
+ llm_graph_input_attn_kv_iswa(
299
299
  const llama_hparams & hparams,
300
300
  const llama_cparams & cparams,
301
- const llama_kv_cache_unified_iswa_context * mctx) :
301
+ const llama_kv_cache_iswa_context * mctx) :
302
302
  hparams(hparams),
303
303
  cparams(cparams),
304
304
  mctx(mctx) {
305
305
  }
306
- ~llm_graph_input_attn_kv_unified_iswa() = default;
306
+ ~llm_graph_input_attn_kv_iswa() = default;
307
307
 
308
308
  void set_input(const llama_ubatch * ubatch) override;
309
309
 
@@ -330,7 +330,7 @@ public:
330
330
  const llama_hparams hparams;
331
331
  const llama_cparams cparams;
332
332
 
333
- const llama_kv_cache_unified_iswa_context * mctx;
333
+ const llama_kv_cache_iswa_context * mctx;
334
334
  };
335
335
 
336
336
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -351,7 +351,7 @@ public:
351
351
  class llm_graph_input_mem_hybrid : public llm_graph_input_i {
352
352
  public:
353
353
  llm_graph_input_mem_hybrid(
354
- std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
354
+ std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
355
355
  std::unique_ptr<llm_graph_input_rs> inp_rs,
356
356
  const llama_memory_hybrid_context * mctx) :
357
357
  inp_attn(std::move(inp_attn)),
@@ -361,11 +361,11 @@ public:
361
361
 
362
362
  void set_input(const llama_ubatch * ubatch) override;
363
363
 
364
- std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
365
- std::unique_ptr<llm_graph_input_rs> inp_rs;
364
+ std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
365
+ std::unique_ptr<llm_graph_input_rs> inp_rs;
366
366
 
367
- llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
368
- llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
367
+ llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
368
+ llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
369
369
 
370
370
  const llama_memory_hybrid_context * mctx;
371
371
  };
@@ -680,14 +680,14 @@ struct llm_graph_context {
680
680
  //
681
681
 
682
682
  ggml_tensor * build_attn_mha(
683
- ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684
- ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685
- ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686
- ggml_tensor * kq_b,
687
- ggml_tensor * kq_mask,
688
- ggml_tensor * sinks,
689
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690
- float kq_scale) const;
683
+ ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684
+ ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685
+ ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686
+ ggml_tensor * kq_b,
687
+ ggml_tensor * kq_mask,
688
+ ggml_tensor * sinks, // [n_head_q]
689
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690
+ float kq_scale) const;
691
691
 
692
692
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
693
693
 
@@ -699,50 +699,39 @@ struct llm_graph_context {
699
699
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
700
700
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
701
701
  ggml_tensor * kq_b,
702
+ ggml_tensor * sinks, // [n_head_q]
702
703
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
703
704
  float kq_scale,
704
705
  int il) const;
705
706
 
706
- llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
707
+ llm_graph_input_attn_kv * build_attn_inp_kv() const;
707
708
 
708
709
  ggml_tensor * build_attn(
709
- llm_graph_input_attn_kv_unified * inp,
710
+ llm_graph_input_attn_kv * inp,
710
711
  ggml_tensor * wo,
711
712
  ggml_tensor * wo_b,
712
713
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
713
714
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
714
715
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
715
716
  ggml_tensor * kq_b,
717
+ ggml_tensor * sinks, // [n_head_q]
716
718
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
717
719
  float kq_scale,
718
720
  int il) const;
719
721
 
720
- llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
722
+ llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
721
723
 
722
724
  // note: if k_cur or v_cur are not provided, they will not be stored in the memory
723
725
  ggml_tensor * build_attn(
724
- llm_graph_input_attn_kv_unified_iswa * inp,
725
- ggml_tensor * wo,
726
- ggml_tensor * wo_b,
727
- ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
728
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
729
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
730
- ggml_tensor * kq_b,
731
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
732
- float kq_scale,
733
- int il) const;
734
-
735
- // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
736
- ggml_tensor * build_attn_with_sinks(
737
- llm_graph_input_attn_kv_unified_iswa * inp,
726
+ llm_graph_input_attn_kv_iswa * inp,
738
727
  ggml_tensor * wo,
739
728
  ggml_tensor * wo_b,
740
729
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
741
730
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
742
731
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
743
732
  ggml_tensor * kq_b,
744
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
745
733
  ggml_tensor * sinks, // [n_head_q]
734
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
746
735
  float kq_scale,
747
736
  int il) const;
748
737
 
@@ -756,6 +745,7 @@ struct llm_graph_context {
756
745
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
757
746
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
758
747
  ggml_tensor * kq_b,
748
+ ggml_tensor * sinks, // [n_head_q]
759
749
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
760
750
  float kq_scale,
761
751
  int il) const;
@@ -765,7 +755,7 @@ struct llm_graph_context {
765
755
  //
766
756
 
767
757
  // TODO: move this implementation to llama_memory_recurrent.
768
- // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
758
+ // this is analogous to llama_kv_cache::cpy_k / cpy_v
769
759
  // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
770
760
  // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
771
761
  // `llama_memory_recurrent`
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
153
153
 
154
154
  GGML_ABORT("fatal error");
155
155
  }
156
+
157
+ bool llama_hparams::has_kv(uint32_t il) const {
158
+ if (n_layer_kv_from_start >= 0) {
159
+ if (il < (uint32_t) n_layer_kv_from_start) {
160
+ return true;
161
+ }
162
+
163
+ return false;
164
+ }
165
+
166
+ // by default, all layers have kv
167
+ return true;
168
+ }
169
+
170
+ uint32_t llama_hparams::n_layer_kv() const {
171
+ uint32_t res = 0;
172
+
173
+ for (uint32_t il = 0; il < n_layer; ++il) {
174
+ if (has_kv(il)) {
175
+ res++;
176
+ }
177
+ }
178
+
179
+ return res;
180
+ }
@@ -41,6 +41,7 @@ struct llama_hparams {
41
41
  uint32_t n_embd;
42
42
  uint32_t n_embd_features = 0;
43
43
  uint32_t n_layer;
44
+ int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
44
45
  uint32_t n_rot;
45
46
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
46
47
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -221,6 +222,11 @@ struct llama_hparams {
221
222
  uint32_t n_pos_per_embd() const;
222
223
 
223
224
  bool is_swa(uint32_t il) const;
225
+
226
+ bool has_kv(uint32_t il) const;
227
+
228
+ // number of layers for which has_kv() returns true
229
+ uint32_t n_layer_kv() const;
224
230
  };
225
231
 
226
232
  static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
@@ -1,4 +1,4 @@
1
- #include "llama-kv-cache-unified-iswa.h"
1
+ #include "llama-kv-cache-iswa.h"
2
2
 
3
3
  #include "llama-impl.h"
4
4
  #include "llama-batch.h"
@@ -8,10 +8,10 @@
8
8
  #include <cassert>
9
9
 
10
10
  //
11
- // llama_kv_cache_unified_iswa
11
+ // llama_kv_cache_iswa
12
12
  //
13
13
 
14
- llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
14
+ llama_kv_cache_iswa::llama_kv_cache_iswa(
15
15
  const llama_model & model,
16
16
  ggml_type type_k,
17
17
  ggml_type type_v,
@@ -22,9 +22,26 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
22
22
  uint32_t kv_size,
23
23
  uint32_t n_seq_max,
24
24
  uint32_t n_ubatch,
25
- uint32_t n_pad) : hparams(model.hparams), unified(unified) {
26
- llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
27
- llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
25
+ uint32_t n_pad,
26
+ const layer_filter_cb & filter,
27
+ const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28
+
29
+ // chain filters
30
+ const layer_filter_cb filter_base = [&](int32_t il) {
31
+ if (filter && !filter(il)) {
32
+ return false;
33
+ }
34
+
35
+ return !model.hparams.is_swa(il);
36
+ };
37
+
38
+ const layer_filter_cb filter_swa = [&](int32_t il) {
39
+ if (filter && !filter(il)) {
40
+ return false;
41
+ }
42
+
43
+ return model.hparams.is_swa(il);
44
+ };
28
45
 
29
46
  const uint32_t size_base = kv_size;
30
47
 
@@ -40,25 +57,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
40
57
 
41
58
  LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
42
59
 
43
- kv_base = std::make_unique<llama_kv_cache_unified>(
44
- model, std::move(filter_base), type_k, type_v,
60
+ kv_base = std::make_unique<llama_kv_cache>(
61
+ model, type_k, type_v,
45
62
  v_trans, offload, unified, size_base, n_seq_max, n_pad,
46
- 0, LLAMA_SWA_TYPE_NONE);
63
+ 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
47
64
 
48
65
  LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
49
66
 
50
- kv_swa = std::make_unique<llama_kv_cache_unified>(
51
- model, std::move(filter_swa), type_k, type_v,
67
+ kv_swa = std::make_unique<llama_kv_cache>(
68
+ model, type_k, type_v,
52
69
  v_trans, offload, unified, size_swa, n_seq_max, n_pad,
53
- hparams.n_swa, hparams.swa_type);
70
+ hparams.n_swa, hparams.swa_type, filter_swa, reuse);
54
71
  }
55
72
 
56
- void llama_kv_cache_unified_iswa::clear(bool data) {
73
+ void llama_kv_cache_iswa::clear(bool data) {
57
74
  kv_base->clear(data);
58
75
  kv_swa ->clear(data);
59
76
  }
60
77
 
61
- bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
78
+ bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
62
79
  bool res = true;
63
80
 
64
81
  res = res & kv_base->seq_rm(seq_id, p0, p1);
@@ -67,36 +84,36 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
67
84
  return res;
68
85
  }
69
86
 
70
- void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
87
+ void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
71
88
  kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
72
89
  kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
73
90
  }
74
91
 
75
- void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
92
+ void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
76
93
  kv_base->seq_keep(seq_id);
77
94
  kv_swa ->seq_keep(seq_id);
78
95
  }
79
96
 
80
- void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
97
+ void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
81
98
  kv_base->seq_add(seq_id, p0, p1, shift);
82
99
  kv_swa ->seq_add(seq_id, p0, p1, shift);
83
100
  }
84
101
 
85
- void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
102
+ void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
86
103
  kv_base->seq_div(seq_id, p0, p1, d);
87
104
  kv_swa ->seq_div(seq_id, p0, p1, d);
88
105
  }
89
106
 
90
- llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
107
+ llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
91
108
  // the base cache is a superset of the SWA cache, so we can just check the SWA cache
92
109
  return kv_swa->seq_pos_min(seq_id);
93
110
  }
94
111
 
95
- llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
112
+ llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
96
113
  return kv_swa->seq_pos_max(seq_id);
97
114
  }
98
115
 
99
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
116
+ llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
100
117
  GGML_UNUSED(embd_all);
101
118
 
102
119
  // first try simple split
@@ -136,7 +153,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
136
153
 
137
154
  assert(sinfos_base.size() == sinfos_swa.size());
138
155
 
139
- return std::make_unique<llama_kv_cache_unified_iswa_context>(
156
+ return std::make_unique<llama_kv_cache_iswa_context>(
140
157
  this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
141
158
  } while (false);
142
159
 
@@ -172,29 +189,29 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
172
189
 
173
190
  assert(sinfos_base.size() == sinfos_swa.size());
174
191
 
175
- return std::make_unique<llama_kv_cache_unified_iswa_context>(
192
+ return std::make_unique<llama_kv_cache_iswa_context>(
176
193
  this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
177
194
  } while (false);
178
195
 
179
196
  // TODO: if we fail again, we should attempt different splitting strategies
180
197
  // but to do that properly, we first have to refactor the batches to be more flexible
181
198
 
182
- return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
199
+ return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
183
200
  }
184
201
 
185
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
186
- return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
202
+ llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
203
+ return std::make_unique<llama_kv_cache_iswa_context>(this);
187
204
  }
188
205
 
189
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
190
- return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
206
+ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
207
+ return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
191
208
  }
192
209
 
193
- bool llama_kv_cache_unified_iswa::get_can_shift() const {
210
+ bool llama_kv_cache_iswa::get_can_shift() const {
194
211
  return kv_base->get_size() == kv_swa->get_size();
195
212
  }
196
213
 
197
- void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
214
+ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
198
215
  if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
199
216
  kv_base->state_write(io, seq_id, flags);
200
217
  }
@@ -202,7 +219,7 @@ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_i
202
219
  kv_swa->state_write(io, seq_id, flags);
203
220
  }
204
221
 
205
- void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
222
+ void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
206
223
  if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
207
224
  kv_base->state_read(io, seq_id, flags);
208
225
  }
@@ -210,29 +227,29 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id
210
227
  kv_swa->state_read(io, seq_id, flags);
211
228
  }
212
229
 
213
- llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
230
+ llama_kv_cache * llama_kv_cache_iswa::get_base() const {
214
231
  return kv_base.get();
215
232
  }
216
233
 
217
- llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
234
+ llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
218
235
  return kv_swa.get();
219
236
  }
220
237
 
221
238
  //
222
- // llama_kv_cache_unified_iswa_context
239
+ // llama_kv_cache_iswa_context
223
240
  //
224
241
 
225
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
242
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
226
243
 
227
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
228
- llama_kv_cache_unified_iswa * kv) :
244
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
245
+ llama_kv_cache_iswa * kv) :
229
246
  ctx_base(kv->get_base()->init_full()),
230
247
  ctx_swa (kv->get_swa ()->init_full()),
231
248
  status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
232
249
  }
233
250
 
234
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
235
- llama_kv_cache_unified_iswa * kv,
251
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
252
+ llama_kv_cache_iswa * kv,
236
253
  llama_context * lctx,
237
254
  bool optimize) :
238
255
  ctx_base(kv->get_base()->init_update(lctx, optimize)),
@@ -240,21 +257,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
240
257
  status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
241
258
  }
242
259
 
243
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
244
- llama_kv_cache_unified_iswa * kv,
260
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
261
+ llama_kv_cache_iswa * kv,
245
262
  slot_info_vec_t sinfos_base,
246
263
  slot_info_vec_t sinfos_swa,
247
264
  std::vector<llama_ubatch> ubatches) :
248
265
  ubatches(std::move(ubatches)),
249
266
  // note: here we copy the ubatches. not sure if this is ideal
250
- ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
251
- ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
267
+ ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
268
+ ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
252
269
  status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
253
270
  }
254
271
 
255
- llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
272
+ llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
256
273
 
257
- bool llama_kv_cache_unified_iswa_context::next() {
274
+ bool llama_kv_cache_iswa_context::next() {
258
275
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
259
276
 
260
277
  ctx_base->next();
@@ -267,7 +284,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
267
284
  return true;
268
285
  }
269
286
 
270
- bool llama_kv_cache_unified_iswa_context::apply() {
287
+ bool llama_kv_cache_iswa_context::apply() {
271
288
  assert(!llama_memory_status_is_fail(status));
272
289
 
273
290
  bool res = true;
@@ -278,24 +295,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
278
295
  return res;
279
296
  }
280
297
 
281
- llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
298
+ llama_memory_status llama_kv_cache_iswa_context::get_status() const {
282
299
  return status;
283
300
  }
284
301
 
285
- const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
302
+ const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
286
303
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
287
304
 
288
305
  return ubatches[i_next];
289
306
  }
290
307
 
291
- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
308
+ const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
292
309
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
293
310
 
294
- return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
311
+ return static_cast<const llama_kv_cache_context *>(ctx_base.get());
295
312
  }
296
313
 
297
- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
314
+ const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
298
315
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
299
316
 
300
- return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
317
+ return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
301
318
  }
@@ -1,19 +1,19 @@
1
1
  #pragma once
2
2
 
3
- #include "llama-kv-cache-unified.h"
3
+ #include "llama-kv-cache.h"
4
4
 
5
5
  #include <vector>
6
6
 
7
7
  //
8
- // llama_kv_cache_unified_iswa
8
+ // llama_kv_cache_iswa
9
9
  //
10
10
 
11
- // utilizes two instances of llama_kv_cache_unified
11
+ // utilizes two instances of llama_kv_cache
12
12
  // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
13
 
14
- class llama_kv_cache_unified_iswa : public llama_memory_i {
14
+ class llama_kv_cache_iswa : public llama_memory_i {
15
15
  public:
16
- llama_kv_cache_unified_iswa(
16
+ llama_kv_cache_iswa(
17
17
  const llama_model & model,
18
18
  ggml_type type_k,
19
19
  ggml_type type_v,
@@ -24,9 +24,11 @@ public:
24
24
  uint32_t kv_size,
25
25
  uint32_t n_seq_max,
26
26
  uint32_t n_ubatch,
27
- uint32_t n_pad);
27
+ uint32_t n_pad,
28
+ const layer_filter_cb & filter,
29
+ const layer_reuse_cb & reuse);
28
30
 
29
- ~llama_kv_cache_unified_iswa() = default;
31
+ ~llama_kv_cache_iswa() = default;
30
32
 
31
33
  //
32
34
  // llama_memory_i
@@ -60,46 +62,46 @@ public:
60
62
  void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
61
63
 
62
64
  //
63
- // llama_kv_cache_unified_iswa specific API
65
+ // llama_kv_cache_iswa specific API
64
66
  //
65
67
 
66
- llama_kv_cache_unified * get_base() const;
67
- llama_kv_cache_unified * get_swa () const;
68
+ llama_kv_cache * get_base() const;
69
+ llama_kv_cache * get_swa () const;
68
70
 
69
71
  private:
70
72
  const llama_hparams & hparams;
71
73
 
72
74
  const bool unified;
73
75
 
74
- std::unique_ptr<llama_kv_cache_unified> kv_base;
75
- std::unique_ptr<llama_kv_cache_unified> kv_swa;
76
+ std::unique_ptr<llama_kv_cache> kv_base;
77
+ std::unique_ptr<llama_kv_cache> kv_swa;
76
78
  };
77
79
 
78
- class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
80
+ class llama_kv_cache_iswa_context : public llama_memory_context_i {
79
81
  public:
80
- using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
82
+ using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
81
83
 
82
84
  // used for errors
83
- llama_kv_cache_unified_iswa_context(llama_memory_status status);
85
+ llama_kv_cache_iswa_context(llama_memory_status status);
84
86
 
85
87
  // used to create a full-cache context
86
- llama_kv_cache_unified_iswa_context(
87
- llama_kv_cache_unified_iswa * kv);
88
+ llama_kv_cache_iswa_context(
89
+ llama_kv_cache_iswa * kv);
88
90
 
89
91
  // used to create an update context
90
- llama_kv_cache_unified_iswa_context(
91
- llama_kv_cache_unified_iswa * kv,
92
+ llama_kv_cache_iswa_context(
93
+ llama_kv_cache_iswa * kv,
92
94
  llama_context * lctx,
93
95
  bool optimize);
94
96
 
95
97
  // used to create a batch processing context from a batch
96
- llama_kv_cache_unified_iswa_context(
97
- llama_kv_cache_unified_iswa * kv,
98
+ llama_kv_cache_iswa_context(
99
+ llama_kv_cache_iswa * kv,
98
100
  slot_info_vec_t sinfos_base,
99
101
  slot_info_vec_t sinfos_swa,
100
102
  std::vector<llama_ubatch> ubatches);
101
103
 
102
- virtual ~llama_kv_cache_unified_iswa_context();
104
+ virtual ~llama_kv_cache_iswa_context();
103
105
 
104
106
  //
105
107
  // llama_memory_context_i
@@ -112,14 +114,14 @@ public:
112
114
  const llama_ubatch & get_ubatch() const override;
113
115
 
114
116
  //
115
- // llama_kv_cache_unified_iswa_context specific API
117
+ // llama_kv_cache_iswa_context specific API
116
118
  //
117
119
 
118
- const llama_kv_cache_unified_context * get_base() const;
119
- const llama_kv_cache_unified_context * get_swa() const;
120
+ const llama_kv_cache_context * get_base() const;
121
+ const llama_kv_cache_context * get_swa() const;
120
122
 
121
123
  private:
122
- //llama_kv_cache_unified_iswa * kv;
124
+ //llama_kv_cache_iswa * kv;
123
125
 
124
126
  // the index of the next ubatch to process
125
127
  size_t i_next = 0;