@fugood/llama.node 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  4. package/src/llama.cpp/common/arg.cpp +44 -0
  5. package/src/llama.cpp/common/common.cpp +22 -6
  6. package/src/llama.cpp/common/common.h +15 -1
  7. package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
  8. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  9. package/src/llama.cpp/ggml/include/ggml.h +104 -10
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  12. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
  19. package/src/llama.cpp/include/llama.h +13 -47
  20. package/src/llama.cpp/src/llama-arch.cpp +298 -3
  21. package/src/llama.cpp/src/llama-arch.h +22 -1
  22. package/src/llama.cpp/src/llama-batch.cpp +103 -71
  23. package/src/llama.cpp/src/llama-batch.h +31 -18
  24. package/src/llama.cpp/src/llama-chat.cpp +59 -1
  25. package/src/llama.cpp/src/llama-chat.h +3 -0
  26. package/src/llama.cpp/src/llama-context.cpp +134 -95
  27. package/src/llama.cpp/src/llama-context.h +13 -16
  28. package/src/llama.cpp/src/llama-cparams.h +3 -2
  29. package/src/llama.cpp/src/llama-graph.cpp +279 -180
  30. package/src/llama.cpp/src/llama-graph.h +183 -122
  31. package/src/llama.cpp/src/llama-hparams.cpp +47 -1
  32. package/src/llama.cpp/src/llama-hparams.h +12 -1
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  34. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  35. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  36. package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  37. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  40. package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
  41. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  42. package/src/llama.cpp/src/llama-memory.h +3 -0
  43. package/src/llama.cpp/src/llama-model.cpp +3373 -743
  44. package/src/llama.cpp/src/llama-model.h +20 -4
  45. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  46. package/src/llama.cpp/src/llama-vocab.cpp +376 -10
  47. package/src/llama.cpp/src/llama-vocab.h +43 -0
  48. package/src/llama.cpp/src/unicode.cpp +207 -0
  49. package/src/llama.cpp/src/unicode.h +2 -0
  50. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
23
23
  ggml_type type_v,
24
24
  bool v_trans,
25
25
  bool offload,
26
+ bool unified,
26
27
  uint32_t kv_size,
27
28
  uint32_t n_seq_max,
28
29
  uint32_t n_pad,
29
30
  uint32_t n_swa,
30
31
  llama_swa_type swa_type) :
31
32
  model(model), hparams(model.hparams), v_trans(v_trans),
32
- n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
33
+ n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
33
34
 
34
35
  GGML_ASSERT(kv_size % n_pad == 0);
35
36
 
@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
45
46
  auto it = ctx_map.find(buft);
46
47
  if (it == ctx_map.end()) {
47
48
  ggml_init_params params = {
48
- /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
49
+ /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
49
50
  /*.mem_buffer =*/ NULL,
50
51
  /*.no_alloc =*/ true,
51
52
  };
@@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified(
64
65
  return it->second;
65
66
  };
66
67
 
67
- head = 0;
68
+ GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
68
69
 
69
- cells.resize(kv_size);
70
+ v_heads.resize(n_stream);
71
+ for (uint32_t s = 0; s < n_stream; ++s) {
72
+ v_heads[s] = 0;
73
+ }
74
+
75
+ v_cells.resize(n_stream);
76
+ for (uint32_t s = 0; s < n_stream; ++s) {
77
+ v_cells[s].resize(kv_size);
78
+ }
79
+
80
+ // by default, all sequence ids are mapped to the 0th stream
81
+ seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
82
+
83
+ if (n_stream > 1) {
84
+ seq_to_stream.resize(n_stream, 0);
85
+ for (uint32_t s = 0; s < n_stream; ++s) {
86
+ seq_to_stream[s] = s;
87
+ }
88
+ }
89
+
90
+ // [TAG_V_CACHE_VARIABLE]
91
+ if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
92
+ LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
93
+ __func__, hparams.n_embd_v_gqa_max());
94
+ }
70
95
 
71
96
  for (uint32_t il = 0; il < n_layer_cache; il++) {
72
97
  if (filter && !filter(il)) {
@@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
74
99
  continue;
75
100
  }
76
101
 
77
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
78
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
102
+ // [TAG_V_CACHE_VARIABLE]
103
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
104
+ const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
79
105
 
80
106
  const char * dev_name = "CPU";
81
107
 
@@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98
124
  ggml_tensor * k;
99
125
  ggml_tensor * v;
100
126
 
101
- k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
127
+ k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
128
+ v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
103
129
 
104
130
  ggml_format_name(k, "cache_k_l%d", il);
105
131
  ggml_format_name(v, "cache_v_l%d", il);
106
132
 
133
+ std::vector<ggml_tensor *> k_stream;
134
+ std::vector<ggml_tensor *> v_stream;
135
+
136
+ for (uint32_t s = 0; s < n_stream; ++s) {
137
+ k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
138
+ v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
139
+ }
140
+
107
141
  map_layer_ids[il] = layers.size();
108
- layers.push_back({ il, k, v });
142
+
143
+ layers.push_back({ il, k, v, k_stream, v_stream, });
109
144
  }
110
145
 
111
146
  // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
@@ -148,20 +183,33 @@ llama_kv_cache_unified::llama_kv_cache_unified(
148
183
  const size_t memory_size_k = size_k_bytes();
149
184
  const size_t memory_size_v = size_v_bytes();
150
185
 
151
- LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
152
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
186
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
187
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
153
188
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
154
189
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
155
190
  }
156
191
 
157
192
  const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
158
193
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
194
+
195
+ const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
196
+ supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
197
+
198
+ if (!supports_set_rows) {
199
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
200
+ GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
201
+ }
202
+
203
+ if (!supports_set_rows) {
204
+ LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
205
+ }
159
206
  }
160
207
 
