@fugood/llama.node 1.0.3 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  3. package/src/llama.cpp/common/arg.cpp +37 -0
  4. package/src/llama.cpp/common/common.cpp +22 -6
  5. package/src/llama.cpp/common/common.h +14 -1
  6. package/src/llama.cpp/ggml/CMakeLists.txt +3 -0
  7. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  8. package/src/llama.cpp/ggml/include/ggml.h +13 -0
  9. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  10. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  11. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +23 -8
  12. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  13. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +39 -0
  14. package/src/llama.cpp/include/llama.h +13 -48
  15. package/src/llama.cpp/src/llama-arch.cpp +222 -15
  16. package/src/llama.cpp/src/llama-arch.h +16 -1
  17. package/src/llama.cpp/src/llama-batch.cpp +76 -70
  18. package/src/llama.cpp/src/llama-batch.h +24 -18
  19. package/src/llama.cpp/src/llama-chat.cpp +44 -1
  20. package/src/llama.cpp/src/llama-chat.h +2 -0
  21. package/src/llama.cpp/src/llama-context.cpp +134 -95
  22. package/src/llama.cpp/src/llama-context.h +13 -16
  23. package/src/llama.cpp/src/llama-cparams.h +3 -2
  24. package/src/llama.cpp/src/llama-graph.cpp +239 -154
  25. package/src/llama.cpp/src/llama-graph.h +162 -126
  26. package/src/llama.cpp/src/llama-hparams.cpp +45 -0
  27. package/src/llama.cpp/src/llama-hparams.h +11 -1
  28. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  29. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  30. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  31. package/src/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  32. package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  33. package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -9
  34. package/src/llama.cpp/src/llama-model.cpp +2309 -665
  35. package/src/llama.cpp/src/llama-model.h +18 -4
  36. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  37. package/src/llama.cpp/src/llama-vocab.cpp +368 -9
  38. package/src/llama.cpp/src/llama-vocab.h +43 -0
  39. package/src/llama.cpp/src/unicode.cpp +207 -0
  40. package/src/llama.cpp/src/unicode.h +2 -0
@@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
23
23
  ggml_type type_v,
24
24
  bool v_trans,
25
25
  bool offload,
26
+ bool unified,
26
27
  uint32_t kv_size,
27
28
  uint32_t n_seq_max,
28
29
  uint32_t n_pad,
29
30
  uint32_t n_swa,
30
31
  llama_swa_type swa_type) :
31
32
  model(model), hparams(model.hparams), v_trans(v_trans),
32
- n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
33
+ n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
33
34
 
34
35
  GGML_ASSERT(kv_size % n_pad == 0);
35
36
 
@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
45
46
  auto it = ctx_map.find(buft);
46
47
  if (it == ctx_map.end()) {
47
48
  ggml_init_params params = {
48
- /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
49
+ /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
49
50
  /*.mem_buffer =*/ NULL,
50
51
  /*.no_alloc =*/ true,
51
52
  };
@@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified(
64
65
  return it->second;
65
66
  };
66
67
 
67
- head = 0;
68
+ GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
68
69
 
69
- cells.resize(kv_size);
70
+ v_heads.resize(n_stream);
71
+ for (uint32_t s = 0; s < n_stream; ++s) {
72
+ v_heads[s] = 0;
73
+ }
74
+
75
+ v_cells.resize(n_stream);
76
+ for (uint32_t s = 0; s < n_stream; ++s) {
77
+ v_cells[s].resize(kv_size);
78
+ }
79
+
80
+ // by default, all sequence ids are mapped to the 0th stream
81
+ seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
82
+
83
+ if (n_stream > 1) {
84
+ seq_to_stream.resize(n_stream, 0);
85
+ for (uint32_t s = 0; s < n_stream; ++s) {
86
+ seq_to_stream[s] = s;
87
+ }
88
+ }
89
+
90
+ // [TAG_V_CACHE_VARIABLE]
91
+ if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
92
+ LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
93
+ __func__, hparams.n_embd_v_gqa_max());
94
+ }
70
95
 
71
96
  for (uint32_t il = 0; il < n_layer_cache; il++) {
72
97
  if (filter && !filter(il)) {
@@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
74
99
  continue;
75
100
  }
76
101
 
77
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
78
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
102
+ // [TAG_V_CACHE_VARIABLE]
103
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
104
+ const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
79
105
 
80
106
  const char * dev_name = "CPU";
81
107
 
@@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98
124
  ggml_tensor * k;
99
125
  ggml_tensor * v;
100
126
 
101
- k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
127
+ k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
128
+ v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
103
129
 
104
130
  ggml_format_name(k, "cache_k_l%d", il);
105
131
  ggml_format_name(v, "cache_v_l%d", il);
106
132
 
133
+ std::vector<ggml_tensor *> k_stream;
134
+ std::vector<ggml_tensor *> v_stream;
135
+
136
+ for (uint32_t s = 0; s < n_stream; ++s) {
137
+ k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
138
+ v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
139
+ }
140
+
107
141
  map_layer_ids[il] = layers.size();
108
- layers.push_back({ il, k, v });
142
+
143
+ layers.push_back({ il, k, v, k_stream, v_stream, });
109
144
  }
110
145
 
111
146
  // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
@@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
148
183
  const size_t memory_size_k = size_k_bytes();
149
184
  const size_t memory_size_v = size_v_bytes();
150
185
 
151
- LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
152
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
186
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
187
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
153
188
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
154
189
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
155
190
  }
@@ -158,7 +193,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
158
193
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
159
194
 
160
195
  const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
161
- supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
196
+ supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
197
+
198
+ if (!supports_set_rows) {
199
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
200
+ GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
201
+ }
162
202
 
163
203
  if (!supports_set_rows) {
164
204
  LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
@@ -166,9 +206,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
166
206
  }
167
207
 
