@fugood/llama.node 1.0.1 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/package.json +14 -14
  2. package/scripts/llama.cpp.patch +12 -12
  3. package/src/llama.cpp/CMakeLists.txt +0 -1
  4. package/src/llama.cpp/common/arg.cpp +17 -0
  5. package/src/llama.cpp/common/chat.cpp +37 -20
  6. package/src/llama.cpp/common/chat.h +2 -0
  7. package/src/llama.cpp/common/common.h +4 -0
  8. package/src/llama.cpp/ggml/CMakeLists.txt +7 -2
  9. package/src/llama.cpp/ggml/include/ggml-backend.h +1 -1
  10. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -0
  11. package/src/llama.cpp/ggml/include/ggml.h +181 -10
  12. package/src/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +38 -1
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +1 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1297 -211
  17. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +7 -0
  18. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  19. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  20. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +103 -9
  21. package/src/llama.cpp/include/llama.h +1 -0
  22. package/src/llama.cpp/src/llama-arch.cpp +108 -2
  23. package/src/llama.cpp/src/llama-arch.h +7 -0
  24. package/src/llama.cpp/src/llama-batch.cpp +27 -1
  25. package/src/llama.cpp/src/llama-batch.h +8 -1
  26. package/src/llama.cpp/src/llama-chat.cpp +15 -0
  27. package/src/llama.cpp/src/llama-chat.h +1 -0
  28. package/src/llama.cpp/src/llama-graph.cpp +95 -81
  29. package/src/llama.cpp/src/llama-graph.h +43 -16
  30. package/src/llama.cpp/src/llama-hparams.cpp +2 -1
  31. package/src/llama.cpp/src/llama-hparams.h +1 -0
  32. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  34. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  35. package/src/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  36. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  37. package/src/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  38. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  39. package/src/llama.cpp/src/llama-memory-recurrent.cpp +34 -16
  40. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  41. package/src/llama.cpp/src/llama-memory.h +3 -0
  42. package/src/llama.cpp/src/llama-model.cpp +1374 -210
  43. package/src/llama.cpp/src/llama-model.h +3 -0
  44. package/src/llama.cpp/src/llama-vocab.cpp +8 -1
  45. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
156
156
 
157
157
  const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
158
158
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
159
+
160
+ const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
161
+ supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
162
+
163
+ if (!supports_set_rows) {
164
+ LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
165
+ }
159
166
  }
160
167
 
161
168
  void llama_kv_cache_unified::clear(bool data) {
@@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
353
360
  ubatches.push_back(std::move(ubatch)); // NOLINT
354
361
  }
355
362
 
356
- auto heads = prepare(ubatches);
357
- if (heads.empty()) {
363
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
364
+ // failed to find a suitable split
365
+ break;
366
+ }
367
+
368
+ auto sinfos = prepare(ubatches);
369
+ if (sinfos.empty()) {
358
370
  break;
359
371
  }
360
372
 
361
373
  return std::make_unique<llama_kv_cache_unified_context>(
362
- this, std::move(heads), std::move(ubatches));
374
+ this, std::move(sinfos), std::move(ubatches));
363
375
  } while (false);
364
376
 
365
377
  return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
402
414
  return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
403
415
  }
404
416
 
405
- llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
406
- llama_kv_cache_unified::ubatch_heads res;
417
+ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
418
+ llama_kv_cache_unified::slot_info_vec_t res;
407
419
 
408
420
  struct state {
409
421
  uint32_t head_old; // old position of the head, before placing the ubatch
410
- uint32_t head_new; // new position of the head, after placing the ubatch
422
+
423
+ slot_info sinfo; // slot info for the ubatch
411
424
 
412
425
  llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
413
426
  };
@@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
418
431
  bool success = true;
419
432
 