161
208
  void llama_kv_cache_unified::clear(bool data) {
162
- cells.reset();
163
-
164
- head = 0;
209
+ for (uint32_t s = 0; s < n_stream; ++s) {
210
+ v_cells[s].reset();
211
+ v_heads[s] = 0;
212
+ }
165
213
 
166
214
  if (data) {
167
215
  for (auto & buf : bufs) {
@@ -171,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) {
171
219
  }
172
220
 
173
221
  bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
222
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
223
+
224
+ auto & cells = v_cells[seq_to_stream[seq_id]];
225
+ auto & head = v_heads[seq_to_stream[seq_id]];
226
+
174
227
  uint32_t new_head = cells.size();
175
228
 
176
229
  if (p0 < 0) {
@@ -217,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
217
270
  }
218
271
 
219
272
  void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
220
- if (seq_id_src == seq_id_dst) {
273
+ GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
274
+ GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
275
+
276
+ const auto s0 = seq_to_stream[seq_id_src];
277
+ const auto s1 = seq_to_stream[seq_id_dst];
278
+
279
+ if (s0 == s1) {
280
+ // since both sequences are in the same stream, no data copy is necessary
281
+ // we just have to update the cells meta data
282
+
283
+ auto & cells = v_cells[s0];
284
+
285
+ if (seq_id_src == seq_id_dst) {
286
+ return;
287
+ }
288
+
289
+ if (p0 < 0) {
290
+ p0 = 0;
291
+ }
292
+
293
+ if (p1 < 0) {
294
+ p1 = std::numeric_limits<llama_pos>::max();
295
+ }
296
+
297
+ for (uint32_t i = 0; i < cells.size(); ++i) {
298
+ if (!cells.pos_in(i, p0, p1)) {
299
+ continue;
300
+ }
301
+
302
+ if (cells.seq_has(i, seq_id_src)) {
303
+ cells.seq_add(i, seq_id_dst);
304
+ }
305
+ }
306
+
221
307
  return;
222
308
  }
223
309
 
224
- if (p0 < 0) {
225
- p0 = 0;
310
+ // cross-stream sequence copies require to copy the actual buffer data
311
+
312
+ bool is_full = true;
313
+
314
+ if (p0 > 0 && p0 + 1 < (int) get_size()) {
315
+ is_full = false;
226
316
  }
227
317
 
228
- if (p1 < 0) {
229
- p1 = std::numeric_limits<llama_pos>::max();
318
+ if (p1 > 0 && p1 + 1 < (int) get_size()) {
319
+ is_full = false;
230
320
  }
231
321
 
232
- for (uint32_t i = 0; i < cells.size(); ++i) {
233
- if (!cells.pos_in(i, p0, p1)) {
234
- continue;
235
- }
322
+ GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
236
323
 
237
- if (cells.seq_has(i, seq_id_src)) {
238
- cells.seq_add(i, seq_id_dst);
324
+ // enqueue the copy operation - the buffer copy will be performed during the next update
325
+ sc_info.ssrc.push_back(s0);
326
+ sc_info.sdst.push_back(s1);
327
+
328
+ v_cells[s1].reset();
329
+ for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
330
+ if (v_cells[s0].seq_has(i, seq_id_src)) {
331
+ llama_pos pos = v_cells[s0].pos_get(i);
332
+ llama_pos shift = v_cells[s0].get_shift(i);
333
+
334
+ if (shift != 0) {
335
+ pos -= shift;
336
+ assert(pos >= 0);
337
+ }
338
+
339
+ v_cells[s1].pos_set(i, pos);
340
+ v_cells[s1].seq_add(i, seq_id_dst);
341
+
342
+ if (shift != 0) {
343
+ v_cells[s1].pos_add(i, shift);
344
+ }
239
345
  }
240
346
  }
347
+
348
+ v_heads[s1] = v_heads[s0];
349
+
350
+ //for (uint32_t s = 0; s < n_stream; ++s) {
351
+ // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
352
+ //}
241
353
  }
242
354
 
243
355
  void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
356
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
357
+
358
+ auto & cells = v_cells[seq_to_stream[seq_id]];
359
+ auto & head = v_heads[seq_to_stream[seq_id]];
360
+
244
361
  uint32_t new_head = cells.size();
245
362
 
246
363
  for (uint32_t i = 0; i < cells.size(); ++i) {
@@ -258,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
258
375
  }
259
376
 
260
377
  void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
378
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
379
+
380
+ auto & cells = v_cells[seq_to_stream[seq_id]];
381
+ auto & head = v_heads[seq_to_stream[seq_id]];
382
+
261
383
  if (shift == 0) {
262
384
  return;
263
385
  }
@@ -297,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
297
419
  }
298
420
 
299
421
  void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
422
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
423
+
424
+ auto & cells = v_cells[seq_to_stream[seq_id]];
425
+
300
426
  if (d == 1) {
301
427
  return;
302
428
  }
@@ -326,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
326
452
  }
327
453
 
328
454
  llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
455
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
456
+
457
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
458
+
329
459
  return cells.seq_pos_min(seq_id);
330
460
  }
331
461
 
332
462
  llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
463
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
464
+
465
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
466
+
333
467
  return cells.seq_pos_max(seq_id);
334
468
  }
335
469
 
@@ -344,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
344
478
 
345
479
  std::vector<llama_ubatch> ubatches;
346
480
  while (true) {
347
- auto ubatch = balloc.split_simple(n_ubatch);
481
+ auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
348
482
 
349
483
  if (ubatch.n_tokens == 0) {
350
484
  break;
@@ -353,13 +487,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
353
487
  ubatches.push_back(std::move(ubatch)); // NOLINT
354
488
  }
355
489
 
356
- auto heads = prepare(ubatches);
357
- if (heads.empty()) {
490
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
491
+ // failed to find a suitable split
492
+ break;
493
+ }
494
+
495
+ auto sinfos = prepare(ubatches);
496
+ if (sinfos.empty()) {
358
497
  break;
359
498
  }
360
499
 
361
500
  return std::make_unique<llama_kv_cache_unified_context>(
362
- this, std::move(heads), std::move(ubatches));
501
+ this, std::move(sinfos), std::move(ubatches));
363
502
  } while (false);
364
503
 
365
504
  return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -375,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
375
514
  defrag_info dinfo;
376
515
 
377
516
  // see if we need to defrag
378
- {
517
+ if (n_stream == 1) {
518
+ // note : for now do not consider defrag for n_stream > 1
519
+ const auto & cells = v_cells[seq_to_stream[0]];
520
+
379
521
  bool do_defrag = optimize;
380
522
 
381
523
  const auto thold = lctx->get_cparams().defrag_thold;
@@ -399,46 +541,69 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
399
541
  }
400
542
  }
401
543
 
402
- return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
544
+ return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
403
545
  }
404
546
 
405
- llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
406
- llama_kv_cache_unified::ubatch_heads res;
547
+ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
548
+ llama_kv_cache_unified::slot_info_vec_t res;
407
549
 
408
- struct state {
409
- uint32_t head_old; // old position of the head, before placing the ubatch
410
- uint32_t head_new; // new position of the head, after placing the ubatch
550
+ struct state_t {
551
+ slot_info sinfo; // slot info for the ubatch
411
552
 
412
- llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
553
+ std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
554
+
555
+ std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
413
556
  };
414
557
 
415
558
  // remember the old state of the cells so we can restore it in the end
416
- std::vector<state> states;
559
+ std::vector<state_t> states;
417
560
 
418
561
  bool success = true;
419
562
 
420
563
  for (const auto & ubatch : ubatches) {
564
+ // non-continuous slots require support for ggml_set_rows()
565
+ const bool cont = supports_set_rows ? false : true;
566
+
421
567
  // only find a suitable slot for the ubatch. don't modify the cells yet
422
- const int32_t head_new = find_slot(ubatch);
423
- if (head_new < 0) {
568
+ const auto sinfo_new = find_slot(ubatch, cont);
569
+ if (sinfo_new.empty()) {
424
570
  success = false;
425
571
  break;
426
572
  }
427
573
 
428
574
  // remeber the position that we found
429
- res.push_back(head_new);
575
+ res.push_back(sinfo_new);
430
576
 
431
577
  // store the old state of the cells in the recovery stack
432
- states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
578
+ {
579
+ state_t state = { sinfo_new, v_heads, {} };
580
+
581
+ for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
582
+ auto & cells = v_cells[sinfo_new.strm[s]];
583
+
584
+ state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
585
+ }
586
+
587
+ states.push_back(std::move(state));
588
+ }
433
589
 
434
590
  // now emplace the ubatch
435
- apply_ubatch(head_new, ubatch);
591
+ apply_ubatch(sinfo_new, ubatch);
436
592
  }
437
593
 
594
+ GGML_ASSERT(!states.empty() || !success);
595
+
438
596
  // iterate backwards and restore the cells to their original state
439
597
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
440
- cells.set(it->head_new, it->cells);
441
- head = it->head_old;
598
+ const auto & sinfo = it->sinfo;
599
+
600
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
601
+ auto & cells = v_cells[sinfo.strm[s]];
602
+ auto & head = v_heads[sinfo.strm[s]];
603
+
604
+ cells.set(sinfo.idxs[s], it->v_cells[s]);
605
+ head = it->v_heads_old[s];
606
+ }
442
607
  }
443
608
 
444
609
  if (!success) {
@@ -448,11 +613,38 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
448
613
  return res;
449
614
  }
450
615
 
451
- bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
616
+ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
452
617
  bool updated = false;
453
618
 
454
619
  auto * sched = lctx->get_sched();
455
620
 