168
208
  void llama_kv_cache_unified::clear(bool data) {
169
- cells.reset();
170
-
171
- head = 0;
209
+ for (uint32_t s = 0; s < n_stream; ++s) {
210
+ v_cells[s].reset();
211
+ v_heads[s] = 0;
212
+ }
172
213
 
173
214
  if (data) {
174
215
  for (auto & buf : bufs) {
@@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) {
178
219
  }
179
220
 
180
221
  bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
222
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
223
+
224
+ auto & cells = v_cells[seq_to_stream[seq_id]];
225
+ auto & head = v_heads[seq_to_stream[seq_id]];
226
+
181
227
  uint32_t new_head = cells.size();
182
228
 
183
229
  if (p0 < 0) {
@@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
224
270
  }
225
271
 
226
272
  void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
227
- if (seq_id_src == seq_id_dst) {
273
+ GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
274
+ GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
275
+
276
+ const auto s0 = seq_to_stream[seq_id_src];
277
+ const auto s1 = seq_to_stream[seq_id_dst];
278
+
279
+ if (s0 == s1) {
280
+ // since both sequences are in the same stream, no data copy is necessary
281
+ // we just have to update the cells meta data
282
+
283
+ auto & cells = v_cells[s0];
284
+
285
+ if (seq_id_src == seq_id_dst) {
286
+ return;
287
+ }
288
+
289
+ if (p0 < 0) {
290
+ p0 = 0;
291
+ }
292
+
293
+ if (p1 < 0) {
294
+ p1 = std::numeric_limits<llama_pos>::max();
295
+ }
296
+
297
+ for (uint32_t i = 0; i < cells.size(); ++i) {
298
+ if (!cells.pos_in(i, p0, p1)) {
299
+ continue;
300
+ }
301
+
302
+ if (cells.seq_has(i, seq_id_src)) {
303
+ cells.seq_add(i, seq_id_dst);
304
+ }
305
+ }
306
+
228
307
  return;
229
308
  }
230
309
 
231
- if (p0 < 0) {
232
- p0 = 0;
310
+ // cross-stream sequence copies require to copy the actual buffer data
311
+
312
+ bool is_full = true;
313
+
314
+ if (p0 > 0 && p0 + 1 < (int) get_size()) {
315
+ is_full = false;
233
316
  }
234
317
 
235
- if (p1 < 0) {
236
- p1 = std::numeric_limits<llama_pos>::max();
318
+ if (p1 > 0 && p1 + 1 < (int) get_size()) {
319
+ is_full = false;
237
320
  }
238
321
 
239
- for (uint32_t i = 0; i < cells.size(); ++i) {
240
- if (!cells.pos_in(i, p0, p1)) {
241
- continue;
242
- }
322
+ GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
323
+
324
+ // enqueue the copy operation - the buffer copy will be performed during the next update
325
+ sc_info.ssrc.push_back(s0);
326
+ sc_info.sdst.push_back(s1);
243
327
 
244
- if (cells.seq_has(i, seq_id_src)) {
245
- cells.seq_add(i, seq_id_dst);
328
+ v_cells[s1].reset();
329
+ for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
330
+ if (v_cells[s0].seq_has(i, seq_id_src)) {
331
+ llama_pos pos = v_cells[s0].pos_get(i);
332
+ llama_pos shift = v_cells[s0].get_shift(i);
333
+
334
+ if (shift != 0) {
335
+ pos -= shift;
336
+ assert(pos >= 0);
337
+ }
338
+
339
+ v_cells[s1].pos_set(i, pos);
340
+ v_cells[s1].seq_add(i, seq_id_dst);
341
+
342
+ if (shift != 0) {
343
+ v_cells[s1].pos_add(i, shift);
344
+ }
246
345
  }
247
346
  }
347
+
348
+ v_heads[s1] = v_heads[s0];
349
+
350
+ //for (uint32_t s = 0; s < n_stream; ++s) {
351
+ // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
352
+ //}
248
353
  }
249
354
 
250
355
  void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
356
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
357
+
358
+ auto & cells = v_cells[seq_to_stream[seq_id]];
359
+ auto & head = v_heads[seq_to_stream[seq_id]];
360
+
251
361
  uint32_t new_head = cells.size();
252
362
 
253
363
  for (uint32_t i = 0; i < cells.size(); ++i) {
@@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
265
375
  }
266
376
 
267
377
  void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
378
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
379
+
380
+ auto & cells = v_cells[seq_to_stream[seq_id]];
381
+ auto & head = v_heads[seq_to_stream[seq_id]];
382
+
268
383
  if (shift == 0) {
269
384
  return;
270
385
  }
@@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
304
419
  }
305
420
 
306
421
  void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
422
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
423
+
424
+ auto & cells = v_cells[seq_to_stream[seq_id]];
425
+
307
426
  if (d == 1) {
308
427
  return;
309
428
  }
@@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
333
452
  }
334
453
 
335
454
  llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
455
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
456
+
457
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
458
+
336
459
  return cells.seq_pos_min(seq_id);
337
460
  }
338
461
 
339
462
  llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
463
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
464
+
465
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
466
+
340
467
  return cells.seq_pos_max(seq_id);
341
468
  }
342
469
 
@@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
351
478
 
352
479
  std::vector<llama_ubatch> ubatches;
353
480
  while (true) {
354
- auto ubatch = balloc.split_simple(n_ubatch);
481
+ auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
355
482
 
356
483
  if (ubatch.n_tokens == 0) {
357
484
  break;
@@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
387
514
  defrag_info dinfo;
388
515
 
389
516
  // see if we need to defrag
390
- {
517
+ if (n_stream == 1) {
518
+ // note : for now do not consider defrag for n_stream > 1
519
+ const auto & cells = v_cells[seq_to_stream[0]];
520
+
391
521
  bool do_defrag = optimize;
392
522
 
393
523
  const auto thold = lctx->get_cparams().defrag_thold;
@@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
411
541
  }
412
542
  }
413
543
 
414
- return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
544
+ return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
415
545
  }
416
546
 
417
547
  llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
418
548
  llama_kv_cache_unified::slot_info_vec_t res;
419
549
 
420
- struct state {
421
- uint32_t head_old; // old position of the head, before placing the ubatch
422
-
550
+ struct state_t {
423
551
  slot_info sinfo; // slot info for the ubatch
424
552
 
425
- llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
553
+ std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
554
+
555
+ std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
426
556
  };
427
557
 
428
558
  // remember the old state of the cells so we can restore it in the end
429
- std::vector<state> states;
559
+ std::vector<state_t> states;
430
560
 
431
561
  bool success = true;
432
562
 
@@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
445
575
  res.push_back(sinfo_new);
446
576
 
447
577
  // store the old state of the cells in the recovery stack
448
- states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
578
+ {
579
+ state_t state = { sinfo_new, v_heads, {} };
580
+
581
+ for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
582
+ auto & cells = v_cells[sinfo_new.strm[s]];
583
+
584
+ state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
585
+ }
586
+
587
+ states.push_back(std::move(state));
588
+ }
449
589
 
450
590
  // now emplace the ubatch
451
591
  apply_ubatch(sinfo_new, ubatch);
452
592
  }
453
593
 
594
+ GGML_ASSERT(!states.empty() || !success);
595
+
454
596
  // iterate backwards and restore the cells to their original state
455
597
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
456
- cells.set(it->sinfo.idxs, it->cells);
457
- head = it->head_old;
598
+ const auto & sinfo = it->sinfo;
599
+
600
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
601
+ auto & cells = v_cells[sinfo.strm[s]];
602
+ auto & head = v_heads[sinfo.strm[s]];
603
+
604
+ cells.set(sinfo.idxs[s], it->v_cells[s]);
605
+ head = it->v_heads_old[s];
606
+ }
458
607
  }
459
608
 
460
609
  if (!success) {
@@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
464
613
  return res;
465
614
  }
466
615
 
467
- bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
616
+ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
468
617
  bool updated = false;
469
618
 
470
619
  auto * sched = lctx->get_sched();
471
620
 