420
433
  for (const auto & ubatch : ubatches) {
434
+ // non-continuous slots require support for ggml_set_rows()
435
+ const bool cont = supports_set_rows ? false : true;
436
+
421
437
  // only find a suitable slot for the ubatch. don't modify the cells yet
422
- const int32_t head_new = find_slot(ubatch);
423
- if (head_new < 0) {
438
+ const auto sinfo_new = find_slot(ubatch, cont);
439
+ if (sinfo_new.empty()) {
424
440
  success = false;
425
441
  break;
426
442
  }
427
443
 
428
444
  // remeber the position that we found
429
- res.push_back(head_new);
445
+ res.push_back(sinfo_new);
430
446
 
431
447
  // store the old state of the cells in the recovery stack
432
- states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
448
+ states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
433
449
 
434
450
  // now emplace the ubatch
435
- apply_ubatch(head_new, ubatch);
451
+ apply_ubatch(sinfo_new, ubatch);
436
452
  }
437
453
 
438
454
  // iterate backwards and restore the cells to their original state
439
455
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
440
- cells.set(it->head_new, it->cells);
456
+ cells.set(it->sinfo.idxs, it->cells);
441
457
  head = it->head_old;
442
458
  }
443
459
 
@@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
539
555
  return updated;
540
556
  }
541
557
 
542
- int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
558
+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
543
559
  const uint32_t n_tokens = ubatch.n_tokens;
544
560
 
545
561
  uint32_t head_cur = this->head;
@@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
552
568
 
553
569
  if (n_tokens > cells.size()) {
554
570
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
555
- return -1;
571
+ return { };
556
572
  }
557
573
 
558
574
  if (debug > 0) {
@@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
615
631
 
616
632
  uint32_t n_tested = 0;
617
633
 
634
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
635
+ // for non-continuous slots, we test the tokens one by one
636
+ const uint32_t n_test = cont ? n_tokens : 1;
637
+
638
+ slot_info res;
639
+
640
+ auto & idxs = res.idxs;
641
+
642
+ idxs.reserve(n_tokens);
643
+
618
644
  while (true) {
619
- if (head_cur + n_tokens > cells.size()) {
645
+ if (head_cur + n_test > cells.size()) {
620
646
  n_tested += cells.size() - head_cur;
621
647
  head_cur = 0;
622
648
  continue;
623
649
  }
624
650
 
625
- bool found = true;
626
- for (uint32_t i = 0; i < n_tokens; i++) {
651
+ for (uint32_t i = 0; i < n_test; i++) {
652
+ const auto idx = head_cur;
653
+
627
654
  //const llama_pos pos = ubatch.pos[i];
628
655
  //const llama_seq_id seq_id = ubatch.seq_id[i][0];
629
656
 
@@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
633
660
  // - (disabled) mask causally, if the sequence is the same as the one we are inserting
634
661
  // - mask SWA, using current max pos for that sequence in the cache
635
662
  // always insert in the cell with minimum pos
636
- bool can_use = cells.is_empty(head_cur + i);
663
+ bool can_use = cells.is_empty(idx);
637
664
 
638
- if (!can_use && cells.seq_count(head_cur + i) == 1) {
639
- const llama_pos pos_cell = cells.pos_get(head_cur + i);
665
+ if (!can_use && cells.seq_count(idx) == 1) {
666
+ const llama_pos pos_cell = cells.pos_get(idx);
640
667
 
641
668
  // (disabled) causal mask
642
669
  // note: it's better to purge any "future" tokens beforehand
643
- //if (cells.seq_has(head_cur + i, seq_id)) {
670
+ //if (cells.seq_has(idx, seq_id)) {
644
671
  // can_use = pos_cell >= pos;
645
672
  //}
646
673
 
647
674
  if (!can_use) {
648
- const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
675
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
649
676
 
650
677
  // SWA mask
651
678
  if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
654
681
  }
655
682
  }
656
683
 
657
- if (!can_use) {
658
- found = false;
659
- head_cur += i + 1;
660
- n_tested += i + 1;
684
+ head_cur++;
685
+ n_tested++;
686
+
687
+ if (can_use) {
688
+ idxs.push_back(idx);
689
+ } else {
661
690
  break;
662
691
  }
663
692
  }
664
693
 
665
- if (found) {
694
+ if (idxs.size() == n_tokens) {
666
695
  break;
667
696
  }
668
697
 
698
+ if (cont) {
699
+ idxs.clear();
700
+ }
701
+
669
702
  if (n_tested >= cells.size()) {
670
703
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
671
- return -1;
704
+ return { };
672
705
  }
673
706
  }
674
707
 
675
- return head_cur;
708
+ // we didn't find a suitable slot - return empty result
709
+ if (idxs.size() < n_tokens) {
710
+ res.clear();
711
+ }
712
+
713
+ return res;
676
714
  }
677
715
 
678
- void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
716
+ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
679
717
  // keep track of the max sequence position that we would overwrite with this ubatch
680
718
  // for non-SWA cache, this would be always empty
681
719
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
683
721
  seq_pos_max_rm[s] = -1;
684
722
  }
685
723
 
724
+ assert(ubatch.n_tokens == sinfo.idxs.size());
725
+
686
726
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
687
- if (!cells.is_empty(head_cur + i)) {
688
- assert(cells.seq_count(head_cur + i) == 1);
727
+ const auto idx = sinfo.idxs.at(i);
689
728
 
690
- const llama_seq_id seq_id = cells.seq_get(head_cur + i);
691
- const llama_pos pos = cells.pos_get(head_cur + i);
729
+ if (!cells.is_empty(idx)) {
730
+ assert(cells.seq_count(idx) == 1);
731
+
732
+ const llama_seq_id seq_id = cells.seq_get(idx);
733
+ const llama_pos pos = cells.pos_get(idx);
692
734
 
693
735
  seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
694
736
 
695
- cells.rm(head_cur + i);
737
+ cells.rm(idx);
696
738
  }
697
739
 
698
- cells.pos_set(head_cur + i, ubatch.pos[i]);
740
+ cells.pos_set(idx, ubatch.pos[i]);
699
741
 
700
742
  for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
701
- cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
743
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
702
744
  }
703
745
  }
704
746
 
@@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
719
761
  }
720
762
 
721
763
  // move the head at the end of the slot
722
- head = head_cur + ubatch.n_tokens;
764
+ head = sinfo.idxs.back() + 1;
723
765
  }
724
766
 
725
767
  bool llama_kv_cache_unified::get_can_shift() const {
@@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
772
814
  0);
773
815
  }