621
+ if (!sc_info.empty()) {
622
+ assert(n_stream > 1 && "stream copy should never happen with a single stream");
623
+
624
+ llama_synchronize(lctx);
625
+
626
+ const size_t n_copy = sc_info.ssrc.size();
627
+
628
+ for (size_t i = 0; i < n_copy; ++i) {
629
+ const auto ssrc = sc_info.ssrc[i];
630
+ const auto sdst = sc_info.sdst[i];
631
+
632
+ assert(ssrc < n_stream);
633
+ assert(sdst < n_stream);
634
+
635
+ LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
636
+
637
+ assert(ssrc != sdst);
638
+
639
+ for (uint32_t il = 0; il < layers.size(); ++il) {
640
+ const auto & layer = layers[il];
641
+
642
+ ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
643
+ ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
644
+ }
645
+ }
646
+ }
647
+
456
648
  if (do_shift) {
457
649
  if (!get_can_shift()) {
458
650
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -464,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
464
656
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
465
657
  ggml_backend_sched_reset(sched);
466
658
 
467
- auto * gf = lctx->graph_init();
659
+ auto * res = lctx->get_gf_res_reserve();
468
660
 
469
- auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
470
- if (!res) {
471
- LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
472
- return updated;
473
- }
661
+ res->reset();
474
662
 
663
+ auto * gf = build_graph_shift(res, lctx);
475
664
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
476
665
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
477
666
  return updated;
@@ -487,12 +676,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
487
676
  updated = true;
488
677
  }
489
678
 
490
- cells.reset_shift();
679
+ for (uint32_t s = 0; s < n_stream; ++s) {
680
+ auto & cells = v_cells[s];
681
+
682
+ cells.reset_shift();
683
+ }
491
684
  }
492
685
 
493
686
  if (!dinfo.empty()) {
494
687
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
495
688
 
689
+ // note: for now do not consider defrag for n_stream > 1
690
+ auto & cells = v_cells[seq_to_stream[0]];
691
+ auto & head = v_heads[seq_to_stream[0]];
692
+
496
693
  // apply moves:
497
694
  {
498
695
  const auto n_kv = dinfo.ids.size();
@@ -513,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
513
710
 
514
711
  ggml_backend_sched_reset(sched);
515
712
 
516
- auto * gf = lctx->graph_init();
713
+ auto * res = lctx->get_gf_res_reserve();
517
714
 
518
- auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
519
- if (!res) {
520
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
521
- return updated;
522
- }
715
+ res->reset();
523
716
 
717
+ auto * gf = build_graph_defrag(res, lctx, dinfo);
524
718
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
525
719
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
526
720
  return updated;
@@ -539,24 +733,14 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
539
733
  return updated;
540
734
  }
541
735
 
542
- int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
543
- const uint32_t n_tokens = ubatch.n_tokens;
544
-
545
- uint32_t head_cur = this->head;
736
+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
737
+ if (debug > 0) {
738
+ const auto & cells = v_cells[seq_to_stream[1]];
546
739
 
547
- // if we have enough unused cells before the current head ->
548
- // better to start searching from the beginning of the cache, hoping to fill it
549
- if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
550
- head_cur = 0;
551
- }
740
+ const uint32_t head_cur = v_heads[1];
552
741
 
553
- if (n_tokens > cells.size()) {
554
- LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
555
- return -1;
556
- }
557
-
558
- if (debug > 0) {
559
- LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
742
+ LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
743
+ __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
560
744
 
561
745
  if ((debug == 2 && n_swa > 0) || debug > 2) {
562
746
  std::string ss;
@@ -613,103 +797,186 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
613
797
  }
614
798
  }
615
799
 
616
- uint32_t n_tested = 0;
800
+ uint32_t n_tokens = ubatch.n_tokens;
801
+ uint32_t n_seqs = 1;
617
802
 
618
- while (true) {
619
- if (head_cur + n_tokens > cells.size()) {
620
- n_tested += cells.size() - head_cur;
803
+ if (n_stream > 1) {
804
+ GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
805
+
806
+ n_seqs = ubatch.n_seqs_unq;
807
+ n_tokens = n_tokens / n_seqs;
808
+ }
809
+
810
+ slot_info res = {
811
+ /*.s0 =*/ LLAMA_MAX_SEQ,
812
+ /*.s1 =*/ 0,
813
+ /*.strm =*/ { },
814
+ /*.idxs =*/ { },
815
+ };
816
+
817
+ res.resize(n_seqs);
818
+
819
+ for (uint32_t s = 0; s < n_seqs; ++s) {
820
+ const auto seq_id = ubatch.seq_id_unq[s];
821
+
822
+ if (n_stream > 1) {
823
+ GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
824
+ GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
825
+ }
826
+
827
+ res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
828
+ res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
829
+
830
+ res.strm[s] = seq_to_stream[seq_id];
831
+ res.idxs[s].reserve(n_tokens);
832
+
833
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
834
+
835
+ uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
836
+
837
+ // if we have enough unused cells before the current head ->
838
+ // better to start searching from the beginning of the cache, hoping to fill it
839
+ if (head_cur > cells.get_used() + 2*n_tokens) {
621
840
  head_cur = 0;
622
- continue;
623
841
  }
624
842
 
625
- bool found = true;
626
- for (uint32_t i = 0; i < n_tokens; i++) {
627
- //const llama_pos pos = ubatch.pos[i];
628
- //const llama_seq_id seq_id = ubatch.seq_id[i][0];
843
+ if (n_tokens > cells.size()) {
844
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
845
+ return { };
846
+ }
847
+
848
+ uint32_t n_tested = 0;
849
+
850
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
851
+ // for non-continuous slots, we test the tokens one by one
852
+ const uint32_t n_test = cont ? n_tokens : 1;
853
+
854
+ while (true) {
855
+ if (head_cur + n_test > cells.size()) {
856
+ n_tested += cells.size() - head_cur;
857
+ head_cur = 0;
858
+ continue;
859
+ }
860
+
861
+ for (uint32_t i = 0; i < n_test; i++) {
862
+ const auto idx = head_cur;
629
863
 
630
- // can we use this cell? either:
631
- // - the cell is empty
632
- // - the cell is occupied only by one sequence:
633
- // - (disabled) mask causally, if the sequence is the same as the one we are inserting
634
- // - mask SWA, using current max pos for that sequence in the cache
635
- // always insert in the cell with minimum pos
636
- bool can_use = cells.is_empty(head_cur + i);
864
+ head_cur++;
865
+ n_tested++;
637
866
 
638
- if (!can_use && cells.seq_count(head_cur + i) == 1) {
639
- const llama_pos pos_cell = cells.pos_get(head_cur + i);
867
+ //const llama_pos pos = ubatch.pos[i];
868
+ //const llama_seq_id seq_id = ubatch.seq_id[i][0];
640
869
 
641
- // (disabled) causal mask
642
- // note: it's better to purge any "future" tokens beforehand
643
- //if (cells.seq_has(head_cur + i, seq_id)) {
644
- // can_use = pos_cell >= pos;
645
- //}
870
+ // can we use this cell? either:
871
+ // - the cell is empty
872
+ // - the cell is occupied only by one sequence:
873
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
874
+ // - mask SWA, using current max pos for that sequence in the cache
875
+ // always insert in the cell with minimum pos
876
+ bool can_use = cells.is_empty(idx);
646
877
 
647
- if (!can_use) {
648
- const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
878
+ if (!can_use && cells.seq_count(idx) == 1) {
879
+ const llama_pos pos_cell = cells.pos_get(idx);
649
880
 
650
- // SWA mask
651
- if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
652
- can_use = true;
881
+ // (disabled) causal mask
882
+ // note: it's better to purge any "future" tokens beforehand
883
+ //if (cells.seq_has(idx, seq_id)) {
884
+ // can_use = pos_cell >= pos;
885
+ //}
886
+
887
+ if (!can_use) {
888
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
889
+
890
+ // SWA mask
891
+ if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
892
+ can_use = true;
893
+ }
894
+ }
895
+ }
896
+
897
+ if (can_use) {
898
+ res.idxs[s].push_back(idx);
899
+ } else {
900
+ if (cont) {
901
+ break;
653
902
  }
654
903
  }
655
904
  }
656
905
 
657
- if (!can_use) {
658
- found = false;
659
- head_cur += i + 1;
660
- n_tested += i + 1;
906
+ if (res.idxs[s].size() == n_tokens) {
661
907
  break;
662
908
  }
663
- }
664
909
 
665
- if (found) {
666
- break;
910
+ if (cont) {
911
+ res.idxs[s].clear();
912
+ }
913
+
914
+ if (n_tested >= cells.size()) {
915
+ //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
916
+ return { };
917
+ }
667
918
  }
668
919
 
669
- if (n_tested >= cells.size()) {
670
- //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
671
- return -1;
920
+ // we didn't find a suitable slot - return empty result
921
+ if (res.idxs[s].size() < n_tokens) {
922
+ return { };
672
923
  }
673
924
  }
674
925
 
675
- return head_cur;
926
+ assert(res.s1 >= res.s0);
927
+
928
+ return res;
676
929
  }
677
930
 
678
- void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
931
+ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
679
932
  // keep track of the max sequence position that we would overwrite with this ubatch
680
933
  // for non-SWA cache, this would be always empty
681
934
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
682
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
935
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
683
936
  seq_pos_max_rm[s] = -1;
684
937
  }
685
938
 
686
- for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
687
- if (!cells.is_empty(head_cur + i)) {
688
- assert(cells.seq_count(head_cur + i) == 1);
939
+ assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
689
940
 
690
- const llama_seq_id seq_id = cells.seq_get(head_cur + i);
691
- const llama_pos pos = cells.pos_get(head_cur + i);
941
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
942
+ for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
943
+ const uint32_t i = s*sinfo.size() + ii;
692
944
 
693
- seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
945
+ auto & cells = v_cells[sinfo.strm[s]];
694
946
 
695
- cells.rm(head_cur + i);
696
- }
947
+ const auto idx = sinfo.idxs[s][ii];
697
948
 
698
- cells.pos_set(head_cur + i, ubatch.pos[i]);
949
+ if (!cells.is_empty(idx)) {
950
+ assert(cells.seq_count(idx) == 1);
699
951
 
700
- for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
701
- cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
952
+ const llama_seq_id seq_id = cells.seq_get(idx);
953
+ const llama_pos pos = cells.pos_get(idx);
954
+
955
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
956
+
957
+ cells.rm(idx);
958
+ }
959
+
960
+ cells.pos_set(idx, ubatch.pos[i]);
961
+
962
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
963
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
964
+ }
702
965
  }
703
966
  }
704
967
 
705
968
  // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
706
969
  // will be present in the cache. so we have to purge any position which is less than those we would overwrite
707
970
  // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
708
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
971
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
709
972
  if (seq_pos_max_rm[s] == -1) {
710
973
  continue;
711
974
  }
712
975
 
976
+ GGML_ASSERT(s < seq_to_stream.size());
977
+
978
+ auto & cells = v_cells[seq_to_stream[s]];
979
+
713
980
  if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
714
981
  LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
715
982
  __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
@@ -719,7 +986,11 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
719
986
  }
720
987
 
721
988
  // move the head at the end of the slot
722
- head = head_cur + ubatch.n_tokens;
989
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
990
+ auto & head = v_heads[sinfo.strm[s]];
991
+
992
+ head = sinfo.idxs[s].back() + 1;
993
+ }
723
994
  }
