@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -6,86 +6,92 @@
6
6
  #include "llama-model.h"
7
7
 
8
8
  #include <algorithm>
9
+ #include <cassert>
9
10
  #include <limits>
10
11
  #include <map>
12
+ #include <stdexcept>
11
13
 
12
14
  static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
13
15
 
14
- uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
15
- // the FA kernels require padding to avoid extra runtime boundary checks
16
- return cparams.flash_attn ? 256u : 32u;
16
+ llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
17
17
  }
18
18
 
19
- bool llama_kv_cache_init(
20
- struct llama_kv_cache & cache,
21
- const llama_model & model,
22
- const llama_cparams & cparams,
23
- ggml_type type_k,
24
- ggml_type type_v,
25
- uint32_t kv_size,
26
- bool offload) {
27
- const struct llama_hparams & hparams = model.hparams;
28
-
19
+ bool llama_kv_cache_unified::init(
20
+ const llama_model & model,
21
+ const llama_cparams & cparams,
22
+ ggml_type type_k,
23
+ ggml_type type_v,
24
+ uint32_t kv_size,
25
+ bool offload) {
29
26
  const int32_t n_layer = hparams.n_layer;
30
27
 
31
- cache.has_shift = false;
28
+ has_shift = false;
32
29
 
33
- cache.recurrent = llama_model_is_recurrent(&model);
34
- cache.v_trans = !cache.recurrent && !cparams.flash_attn;
35
- cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
30
+ recurrent = llama_model_is_recurrent(&model);
31
+ v_trans = !recurrent && !cparams.flash_attn;
32
+ can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
36
33
 
37
34
  LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
38
- __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
35
+ __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
39
36
 
40
- cache.head = 0;
41
- cache.size = kv_size;
42
- cache.used = 0;
37
+ head = 0;
38
+ size = kv_size;
39
+ used = 0;
43
40
 
44
- cache.type_k = type_k;
45
- cache.type_v = type_v;
41
+ this->type_k = type_k;
42
+ this->type_v = type_v;
46
43
 
47
- cache.cells.clear();
48
- cache.cells.resize(kv_size);
44
+ cells.clear();
45
+ cells.resize(kv_size);
49
46
 
50
47
  // create a context for each buffer type
51
48
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
52
49
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
53
50
  auto it = ctx_map.find(buft);
54
51
  if (it == ctx_map.end()) {
55
- struct ggml_init_params params = {
52
+ ggml_init_params params = {
56
53
  /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
57
54
  /*.mem_buffer =*/ NULL,
58
55
  /*.no_alloc =*/ true,
59
56
  };
57
+
60
58
  ggml_context * ctx = ggml_init(params);
61
59
  if (!ctx) {
62
60
  return nullptr;
63
61
  }
62
+
64
63
  ctx_map[buft] = ctx;
65
- cache.ctxs.emplace_back(ctx);
64
+ ctxs.emplace_back(ctx);
65
+
66
66
  return ctx;
67
67
  }
68
+
68
69
  return it->second;
69
70
  };
70
71
 
71
- cache.k_l.reserve(n_layer);
72
- cache.v_l.reserve(n_layer);
72
+ k_l.reserve(n_layer);
73
+ v_l.reserve(n_layer);
73
74
 
74
75
  for (int i = 0; i < n_layer; i++) {
75
76
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
76
77
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
77
78
 
78
- LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
79
+ const char * dev_name = "CPU";
79
80
 
80
81
  ggml_backend_buffer_type_t buft;
81
82
  if (offload) {
82
83
  auto * dev = model.dev_layer(i);
83
84
  buft = ggml_backend_dev_buffer_type(dev);
85
+
86
+ dev_name = ggml_backend_dev_name(dev);
84
87
  } else {
85
88
  buft = ggml_backend_cpu_buffer_type();
86
89
  }
87
- ggml_context * ctx = ctx_for_buft(buft);
88
90
 
91
+ LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
92
+ i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
93
+
94
+ ggml_context * ctx = ctx_for_buft(buft);
89
95
  if (!ctx) {
90
96
  LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
91
97
  return false;
@@ -95,8 +101,8 @@ bool llama_kv_cache_init(
95
101
  ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
96
102
  ggml_format_name(k, "cache_k_l%d", i);
97
103
  ggml_format_name(v, "cache_v_l%d", i);
98
- cache.k_l.push_back(k);
99
- cache.v_l.push_back(v);
104
+ k_l.push_back(k);
105
+ v_l.push_back(v);
100
106
  }
101
107
 
102
108
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -111,20 +117,346 @@ bool llama_kv_cache_init(
111
117
  }
112
118
  ggml_backend_buffer_clear(buf, 0);
113
119
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
114
- cache.bufs.emplace_back(buf);
120
+ bufs.emplace_back(buf);
115
121
  }
116
122
 
117
123
  return true;
118
124
  }
119
125
 
120
- struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
121
- struct llama_kv_cache & cache,
122
- const struct llama_ubatch & ubatch) {
126
+ int32_t llama_kv_cache_unified::get_n_tokens() const {
127
+ int32_t result = 0;
128
+
129
+ for (uint32_t i = 0; i < size; i++) {
130
+ result += cells[i].seq_id.size();
131
+ }
132
+
133
+ return result;
134
+ }
135
+
136
+ uint32_t llama_kv_cache_unified::get_used_cells() const {
137
+ return used;
138
+ }
139
+
140
+ size_t llama_kv_cache_unified::total_size() const {
141
+ size_t size = 0;
142
+ for (const auto & buf : bufs) {
143
+ size += ggml_backend_buffer_get_size(buf.get());
144
+ }
145
+
146
+ return size;
147
+ }
148
+
149
+ llama_pos llama_kv_cache_unified::pos_max() const {
150
+ llama_pos pos_max = -1;
151
+ for (const auto & cell : cells) {
152
+ pos_max = std::max(pos_max, cell.pos);
153
+ }
154
+
155
+ return pos_max;
156
+ }
157
+
158
+ void llama_kv_cache_unified::clear() {
159
+ for (int32_t i = 0; i < (int32_t) size; ++i) {
160
+ cells[i].pos = -1;
161
+ cells[i].seq_id.clear();
162
+ cells[i].src = -1;
163
+ cells[i].tail = -1;
164
+ }
165
+ head = 0;
166
+ used = 0;
167
+
168
+ for (auto & buf : bufs) {
169
+ ggml_backend_buffer_clear(buf.get(), 0);
170
+ }
171
+ }
172
+
173
+ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
174
+ uint32_t new_head = size;
175
+
176
+ if (p0 < 0) {
177
+ p0 = 0;
178
+ }
179
+
180
+ if (p1 < 0) {
181
+ p1 = std::numeric_limits<llama_pos>::max();
182
+ }
183
+
184
+ // models like Mamba or RWKV can't have a state partially erased
185
+ if (recurrent) {
186
+ if (seq_id >= (int64_t) size) {
187
+ // could be fatal
188
+ return false;
189
+ }
190
+ if (0 <= seq_id) {
191
+ int32_t & tail_id = cells[seq_id].tail;
192
+ if (tail_id >= 0) {
193
+ const llama_kv_cell & cell = cells[tail_id];
194
+ // partial intersection is invalid
195
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
196
+ return false;
197
+ }
198
+ // invalidate tails which will be cleared
199
+ if (p0 <= cell.pos && cell.pos < p1) {
200
+ tail_id = -1;
201
+ }
202
+ }
203
+ } else {
204
+ // seq_id is negative, then the range should include everything or nothing
205
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
206
+ return false;
207
+ }
208
+ }
209
+ }
210
+
211
+ for (uint32_t i = 0; i < size; ++i) {
212
+ if (cells[i].pos >= p0 && cells[i].pos < p1) {
213
+ if (seq_id < 0) {
214
+ cells[i].seq_id.clear();
215
+ } else if (cells[i].has_seq_id(seq_id)) {
216
+ cells[i].seq_id.erase(seq_id);
217
+ } else {
218
+ continue;
219
+ }
220
+ if (cells[i].is_empty()) {
221
+ // keep count of the number of used cells
222
+ if (cells[i].pos >= 0) {
223
+ used--;
224
+ }
225
+
226
+ cells[i].pos = -1;
227
+ cells[i].src = -1;
228
+
229
+ if (new_head == size) {
230
+ new_head = i;
231
+ }
232
+ }
233
+ }
234
+ }
235
+
236
+ // If we freed up a slot, set head to it so searching can start there.
237
+ if (new_head != size && new_head < head) {
238
+ head = new_head;
239
+ }
240
+
241
+ return true;
242
+ }
243
+
244
+ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
245
+ if (seq_id_src == seq_id_dst) {
246
+ return;
247
+ }
248
+
249
+ if (p0 < 0) {
250
+ p0 = 0;
251
+ }
252
+
253
+ if (p1 < 0) {
254
+ p1 = std::numeric_limits<llama_pos>::max();
255
+ }
256
+
257
+ if (recurrent) {
258
+ if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
259
+ llama_kv_cell & tail_src = cells[seq_id_src];
260
+ llama_kv_cell & tail_dst = cells[seq_id_dst];
261
+ if (tail_dst.tail >= 0) {
262
+ // clear destination seq_id if it wasn't empty
263
+ llama_kv_cell & cell_dst = cells[tail_dst.tail];
264
+
265
+ cell_dst.seq_id.erase(seq_id_dst);
266
+ tail_dst.tail = -1;
267
+ if (cell_dst.seq_id.empty()) {
268
+ cell_dst.pos = -1;
269
+ cell_dst.delta = -1;
270
+ cell_dst.src = -1;
271
+ used -= 1;
272
+ }
273
+ }
274
+ if (tail_src.tail >= 0) {
275
+ llama_kv_cell & cell_src = cells[tail_src.tail];
276
+
277
+ cell_src.seq_id.insert(seq_id_dst);
278
+ tail_dst.tail = tail_src.tail;
279
+ }
280
+ }
281
+
282
+ return;
283
+ }
284
+
285
+ // otherwise, this is the KV of a Transformer-like model
286
+ head = 0;
287
+
288
+ for (uint32_t i = 0; i < size; ++i) {
289
+ if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
290
+ cells[i].seq_id.insert(seq_id_dst);
291
+ }
292
+ }
293
+ }
294
+
295
+ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
296
+ uint32_t new_head = size;
297
+
298
+ for (uint32_t i = 0; i < size; ++i) {
299
+ if (recurrent && (llama_seq_id) i != seq_id) {
300
+ cells[i].tail = -1;
301
+ }
302
+
303
+ if (!cells[i].has_seq_id(seq_id)) {
304
+ if (cells[i].pos >= 0) {
305
+ used--;
306
+ }
307
+
308
+ cells[i].pos = -1;
309
+ cells[i].src = -1;
310
+ cells[i].seq_id.clear();
311
+
312
+ if (new_head == size){
313
+ new_head = i;
314
+ }
315
+ } else {
316
+ cells[i].seq_id.clear();
317
+ cells[i].seq_id.insert(seq_id);
318
+ }
319
+ }
320
+
321
+ // If we freed up a slot, set head to it so searching can start there.
322
+ if (new_head != size && new_head < head) {
323
+ head = new_head;
324
+ }
325
+ }
326
+
327
+ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
328
+ if (delta == 0) {
329
+ return;
330
+ }
331
+
332
+ uint32_t new_head = size;
333
+
334
+ if (p0 < 0) {
335
+ p0 = 0;
336
+ }
337
+
338
+ if (p1 < 0) {
339
+ p1 = std::numeric_limits<llama_pos>::max();
340
+ }
341
+
342
+ // If there is no range then return early to avoid looping over the
343
+ if (p0 == p1) {
344
+ return;
345
+ }
346
+
347
+ if (recurrent) {
348
+ // for Mamba-like or RWKV models, only the pos needs to be shifted
349
+ if (0 <= seq_id && seq_id < (int64_t) size) {
350
+ const int32_t tail_id = cells[seq_id].tail;
351
+ if (tail_id >= 0) {
352
+ llama_kv_cell & cell = cells[tail_id];
353
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
354
+ cell.pos += delta;
355
+ }
356
+ }
357
+ }
358
+ return;
359
+ }
360
+
361
+ for (uint32_t i = 0; i < size; ++i) {
362
+ if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
363
+ has_shift = true;
364
+ cells[i].pos += delta;
365
+ cells[i].delta += delta;
366
+
367
+ if (cells[i].pos < 0) {
368
+ if (!cells[i].is_empty()) {
369
+ used--;
370
+ }
371
+ cells[i].pos = -1;
372
+ cells[i].seq_id.clear();
373
+ if (new_head == size) {
374
+ new_head = i;
375
+ }
376
+ }
377
+ }
378
+ }
379
+
380
+ // If we freed up a slot, set head to it so searching can start there.
381
+ // Otherwise we just start the next search from the beginning.
382
+ head = new_head != size ? new_head : 0;
383
+ }
384
+
385
+ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
386
+ if (d == 1) {
387
+ return;
388
+ }
389
+
390
+ if (p0 < 0) {
391
+ p0 = 0;
392
+ }
393
+
394
+ if (p1 < 0) {
395
+ p1 = std::numeric_limits<llama_pos>::max();
396
+ }
397
+
398
+ // If there is no range then return early to avoid looping over the cache.
399
+ if (p0 == p1) {
400
+ return;
401
+ }
402
+
403
+ if (recurrent) {
404
+ // for Mamba-like or RWKV models, only the pos needs to be changed
405
+ if (0 <= seq_id && seq_id < (int64_t) size) {
406
+ const int32_t tail_id = cells[seq_id].tail;
407
+ if (tail_id >= 0) {
408
+ llama_kv_cell & cell = cells[tail_id];
409
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
410
+ cell.pos /= d;
411
+ }
412
+ }
413
+ }
414
+
415
+ return;
416
+ }
417
+
418
+ for (uint32_t i = 0; i < size; ++i) {
419
+ if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
420
+ has_shift = true;
421
+
422
+ {
423
+ llama_pos p_old = cells[i].pos;
424
+ cells[i].pos /= d;
425
+ cells[i].delta += cells[i].pos - p_old;
426
+ }
427
+ }
428
+ }
429
+ }
430
+
431
+ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
432
+ llama_pos result = 0;
433
+
434
+ for (uint32_t i = 0; i < size; ++i) {
435
+ if (cells[i].has_seq_id(seq_id)) {
436
+ result = std::max(result, cells[i].pos);
437
+ }
438
+ }
439
+
440
+ return result;
441
+ }
442
+
443
+ void llama_kv_cache_unified::defrag() {
444
+ if (!recurrent) {
445
+ do_defrag = true;
446
+ }
447
+ }
448
+
449
+ bool llama_kv_cache_unified::get_can_shift() const {
450
+ return can_shift;
451
+ }
452
+
453
+ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
454
+ const llama_ubatch & ubatch) {
123
455
  const uint32_t n_tokens = ubatch.n_tokens;
124
456
  const uint32_t n_seqs = ubatch.n_seqs;
125
457
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
126
458
 
127
- if (cache.recurrent) {
459
+ if (recurrent) {
128
460
  // For recurrent state architectures (like Mamba or RWKV),
129
461
  // each cache cell can store the state for a whole sequence.
130
462
  // A slot should be always be contiguous.
@@ -132,7 +464,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
132
464
  // can only process batches with an equal number of new tokens in each sequence
133
465
  GGML_ASSERT(ubatch.equal_seqs);
134
466
 
135
- int32_t min = cache.size - 1;
467
+ int32_t min = size - 1;
136
468
  int32_t max = 0;
137
469
 
138
470
  // everything should fit if all seq_ids are smaller than the max
@@ -141,16 +473,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
141
473
  for (uint32_t j = 0; j < n_seq_id; ++j) {
142
474
  const llama_seq_id seq_id = ubatch.seq_id[s][j];
143
475
 
144
- if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
476
+ if (seq_id < 0 || (uint32_t) seq_id >= size) {
145
477
  // too big seq_id
146
478
  // TODO: would it be possible to resize the cache instead?
147
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
479
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
148
480
  return llama_kv_cache_slot_info_failed;
149
481
  }
150
482
  if (j > 0) {
151
- llama_kv_cell & seq = cache.cells[seq_id];
483
+ llama_kv_cell & seq = cells[seq_id];
152
484
  if (seq.tail >= 0) {
153
- llama_kv_cell & cell = cache.cells[seq.tail];
485
+ llama_kv_cell & cell = cells[seq.tail];
154
486
  // clear cells from seq_ids that become shared
155
487
  // (should not normally happen, but let's handle it anyway)
156
488
  cell.seq_id.erase(seq_id);
@@ -158,7 +490,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
158
490
  if (cell.seq_id.empty()) {
159
491
  cell.pos = -1;
160
492
  cell.src = -1;
161
- cache.used -= 1;
493
+ used -= 1;
162
494
  }
163
495
  }
164
496
  }
@@ -168,9 +500,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
168
500
  #ifndef NDEBUG
169
501
  {
170
502
  std::vector<int32_t> tails_verif;
171
- tails_verif.assign(cache.size, -1);
172
- for (uint32_t i = 0; i < cache.size; ++i) {
173
- llama_kv_cell & cell = cache.cells[i];
503
+ tails_verif.assign(size, -1);
504
+ for (uint32_t i = 0; i < size; ++i) {
505
+ llama_kv_cell & cell = cells[i];
174
506
  for (llama_seq_id seq_id : cell.seq_id) {
175
507
  if (tails_verif[seq_id] != -1) {
176
508
  LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -178,20 +510,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
178
510
  tails_verif[seq_id] = i;
179
511
  }
180
512
  }
181
- for (uint32_t i = 0; i < cache.size; ++i) {
182
- if (tails_verif[i] != cache.cells[i].tail) {
183
- LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
513
+ for (uint32_t i = 0; i < size; ++i) {
514
+ if (tails_verif[i] != cells[i].tail) {
515
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
184
516
  }
185
517
  }
186
518
  }
187
519
  #endif
188
520
 
189
521
  // find next empty cell
190
- uint32_t next_empty_cell = cache.head;
522
+ uint32_t next_empty_cell = head;
191
523
 
192
- for (uint32_t i = 0; i < cache.size; ++i) {
193
- if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
194
- llama_kv_cell & cell = cache.cells[next_empty_cell];
524
+ for (uint32_t i = 0; i < size; ++i) {
525
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
526
+ llama_kv_cell & cell = cells[next_empty_cell];
195
527
  if (cell.is_empty()) { break; }
196
528
  next_empty_cell += 1;
197
529
  }
@@ -199,20 +531,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
199
531
  // find usable cell range
200
532
  for (uint32_t s = 0; s < n_seqs; ++s) {
201
533
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
202
- llama_kv_cell & seq_meta = cache.cells[seq_id];
534
+ llama_kv_cell & seq_meta = cells[seq_id];
203
535
  bool has_cell = false;
204
536
  if (seq_meta.tail >= 0) {
205
- llama_kv_cell & cell = cache.cells[seq_meta.tail];
537
+ llama_kv_cell & cell = cells[seq_meta.tail];
206
538
  GGML_ASSERT(cell.has_seq_id(seq_id));
207
539
  // does this seq_id "own" the cell?
208
540
  if (cell.seq_id.size() == 1) { has_cell = true; }
209
541
  }
210
542
  if (!has_cell) {
211
- llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
543
+ llama_kv_cell & empty_cell = cells[next_empty_cell];
212
544
  GGML_ASSERT(empty_cell.is_empty());
213
545
  // copy old tail into the empty cell
214
546
  if (seq_meta.tail >= 0) {
215
- llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
547
+ llama_kv_cell & orig_cell = cells[seq_meta.tail];
216
548
  empty_cell.pos = orig_cell.pos;
217
549
  empty_cell.src = orig_cell.src;
218
550
  orig_cell.seq_id.erase(seq_id);
@@ -222,9 +554,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
222
554
  // find next empty cell
223
555
  if (s + 1 < n_seqs) {
224
556
  next_empty_cell += 1;
225
- for (uint32_t i = 0; i < cache.size; ++i) {
226
- if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
227
- llama_kv_cell & cell = cache.cells[next_empty_cell];
557
+ for (uint32_t i = 0; i < size; ++i) {
558
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
559
+ llama_kv_cell & cell = cells[next_empty_cell];
228
560
  if (cell.is_empty()) { break; }
229
561
  next_empty_cell += 1;
230
562
  }
@@ -237,10 +569,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237
569
  // gather and re-order
238
570
  for (uint32_t s = 0; s < n_seqs; ++s) {
239
571
  int32_t dst_id = s + min;
240
- int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
572
+ int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
241
573
  if (dst_id != src_id) {
242
- llama_kv_cell & dst_cell = cache.cells[dst_id];
243
- llama_kv_cell & src_cell = cache.cells[src_id];
574
+ llama_kv_cell & dst_cell = cells[dst_id];
575
+ llama_kv_cell & src_cell = cells[src_id];
244
576
 
245
577
  std::swap(dst_cell.pos, src_cell.pos);
246
578
  std::swap(dst_cell.src, src_cell.src);
@@ -248,10 +580,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
248
580
 
249
581
  // swap tails (assuming they NEVER overlap)
250
582
  for (const llama_seq_id seq_id : src_cell.seq_id) {
251
- cache.cells[seq_id].tail = src_id;
583
+ cells[seq_id].tail = src_id;
252
584
  }
253
585
  for (const llama_seq_id seq_id : dst_cell.seq_id) {
254
- cache.cells[seq_id].tail = dst_id;
586
+ cells[seq_id].tail = dst_id;
255
587
  }
256
588
  }
257
589
  }
@@ -260,7 +592,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
260
592
  for (uint32_t s = 0; s < n_seqs; ++s) {
261
593
  const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
262
594
  int32_t cell_id = s + min;
263
- llama_kv_cell & cell = cache.cells[cell_id];
595
+ llama_kv_cell & cell = cells[cell_id];
264
596
 
265
597
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266
598
  // What should happen when the pos backtracks or skips a value?
@@ -273,41 +605,42 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
273
605
  for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
274
606
  const llama_seq_id seq_id = ubatch.seq_id[s][j];
275
607
  cell.seq_id.insert(seq_id);
276
- cache.cells[seq_id].tail = cell_id;
608
+ cells[seq_id].tail = cell_id;
277
609
  }
278
610
  }
279
611
 
280
612
  // allow getting the range of used cells, from head to head + n
281
- cache.head = min;
282
- cache.n = max - min + 1;
283
- cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
613
+ head = min;
614
+ n = max - min + 1;
615
+ used = std::count_if(cells.begin(), cells.end(),
284
616
  [](const llama_kv_cell& cell){ return !cell.is_empty(); });
285
617
 
286
618
  // sanity check
287
- return llama_kv_cache_slot_info(cache.n >= n_seqs);
619
+ return llama_kv_cache_slot_info(n >= n_seqs);
288
620
  }
621
+
289
622
  // otherwise, one cell per token.
290
623
 
291
- if (n_tokens > cache.size) {
292
- LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
624
+ if (n_tokens > size) {
625
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
293
626
  return llama_kv_cache_slot_info_failed;
294
627
  }
295
628
 
296
629
  uint32_t n_tested = 0;
297
630
 
298
631
  while (true) {
299
- if (cache.head + n_tokens > cache.size) {
300
- n_tested += cache.size - cache.head;
301
- cache.head = 0;
632
+ if (head + n_tokens > size) {
633
+ n_tested += size - head;
634
+ head = 0;
302
635
  continue;
303
636
  }
304
637
 
305
638
  bool found = true;
306
639
  for (uint32_t i = 0; i < n_tokens; i++) {
307
- if (cache.cells[cache.head + i].pos >= 0) {
640
+ if (cells[head + i].pos >= 0) {
308
641
  found = false;
309
- cache.head += i + 1;
310
- n_tested += i + 1;
642
+ head += i + 1;
643
+ n_tested += i + 1;
311
644
  break;
312
645
  }
313
646
  }
@@ -316,7 +649,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
316
649
  break;
317
650
  }
318
651
 
319
- if (n_tested >= cache.size) {
652
+ if (n_tested >= size) {
320
653
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
321
654
  return llama_kv_cache_slot_info_failed;
322
655
  }
@@ -325,22 +658,27 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
325
658
  for (uint32_t s = 0; s < n_seqs; s++) {
326
659
  for (uint32_t i = 0; i < n_seq_tokens; ++i) {
327
660
  uint32_t k = s*n_seq_tokens + i;
328
- cache.cells[cache.head + k].pos = ubatch.pos[k];
661
+ cells[head + k].pos = ubatch.pos[k];
329
662
 
330
663
  for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
331
- cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
664
+ cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
332
665
  }
333
666
  }
334
667
  }
335
668
 
336
- cache.used += n_tokens;
669
+ used += n_tokens;
337
670
 
338
- return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
671
+ return llama_kv_cache_slot_info(head, head + n_tokens);
339
672
  }
340
673
 
341
- uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
342
- for (uint32_t i = cache.size; i > 0; --i) {
343
- const llama_kv_cell & cell = cache.cells[i - 1];
674
+ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
675
+ // the FA kernels require padding to avoid extra runtime boundary checks
676
+ return cparams.flash_attn ? 256u : 32u;
677
+ }
678
+
679
+ uint32_t llama_kv_cache_unified::cell_max() const {
680
+ for (uint32_t i = size; i > 0; --i) {
681
+ const llama_kv_cell & cell = cells[i - 1];
344
682
 
345
683
  if (cell.pos >= 0 && !cell.is_empty()) {
346
684
  return i;
@@ -350,289 +688,659 @@ uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
350
688
  return 0;
351
689
  }
352
690
 
353
- void llama_kv_cache_clear(struct llama_kv_cache & cache) {
354
- for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
355
- cache.cells[i].pos = -1;
356
- cache.cells[i].seq_id.clear();
357
- cache.cells[i].src = -1;
358
- cache.cells[i].tail = -1;
691
+ size_t llama_kv_cache_unified::size_k_bytes() const {
692
+ size_t size_k_bytes = 0;
693
+
694
+ for (const auto & k : k_l) {
695
+ size_k_bytes += ggml_nbytes(k);
359
696
  }
360
- cache.head = 0;
361
- cache.used = 0;
362
697
 
363
- for (auto & buf : cache.bufs) {
364
- ggml_backend_buffer_clear(buf.get(), 0);
698
+ return size_k_bytes;
699
+ }
700
+
701
+ size_t llama_kv_cache_unified::size_v_bytes() const {
702
+ size_t size_v_bytes = 0;
703
+
704
+ for (const auto & v : v_l) {
705
+ size_v_bytes += ggml_nbytes(v);
365
706
  }
707
+
708
+ return size_v_bytes;
366
709
  }
367
710
 
368
- bool llama_kv_cache_seq_rm(
369
- struct llama_kv_cache & cache,
370
- llama_seq_id seq_id,
371
- llama_pos p0,
372
- llama_pos p1) {
373
- uint32_t new_head = cache.size;
711
+ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
712
+ const uint32_t n_layer = hparams.n_layer;
374
713
 
375
- if (p0 < 0) p0 = 0;
376
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
714
+ const uint32_t n_kv = cell_max();
715
+ const uint32_t n_used = used;
377
716
 
378
- // models like Mamba or RWKV can't have a state partially erased
379
- if (cache.recurrent) {
380
- if (seq_id >= (int64_t) cache.size) {
381
- // could be fatal
382
- return false;
717
+ assert(n_used <= n_kv);
718
+
719
+ //const int64_t t_start = ggml_time_us();
720
+
721
+ // number of cells moved
722
+ uint32_t n_moves = 0;
723
+
724
+ // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
725
+ // - source view, destination view, copy operation
726
+ // - x2 for keys and values
727
+ //const uint32_t max_moves = max_nodes()/(6*n_layer);
728
+ // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
729
+ const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
730
+
731
+ // determine which KV cells to move where
732
+ //
733
+ // cell i moves to ids[i]
734
+ //
735
+ // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
736
+ //
737
+ auto & ids = defrag_info.ids;
738
+
739
+ ids.clear();
740
+ ids.resize(n_kv, n_kv);
741
+
742
+ for (uint32_t i0 = 0; i0 < n_used; ++i0) {
743
+ const auto & cell0 = cells[i0];
744
+
745
+ if (!cell0.is_empty()) {
746
+ ids[i0] = i0;
747
+
748
+ continue;
383
749
  }
384
- if (0 <= seq_id) {
385
- int32_t & tail_id = cache.cells[seq_id].tail;
386
- if (tail_id >= 0) {
387
- const llama_kv_cell & cell = cache.cells[tail_id];
388
- // partial intersection is invalid
389
- if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
390
- return false;
391
- }
392
- // invalidate tails which will be cleared
393
- if (p0 <= cell.pos && cell.pos < p1) {
394
- tail_id = -1;
395
- }
750
+
751
+ // found a hole - fill it with data from the end of the cache
752
+
753
+ uint32_t nh = 1;
754
+
755
+ // determine the size of the hole
756
+ while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
757
+ nh++;
758
+ }
759
+
760
+ uint32_t nf = 0;
761
+ uint32_t is = n_kv - 1;
762
+
763
+ // starting from the end, find nh non-empty cells
764
+ for (; is > i0; --is) {
765
+ const auto & cell1 = cells[is];
766
+
767
+ if (cell1.is_empty() || ids[is] != n_kv) {
768
+ continue;
396
769
  }
397
- } else {
398
- // seq_id is negative, then the range should include everything or nothing
399
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
400
- return false;
770
+
771
+ // non-empty cell which is not yet moved
772
+ nf++;
773
+
774
+ if (nf == nh) {
775
+ break;
401
776
  }
402
777
  }
403
- }
404
778
 
405
- for (uint32_t i = 0; i < cache.size; ++i) {
406
- if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
407
- if (seq_id < 0) {
408
- cache.cells[i].seq_id.clear();
409
- } else if (cache.cells[i].has_seq_id(seq_id)) {
410
- cache.cells[i].seq_id.erase(seq_id);
411
- } else {
779
+ // this can only happen if `n_used` is not accurate, which would be a bug
780
+ GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
781
+
782
+ nf = 0;
783
+
784
+ uint32_t i1 = is;
785
+
786
+ // are we moving a continuous block of memory?
787
+ bool cont = false;
788
+
789
+ // should we stop searching for the next move?
790
+ bool stop = false;
791
+
792
+ // go back and move the nf cells to the hole
793
+ for (; i1 < n_kv; ++i1) {
794
+ auto & cell1 = cells[i1];
795
+
796
+ if (cell1.is_empty() || ids[i1] != n_kv) {
797
+ if (n_moves == max_moves) {
798
+ stop = true;
799
+ break;
800
+ }
801
+
802
+ cont = false;
412
803
  continue;
413
804
  }
414
- if (cache.cells[i].is_empty()) {
415
- // keep count of the number of used cells
416
- if (cache.cells[i].pos >= 0) cache.used--;
417
805
 
418
- cache.cells[i].pos = -1;
419
- cache.cells[i].src = -1;
420
- if (new_head == cache.size) new_head = i;
806
+ // this cell goes to (i0 + nf)
807
+ ids[i1] = i0 + nf;
808
+
809
+ // move the cell meta data
810
+ cells[i0 + nf] = cell1;
811
+
812
+ // clear the old cell and move the head there
813
+ cell1 = llama_kv_cell();
814
+ head = n_used;
815
+
816
+ if (!cont) {
817
+ n_moves++;
818
+ cont = true;
819
+ }
820
+
821
+ nf++;
822
+
823
+ if (nf == nh) {
824
+ break;
421
825
  }
422
826
  }
827
+
828
+ if (stop || n_moves == max_moves) {
829
+ break;
830
+ }
831
+
832
+ //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
833
+
834
+ i0 += nh - 1;
423
835
  }
424
836
 
425
- // If we freed up a slot, set head to it so searching can start there.
426
- if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
837
+ if (n_moves == 0) {
838
+ return false;
839
+ }
840
+
841
+ LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
842
+
843
+ LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
427
844
 
428
845
  return true;
429
846
  }
430
847
 
431
- void llama_kv_cache_seq_cp(
432
- struct llama_kv_cache & cache,
433
- llama_seq_id seq_id_src,
434
- llama_seq_id seq_id_dst,
435
- llama_pos p0,
436
- llama_pos p1) {
437
- if (p0 < 0) p0 = 0;
438
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
439
-
440
- if (cache.recurrent) {
441
- if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
442
- llama_kv_cell & tail_src = cache.cells[seq_id_src];
443
- llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
444
- if (tail_dst.tail >= 0) {
445
- // clear destination seq_id if it wasn't empty
446
- llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
447
-
448
- cell_dst.seq_id.erase(seq_id_dst);
449
- tail_dst.tail = -1;
450
- if (cell_dst.seq_id.empty()) {
451
- cell_dst.pos = -1;
452
- cell_dst.delta = -1;
453
- cell_dst.src = -1;
454
- cache.used -= 1;
455
- }
848
+ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
849
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
850
+ uint32_t cell_count = 0;
851
+
852
+ // Count the number of cells with the specified seq_id
853
+ // Find all the ranges of cells with this seq id (or all, when -1)
854
+ uint32_t cell_range_begin = size;
855
+ for (uint32_t i = 0; i < size; ++i) {
856
+ const auto & cell = cells[i];
857
+ if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
858
+ ++cell_count;
859
+ if (cell_range_begin == size) {
860
+ cell_range_begin = i;
456
861
  }
457
- if (tail_src.tail >= 0) {
458
- llama_kv_cell & cell_src = cache.cells[tail_src.tail];
459
-
460
- cell_src.seq_id.insert(seq_id_dst);
461
- tail_dst.tail = tail_src.tail;
862
+ } else {
863
+ if (cell_range_begin != size) {
864
+ cell_ranges.emplace_back(cell_range_begin, i);
865
+ cell_range_begin = size;
462
866
  }
463
867
  }
868
+ }
869
+ if (cell_range_begin != size) {
870
+ cell_ranges.emplace_back(cell_range_begin, size);
871
+ }
464
872
 
465
- return;
873
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
874
+ uint32_t cell_count_check = 0;
875
+ for (const auto & range : cell_ranges) {
876
+ cell_count_check += range.second - range.first;
466
877
  }
467
- // otherwise, this is the KV cache of a Transformer-like model
878
+ GGML_ASSERT(cell_count == cell_count_check);
879
+
880
+ io.write(&cell_count, sizeof(cell_count));
881
+
882
+ state_write_meta(io, cell_ranges, seq_id);
883
+ state_write_data(io, cell_ranges);
884
+ }
885
+
886
+ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
887
+ uint32_t cell_count;
888
+ io.read_to(&cell_count, sizeof(cell_count));
468
889
 
469
- cache.head = 0;
890
+ bool res = true;
891
+ res = res && state_read_meta(io, cell_count, seq_id);
892
+ res = res && state_read_data(io, cell_count);
470
893
 
471
- for (uint32_t i = 0; i < cache.size; ++i) {
472
- if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
473
- cache.cells[i].seq_id.insert(seq_id_dst);
894
+ if (!res) {
895
+ if (seq_id == -1) {
896
+ clear();
897
+ } else {
898
+ seq_rm(seq_id, -1, -1);
474
899
  }
900
+ throw std::runtime_error("failed to restore kv cache");
475
901
  }
476
902
  }
477
903
 
478
- void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
479
- uint32_t new_head = cache.size;
904
+ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
905
+ for (const auto & range : cell_ranges) {
906
+ for (uint32_t i = range.first; i < range.second; ++i) {
907
+ const auto & cell = cells[i];
908
+ const llama_pos pos = cell.pos;
909
+ const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
910
+
911
+ io.write(&pos, sizeof(pos));
912
+ io.write(&n_seq_id, sizeof(n_seq_id));
480
913
 
481
- for (uint32_t i = 0; i < cache.size; ++i) {
482
- if (cache.recurrent && (llama_seq_id) i != seq_id) {
483
- cache.cells[i].tail = -1;
914
+ if (n_seq_id) {
915
+ for (auto seq_id : cell.seq_id) {
916
+ io.write(&seq_id, sizeof(seq_id));
917
+ }
918
+ }
484
919
  }
485
- if (!cache.cells[i].has_seq_id(seq_id)) {
486
- if (cache.cells[i].pos >= 0) cache.used--;
487
- cache.cells[i].pos = -1;
488
- cache.cells[i].src = -1;
489
- cache.cells[i].seq_id.clear();
490
- if (new_head == cache.size) new_head = i;
491
- } else {
492
- cache.cells[i].seq_id.clear();
493
- cache.cells[i].seq_id.insert(seq_id);
920
+ }
921
+ }
922
+
923
+ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
924
+ const uint32_t v_trans = this->v_trans ? 1 : 0;
925
+ const uint32_t n_layer = hparams.n_layer;
926
+
927
+ io.write(&v_trans, sizeof(v_trans));
928
+ io.write(&n_layer, sizeof(n_layer));
929
+
930
+ std::vector<uint8_t> tmp_buf;
931
+
932
+ // Iterate and write all the keys first, each row is a cell
933
+ // Get whole range at a time
934
+ for (uint32_t il = 0; il < n_layer; ++il) {
935
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
936
+
937
+ // Write key type
938
+ const int32_t k_type_i = (int32_t)k_l[il]->type;
939
+ io.write(&k_type_i, sizeof(k_type_i));
940
+
941
+ // Write row size of key
942
+ const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
943
+ io.write(&k_size_row, sizeof(k_size_row));
944
+
945
+ // Read each range of cells of k_size length each into tmp_buf and write out
946
+ for (const auto & range : cell_ranges) {
947
+ const size_t range_size = range.second - range.first;
948
+ const size_t buf_size = range_size * k_size_row;
949
+ io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
494
950
  }
495
951
  }
496
952
 
497
- // If we freed up a slot, set head to it so searching can start there.
498
- if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
499
- }
953
+ if (!v_trans) {
954
+ for (uint32_t il = 0; il < n_layer; ++il) {
955
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
500
956
 
501
- void llama_kv_cache_seq_add(
502
- struct llama_kv_cache & cache,
503
- llama_seq_id seq_id,
504
- llama_pos p0,
505
- llama_pos p1,
506
- llama_pos delta) {
507
- uint32_t new_head = cache.size;
508
-
509
- if (p0 < 0) p0 = 0;
510
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
511
- // If there is no range then return early to avoid looping over the cache.
512
- if (p0 == p1) return;
957
+ // Write value type
958
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
959
+ io.write(&v_type_i, sizeof(v_type_i));
513
960
 
514
- if (cache.recurrent) {
515
- // for Mamba-like or RWKV models, only the pos needs to be shifted
516
- if (0 <= seq_id && seq_id < (int64_t) cache.size) {
517
- const int32_t tail_id = cache.cells[seq_id].tail;
518
- if (tail_id >= 0) {
519
- llama_kv_cell & cell = cache.cells[tail_id];
520
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
521
- cell.pos += delta;
961
+ // Write row size of value
962
+ const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
963
+ io.write(&v_size_row, sizeof(v_size_row));
964
+
965
+ // Read each range of cells of v_size length each into tmp_buf and write out
966
+ for (const auto & range : cell_ranges) {
967
+ const size_t range_size = range.second - range.first;
968
+ const size_t buf_size = range_size * v_size_row;
969
+ io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
970
+ }
971
+ }
972
+ } else {
973
+ // When v is transposed, we also need the element size and get the element ranges from each row
974
+ const uint32_t kv_size = size;
975
+ for (uint32_t il = 0; il < n_layer; ++il) {
976
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
977
+
978
+ // Write value type
979
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
980
+ io.write(&v_type_i, sizeof(v_type_i));
981
+
982
+ // Write element size
983
+ const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
984
+ io.write(&v_size_el, sizeof(v_size_el));
985
+
986
+ // Write GQA embedding size
987
+ io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
988
+
989
+ // For each row, we get the element values of each cell
990
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
991
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
992
+ for (const auto & range : cell_ranges) {
993
+ const size_t range_size = range.second - range.first;
994
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
995
+ const size_t buf_size = range_size * v_size_el;
996
+ io.write_tensor(v_l[il], src_offset, buf_size);
522
997
  }
523
998
  }
524
999
  }
525
- return;
526
1000
  }
1001
+ }
1002
+
1003
+ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1004
+ if (dest_seq_id != -1) {
1005
+ // single sequence
527
1006
 
528
- for (uint32_t i = 0; i < cache.size; ++i) {
529
- if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
530
- cache.has_shift = true;
531
- cache.cells[i].pos += delta;
532
- cache.cells[i].delta += delta;
1007
+ seq_rm(dest_seq_id, -1, -1);
1008
+
1009
+ llama_sbatch sbatch;
1010
+ llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1011
+
1012
+ batch.n_tokens = cell_count;
1013
+ batch.n_seq_tokens = cell_count;
1014
+ batch.n_seqs = 1;
1015
+
1016
+ for (uint32_t i = 0; i < cell_count; ++i) {
1017
+ llama_pos pos;
1018
+ uint32_t n_seq_id;
1019
+
1020
+ io.read_to(&pos, sizeof(pos));
1021
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1022
+
1023
+ if (n_seq_id != 0) {
1024
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1025
+ return false;
1026
+ }
533
1027
 
534
- if (cache.cells[i].pos < 0) {
535
- if (!cache.cells[i].is_empty()) {
536
- cache.used--;
1028
+ batch.pos[i] = pos;
1029
+ }
1030
+ batch.n_seq_id[0] = 1;
1031
+ batch.seq_id[0] = &dest_seq_id;
1032
+ if (!find_slot(batch)) {
1033
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1034
+ return false;
1035
+ }
1036
+
1037
+ // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1038
+ // Assume that this is one contiguous block of cells
1039
+ GGML_ASSERT(head + cell_count <= size);
1040
+ GGML_ASSERT(cells[head].pos == batch.pos[0]);
1041
+ GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
1042
+ GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
1043
+ GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
1044
+ } else {
1045
+ // whole KV cache restore
1046
+
1047
+ if (cell_count > size) {
1048
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1049
+ return false;
1050
+ }
1051
+
1052
+ clear();
1053
+
1054
+ for (uint32_t i = 0; i < cell_count; ++i) {
1055
+ llama_kv_cell & cell = cells[i];
1056
+
1057
+ llama_pos pos;
1058
+ uint32_t n_seq_id;
1059
+
1060
+ io.read_to(&pos, sizeof(pos));
1061
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1062
+
1063
+ cell.pos = pos;
1064
+
1065
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
1066
+ llama_seq_id seq_id;
1067
+ io.read_to(&seq_id, sizeof(seq_id));
1068
+
1069
+ // TODO: llama_kv_cache_unified should have a notion of max sequences
1070
+ //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1071
+ if (seq_id < 0) {
1072
+ //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1073
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
1074
+ return false;
537
1075
  }
538
- cache.cells[i].pos = -1;
539
- cache.cells[i].seq_id.clear();
540
- if (new_head == cache.size) {
541
- new_head = i;
1076
+
1077
+ cell.seq_id.insert(seq_id);
1078
+
1079
+ if (recurrent) {
1080
+ int32_t & tail = cells[seq_id].tail;
1081
+ if (tail != -1) {
1082
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
1083
+ return false;
1084
+ }
1085
+ tail = i;
542
1086
  }
543
1087
  }
544
1088
  }
1089
+
1090
+ head = 0;
1091
+ used = cell_count;
545
1092
  }
546
1093
 
547
- // If we freed up a slot, set head to it so searching can start there.
548
- // Otherwise we just start the next search from the beginning.
549
- cache.head = new_head != cache.size ? new_head : 0;
1094
+ if (recurrent) {
1095
+ for (uint32_t i = 0; i < cell_count; ++i) {
1096
+ uint32_t cell_id = head + i;
1097
+ // make sure the recurrent states will keep their restored state
1098
+ cells[cell_id].src = cell_id;
1099
+ }
1100
+ }
1101
+
1102
+ return true;
550
1103
  }
551
1104
 
552
- void llama_kv_cache_seq_div(
553
- struct llama_kv_cache & cache,
554
- llama_seq_id seq_id,
555
- llama_pos p0,
556
- llama_pos p1,
557
- int d) {
558
- if (p0 < 0) p0 = 0;
559
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
560
- // If there is no range then return early to avoid looping over the cache.
561
- if (p0 == p1) return;
1105
+ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1106
+ uint32_t v_trans;
1107
+ uint32_t n_layer;
1108
+ io.read_to(&v_trans, sizeof(v_trans));
1109
+ io.read_to(&n_layer, sizeof(n_layer));
562
1110
 
563
- if (cache.recurrent) {
564
- // for Mamba-like or RWKV models, only the pos needs to be changed
565
- if (0 <= seq_id && seq_id < (int64_t) cache.size) {
566
- const int32_t tail_id = cache.cells[seq_id].tail;
567
- if (tail_id >= 0) {
568
- llama_kv_cell & cell = cache.cells[tail_id];
569
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
570
- cell.pos /= d;
1111
+ if (n_layer != hparams.n_layer) {
1112
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1113
+ return false;
1114
+ }
1115
+ if (cell_count > size) {
1116
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1117
+ return false;
1118
+ }
1119
+ if (v_trans != (bool) v_trans) {
1120
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1121
+ return false;
1122
+ }
1123
+
1124
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1125
+ for (uint32_t il = 0; il < n_layer; ++il) {
1126
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1127
+
1128
+ // Read type of key
1129
+ int32_t k_type_i_ref;
1130
+ io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1131
+ const int32_t k_type_i = (int32_t) k_l[il]->type;
1132
+ if (k_type_i != k_type_i_ref) {
1133
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1134
+ return false;
1135
+ }
1136
+
1137
+ // Read row size of key
1138
+ uint64_t k_size_row_ref;
1139
+ io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1140
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1141
+ if (k_size_row != k_size_row_ref) {
1142
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1143
+ return false;
1144
+ }
1145
+
1146
+ if (cell_count) {
1147
+ // Read and set the keys for the whole cell range
1148
+ ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1149
+ }
1150
+ }
1151
+
1152
+ if (!v_trans) {
1153
+ for (uint32_t il = 0; il < n_layer; ++il) {
1154
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1155
+
1156
+ // Read type of value
1157
+ int32_t v_type_i_ref;
1158
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1159
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1160
+ if (v_type_i != v_type_i_ref) {
1161
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1162
+ return false;
1163
+ }
1164
+
1165
+ // Read row size of value
1166
+ uint64_t v_size_row_ref;
1167
+ io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1168
+ const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1169
+ if (v_size_row != v_size_row_ref) {
1170
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1171
+ return false;
1172
+ }
1173
+
1174
+ if (cell_count) {
1175
+ // Read and set the values for the whole cell range
1176
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1177
+ }
1178
+ }
1179
+ } else {
1180
+ // For each layer, read the values for each cell (transposed)
1181
+ for (uint32_t il = 0; il < n_layer; ++il) {
1182
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1183
+
1184
+ // Read type of value
1185
+ int32_t v_type_i_ref;
1186
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1187
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1188
+ if (v_type_i != v_type_i_ref) {
1189
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1190
+ return false;
1191
+ }
1192
+
1193
+ // Read element size of value
1194
+ uint32_t v_size_el_ref;
1195
+ io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1196
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
1197
+ if (v_size_el != v_size_el_ref) {
1198
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1199
+ return false;
1200
+ }
1201
+
1202
+ // Read GQA embedding size
1203
+ uint32_t n_embd_v_gqa_ref;
1204
+ io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1205
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1206
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1207
+ return false;
1208
+ }
1209
+
1210
+ if (cell_count) {
1211
+ // For each row in the transposed matrix, read the values for the whole cell range
1212
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1213
+ const size_t dst_offset = (head + j * size) * v_size_el;
1214
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
571
1215
  }
572
1216
  }
573
1217
  }
1218
+ }
1219
+
1220
+ return true;
1221
+ }
1222
+
1223
+ //
1224
+ // interface implementation
1225
+ //
1226
+
1227
+ int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
1228
+ if (!kv) {
1229
+ return 0;
1230
+ }
1231
+
1232
+ return kv->get_n_tokens();
1233
+ }
1234
+
1235
+ int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
1236
+ if (!kv) {
1237
+ return 0;
1238
+ }
1239
+
1240
+ return kv->get_used_cells();
1241
+ }
1242
+
1243
+ void llama_kv_cache_clear(llama_kv_cache * kv) {
1244
+ if (!kv) {
574
1245
  return;
575
1246
  }
576
1247
 
577
- for (uint32_t i = 0; i < cache.size; ++i) {
578
- if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
579
- cache.has_shift = true;
1248
+ kv->clear();
1249
+ }
580
1250
 
581
- {
582
- llama_pos p_old = cache.cells[i].pos;
583
- cache.cells[i].pos /= d;
584
- cache.cells[i].delta += cache.cells[i].pos - p_old;
585
- }
586
- }
1251
+ bool llama_kv_cache_seq_rm(
1252
+ llama_kv_cache * kv,
1253
+ llama_seq_id seq_id,
1254
+ llama_pos p0,
1255
+ llama_pos p1) {
1256
+ if (!kv) {
1257
+ return true;
587
1258
  }
1259
+
1260
+ return kv->seq_rm(seq_id, p0, p1);
588
1261
  }
589
1262
 
590
- llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
591
- llama_pos result = 0;
1263
+ void llama_kv_cache_seq_cp(
1264
+ llama_kv_cache * kv,
1265
+ llama_seq_id seq_id_src,
1266
+ llama_seq_id seq_id_dst,
1267
+ llama_pos p0,
1268
+ llama_pos p1) {
1269
+ if (!kv) {
1270
+ return;
1271
+ }
592
1272
 
593
- for (uint32_t i = 0; i < cache.size; ++i) {
594
- if (cache.cells[i].has_seq_id(seq_id)) {
595
- result = std::max(result, cache.cells[i].pos);
596
- }
1273
+ kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1274
+ }
1275
+
1276
+ void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
1277
+ if (!kv) {
1278
+ return;
597
1279
  }
598
1280
 
599
- return result;
1281
+ kv->seq_keep(seq_id);
600
1282
  }
601
1283
 
602
- void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
603
- if (!cache.recurrent) {
604
- cache.do_defrag = true;
1284
+ void llama_kv_cache_seq_add(
1285
+ llama_kv_cache * kv,
1286
+ llama_seq_id seq_id,
1287
+ llama_pos p0,
1288
+ llama_pos p1,
1289
+ llama_pos delta) {
1290
+ if (!kv) {
1291
+ return;
605
1292
  }
1293
+
1294
+ kv->seq_add(seq_id, p0, p1, delta);
606
1295
  }
607
1296
 
608
- int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) {
609
- int result = 0;
1297
+ void llama_kv_cache_seq_div(
1298
+ llama_kv_cache * kv,
1299
+ llama_seq_id seq_id,
1300
+ llama_pos p0,
1301
+ llama_pos p1,
1302
+ int d) {
1303
+ if (!kv) {
1304
+ return;
1305
+ }
610
1306
 
611
- for (uint32_t i = 0; i < kv.size; i++) {
612
- result += kv.cells[i].seq_id.size();
1307
+ kv->seq_div(seq_id, p0, p1, d);
1308
+ }
1309
+
1310
+ llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
1311
+ if (!kv) {
1312
+ return 0;
613
1313
  }
614
1314
 
615
- return result;
1315
+ return kv->seq_pos_max(seq_id);
616
1316
  }
617
1317
 
618
- int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) {
619
- return kv.used;
1318
+ void llama_kv_cache_defrag(llama_kv_cache * kv) {
1319
+ if (!kv) {
1320
+ return;
1321
+ }
1322
+
1323
+ kv->defrag();
620
1324
  }
621
1325
 
622
- bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) {
623
- return kv.can_shift;
1326
+ bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
1327
+ if (!kv) {
1328
+ return false;
1329
+ }
1330
+
1331
+ return kv->get_can_shift();
624
1332
  }
625
1333
 
626
1334
  //
627
1335
  // kv cache view
628
1336
  //
629
1337
 
630
- struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) {
631
- struct llama_kv_cache_view result = {
1338
+ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
1339
+ llama_kv_cache_view result = {
632
1340
  /*.n_cells = */ 0,
633
1341
  /*.n_seq_max = */ n_seq_max,
634
1342
  /*.token_count = */ 0,
635
- /*.used_cells = */ llama_get_kv_cache_used_cells(kv),
1343
+ /*.used_cells = */ llama_kv_cache_used_cells(&kv),
636
1344
  /*.max_contiguous = */ 0,
637
1345
  /*.max_contiguous_idx = */ -1,
638
1346
  /*.cells = */ nullptr,
@@ -642,7 +1350,7 @@ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache
642
1350
  return result;
643
1351
  }
644
1352
 
645
- void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
1353
+ void llama_kv_cache_view_free(llama_kv_cache_view * view) {
646
1354
  if (view->cells != nullptr) {
647
1355
  free(view->cells);
648
1356
  view->cells = nullptr;
@@ -653,18 +1361,25 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
653
1361
  }
654
1362
  }
655
1363
 
656
- void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
657
- if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
658
- view->n_cells = int32_t(kv.size);
659
- void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
1364
+ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
1365
+ // TODO: rework this in the future, for now quick hack
1366
+ const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
1367
+ if (kvu == nullptr) {
1368
+ LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
1369
+ return;
1370
+ }
1371
+
1372
+ if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
1373
+ view->n_cells = int32_t(kvu->size);
1374
+ void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
660
1375
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
661
- view->cells = (struct llama_kv_cache_view_cell *)p;
1376
+ view->cells = (llama_kv_cache_view_cell *)p;
662
1377
  p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
663
1378
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
664
1379
  view->cells_sequences = (llama_seq_id *)p;
665
1380
  }
666
1381
 
667
- const std::vector<llama_kv_cell> & kv_cells = kv.cells;
1382
+ const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
668
1383
  llama_kv_cache_view_cell * c_curr = view->cells;
669
1384
  llama_seq_id * cs_curr = view->cells_sequences;
670
1385
  int32_t used_cells = 0;
@@ -673,7 +1388,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
673
1388
  uint32_t max_contig = 0;
674
1389
  int32_t max_contig_idx = -1;
675
1390
 
676
- for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
1391
+ for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
677
1392
  const size_t curr_size = kv_cells[i].seq_id.size();
678
1393
  token_count += curr_size;
679
1394
  c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@@ -711,8 +1426,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
711
1426
  view->max_contiguous_idx = max_contig_idx;
712
1427
  view->token_count = token_count;
713
1428
  view->used_cells = used_cells;
714
- if (uint32_t(used_cells) != kv.used) {
1429
+ if (uint32_t(used_cells) != kvu->used) {
715
1430
  LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
716
- __func__, kv.used, used_cells);
1431
+ __func__, kvu->used, used_cells);
717
1432
  }
718
1433
  }