621
+ if (!sc_info.empty()) {
622
+ assert(n_stream > 1 && "stream copy should never happen with a single stream");
623
+
624
+ llama_synchronize(lctx);
625
+
626
+ const size_t n_copy = sc_info.ssrc.size();
627
+
628
+ for (size_t i = 0; i < n_copy; ++i) {
629
+ const auto ssrc = sc_info.ssrc[i];
630
+ const auto sdst = sc_info.sdst[i];
631
+
632
+ assert(ssrc < n_stream);
633
+ assert(sdst < n_stream);
634
+
635
+ LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
636
+
637
+ assert(ssrc != sdst);
638
+
639
+ for (uint32_t il = 0; il < layers.size(); ++il) {
640
+ const auto & layer = layers[il];
641
+
642
+ ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
643
+ ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
644
+ }
645
+ }
646
+ }
647
+
472
648
  if (do_shift) {
473
649
  if (!get_can_shift()) {
474
650
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -480,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
480
656
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
481
657
  ggml_backend_sched_reset(sched);
482
658
 
483
- auto * gf = lctx->graph_init();
659
+ auto * res = lctx->get_gf_res_reserve();
484
660
 
485
- auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
486
- if (!res) {
487
- LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
488
- return updated;
489
- }
661
+ res->reset();
490
662
 
663
+ auto * gf = build_graph_shift(res, lctx);
491
664
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
492
665
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
493
666
  return updated;
@@ -503,12 +676,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
503
676
  updated = true;
504
677
  }
505
678
 
506
- cells.reset_shift();
679
+ for (uint32_t s = 0; s < n_stream; ++s) {
680
+ auto & cells = v_cells[s];
681
+
682
+ cells.reset_shift();
683
+ }
507
684
  }
508
685
 
509
686
  if (!dinfo.empty()) {
510
687
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
511
688
 
689
+ // note: for now do not consider defrag for n_stream > 1
690
+ auto & cells = v_cells[seq_to_stream[0]];
691
+ auto & head = v_heads[seq_to_stream[0]];
692
+
512
693
  // apply moves:
513
694
  {
514
695
  const auto n_kv = dinfo.ids.size();
@@ -529,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
529
710
 
530
711
  ggml_backend_sched_reset(sched);
531
712
 
532
- auto * gf = lctx->graph_init();
713
+ auto * res = lctx->get_gf_res_reserve();
533
714
 
534
- auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
535
- if (!res) {
536
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
537
- return updated;
538
- }
715
+ res->reset();
539
716
 
717
+ auto * gf = build_graph_defrag(res, lctx, dinfo);
540
718
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
541
719
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
542
720
  return updated;
@@ -556,23 +734,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
556
734
  }
557
735
 
558
736
  llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
559
- const uint32_t n_tokens = ubatch.n_tokens;
737
+ if (debug > 0) {
738
+ const auto & cells = v_cells[seq_to_stream[1]];
560
739
 
561
- uint32_t head_cur = this->head;
740
+ const uint32_t head_cur = v_heads[1];
562
741
 
563
- // if we have enough unused cells before the current head ->
564
- // better to start searching from the beginning of the cache, hoping to fill it
565
- if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
566
- head_cur = 0;
567
- }
568
-
569
- if (n_tokens > cells.size()) {
570
- LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
571
- return { };
572
- }
573
-
574
- if (debug > 0) {
575
- LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
742
+ LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
743
+ __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
576
744
 
577
745
  if ((debug == 2 && n_swa > 0) || debug > 2) {
578
746
  std::string ss;
@@ -629,86 +797,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
629
797
  }
630
798
  }
631
799
 
632
- uint32_t n_tested = 0;
800
+ uint32_t n_tokens = ubatch.n_tokens;
801
+ uint32_t n_seqs = 1;
802
+
803
+ if (n_stream > 1) {
804
+ GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
633
805
 
634
- // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
635
- // for non-continuous slots, we test the tokens one by one
636
- const uint32_t n_test = cont ? n_tokens : 1;
806
+ n_seqs = ubatch.n_seqs_unq;
807
+ n_tokens = n_tokens / n_seqs;
808
+ }
637
809
 
638
- slot_info res;
810
+ slot_info res = {
811
+ /*.s0 =*/ LLAMA_MAX_SEQ,
812
+ /*.s1 =*/ 0,
813
+ /*.strm =*/ { },
814
+ /*.idxs =*/ { },
815
+ };
639
816
 
640
- auto & idxs = res.idxs;
817
+ res.resize(n_seqs);
641
818
 
642
- idxs.reserve(n_tokens);
819
+ for (uint32_t s = 0; s < n_seqs; ++s) {
820
+ const auto seq_id = ubatch.seq_id_unq[s];
643
821
 
644
- while (true) {
645
- if (head_cur + n_test > cells.size()) {
646
- n_tested += cells.size() - head_cur;
822
+ if (n_stream > 1) {
823
+ GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
824
+ GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
825
+ }
826
+
827
+ res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
828
+ res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
829
+
830
+ res.strm[s] = seq_to_stream[seq_id];
831
+ res.idxs[s].reserve(n_tokens);
832
+
833
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
834
+
835
+ uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
836
+
837
+ // if we have enough unused cells before the current head ->
838
+ // better to start searching from the beginning of the cache, hoping to fill it
839
+ if (head_cur > cells.get_used() + 2*n_tokens) {
647
840
  head_cur = 0;
648
- continue;
649
841
  }
650
842
 
651
- for (uint32_t i = 0; i < n_test; i++) {
652
- const auto idx = head_cur;
843
+ if (n_tokens > cells.size()) {
844
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
845
+ return { };
846
+ }
847
+
848
+ uint32_t n_tested = 0;
849
+
850
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
851
+ // for non-continuous slots, we test the tokens one by one
852
+ const uint32_t n_test = cont ? n_tokens : 1;
653
853
 
654
- //const llama_pos pos = ubatch.pos[i];
655
- //const llama_seq_id seq_id = ubatch.seq_id[i][0];
854
+ while (true) {
855
+ if (head_cur + n_test > cells.size()) {
856
+ n_tested += cells.size() - head_cur;
857
+ head_cur = 0;
858
+ continue;
859
+ }
656
860
 
657
- // can we use this cell? either:
658
- // - the cell is empty
659
- // - the cell is occupied only by one sequence:
660
- // - (disabled) mask causally, if the sequence is the same as the one we are inserting
661
- // - mask SWA, using current max pos for that sequence in the cache
662
- // always insert in the cell with minimum pos
663
- bool can_use = cells.is_empty(idx);
861
+ for (uint32_t i = 0; i < n_test; i++) {
862
+ const auto idx = head_cur;
664
863
 
665
- if (!can_use && cells.seq_count(idx) == 1) {
666
- const llama_pos pos_cell = cells.pos_get(idx);
864
+ head_cur++;
865
+ n_tested++;
667
866
 
668
- // (disabled) causal mask
669
- // note: it's better to purge any "future" tokens beforehand
670
- //if (cells.seq_has(idx, seq_id)) {
671
- // can_use = pos_cell >= pos;
672
- //}
867
+ //const llama_pos pos = ubatch.pos[i];
868
+ //const llama_seq_id seq_id = ubatch.seq_id[i][0];
673
869
 
674
- if (!can_use) {
675
- const llama_seq_id seq_id_cell = cells.seq_get(idx);
870
+ // can we use this cell? either:
871
+ // - the cell is empty
872
+ // - the cell is occupied only by one sequence:
873
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
874
+ // - mask SWA, using current max pos for that sequence in the cache
875
+ // always insert in the cell with minimum pos
876
+ bool can_use = cells.is_empty(idx);
676
877
 
677
- // SWA mask
678
- if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
679
- can_use = true;
878
+ if (!can_use && cells.seq_count(idx) == 1) {
879
+ const llama_pos pos_cell = cells.pos_get(idx);
880
+
881
+ // (disabled) causal mask
882
+ // note: it's better to purge any "future" tokens beforehand
883
+ //if (cells.seq_has(idx, seq_id)) {
884
+ // can_use = pos_cell >= pos;
885
+ //}
886
+
887
+ if (!can_use) {
888
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
889
+
890
+ // SWA mask
891
+ if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
892
+ can_use = true;
893
+ }
680
894
  }
681
895
  }
682
- }
683
896
 
684
- head_cur++;
685
- n_tested++;
897
+ if (can_use) {
898
+ res.idxs[s].push_back(idx);
899
+ } else {
900
+ if (cont) {
901
+ break;
902
+ }
903
+ }
904
+ }
686
905
 
687
- if (can_use) {
688
- idxs.push_back(idx);
689
- } else {
906
+ if (res.idxs[s].size() == n_tokens) {
690
907
  break;
691
908
  }
692
- }
693
909
 
694
- if (idxs.size() == n_tokens) {
695
- break;
696
- }
910
+ if (cont) {
911
+ res.idxs[s].clear();
912
+ }
697
913
 
698
- if (cont) {
699
- idxs.clear();
914
+ if (n_tested >= cells.size()) {
915
+ //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
916
+ return { };
917
+ }
700
918
  }
701
919
 
702
- if (n_tested >= cells.size()) {
703
- //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
920
+ // we didn't find a suitable slot - return empty result
921
+ if (res.idxs[s].size() < n_tokens) {
704
922
  return { };
705
923
  }
706
924
  }
707
925
 
708
- // we didn't find a suitable slot - return empty result
709
- if (idxs.size() < n_tokens) {
710
- res.clear();
711
- }
926
+ assert(res.s1 >= res.s0);
712
927
 
713
928
  return res;
714
929
  }
@@ -717,41 +932,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
717
932
  // keep track of the max sequence position that we would overwrite with this ubatch
718
933
  // for non-SWA cache, this would be always empty
719
934
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
720
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
935
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
721
936
  seq_pos_max_rm[s] = -1;
722
937
  }
723
938
 
724
- assert(ubatch.n_tokens == sinfo.idxs.size());
939
+ assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
725
940
 
726
- for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
727
- const auto idx = sinfo.idxs.at(i);
941
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
942
+ for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
943
+ const uint32_t i = s*sinfo.size() + ii;
728
944
 
729
- if (!cells.is_empty(idx)) {
730
- assert(cells.seq_count(idx) == 1);
945
+ auto & cells = v_cells[sinfo.strm[s]];
731
946
 
732
- const llama_seq_id seq_id = cells.seq_get(idx);
733
- const llama_pos pos = cells.pos_get(idx);
947
+ const auto idx = sinfo.idxs[s][ii];
734
948
 
735
- seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
949
+ if (!cells.is_empty(idx)) {
950
+ assert(cells.seq_count(idx) == 1);
736
951
 
737
- cells.rm(idx);
738
- }
952
+ const llama_seq_id seq_id = cells.seq_get(idx);
953
+ const llama_pos pos = cells.pos_get(idx);
739
954
 
740
- cells.pos_set(idx, ubatch.pos[i]);
955
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
956
+
957
+ cells.rm(idx);
958
+ }
741
959
 
742
- for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
743
- cells.seq_add(idx, ubatch.seq_id[i][s]);
960
+ cells.pos_set(idx, ubatch.pos[i]);
961
+
962
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
963
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
964
+ }
744
965
  }
745
966
  }
746
967
 
747
968
  // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
748
969
  // will be present in the cache. so we have to purge any position which is less than those we would overwrite
749
970
  // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
750
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
971
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
751
972
  if (seq_pos_max_rm[s] == -1) {
752
973
  continue;
753
974
  }
754
975
 
976
+ GGML_ASSERT(s < seq_to_stream.size());
977
+
978
+ auto & cells = v_cells[seq_to_stream[s]];
979
+
755
980
  if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
756
981
  LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
757
982
  __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
@@ -761,7 +986,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
761
986
  }