724
995
 
725
996
  bool llama_kv_cache_unified::get_can_shift() const {
@@ -727,99 +998,290 @@ bool llama_kv_cache_unified::get_can_shift() const {
727
998
  }
728
999
 
729
1000
  uint32_t llama_kv_cache_unified::get_size() const {
1001
+ const auto & cells = v_cells[seq_to_stream[0]];
1002
+
730
1003
  return cells.size();
731
1004
  }
732
1005
 
1006
+ uint32_t llama_kv_cache_unified::get_n_stream() const {
1007
+ return n_stream;
1008
+ }
1009
+
733
1010
  bool llama_kv_cache_unified::get_has_shift() const {
734
- return cells.get_has_shift();
1011
+ bool result = false;
1012
+
1013
+ for (uint32_t s = 0; s < n_stream; ++s) {
1014
+ result |= v_cells[s].get_has_shift();
1015
+ }
1016
+
1017
+ return result;
735
1018
  }
736
1019
 
737
1020
  uint32_t llama_kv_cache_unified::get_n_kv() const {
738
- return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
1021
+ uint32_t result = 0;
1022
+
1023
+ for (uint32_t s = 0; s < n_stream; ++s) {
1024
+ const auto & cells = v_cells[s];
1025
+
1026
+ result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
1027
+ }
1028
+
1029
+ return result;
1030
+ }
1031
+
1032
+ bool llama_kv_cache_unified::get_supports_set_rows() const {
1033
+ return supports_set_rows;
739
1034
  }
740
1035
 
741
- ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
1036
+ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
742
1037
  const int32_t ikv = map_layer_ids.at(il);
743
1038
 
744
1039
  auto * k = layers[ikv].k;
745
1040
 
746
- return ggml_view_3d(ctx, k,
747
- hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
1041
+ const uint64_t kv_size = get_size();
1042
+ const uint64_t n_embd_k_gqa = k->ne[0];
1043
+
1044
+ assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
1045
+
1046
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1047
+
1048
+ return ggml_view_4d(ctx, k,
1049
+ hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
748
1050
  ggml_row_size(k->type, hparams.n_embd_head_k),
749
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
750
- 0);
1051
+ ggml_row_size(k->type, n_embd_k_gqa),
1052
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size),
1053
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
751
1054
  }
752
1055
 
753
- ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
1056
+ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
754
1057
  const int32_t ikv = map_layer_ids.at(il);
755
1058
 
756
1059
  auto * v = layers[ikv].v;
757
1060
 
1061
+ const uint64_t kv_size = get_size();
1062
+ const uint64_t n_embd_v_gqa = v->ne[0];
1063
+
1064
+ // [TAG_V_CACHE_VARIABLE]
1065
+ assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
1066
+
1067
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1068
+
758
1069
  if (!v_trans) {
759
1070
  // note: v->nb[1] <= v->nb[2]
760
- return ggml_view_3d(ctx, v,
761
- hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
762
- ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
763
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
764
- 0);
1071
+ return ggml_view_4d(ctx, v,
1072
+ hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1073
+ ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1074
+ ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1075
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
1076
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
765
1077
  }
766
1078
 
767
1079
  // note: v->nb[1] > v->nb[2]
768
- return ggml_view_3d(ctx, v,
769
- n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
770
- ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
771
- ggml_row_size(v->type, v->ne[1]), // v->nb[2]
772
- 0);
1080
+ return ggml_view_4d(ctx, v,
1081
+ n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1082
+ ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1083
+ ggml_row_size(v->type, kv_size), // v->nb[2]
1084
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
1085
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
773
1086
  }
774
1087
 
775
- ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
1088
+ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
776
1089
  const int32_t ikv = map_layer_ids.at(il);
777
1090
 
778
1091
  auto * k = layers[ikv].k;
779
1092
 
1093
+ const int64_t n_embd_k_gqa = k->ne[0];
780
1094
  const int64_t n_tokens = k_cur->ne[2];
781
1095
 
1096
+ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
1097
+
1098
+ if (k_idxs && supports_set_rows) {
1099
+ if (k->ne[2] > 1) {
1100
+ k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
1101
+ }
1102
+
1103
+ return ggml_set_rows(ctx, k, k_cur, k_idxs);
1104
+ }
1105
+
1106
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
1107
+ // will be removed when ggml_set_rows() is adopted by all backends
1108
+
1109
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
1110
+
782
1111
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
783
- n_tokens*hparams.n_embd_k_gqa(il),
784
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
1112
+ n_tokens*n_embd_k_gqa,
1113
+ ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
785
1114
 
786
1115
  return ggml_cpy(ctx, k_cur, k_view);
787
1116
  }