774
816
 
775
- ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
817
+ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
776
818
  const int32_t ikv = map_layer_ids.at(il);
777
819
 
778
820
  auto * k = layers[ikv].k;
779
821
 
822
+ const int64_t n_embd_k_gqa = k->ne[0];
780
823
  const int64_t n_tokens = k_cur->ne[2];
781
824
 
825
+ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
826
+
827
+ if (k_idxs && supports_set_rows) {
828
+ return ggml_set_rows(ctx, k, k_cur, k_idxs);
829
+ }
830
+
831
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
832
+ // will be removed when ggml_set_rows() is adopted by all backends
833
+
782
834
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
783
- n_tokens*hparams.n_embd_k_gqa(il),
784
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
835
+ n_tokens*n_embd_k_gqa,
836
+ ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
785
837
 
786
838
  return ggml_cpy(ctx, k_cur, k_view);
787
839
  }
788
840
 
789
- ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
841
+ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
790
842
  const int32_t ikv = map_layer_ids.at(il);
791
843
 
792
844
  auto * v = layers[ikv].v;
793
845
 
846
+ const int64_t n_embd_v_gqa = v->ne[0];
794
847
  const int64_t n_tokens = v_cur->ne[2];
795
848
 
796
- v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
849
+ v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
850
+
851
+ if (v_idxs && supports_set_rows) {
852
+ if (!v_trans) {
853
+ return ggml_set_rows(ctx, v, v_cur, v_idxs);
854
+ }
855
+
856
+ // the row becomes a single element
857
+ ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
858
+
859
+ // note: the V cache is transposed when not using flash attention
860
+ v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
861
+
862
+ // note: we can be more explicit here at the cost of extra cont
863
+ // however, above we take advantage that a row of single element is always continuous regardless of the row stride
864
+ //v_cur = ggml_transpose(ctx, v_cur);
865
+ //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
866
+
867
+ // we broadcast the KV indices n_embd_v_gqa times
868
+ // v [1, n_kv, n_embd_v_gqa]
869
+ // v_cur [1, n_tokens, n_embd_v_gqa]
870
+ // v_idxs [n_tokens, 1, 1]
871
+ return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
872
+ }
873
+
874
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
875
+ // will be removed when ggml_set_rows() is adopted by all backends
797
876
 