762
987
 
763
988
  // move the head at the end of the slot
764
- head = sinfo.idxs.back() + 1;
989
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
990
+ auto & head = v_heads[sinfo.strm[s]];
991
+
992
+ head = sinfo.idxs[s].back() + 1;
993
+ }
765
994
  }
766
995
 
767
996
  bool llama_kv_cache_unified::get_can_shift() const {
@@ -769,49 +998,91 @@ bool llama_kv_cache_unified::get_can_shift() const {
769
998
  }
770
999
 
771
1000
  uint32_t llama_kv_cache_unified::get_size() const {
1001
+ const auto & cells = v_cells[seq_to_stream[0]];
1002
+
772
1003
  return cells.size();
773
1004
  }
774
1005
 
1006
+ uint32_t llama_kv_cache_unified::get_n_stream() const {
1007
+ return n_stream;
1008
+ }
1009
+
775
1010
  bool llama_kv_cache_unified::get_has_shift() const {
776
- return cells.get_has_shift();
1011
+ bool result = false;
1012
+
1013
+ for (uint32_t s = 0; s < n_stream; ++s) {
1014
+ result |= v_cells[s].get_has_shift();
1015
+ }
1016
+
1017
+ return result;
777
1018
  }
778
1019
 
779
1020
  uint32_t llama_kv_cache_unified::get_n_kv() const {
780
- return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
1021
+ uint32_t result = 0;
1022
+
1023
+ for (uint32_t s = 0; s < n_stream; ++s) {
1024
+ const auto & cells = v_cells[s];
1025
+
1026
+ result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
1027
+ }
1028
+
1029
+ return result;
781
1030
  }
782
1031
 
783
- ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
1032
+ bool llama_kv_cache_unified::get_supports_set_rows() const {
1033
+ return supports_set_rows;
1034
+ }
1035
+
1036
+ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
784
1037
  const int32_t ikv = map_layer_ids.at(il);
785
1038
 
786
1039
  auto * k = layers[ikv].k;
787
1040
 
788
- return ggml_view_3d(ctx, k,
789
- hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
1041
+ const uint64_t kv_size = get_size();
1042
+ const uint64_t n_embd_k_gqa = k->ne[0];
1043
+
1044
+ assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
1045
+
1046
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1047
+
1048
+ return ggml_view_4d(ctx, k,
1049
+ hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
790
1050
  ggml_row_size(k->type, hparams.n_embd_head_k),
791
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
792
- 0);
1051
+ ggml_row_size(k->type, n_embd_k_gqa),
1052
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size),
1053
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
793
1054
  }
794
1055
 
795
- ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
1056
+ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
796
1057
  const int32_t ikv = map_layer_ids.at(il);
797
1058
 
798
1059
  auto * v = layers[ikv].v;
799
1060
 