788
1117
 
789
- ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
1118
+ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
790
1119
  const int32_t ikv = map_layer_ids.at(il);
791
1120
 
792
1121
  auto * v = layers[ikv].v;
793
1122
 
794
- const int64_t n_tokens = v_cur->ne[2];
1123
+ const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
1124
+ const int64_t n_tokens = v_cur->ne[2];
795
1125
 
796
- v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
1126
+ v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1127
+
1128
+ if (v_idxs && supports_set_rows) {
1129
+ if (!v_trans) {
1130
+ if (v->ne[2] > 1) {
1131
+ v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1132
+ }
1133
+
1134
+ return ggml_set_rows(ctx, v, v_cur, v_idxs);
1135
+ }
1136
+
1137
+ // [TAG_V_CACHE_VARIABLE]
1138
+ if (n_embd_v_gqa < v->ne[0]) {
1139
+ v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
1140
+ }
1141
+
1142
+ // the row becomes a single element
1143
+ ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
1144
+
1145
+ v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
1146
+
1147
+ return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
1148
+ }
1149
+
1150
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
1151
+ // will be removed when ggml_set_rows() is adopted by all backends
1152
+
1153
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
797
1154
 
798
1155
  ggml_tensor * v_view = nullptr;
799
1156
 
800
1157
  if (!v_trans) {
801
1158
  v_view = ggml_view_1d(ctx, v,
802
- n_tokens*hparams.n_embd_v_gqa(il),
803
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
1159
+ n_tokens*n_embd_v_gqa,
1160
+ ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
804
1161
  } else {
805
- // note: the V cache is transposed when not using flash attention
806
- v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
807
- (v->ne[1])*ggml_element_size(v),
808
- (head_cur)*ggml_element_size(v));
809
-
810
1162
  v_cur = ggml_transpose(ctx, v_cur);
1163
+
1164
+ v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
1165
+ (v->ne[1] )*ggml_element_size(v),
1166
+ (sinfo.head())*ggml_element_size(v));
811
1167
  }
812
1168
 
813
1169
  return ggml_cpy(ctx, v_cur, v_view);
814
1170
  }
815
1171
 
1172
+ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1173
+ const uint32_t n_tokens = ubatch.n_tokens;
1174
+
1175
+ ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1176
+
1177
+ ggml_set_input(k_idxs);
1178
+
1179
+ return k_idxs;
1180
+ }
1181
+
1182
+ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1183
+ const uint32_t n_tokens = ubatch.n_tokens;
1184
+
1185
+ ggml_tensor * v_idxs;
1186
+
1187
+ if (!v_trans) {
1188
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1189
+ } else {
1190
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
1191
+ }
1192
+
1193
+ ggml_set_input(v_idxs);
1194
+
1195
+ return v_idxs;
1196
+ }
1197
+
1198
+ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1199
+ if (!supports_set_rows) {
1200
+ return;
1201
+ }
1202
+
1203
+ const uint32_t n_tokens = ubatch->n_tokens;
1204
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1205
+
1206
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1207
+ int64_t * data = (int64_t *) dst->data;
1208
+
1209
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1210
+ const int64_t offs = sinfo.strm[s]*get_size();
1211
+
1212
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1213
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1214
+ }
1215
+ }
1216
+ }
1217
+
1218
+ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1219
+ if (!supports_set_rows) {
1220
+ return;
1221
+ }
1222
+
1223
+ const uint32_t n_tokens = ubatch->n_tokens;
1224
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1225
+
1226
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1227
+ int64_t * data = (int64_t *) dst->data;
1228
+
1229
+ if (!v_trans) {
1230
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1231
+ const int64_t offs = sinfo.strm[s]*get_size();
1232
+
1233
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1234
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1235
+ }
1236
+ }
1237
+ } else {
1238
+ // note: the V cache is transposed when not using flash attention
1239
+ const int64_t kv_size = get_size();
1240
+
1241
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
1242
+
1243
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1244
+ const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
1245
+
1246
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1247
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1248
+ data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
1249
+ }
1250
+ }
1251
+ }
1252
+ }
1253
+ }
1254
+
1255
+ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
1256
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1257
+
1258
+ int32_t * data = (int32_t *) dst->data;
1259
+
1260
+ for (uint32_t s = 0; s < n_stream; ++s) {
1261
+ const auto & cells = v_cells[s];
1262
+
1263
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1264
+ data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1265
+ }
1266
+ }
1267
+ }
1268
+
816
1269
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
817
1270
  const uint32_t n_tokens = ubatch->n_tokens;
818
1271
 
819
1272
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
820
1273
  float * data = (float *) dst->data;
821
1274
 
822
- const int64_t n_kv = dst->ne[0];
1275
+ const int64_t n_kv = dst->ne[0];
1276
+ const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1277
+
1278
+ GGML_ASSERT(n_tokens%n_stream == 0);
1279
+
1280
+ // n_tps == n_tokens_per_stream
1281
+ const int64_t n_tps = n_tokens/n_stream;
1282
+ const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
1283
+
1284
+ std::fill(data, data + ggml_nelements(dst), -INFINITY);
823
1285
 
824
1286
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
825
1287
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -833,70 +1295,57 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
833
1295
  // xxxxx-----
834
1296
  // xxxxx-----
835
1297
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
1298
+ // TODO: optimize this section
836
1299
  for (uint32_t h = 0; h < 1; ++h) {
837
- for (uint32_t i = 0; i < n_tokens; ++i) {
838
- const llama_seq_id seq_id = ubatch->seq_id[i][0];
1300
+ for (uint32_t s = 0; s < n_stream; ++s) {
1301
+ for (uint32_t ii = 0; ii < n_tps; ++ii) {
1302
+ const uint32_t i = s*n_tps + ii;
839
1303
 
840
- const llama_pos p1 = ubatch->pos[i];
1304
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
841
1305
 
842
- for (uint32_t j = 0; j < n_kv; ++j) {
843
- float f = 0.0f;
1306
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
844
1307
 
845
- bool masked = false;
1308
+ const llama_pos p1 = ubatch->pos[i];
846
1309
 
847
- if (cells.is_empty(j)) {
848
- masked = true;
849
- } else {
850
- const llama_pos p0 = cells.pos_get(j);
1310
+ const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1311
+
1312
+ for (uint32_t j = 0; j < n_kv; ++j) {
1313
+ if (cells.is_empty(j)) {
1314
+ continue;
1315
+ }
851
1316
 
852
1317
  // mask the token if not the same sequence
853
- masked = masked || (!cells.seq_has(j, seq_id));
1318
+ if (!cells.seq_has(j, seq_id)) {
1319
+ continue;
1320
+ }
1321
+
1322
+ const llama_pos p0 = cells.pos_get(j);
854
1323
 
855
1324
  // mask future tokens
856
- masked = masked || (causal_attn && p0 > p1);
1325
+ if (causal_attn && p0 > p1) {
1326
+ continue;
1327
+ }
857
1328
 
858
1329
  // apply SWA if any
859
- masked = masked || (is_masked_swa(p0, p1));
860
-
861
- if (!masked && hparams.use_alibi) {
862
- f = -std::abs(p0 - p1);
1330
+ if (is_masked_swa(p0, p1)) {
1331
+ continue;
863
1332
  }
864
- }
865
1333
 
866
- if (masked) {
867
- f = -INFINITY;
868
- }
869
-
870
- data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
871
- }
872
- }
873
-
874
- // mask padded tokens
875
- if (data) {
876
- for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
877
- for (uint32_t j = 0; j < n_kv; ++j) {
878
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
1334
+ data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
879
1335
  }
880
1336
  }
881
1337
  }
882
1338
  }
883
1339
  }
884
1340
 
885
- void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
886
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
887
-
888
- int32_t * data = (int32_t *) dst->data;
889
-
890
- for (uint32_t i = 0; i < cells.size(); ++i) {
891
- data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
892
- }
893
- }
894
-
895
1341
  void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
896
1342
  const int64_t n_tokens = ubatch->n_tokens;
897
1343
 
1344
+ GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
1345
+ const auto & cells = v_cells[0];
1346
+
898
1347
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
899
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
1348
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
900
1349
 
901
1350
  int32_t * data = (int32_t *) dst->data;
902
1351
 
@@ -1001,7 +1450,7 @@ public:
1001
1450
 
1002
1451
  void set_input(const llama_ubatch * ubatch) override;
1003
1452
 
1004
- ggml_tensor * k_shift; // I32 [kv_size]
1453
+ ggml_tensor * k_shift; // I32 [kv_size*n_stream]
1005
1454
 
1006
1455
  const llama_kv_cache_unified * kv_self;
1007
1456
  };