798
877
  ggml_tensor * v_view = nullptr;
799
878
 
800
879
  if (!v_trans) {
801
880
  v_view = ggml_view_1d(ctx, v,
802
- n_tokens*hparams.n_embd_v_gqa(il),
803
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
881
+ n_tokens*n_embd_v_gqa,
882
+ ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
804
883
  } else {
805
- // note: the V cache is transposed when not using flash attention
806
- v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
807
- (v->ne[1])*ggml_element_size(v),
808
- (head_cur)*ggml_element_size(v));
809
-
810
884
  v_cur = ggml_transpose(ctx, v_cur);
885
+
886
+ v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
887
+ (v->ne[1] )*ggml_element_size(v),
888
+ (sinfo.head())*ggml_element_size(v));
811
889
  }
812
890
 
813
891
  return ggml_cpy(ctx, v_cur, v_view);
814
892
  }
815
893
 
894
+ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
895
+ const uint32_t n_tokens = ubatch.n_tokens;
896
+
897
+ ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
898
+
899
+ ggml_set_input(k_idxs);
900
+
901
+ return k_idxs;
902
+ }
903
+
904
+ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
905
+ const uint32_t n_tokens = ubatch.n_tokens;
906
+
907
+ ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
908
+
909
+ ggml_set_input(v_idxs);
910
+
911
+ return v_idxs;
912
+ }
913
+
914
+ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
915
+ if (!supports_set_rows) {
916
+ return;
917
+ }
918
+
919
+ const uint32_t n_tokens = ubatch->n_tokens;
920
+
921
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
922
+ int64_t * data = (int64_t *) dst->data;
923
+
924
+ for (int64_t i = 0; i < n_tokens; ++i) {
925
+ data[i] = sinfo.idxs.at(i);
926
+ }
927
+ }
928
+
929
+ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
930
+ if (!supports_set_rows) {
931
+ return;
932
+ }
933
+
934
+ const uint32_t n_tokens = ubatch->n_tokens;
935
+
936
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
937
+ int64_t * data = (int64_t *) dst->data;
938
+
939
+ for (int64_t i = 0; i < n_tokens; ++i) {
940
+ data[i] = sinfo.idxs.at(i);
941
+ }
942
+ }
943
+
816
944
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
817
945
  const uint32_t n_tokens = ubatch->n_tokens;
818
946
 
@@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1552
1680
  ubatch.seq_id[i] = &dest_seq_id;
1553
1681
  }
1554
1682
 
1555
- const auto head_cur = find_slot(ubatch);
1556
- if (head_cur < 0) {
1683
+ const auto sinfo = find_slot(ubatch, true);
1684
+ if (sinfo.empty()) {
1557
1685
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1558
1686
  return false;
1559
1687
  }
1560
1688
 
1561
- apply_ubatch(head_cur, ubatch);
1689
+ apply_ubatch(sinfo, ubatch);
1690
+
1691
+ const auto head_cur = sinfo.head();
1562
1692
 
1563
1693
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1564
1694
  head = head_cur;
@@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
1744
1874
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1745
1875
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1746
1876
  n_kv = kv->get_size();
1747
- head = 0;
1877
+
1878
+ // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
1879
+ sinfos.resize(1);
1880
+ sinfos[0].idxs.resize(1);
1881
+ sinfos[0].idxs[0] = 0;
1748
1882
  }