1061
+ const uint64_t kv_size = get_size();
1062
+ const uint64_t n_embd_v_gqa = v->ne[0];
1063
+
1064
+ // [TAG_V_CACHE_VARIABLE]
1065
+ assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
1066
+
1067
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1068
+
800
1069
  if (!v_trans) {
801
1070
  // note: v->nb[1] <= v->nb[2]
802
- return ggml_view_3d(ctx, v,
803
- hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
804
- ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
805
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
806
- 0);
1071
+ return ggml_view_4d(ctx, v,
1072
+ hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1073
+ ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1074
+ ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1075
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
1076
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
807
1077
  }
808
1078
 
809
1079
  // note: v->nb[1] > v->nb[2]
810
- return ggml_view_3d(ctx, v,
811
- n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
812
- ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
813
- ggml_row_size(v->type, v->ne[1]), // v->nb[2]
814
- 0);
1080
+ return ggml_view_4d(ctx, v,
1081
+ n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1082
+ ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1083
+ ggml_row_size(v->type, kv_size), // v->nb[2]
1084
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
1085
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
815
1086
  }
816
1087
 
817
1088
  ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -825,12 +1096,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
825
1096
  k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
826
1097
 
827
1098
  if (k_idxs && supports_set_rows) {
1099
+ if (k->ne[2] > 1) {
1100
+ k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
1101
+ }
1102
+
828
1103
  return ggml_set_rows(ctx, k, k_cur, k_idxs);
829
1104
  }
830
1105
 
831
1106
  // TODO: fallback to old ggml_cpy() method for backwards compatibility
832
1107
  // will be removed when ggml_set_rows() is adopted by all backends
833
1108
 
1109
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
1110
+
834
1111
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
835
1112
  n_tokens*n_embd_k_gqa,
836
1113
  ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
@@ -843,37 +1120,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
843
1120
 
844
1121
  auto * v = layers[ikv].v;
845
1122
 
846
- const int64_t n_embd_v_gqa = v->ne[0];
847
- const int64_t n_tokens = v_cur->ne[2];
1123
+ const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
1124
+ const int64_t n_tokens = v_cur->ne[2];
848
1125
 
849
1126
  v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
850
1127
 
851
1128
  if (v_idxs && supports_set_rows) {
852
1129
  if (!v_trans) {
1130
+ if (v->ne[2] > 1) {
1131
+ v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1132
+ }
1133
+
853
1134
  return ggml_set_rows(ctx, v, v_cur, v_idxs);
854
1135
  }
855
1136
 
856
- // the row becomes a single element
857
- ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
1137
+ // [TAG_V_CACHE_VARIABLE]
1138
+ if (n_embd_v_gqa < v->ne[0]) {
1139
+ v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
1140
+ }
858
1141
 
859
- // note: the V cache is transposed when not using flash attention
860
- v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
1142
+ // the row becomes a single element
1143
+ ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
861
1144
 
862
- // note: we can be more explicit here at the cost of extra cont
863
- // however, above we take advantage that a row of single element is always continuous regardless of the row stride
864
- //v_cur = ggml_transpose(ctx, v_cur);
865
- //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
1145
+ v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
866
1146
 
867
- // we broadcast the KV indices n_embd_v_gqa times
868
- // v [1, n_kv, n_embd_v_gqa]
869
- // v_cur [1, n_tokens, n_embd_v_gqa]
870
- // v_idxs [n_tokens, 1, 1]
871
1147
  return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
872
1148
  }
873
1149
 
874
1150
  // TODO: fallback to old ggml_cpy() method for backwards compatibility
875
1151
  // will be removed when ggml_set_rows() is adopted by all backends
876
1152
 
1153
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
1154
+
877
1155
  ggml_tensor * v_view = nullptr;
878
1156
 
879
1157
  if (!v_trans) {
@@ -904,7 +1182,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
904
1182
  ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
905
1183
  const uint32_t n_tokens = ubatch.n_tokens;
906
1184
 
907
- ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1185
+ ggml_tensor * v_idxs;
1186
+
1187
+ if (!v_trans) {
1188
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1189
+ } else {
1190
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
1191
+ }
908
1192
 
909
1193
  ggml_set_input(v_idxs);
910
1194
 
@@ -917,12 +1201,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
917
1201
  }
918
1202
 
919
1203
  const uint32_t n_tokens = ubatch->n_tokens;
1204
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
920
1205
 
921
1206
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
922
1207
  int64_t * data = (int64_t *) dst->data;
923
1208
 
924
- for (int64_t i = 0; i < n_tokens; ++i) {
925
- data[i] = sinfo.idxs.at(i);
1209
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1210
+ const int64_t offs = sinfo.strm[s]*get_size();
1211
+
1212
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1213
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1214
+ }
926
1215
  }
927
1216
  }
928
1217
 
@@ -932,12 +1221,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
932
1221
  }
933
1222
 
934
1223
  const uint32_t n_tokens = ubatch->n_tokens;
1224
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
935
1225
 
936
1226
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
937
1227
  int64_t * data = (int64_t *) dst->data;
938
1228
 
939
- for (int64_t i = 0; i < n_tokens; ++i) {
940
- data[i] = sinfo.idxs.at(i);
1229
+ if (!v_trans) {
1230
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1231
+ const int64_t offs = sinfo.strm[s]*get_size();
1232
+
1233
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1234
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1235
+ }
1236
+ }
1237
+ } else {
1238
+ // note: the V cache is transposed when not using flash attention
1239
+ const int64_t kv_size = get_size();
1240
+
1241
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
1242
+
1243
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1244
+ const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
1245
+
1246
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1247
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1248
+ data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
1249
+ }
1250
+ }
1251
+ }
1252
+ }
1253
+ }
1254
+
1255
+ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
1256
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1257
+
1258
+ int32_t * data = (int32_t *) dst->data;
1259
+
1260
+ for (uint32_t s = 0; s < n_stream; ++s) {
1261
+ const auto & cells = v_cells[s];
1262
+
1263
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1264
+ data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1265
+ }
941
1266
  }
942
1267
  }
943
1268
 
@@ -947,7 +1272,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
947
1272
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
948
1273
  float * data = (float *) dst->data;
949
1274
 
950
- const int64_t n_kv = dst->ne[0];
1275
+ const int64_t n_kv = dst->ne[0];
1276
+ const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1277
+
1278
+ GGML_ASSERT(n_tokens%n_stream == 0);
1279
+
1280
+ // n_tps == n_tokens_per_stream
1281
+ const int64_t n_tps = n_tokens/n_stream;
1282
+ const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
1283
+
1284
+ std::fill(data, data + ggml_nelements(dst), -INFINITY);
951
1285
 
952
1286
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
953
1287
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -961,70 +1295,57 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
961
1295
  // xxxxx-----
962
1296
  // xxxxx-----