@@ -1014,20 +1463,20 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1014
1463
  }
1015
1464
  }
1016
1465
 
1017
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1018
- const llama_cparams & cparams,
1019
- ggml_context * ctx,
1020
- ggml_cgraph * gf) const {
1021
- auto res = std::make_unique<llm_graph_result>();
1466
+ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
1467
+ auto * ctx = res->get_ctx();
1468
+ auto * gf = res->get_gf();
1022
1469
 
1023
1470
  const auto & n_embd_head_k = hparams.n_embd_head_k;
1024
1471
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1025
1472
 
1026
1473
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
1027
1474
 
1028
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
1475
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
1029
1476
  ggml_set_input(inp->k_shift);
1030
1477
 
1478
+ const auto & cparams = lctx->get_cparams();
1479
+
1031
1480
  for (const auto & layer : layers) {
1032
1481
  const uint32_t il = layer.il;
1033
1482
 
@@ -1041,7 +1490,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1041
1490
 
1042
1491
  ggml_tensor * k =
1043
1492
  ggml_view_3d(ctx, layer.k,
1044
- n_embd_head_k, n_head_kv, cells.size(),
1493
+ n_embd_head_k, n_head_kv, get_size()*n_stream,
1045
1494
  ggml_row_size(layer.k->type, n_embd_head_k),
1046
1495
  ggml_row_size(layer.k->type, n_embd_k_gqa),
1047
1496
  0);
@@ -1053,18 +1502,24 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1053
1502
 
1054
1503
  res->add_input(std::move(inp));
1055
1504
 
1056
- return res;
1505
+ return gf;
1057
1506
  }
1058
1507
 
1059
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1060
- const llama_cparams & cparams,
1061
- ggml_context * ctx,
1062
- ggml_cgraph * gf,
1063
- const defrag_info & dinfo) const {
1064
- auto res = std::make_unique<llm_graph_result>();
1508
+ ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
1509
+ llm_graph_result * res,
1510
+ llama_context * lctx,
1511
+ const defrag_info & dinfo) const {
1512
+ auto * ctx = res->get_ctx();
1513
+ auto * gf = res->get_gf();
1514
+
1515
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
1516
+
1517
+ const auto & cells = v_cells[0];
1065
1518
 
1066
1519
  const auto & ids = dinfo.ids;
1067
1520
 
1521
+ const auto & cparams = lctx->get_cparams();
1522
+
1068
1523
  #if 0
1069
1524
  // CPU defrag
1070
1525
  //
@@ -1201,10 +1656,14 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1201
1656
  //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
1202
1657
  #endif
1203
1658
 
1204
- return res;
1659
+ return gf;
1205
1660
  }
1206
1661
 
1207
1662
  llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
1663
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
1664
+
1665
+ const auto & cells = v_cells[0];
1666
+
1208
1667
  const uint32_t n_layer = layers.size();
1209
1668
 
1210
1669
  const uint32_t n_kv = cells.used_max_p1();
@@ -1350,64 +1809,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1350
1809
  }
1351
1810
 
1352
1811
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1353
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1354
- uint32_t cell_count = 0;
1812
+ io.write(&n_stream, sizeof(n_stream));
1355
1813
 
1356
- // Count the number of cells with the specified seq_id
1357
- // Find all the ranges of cells with this seq id (or all, when -1)
1358
- uint32_t cell_range_begin = cells.size();
1814
+ for (uint32_t s = 0; s < n_stream; ++s) {
1815
+ cell_ranges_t cr { s, {} };
1359
1816
 
1360
- for (uint32_t i = 0; i < cells.size(); ++i) {
1361
- if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1362
- ++cell_count;
1363
- if (cell_range_begin == cells.size()) {
1364
- cell_range_begin = i;
1365
- }
1366
- } else {
1367
- if (cell_range_begin != cells.size()) {
1368
- cell_ranges.emplace_back(cell_range_begin, i);
1369
- cell_range_begin = cells.size();
1817
+ uint32_t cell_count = 0;
1818
+
1819
+ const auto & cells = v_cells[s];
1820
+
1821
+ // Count the number of cells with the specified seq_id
1822
+ // Find all the ranges of cells with this seq id (or all, when -1)
1823
+ uint32_t cell_range_begin = cells.size();
1824
+
1825
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1826
+ if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1827
+ ++cell_count;
1828
+ if (cell_range_begin == cells.size()) {
1829
+ cell_range_begin = i;
1830
+ }
1831
+ } else {
1832
+ if (cell_range_begin != cells.size()) {
1833
+ cr.data.emplace_back(cell_range_begin, i);
1834
+ cell_range_begin = cells.size();
1835
+ }
1370
1836
  }
1371
1837
  }
1372
- }
1373
1838
 
1374
- if (cell_range_begin != cells.size()) {
1375
- cell_ranges.emplace_back(cell_range_begin, cells.size());
1376
- }
1839
+ if (cell_range_begin != cells.size()) {
1840
+ cr.data.emplace_back(cell_range_begin, cells.size());
1841
+ }
1377
1842
 
1378
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1379
- uint32_t cell_count_check = 0;
1380
- for (const auto & range : cell_ranges) {
1381
- cell_count_check += range.second - range.first;
1382
- }
1383
- GGML_ASSERT(cell_count == cell_count_check);
1843
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1844
+ uint32_t cell_count_check = 0;
1845
+ for (const auto & range : cr.data) {
1846
+ cell_count_check += range.second - range.first;
1847
+ }
1848
+ GGML_ASSERT(cell_count == cell_count_check);
1849
+
1850
+ io.write(&cell_count, sizeof(cell_count));
1384
1851
 
1385
- io.write(&cell_count, sizeof(cell_count));
1852
+ // skip empty streams
1853
+ if (cell_count == 0) {
1854
+ continue;
1855
+ }
1386
1856
 
1387
- state_write_meta(io, cell_ranges, seq_id);
1388
- state_write_data(io, cell_ranges);
1857
+ state_write_meta(io, cr, seq_id);
1858
+ state_write_data(io, cr);
1859
+ }
1389
1860
  }
1390
1861
 
1391
1862
  void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1392