1749
1883
 
1750
1884
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1759
1893
 
1760
1894
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1761
1895
  llama_kv_cache_unified * kv,
1762
- llama_kv_cache_unified::ubatch_heads heads,
1763
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1896
+ llama_kv_cache_unified::slot_info_vec_t sinfos,
1897
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
1764
1898
  }
1765
1899
 
1766
1900
  llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
1768
1902
  bool llama_kv_cache_unified_context::next() {
1769
1903
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1770
1904
 
1771
- if (++i_next >= ubatches.size()) {
1905
+ if (++i_cur >= ubatches.size()) {
1772
1906
  return false;
1773
1907
  }
1774
1908
 
@@ -1776,7 +1910,7 @@ bool llama_kv_cache_unified_context::next() {
1776
1910
  }
1777
1911
 
1778
1912
  bool llama_kv_cache_unified_context::apply() {
1779
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1913
+ assert(!llama_memory_status_is_fail(status));
1780
1914
 
1781
1915
  // no ubatches -> this is a KV cache update
1782
1916
  if (ubatches.empty()) {
@@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() {
1785
1919
  return true;
1786
1920
  }
1787
1921
 
1788
- kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1922
+ kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
1789
1923
 
1790
1924
  n_kv = kv->get_n_kv();
1791
- head = heads[i_next];
1792
1925
 
1793
1926
  return true;
1794
1927
  }
@@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1800
1933
  const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1801
1934
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1802
1935
 
1803
- return ubatches[i_next];
1936
+ return ubatches[i_cur];
1804
1937
  }
1805
1938
 
1806
1939
  uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
1815
1948
  return kv->get_v(ctx, il, n_kv);
1816
1949
  }
1817
1950
 
1818
- ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1819
- return kv->cpy_k(ctx, k_cur, il, head);
1951
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1952
+ return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1953
+ }
1954
+
1955
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1956
+ return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
1957
+ }
1958
+
1959
+ ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1960
+ return kv->build_input_k_idxs(ctx, ubatch);
1820
1961
  }
1821
1962
 
1822
- ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1823
- return kv->cpy_v(ctx, v_cur, il, head);
1963
+ ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1964
+ return kv->build_input_v_idxs(ctx, ubatch);
1824
1965
  }
1825
1966
 
1826
1967
  void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1827
1968
  kv->set_input_k_shift(dst);
1828
1969
  }
1829
1970
 
1971
+ void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1972
+ kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
1973
+ }
1974
+
1975
+ void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1976
+ kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
1977
+ }
1978
+
1830
1979
  void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1831
1980
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1832
1981
  }
@@ -24,8 +24,6 @@ public:
24
24
  // this callback is used to filter out layers that should not be included in the cache
25
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
26
 