963
1297
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
1298
+ // TODO: optimize this section
964
1299
  for (uint32_t h = 0; h < 1; ++h) {
965
- for (uint32_t i = 0; i < n_tokens; ++i) {
966
- const llama_seq_id seq_id = ubatch->seq_id[i][0];
1300
+ for (uint32_t s = 0; s < n_stream; ++s) {
1301
+ for (uint32_t ii = 0; ii < n_tps; ++ii) {
1302
+ const uint32_t i = s*n_tps + ii;
967
1303
 
968
- const llama_pos p1 = ubatch->pos[i];
1304
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
969
1305
 
970
- for (uint32_t j = 0; j < n_kv; ++j) {
971
- float f = 0.0f;
1306
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
972
1307
 
973
- bool masked = false;
1308
+ const llama_pos p1 = ubatch->pos[i];
974
1309
 
975
- if (cells.is_empty(j)) {
976
- masked = true;
977
- } else {
978
- const llama_pos p0 = cells.pos_get(j);
1310
+ const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1311
+
1312
+ for (uint32_t j = 0; j < n_kv; ++j) {
1313
+ if (cells.is_empty(j)) {
1314
+ continue;
1315
+ }
979
1316
 
980
1317
  // mask the token if not the same sequence
981
- masked = masked || (!cells.seq_has(j, seq_id));
1318
+ if (!cells.seq_has(j, seq_id)) {
1319
+ continue;
1320
+ }
1321
+
1322
+ const llama_pos p0 = cells.pos_get(j);
982
1323
 
983
1324
  // mask future tokens
984
- masked = masked || (causal_attn && p0 > p1);
1325
+ if (causal_attn && p0 > p1) {
1326
+ continue;
1327
+ }
985
1328
 
986
1329
  // apply SWA if any
987
- masked = masked || (is_masked_swa(p0, p1));
988
-
989
- if (!masked && hparams.use_alibi) {
990
- f = -std::abs(p0 - p1);
1330
+ if (is_masked_swa(p0, p1)) {
1331
+ continue;
991
1332
  }
992
- }
993
-
994
- if (masked) {
995
- f = -INFINITY;
996
- }
997
-
998
- data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
999
- }
1000
- }
1001
1333
 
1002
- // mask padded tokens
1003
- if (data) {
1004
- for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
1005
- for (uint32_t j = 0; j < n_kv; ++j) {
1006
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
1334
+ data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
1007
1335
  }
1008
1336
  }
1009
1337
  }
1010
1338
  }
1011
1339
  }
1012
1340
 
1013
- void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
1014
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1015
-
1016
- int32_t * data = (int32_t *) dst->data;
1017
-
1018
- for (uint32_t i = 0; i < cells.size(); ++i) {
1019
- data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1020
- }
1021
- }
1022
-
1023
1341
  void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1024
1342
  const int64_t n_tokens = ubatch->n_tokens;
1025
1343
 
1344
+ GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
1345
+ const auto & cells = v_cells[0];
1346
+
1026
1347
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1027
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
1348
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
1028
1349
 
1029
1350
  int32_t * data = (int32_t *) dst->data;
1030
1351
 
@@ -1129,7 +1450,7 @@ public:
1129
1450
 
1130
1451
  void set_input(const llama_ubatch * ubatch) override;
1131
1452
 
1132
- ggml_tensor * k_shift; // I32 [kv_size]
1453
+ ggml_tensor * k_shift; // I32 [kv_size*n_stream]
1133
1454
 
1134
1455
  const llama_kv_cache_unified * kv_self;
1135
1456
  };
@@ -1142,20 +1463,20 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1142
1463
  }
1143
1464
  }
1144
1465
 
1145
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1146
- const llama_cparams & cparams,
1147
- ggml_context * ctx,
1148
- ggml_cgraph * gf) const {
1149
- auto res = std::make_unique<llm_graph_result>();
1466
+ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
1467
+ auto * ctx = res->get_ctx();
1468
+ auto * gf = res->get_gf();
1150
1469
 
1151
1470
  const auto & n_embd_head_k = hparams.n_embd_head_k;
1152
1471
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1153
1472
 
1154
1473
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
1155
1474
 
1156
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
1475
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
1157
1476
  ggml_set_input(inp->k_shift);
1158
1477
 
1478
+ const auto & cparams = lctx->get_cparams();
1479
+
1159
1480
  for (const auto & layer : layers) {
1160
1481
  const uint32_t il = layer.il;
1161
1482
 
@@ -1169,7 +1490,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1169
1490
 
1170
1491
  ggml_tensor * k =
1171
1492
  ggml_view_3d(ctx, layer.k,
1172
- n_embd_head_k, n_head_kv, cells.size(),
1493
+ n_embd_head_k, n_head_kv, get_size()*n_stream,
1173
1494
  ggml_row_size(layer.k->type, n_embd_head_k),
1174
1495
  ggml_row_size(layer.k->type, n_embd_k_gqa),
1175
1496
  0);
@@ -1181,18 +1502,24 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1181
1502
 
1182
1503
  res->add_input(std::move(inp));
1183
1504
 
1184
- return res;
1505
+ return gf;
1185
1506
  }
1186
1507
 
1187
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1188
- const llama_cparams & cparams,
1189
- ggml_context * ctx,
1190
- ggml_cgraph * gf,
1191
- const defrag_info & dinfo) const {
1192
- auto res = std::make_unique<llm_graph_result>();
1508
+ ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
1509
+ llm_graph_result * res,
1510
+ llama_context * lctx,
1511
+ const defrag_info & dinfo) const {
1512
+ auto * ctx = res->get_ctx();
1513
+ auto * gf = res->get_gf();
1514
+
1515
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
1516
+
1517
+ const auto & cells = v_cells[0];
1193
1518
 
1194
1519
  const auto & ids = dinfo.ids;
1195
1520
 
1521
+ const auto & cparams = lctx->get_cparams();
1522
+
1196
1523
  #if 0
1197
1524
  // CPU defrag
1198
1525
  //
@@ -1329,10 +1656,14 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1329
1656
  //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
1330
1657
  #endif
1331
1658
 
1332
- return res;
1659
+ return gf;
1333
1660
  }
1334
1661
 
1335
1662
  llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
1663
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
1664
+
1665
+ const auto & cells = v_cells[0];
1666
+
1336
1667
  const uint32_t n_layer = layers.size();
1337
1668
 
1338
1669
  const uint32_t n_kv = cells.used_max_p1();
@@ -1478,64 +1809,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1478
1809
  }
1479
1810
 
1480
1811
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1481
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1482
- uint32_t cell_count = 0;
1812
+ io.write(&n_stream, sizeof(n_stream));
1483
1813
 
1484
- // Count the number of cells with the specified seq_id
1485
- // Find all the ranges of cells with this seq id (or all, when -1)
1486
- uint32_t cell_range_begin = cells.size();
1814
+ for (uint32_t s = 0; s < n_stream; ++s) {
1815
+ cell_ranges_t cr { s, {} };
1487
1816
 
1488
- for (uint32_t i = 0; i < cells.size(); ++i) {
1489
- if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1490
- ++cell_count;
1491
- if (cell_range_begin == cells.size()) {
1492
- cell_range_begin = i;
1493
- }
1494
- } else {
1495
- if (cell_range_begin != cells.size()) {
1496
- cell_ranges.emplace_back(cell_range_begin, i);
1497
- cell_range_begin = cells.size();
1817
+ uint32_t cell_count = 0;
1818
+
1819
+ const auto & cells = v_cells[s];
1820
+
1821
+ // Count the number of cells with the specified seq_id
1822
+ // Find all the ranges of cells with this seq id (or all, when -1)
1823
+ uint32_t cell_range_begin = cells.size();
1824
+
1825
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1826
+ if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1827
+ ++cell_count;
1828
+ if (cell_range_begin == cells.size()) {
1829
+ cell_range_begin = i;
1830
+ }
1831
+ } else {
1832
+ if (cell_range_begin != cells.size()) {
1833
+ cr.data.emplace_back(cell_range_begin, i);
1834
+ cell_range_begin = cells.size();
1835
+ }
1498
1836
  }
1499
1837
  }
1500
- }
1501
1838
 