- uint32_t cell_count;
1393
- io.read_to(&cell_count, sizeof(cell_count));
1863
+ GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
1864
+
1865
+ uint32_t n_stream_cur;
1866
+ io.read_to(&n_stream_cur, sizeof(n_stream_cur));
1867
+ if (n_stream_cur != n_stream) {
1868
+ throw std::runtime_error("n_stream mismatch");
1869
+ }
1870
+
1871
+ for (uint32_t s = 0; s < n_stream; ++s) {
1872
+ uint32_t cell_count;
1873
+ io.read_to(&cell_count, sizeof(cell_count));
1874
+
1875
+ if (cell_count == 0) {
1876
+ continue;
1877
+ }
1878
+
1879
+ const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
1394
1880
 
1395
- bool res = true;
1396
- res = res && state_read_meta(io, cell_count, seq_id);
1397
- res = res && state_read_data(io, cell_count);
1881
+ bool res = true;
1882
+ res = res && state_read_meta(io, strm, cell_count, seq_id);
1883
+ res = res && state_read_data(io, strm, cell_count);
1398
1884
 
1399
- if (!res) {
1400
- if (seq_id == -1) {
1401
- clear(true);
1402
- } else {
1403
- seq_rm(seq_id, -1, -1);
1885
+ if (!res) {
1886
+ if (seq_id == -1) {
1887
+ clear(true);
1888
+ } else {
1889
+ seq_rm(seq_id, -1, -1);
1890
+ }
1891
+ throw std::runtime_error("failed to restore kv cache");
1404
1892
  }
1405
- throw std::runtime_error("failed to restore kv cache");
1406
1893
  }
1407
1894
  }
1408
1895
 
1409
- void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1410
- for (const auto & range : cell_ranges) {
1896
+ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
1897
+ const auto & cells = v_cells[cr.strm];
1898
+
1899
+ for (const auto & range : cr.data) {
1411
1900
  for (uint32_t i = range.first; i < range.second; ++i) {
1412
1901
  std::vector<llama_seq_id> seq_ids;
1413
1902
 
@@ -1432,7 +1921,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
1432
1921
  }
1433
1922
  }
1434
1923
 
1435
- void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1924
+ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
1925
+ const auto & cells = v_cells[cr.strm];
1926
+
1436
1927
  const uint32_t v_trans = this->v_trans ? 1 : 0;
1437
1928
  const uint32_t n_layer = layers.size();
1438
1929
 
@@ -1448,19 +1939,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1448
1939
 
1449
1940
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1450
1941
 
1942
+ auto * k = layer.k_stream[cr.strm];
1943
+
1451
1944
  // Write key type
1452
- const int32_t k_type_i = (int32_t)layer.k->type;
1945
+ const int32_t k_type_i = (int32_t) k->type;
1453
1946
  io.write(&k_type_i, sizeof(k_type_i));
1454
1947
 
1455
1948
  // Write row size of key
1456
- const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1949
+ const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1457
1950
  io.write(&k_size_row, sizeof(k_size_row));
1458
1951
 
1459
1952
  // Read each range of cells of k_size length each into tmp_buf and write out
1460
- for (const auto & range : cell_ranges) {
1953
+ for (const auto & range : cr.data) {
1461
1954
  const size_t range_size = range.second - range.first;
1462
1955
  const size_t buf_size = range_size * k_size_row;
1463
- io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1956
+ io.write_tensor(k, range.first * k_size_row, buf_size);
1464
1957
  }
1465
1958
  }
1466
1959
 
@@ -1470,19 +1963,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1470
1963
 
1471
1964
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1472
1965
 
1966
+ auto * v = layer.v_stream[cr.strm];
1967
+
1473
1968
  // Write value type
1474
- const int32_t v_type_i = (int32_t)layer.v->type;
1969
+ const int32_t v_type_i = (int32_t) v->type;
1475
1970
  io.write(&v_type_i, sizeof(v_type_i));
1476
1971
 
1477
1972
  // Write row size of value
1478
- const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1973
+ const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1479
1974
  io.write(&v_size_row, sizeof(v_size_row));
1480
1975
 
1481
1976
  // Read each range of cells of v_size length each into tmp_buf and write out
1482
- for (const auto & range : cell_ranges) {
1977
+ for (const auto & range : cr.data) {
1483
1978
  const size_t range_size = range.second - range.first;
1484
1979
  const size_t buf_size = range_size * v_size_row;
1485
- io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1980
+ io.write_tensor(v, range.first * v_size_row, buf_size);
1486
1981
  }
1487
1982
  }
1488
1983
  } else {
@@ -1494,12 +1989,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1494
1989
 
1495
1990
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1496
1991
 
1992
+ auto * v = layer.v_stream[cr.strm];
1993
+
1497
1994
  // Write value type
1498
- const int32_t v_type_i = (int32_t)layer.v->type;
1995
+ const int32_t v_type_i = (int32_t) v->type;
1499
1996
  io.write(&v_type_i, sizeof(v_type_i));
1500
1997
 
1501
1998
  // Write element size
1502
- const uint32_t v_size_el = ggml_type_size(layer.v->type);
1999
+ const uint32_t v_size_el = ggml_type_size(v->type);
1503
2000
  io.write(&v_size_el, sizeof(v_size_el));
1504
2001
 
1505
2002
  // Write GQA embedding size
@@ -1508,27 +2005,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1508
2005
  // For each row, we get the element values of each cell
1509
2006
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1510
2007
  // Read each range of cells of v_size_el length each into tmp_buf and write out
1511
- for (const auto & range : cell_ranges) {
2008
+ for (const auto & range : cr.data) {
1512
2009
  const size_t range_size = range.second - range.first;
1513
2010
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1514
2011
  const size_t buf_size = range_size * v_size_el;
1515
- io.write_tensor(layer.v, src_offset, buf_size);
2012
+ io.write_tensor(v, src_offset, buf_size);
1516
2013
  }
1517
2014
  }
1518
2015
  }
1519
2016
  }
1520
2017
  }
1521
2018
 