27
- using ubatch_heads = std::vector<uint32_t>;
28
-
29
27
  struct defrag_info {
30
28
  bool empty() const {
31
29
  return ids.empty();
@@ -37,6 +35,32 @@ public:
37
35
  std::vector<uint32_t> ids;
38
36
  };
39
37
 
38
+ // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
39
+ // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
40
+ struct slot_info {
41
+ // data for ggml_set_rows
42
+ using idx_vec_t = std::vector<uint32_t>;
43
+
44
+ idx_vec_t idxs;
45
+
46
+ uint32_t head() const {
47
+ return idxs.at(0);
48
+ }
49
+
50
+ bool empty() const {
51
+ return idxs.empty();
52
+ }
53
+
54
+ void clear() {
55
+ idxs.clear();
56
+ }
57
+
58
+ // TODO: implement
59
+ //std::vector<idx_vec_t> seq_idxs;
60
+ };
61
+
62
+ using slot_info_vec_t = std::vector<slot_info>;
63
+
40
64
  llama_kv_cache_unified(
41
65
  const llama_model & model,
42
66
  layer_filter_cb && filter,
@@ -102,30 +126,37 @@ public:
102
126
  ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
103
127
 
104
128
  // store k_cur and v_cur in the cache based on the provided head location
105
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
106
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
129
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
130
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
107
131
 
108
132
  //
109
133
  // preparation API
110
134
  //
111
135
 
112
- // find places for the provided ubatches in the cache, returns the head locations
136
+ // find places for the provided ubatches in the cache, returns the slot infos
113
137
  // return empty vector on failure
114
- ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
138
+ slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
115
139
 
116
140
  bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
117
141
 
118
- // return the cell position where we can insert the ubatch
119
- // return -1 on failure to find a contiguous slot of kv cells
120
- int32_t find_slot(const llama_ubatch & ubatch) const;
142
+ // find a slot of kv cells that can hold the ubatch
143
+ // if cont == true, then the slot must be continuous
144
+ // return empty slot_info on failure
145
+ slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
121
146
 
122
- // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
123
- void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
147
+ // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
148
+ void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
124
149
 
125
150
  //
126
- // set_input API
151
+ // input API
127
152
  //
128
153
 
154
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
155
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
156
+
157
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
158
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
159
+
129
160
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
130
161
  void set_input_k_shift (ggml_tensor * dst) const;
131
162
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -157,8 +188,13 @@ private:
157
188
  // SWA
158
189
  const uint32_t n_swa = 0;
159
190
 
191
+ // env: LLAMA_KV_CACHE_DEBUG
160
192
  int debug = 0;
161
193
 
194
+ // env: LLAMA_SET_ROWS (temporary)
195
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14285
196
+ int supports_set_rows = false;
197
+
162
198
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
163
199
 
164
200
  std::vector<ggml_context_ptr> ctxs;
@@ -211,8 +247,8 @@ private:
211
247
  class llama_kv_cache_unified_context : public llama_memory_context_i {
212
248
  public:
213
249
  // some shorthands
214
- using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
- using defrag_info = llama_kv_cache_unified::defrag_info;
250
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
251
+ using defrag_info = llama_kv_cache_unified::defrag_info;
216
252
 
217
253
  // used for errors
218
254
  llama_kv_cache_unified_context(llama_memory_status status);
@@ -231,7 +267,7 @@ public:
231
267
  // used to create a batch procesing context from a batch
232
268
  llama_kv_cache_unified_context(
233
269
  llama_kv_cache_unified * kv,
234
- ubatch_heads heads,
270
+ slot_info_vec_t sinfos,
235
271
  std::vector<llama_ubatch> ubatches);
236
272
 
237
273
  virtual ~llama_kv_cache_unified_context();
@@ -257,11 +293,16 @@ public:
257
293
  ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
258
294
 
259
295
  // store k_cur and v_cur in the cache based on the provided head location
260
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
261
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
296
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
297
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
298
+
299
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
300
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
262
301
 
263
- void set_input_k_shift(ggml_tensor * dst) const;
302
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
303
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
264
304
 
305
+ void set_input_k_shift (ggml_tensor * dst) const;
265
306
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
266
307
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
267
308
 
@@ -283,10 +324,10 @@ private:
283
324
  // batch processing context
284
325
  //
285
326
 
286
- // the index of the next ubatch to process
287
- size_t i_next = 0;
327
+ // the index of the cur ubatch to process
328
+ size_t i_cur = 0;
288
329
 
289
- ubatch_heads heads;
330
+ slot_info_vec_t sinfos;
290
331
 
291
332
  std::vector<llama_ubatch> ubatches;
292
333
 
@@ -297,7 +338,4 @@ private:
297
338
  // a heuristic, to avoid attending the full cache if it is not yet utilized
298
339
  // as the cache gets filled, the benefit from this heuristic disappears
299
340
  int32_t n_kv;
300
-
301
- // the beginning of the current slot in which the ubatch will be inserted
302
- int32_t head;
303
341
  };