1502
- if (cell_range_begin != cells.size()) {
1503
- cell_ranges.emplace_back(cell_range_begin, cells.size());
1504
- }
1839
+ if (cell_range_begin != cells.size()) {
1840
+ cr.data.emplace_back(cell_range_begin, cells.size());
1841
+ }
1505
1842
 
1506
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1507
- uint32_t cell_count_check = 0;
1508
- for (const auto & range : cell_ranges) {
1509
- cell_count_check += range.second - range.first;
1510
- }
1511
- GGML_ASSERT(cell_count == cell_count_check);
1843
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1844
+ uint32_t cell_count_check = 0;
1845
+ for (const auto & range : cr.data) {
1846
+ cell_count_check += range.second - range.first;
1847
+ }
1848
+ GGML_ASSERT(cell_count == cell_count_check);
1512
1849
 
1513
- io.write(&cell_count, sizeof(cell_count));
1850
+ io.write(&cell_count, sizeof(cell_count));
1514
1851
 
1515
- state_write_meta(io, cell_ranges, seq_id);
1516
- state_write_data(io, cell_ranges);
1852
+ // skip empty streams
1853
+ if (cell_count == 0) {
1854
+ continue;
1855
+ }
1856
+
1857
+ state_write_meta(io, cr, seq_id);
1858
+ state_write_data(io, cr);
1859
+ }
1517
1860
  }
1518
1861
 
1519
1862
  void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1520
- uint32_t cell_count;
1521
- io.read_to(&cell_count, sizeof(cell_count));
1863
+ GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
1522
1864
 
1523
- bool res = true;
1524
- res = res && state_read_meta(io, cell_count, seq_id);
1525
- res = res && state_read_data(io, cell_count);
1865
+ uint32_t n_stream_cur;
1866
+ io.read_to(&n_stream_cur, sizeof(n_stream_cur));
1867
+ if (n_stream_cur != n_stream) {
1868
+ throw std::runtime_error("n_stream mismatch");
1869
+ }
1870
+
1871
+ for (uint32_t s = 0; s < n_stream; ++s) {
1872
+ uint32_t cell_count;
1873
+ io.read_to(&cell_count, sizeof(cell_count));
1874
+
1875
+ if (cell_count == 0) {
1876
+ continue;
1877
+ }
1526
1878
 
1527
- if (!res) {
1528
- if (seq_id == -1) {
1529
- clear(true);
1530
- } else {
1531
- seq_rm(seq_id, -1, -1);
1879
+ const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
1880
+
1881
+ bool res = true;
1882
+ res = res && state_read_meta(io, strm, cell_count, seq_id);
1883
+ res = res && state_read_data(io, strm, cell_count);
1884
+
1885
+ if (!res) {
1886
+ if (seq_id == -1) {
1887
+ clear(true);
1888
+ } else {
1889
+ seq_rm(seq_id, -1, -1);
1890
+ }
1891
+ throw std::runtime_error("failed to restore kv cache");
1532
1892
  }
1533
- throw std::runtime_error("failed to restore kv cache");
1534
1893
  }
1535
1894
  }
1536
1895
 
1537
- void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1538
- for (const auto & range : cell_ranges) {
1896
+ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
1897
+ const auto & cells = v_cells[cr.strm];
1898
+
1899
+ for (const auto & range : cr.data) {
1539
1900
  for (uint32_t i = range.first; i < range.second; ++i) {
1540
1901
  std::vector<llama_seq_id> seq_ids;
1541
1902
 
@@ -1560,7 +1921,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
1560
1921
  }
1561
1922
  }
1562
1923
 
1563
- void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1924
+ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
1925
+ const auto & cells = v_cells[cr.strm];
1926
+
1564
1927
  const uint32_t v_trans = this->v_trans ? 1 : 0;
1565
1928
  const uint32_t n_layer = layers.size();
1566
1929
 
@@ -1576,19 +1939,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1576
1939
 
1577
1940
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1578
1941
 
1942
+ auto * k = layer.k_stream[cr.strm];
1943
+
1579
1944
  // Write key type
1580
- const int32_t k_type_i = (int32_t)layer.k->type;
1945
+ const int32_t k_type_i = (int32_t) k->type;
1581
1946
  io.write(&k_type_i, sizeof(k_type_i));
1582
1947
 
1583
1948
  // Write row size of key
1584
- const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1949
+ const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1585
1950
  io.write(&k_size_row, sizeof(k_size_row));
1586
1951
 
1587
1952
  // Read each range of cells of k_size length each into tmp_buf and write out
1588
- for (const auto & range : cell_ranges) {
1953
+ for (const auto & range : cr.data) {
1589
1954
  const size_t range_size = range.second - range.first;
1590
1955
  const size_t buf_size = range_size * k_size_row;
1591
- io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1956
+ io.write_tensor(k, range.first * k_size_row, buf_size);
1592
1957
  }
1593
1958
  }
1594
1959
 
@@ -1598,19 +1963,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1598
1963
 
1599
1964
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1600
1965
 
1966
+ auto * v = layer.v_stream[cr.strm];
1967
+
1601
1968
  // Write value type
1602
- const int32_t v_type_i = (int32_t)layer.v->type;
1969
+ const int32_t v_type_i = (int32_t) v->type;
1603
1970
  io.write(&v_type_i, sizeof(v_type_i));
1604
1971
 
1605
1972
  // Write row size of value
1606
- const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1973
+ const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1607
1974
  io.write(&v_size_row, sizeof(v_size_row));
1608
1975
 
1609
1976
  // Read each range of cells of v_size length each into tmp_buf and write out
1610
- for (const auto & range : cell_ranges) {
1977
+ for (const auto & range : cr.data) {
1611
1978
  const size_t range_size = range.second - range.first;
1612
1979
  const size_t buf_size = range_size * v_size_row;
1613
- io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1980
+ io.write_tensor(v, range.first * v_size_row, buf_size);
1614
1981
  }
1615
1982
  }
1616
1983
  } else {
@@ -1622,12 +1989,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1622
1989
 
1623
1990
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1624
1991
 
1992
+ auto * v = layer.v_stream[cr.strm];
1993
+
1625
1994
  // Write value type
1626
- const int32_t v_type_i = (int32_t)layer.v->type;
1995
+ const int32_t v_type_i = (int32_t) v->type;
1627
1996
  io.write(&v_type_i, sizeof(v_type_i));
1628
1997
 
1629
1998
  // Write element size
1630
- const uint32_t v_size_el = ggml_type_size(layer.v->type);
1999
+ const uint32_t v_size_el = ggml_type_size(v->type);
1631
2000
  io.write(&v_size_el, sizeof(v_size_el));
1632
2001
 
1633
2002
  // Write GQA embedding size
@@ -1636,27 +2005,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1636
2005
  // For each row, we get the element values of each cell
1637
2006
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1638
2007
  // Read each range of cells of v_size_el length each into tmp_buf and write out
1639
- for (const auto & range : cell_ranges) {
2008
+ for (const auto & range : cr.data) {
1640
2009
  const size_t range_size = range.second - range.first;
1641
2010
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1642
2011
  const size_t buf_size = range_size * v_size_el;
1643
- io.write_tensor(layer.v, src_offset, buf_size);
2012
+ io.write_tensor(v, src_offset, buf_size);
1644
2013
  }
1645
2014
  }
1646
2015
  }
1647
2016
  }