1522
- bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
2019
+ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
2020
+ auto & cells = v_cells[strm];
2021
+ auto & head = v_heads[strm];
2022
+
1523
2023
  if (dest_seq_id != -1) {
1524
2024
  // single sequence
1525
-
1526
2025
  seq_rm(dest_seq_id, -1, -1);
1527
2026
 
1528
2027
  llama_batch_allocr balloc(hparams.n_pos_per_embd());
1529
2028
 
1530
2029
  llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1531
2030
 
2031
+ ubatch.seq_id_unq[0] = dest_seq_id;
2032
+
1532
2033
  for (uint32_t i = 0; i < cell_count; ++i) {
1533
2034
  llama_pos pos;
1534
2035
  uint32_t n_seq_id;
@@ -1552,17 +2053,21 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1552
2053
  ubatch.seq_id[i] = &dest_seq_id;
1553
2054
  }
1554
2055
 
1555
- const auto head_cur = find_slot(ubatch);
1556
- if (head_cur < 0) {
2056
+ const auto sinfo = find_slot(ubatch, true);
2057
+ if (sinfo.empty()) {
1557
2058
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1558
2059
  return false;
1559
2060
  }
1560
2061
 
1561
- apply_ubatch(head_cur, ubatch);
2062
+ apply_ubatch(sinfo, ubatch);
2063
+
2064
+ const auto head_cur = sinfo.head();
1562
2065
 
1563
2066
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1564
2067
  head = head_cur;
1565
2068
 
2069
+ LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
2070
+
1566
2071
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1567
2072
  // Assume that this is one contiguous block of cells
1568
2073
  GGML_ASSERT(head_cur + cell_count <= cells.size());
@@ -1608,7 +2113,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1608
2113
  return true;
1609
2114
  }
1610
2115
 
1611
- bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
2116
+ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
2117
+ auto & cells = v_cells[strm];
2118
+ auto & head = v_heads[strm];
2119
+
1612
2120
  uint32_t v_trans;
1613
2121
  uint32_t n_layer;
1614
2122
 
@@ -1636,10 +2144,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1636
2144
 
1637
2145
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1638
2146
 
2147
+ auto * k = layer.k_stream[strm];
2148
+
1639
2149
  // Read type of key
1640
2150
  int32_t k_type_i_ref;
1641
2151
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1642
- const int32_t k_type_i = (int32_t) layer.k->type;
2152
+ const int32_t k_type_i = (int32_t) k->type;
1643
2153
  if (k_type_i != k_type_i_ref) {
1644
2154
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1645
2155
  return false;
@@ -1648,7 +2158,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1648
2158
  // Read row size of key
1649
2159
  uint64_t k_size_row_ref;
1650
2160
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1651
- const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
2161
+ const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1652
2162
  if (k_size_row != k_size_row_ref) {
1653
2163
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1654
2164
  return false;
@@ -1656,7 +2166,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1656
2166
 
1657
2167
  if (cell_count) {
1658
2168
  // Read and set the keys for the whole cell range
1659
- ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
2169
+ ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1660
2170
  }
1661
2171
  }
1662
2172
 
@@ -1666,10 +2176,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1666
2176
 
1667
2177
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1668
2178
 
2179
+ auto * v = layer.v_stream[strm];
2180
+
1669
2181
  // Read type of value
1670
2182
  int32_t v_type_i_ref;
1671
2183
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1672
- const int32_t v_type_i = (int32_t)layer.v->type;
2184
+ const int32_t v_type_i = (int32_t) v->type;
1673
2185
  if (v_type_i != v_type_i_ref) {
1674
2186
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1675
2187
  return false;
@@ -1678,7 +2190,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1678
2190
  // Read row size of value
1679
2191
  uint64_t v_size_row_ref;
1680
2192
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1681
- const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
2193
+ const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1682
2194
  if (v_size_row != v_size_row_ref) {
1683
2195
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1684
2196
  return false;
@@ -1686,7 +2198,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1686
2198
 
1687
2199
  if (cell_count) {
1688
2200
  // Read and set the values for the whole cell range
1689
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
2201
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1690
2202
  }
1691
2203
  }
1692
2204
  } else {
@@ -1696,10 +2208,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1696
2208
 
1697
2209
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1698
2210
 
2211
+ auto * v = layer.v_stream[strm];
2212
+
1699
2213
  // Read type of value
1700
2214
  int32_t v_type_i_ref;
1701
2215
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1702
- const int32_t v_type_i = (int32_t)layer.v->type;
2216
+ const int32_t v_type_i = (int32_t) v->type;
1703
2217
  if (v_type_i != v_type_i_ref) {
1704
2218
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1705
2219
  return false;
@@ -1708,7 +2222,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1708
2222
  // Read element size of value
1709
2223
  uint32_t v_size_el_ref;
1710
2224
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1711
- const size_t v_size_el = ggml_type_size(layer.v->type);
2225
+ const size_t v_size_el = ggml_type_size(v->type);
1712
2226
  if (v_size_el != v_size_el_ref) {
1713
2227
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1714
2228
  return false;
@@ -1726,7 +2240,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1726
2240
  // For each row in the transposed matrix, read the values for the whole cell range
1727
2241
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1728
2242
  const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1729
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2243
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1730
2244
  }
1731
2245
  }
1732
2246
  }
@@ -1744,23 +2258,35 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
1744
2258
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1745
2259
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1746
2260
  n_kv = kv->get_size();
1747
- head = 0;
2261
+
2262
+ const uint32_t n_stream = kv->get_n_stream();
2263
+
2264
+ // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
2265
+ sinfos.resize(1);
2266
+ sinfos[0].s0 = 0;
2267
+ sinfos[0].s1 = n_stream - 1;
2268
+ sinfos[0].idxs.resize(n_stream);
2269
+ for (uint32_t s = 0; s < n_stream; ++s) {
2270
+ sinfos[0].strm.push_back(s);
2271
+ sinfos[0].idxs[s].resize(1, 0);
2272
+ }
1748
2273
  }
1749
2274
 
1750
2275
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1751
2276
  llama_kv_cache_unified * kv,
1752
2277
  llama_context * lctx,
1753
2278
  bool do_shift,
1754
- defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1755
- if (!do_shift && this->dinfo.empty()) {
2279
+ defrag_info dinfo,
2280
+ stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
2281
+ if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
1756
2282
  status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1757
2283
  }
1758
2284
  }
1759
2285
 
1760
2286
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1761
2287
  llama_kv_cache_unified * kv,
1762
- llama_kv_cache_unified::ubatch_heads heads,
1763
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
2288
+ llama_kv_cache_unified::slot_info_vec_t sinfos,
2289
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
1764
2290
  }
1765
2291
 
1766
2292
  llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@@ -1768,7 +2294,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
1768
2294
  bool llama_kv_cache_unified_context::next() {
1769
2295
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1770
2296
 
1771
- if (++i_next >= ubatches.size()) {
2297
+ if (++i_cur >= ubatches.size()) {
1772
2298
  return false;
1773
2299
  }
1774
2300
 
@@ -1776,19 +2302,18 @@ bool llama_kv_cache_unified_context::next() {
1776
2302
  }
1777
2303
 
1778
2304
  bool llama_kv_cache_unified_context::apply() {
1779
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
2305
+ assert(!llama_memory_status_is_fail(status));
1780
2306
 
1781
2307
  // no ubatches -> this is a KV cache update
1782
2308
  if (ubatches.empty()) {
1783
- kv->update(lctx, do_shift, dinfo);
2309
+ kv->update(lctx, do_shift, dinfo, sc_info);
1784
2310
 
1785
2311
  return true;
1786
2312
  }
1787
2313
 
1788
- kv->apply_ubatch(heads[i_next], ubatches[i_next]);
2314
+ kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
1789
2315
 
1790
2316
  n_kv = kv->get_n_kv();
1791
- head = heads[i_next];
1792
2317
 
1793
2318
  return true;
1794
2319
  }
@@ -1800,33 +2325,53 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1800
2325
  const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1801
2326
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1802
2327
 
1803
- return ubatches[i_next];
2328
+ return ubatches[i_cur];
1804
2329
  }
1805
2330
 
1806
2331
  uint32_t llama_kv_cache_unified_context::get_n_kv() const {
1807
2332
  return n_kv;
1808
2333
  }
1809
2334
 
2335
+ bool llama_kv_cache_unified_context::get_supports_set_rows() const {
2336
+ return kv->get_supports_set_rows();
2337
+ }
2338
+
1810
2339
  ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
1811
- return kv->get_k(ctx, il, n_kv);
2340
+ return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
1812
2341
  }
1813
2342
 
1814
2343
  ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
1815
- return kv->get_v(ctx, il, n_kv);
2344
+ return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
1816
2345
  }
1817
2346
 
1818
- ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1819
- return kv->cpy_k(ctx, k_cur, il, head);
2347
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
2348
+ return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1820
2349
  }
1821
2350
 
1822
- ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1823
- return kv->cpy_v(ctx, v_cur, il, head);
2351
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
2352
+ return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
2353
+ }
2354
+
2355
+ ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2356
+ return kv->build_input_k_idxs(ctx, ubatch);
2357
+ }
2358
+
2359
+ ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2360
+ return kv->build_input_v_idxs(ctx, ubatch);
1824
2361
  }
1825
2362
 
1826
2363
  void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1827
2364
  kv->set_input_k_shift(dst);
1828
2365
  }
1829
2366
 
2367
+ void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2368
+ kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
2369
+ }
2370
+
2371
+ void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2372
+ kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
2373
+ }
2374
+
1830
2375
  void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1831
2376
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1832
2377
  }