1648
2017
  }
1649
2018
 
1650
- bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
2019
+ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
2020
+ auto & cells = v_cells[strm];
2021
+ auto & head = v_heads[strm];
2022
+
1651
2023
  if (dest_seq_id != -1) {
1652
2024
  // single sequence
1653
-
1654
2025
  seq_rm(dest_seq_id, -1, -1);
1655
2026
 
1656
2027
  llama_batch_allocr balloc(hparams.n_pos_per_embd());
1657
2028
 
1658
2029
  llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1659
2030
 
2031
+ ubatch.seq_id_unq[0] = dest_seq_id;
2032
+
1660
2033
  for (uint32_t i = 0; i < cell_count; ++i) {
1661
2034
  llama_pos pos;
1662
2035
  uint32_t n_seq_id;
@@ -1693,6 +2066,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1693
2066
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1694
2067
  head = head_cur;
1695
2068
 
2069
+ LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
2070
+
1696
2071
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1697
2072
  // Assume that this is one contiguous block of cells
1698
2073
  GGML_ASSERT(head_cur + cell_count <= cells.size());
@@ -1738,7 +2113,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1738
2113
  return true;
1739
2114
  }
1740
2115
 
1741
- bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
2116
+ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
2117
+ auto & cells = v_cells[strm];
2118
+ auto & head = v_heads[strm];
2119
+
1742
2120
  uint32_t v_trans;
1743
2121
  uint32_t n_layer;
1744
2122
 
@@ -1766,10 +2144,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1766
2144
 
1767
2145
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1768
2146
 
2147
+ auto * k = layer.k_stream[strm];
2148
+
1769
2149
  // Read type of key
1770
2150
  int32_t k_type_i_ref;
1771
2151
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1772
- const int32_t k_type_i = (int32_t) layer.k->type;
2152
+ const int32_t k_type_i = (int32_t) k->type;
1773
2153
  if (k_type_i != k_type_i_ref) {
1774
2154
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1775
2155
  return false;
@@ -1778,7 +2158,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1778
2158
  // Read row size of key
1779
2159
  uint64_t k_size_row_ref;
1780
2160
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1781
- const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
2161
+ const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1782
2162
  if (k_size_row != k_size_row_ref) {
1783
2163
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1784
2164
  return false;
@@ -1786,7 +2166,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1786
2166
 
1787
2167
  if (cell_count) {
1788
2168
  // Read and set the keys for the whole cell range
1789
- ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
2169
+ ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1790
2170
  }
1791
2171
  }
1792
2172
 
@@ -1796,10 +2176,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1796
2176
 
1797
2177
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1798
2178
 
2179
+ auto * v = layer.v_stream[strm];
2180
+
1799
2181
  // Read type of value
1800
2182
  int32_t v_type_i_ref;
1801
2183
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1802
- const int32_t v_type_i = (int32_t)layer.v->type;
2184
+ const int32_t v_type_i = (int32_t) v->type;
1803
2185
  if (v_type_i != v_type_i_ref) {
1804
2186
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1805
2187
  return false;
@@ -1808,7 +2190,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1808
2190
  // Read row size of value
1809
2191
  uint64_t v_size_row_ref;
1810
2192
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1811
- const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
2193
+ const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1812
2194
  if (v_size_row != v_size_row_ref) {
1813
2195
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1814
2196
  return false;
@@ -1816,7 +2198,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1816
2198
 
1817
2199
  if (cell_count) {
1818
2200
  // Read and set the values for the whole cell range
1819
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
2201
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1820
2202
  }
1821
2203
  }
1822
2204
  } else {
@@ -1826,10 +2208,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1826
2208
 
1827
2209
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1828
2210
 
2211
+ auto * v = layer.v_stream[strm];
2212
+
1829
2213
  // Read type of value
1830
2214
  int32_t v_type_i_ref;
1831
2215
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1832
- const int32_t v_type_i = (int32_t)layer.v->type;
2216
+ const int32_t v_type_i = (int32_t) v->type;
1833
2217
  if (v_type_i != v_type_i_ref) {
1834
2218
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1835
2219
  return false;
@@ -1838,7 +2222,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1838
2222
  // Read element size of value
1839
2223
  uint32_t v_size_el_ref;
1840
2224
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1841
- const size_t v_size_el = ggml_type_size(layer.v->type);
2225
+ const size_t v_size_el = ggml_type_size(v->type);
1842
2226
  if (v_size_el != v_size_el_ref) {
1843
2227
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1844
2228
  return false;
@@ -1856,7 +2240,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1856
2240
  // For each row in the transposed matrix, read the values for the whole cell range
1857
2241
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1858
2242
  const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1859
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2243
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1860
2244
  }
1861
2245
  }
1862
2246
  }
@@ -1875,18 +2259,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1875
2259
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1876
2260
  n_kv = kv->get_size();
1877
2261
 
2262
+ const uint32_t n_stream = kv->get_n_stream();
2263
+
1878
2264
  // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
1879
2265
  sinfos.resize(1);
1880
- sinfos[0].idxs.resize(1);
1881
- sinfos[0].idxs[0] = 0;
2266
+ sinfos[0].s0 = 0;
2267
+ sinfos[0].s1 = n_stream - 1;
2268
+ sinfos[0].idxs.resize(n_stream);
2269
+ for (uint32_t s = 0; s < n_stream; ++s) {
2270
+ sinfos[0].strm.push_back(s);
2271
+ sinfos[0].idxs[s].resize(1, 0);
2272
+ }
1882
2273
  }
1883
2274
 
1884
2275
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1885
2276
  llama_kv_cache_unified * kv,
1886
2277
  llama_context * lctx,
1887
2278
  bool do_shift,
1888
- defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1889
- if (!do_shift && this->dinfo.empty()) {
2279
+ defrag_info dinfo,
2280
+ stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
2281
+ if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
1890
2282
  status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1891
2283
  }
1892
2284
  }
@@ -1914,7 +2306,7 @@ bool llama_kv_cache_unified_context::apply() {
1914
2306
 
1915
2307
  // no ubatches -> this is a KV cache update
1916
2308
  if (ubatches.empty()) {
1917
- kv->update(lctx, do_shift, dinfo);
2309
+ kv->update(lctx, do_shift, dinfo, sc_info);
1918
2310
 
1919
2311
  return true;
1920
2312
  }
@@ -1940,12 +2332,16 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
1940
2332
  return n_kv;
1941
2333
  }
1942
2334
 
2335
+ bool llama_kv_cache_unified_context::get_supports_set_rows() const {
2336
+ return kv->get_supports_set_rows();
2337
+ }
2338
+
1943
2339
  ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
1944
- return kv->get_k(ctx, il, n_kv);
2340
+ return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
1945
2341
  }
1946
2342
 
1947
2343
  ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
1948
- return kv->get_v(ctx, il, n_kv);
2344
+ return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
1949
2345
  }
1950
2346
 
1951
2347
  ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {