cui-llama.rn 1.1.0 → 1.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/llama.cpp CHANGED
@@ -2527,10 +2527,29 @@ struct llama_layer {
2527
2527
  struct lm_ggml_tensor * ffn_down_scale;
2528
2528
  };
2529
2529
 
2530
+ // very similar to llama_batch,
2531
+ // but has more metadata about sequences
2532
+ struct llama_ubatch {
2533
+ bool equal_seqs;
2534
+ // TODO: whole_seqs for embeddings?
2535
+
2536
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
2537
+ uint32_t n_seq_tokens; // tokens per sequence
2538
+ uint32_t n_seqs;
2539
+
2540
+ llama_token * token; // [n_tokens]
2541
+ float * embd; // [n_embd, n_tokens]
2542
+ llama_pos * pos; // [n_tokens]
2543
+ int32_t * n_seq_id; // [n_seqs]
2544
+ llama_seq_id ** seq_id; // [n_seqs]
2545
+ int8_t * output; // [n_tokens]
2546
+ };
2547
+
2530
2548
  struct llama_kv_cell {
2531
2549
  llama_pos pos = -1;
2532
2550
  llama_pos delta = 0;
2533
- int32_t src = 0; // used by recurrent state models to copy states
2551
+ int32_t src = -1; // used by recurrent state models to copy states
2552
+ int32_t tail = -1;
2534
2553
 
2535
2554
  std::set<llama_seq_id> seq_id;
2536
2555
 
@@ -2551,7 +2570,6 @@ struct llama_kv_cell {
2551
2570
  struct llama_kv_cache {
2552
2571
  bool has_shift = false;
2553
2572
  bool do_defrag = false;
2554
- bool do_copy = false;
2555
2573
  bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
2556
2574
  bool v_trans = true; // the value tensor is transposed
2557
2575
 
@@ -2714,6 +2732,340 @@ struct llama_model {
2714
2732
  }
2715
2733
  };
2716
2734
 
2735
+ struct llama_sbatch_seq {
2736
+ int32_t n_seq_id;
2737
+ llama_seq_id * seq_id;
2738
+ size_t offset;
2739
+ size_t length;
2740
+
2741
+ // helper for smoother batch API transition -- can be deprecated in the future
2742
+ llama_seq_id all_seq_id; // used if seq_id == NULL
2743
+ };
2744
+
2745
+ // sequence-length-aware batch splitting
2746
+ struct llama_sbatch {
2747
+ // tokens left in this batch
2748
+ size_t n_tokens;
2749
+
2750
+ size_t n_embd;
2751
+
2752
+ bool logits_all; // TODO: remove once lctx.logits_all is removed too
2753
+
2754
+ // sorted indices into the batch
2755
+ std::vector<size_t> ids;
2756
+ // batch indices of the output
2757
+ std::vector<size_t> out_ids;
2758
+ std::vector<llama_sbatch_seq> seq;
2759
+ const llama_batch * batch = nullptr;
2760
+
2761
+ // buffers for the ubatch
2762
+ std::vector<llama_token> ubatch_token;
2763
+ std::vector<float> ubatch_embd;
2764
+ std::vector<llama_pos> ubatch_pos;
2765
+ std::vector<int32_t> ubatch_n_seq_id;
2766
+ std::vector<llama_seq_id *> ubatch_seq_id;
2767
+ std::vector<int8_t> ubatch_output;
2768
+
2769
+ llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
2770
+ // clear empty sequences
2771
+ // the previous ubatch is assumed to be gone,
2772
+ // so nothing should refer to values in these sequences anymore.
2773
+ for (size_t i = seq.size(); i-- > 0;) {
2774
+ if (seq[i].length == 0) {
2775
+ seq.pop_back();
2776
+ } else {
2777
+ break;
2778
+ }
2779
+ }
2780
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
2781
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
2782
+ ubatch_pos.resize(n_ubatch);
2783
+ ubatch_n_seq_id.resize(n_ubatch);
2784
+ ubatch_seq_id.resize(n_ubatch);
2785
+ ubatch_output.resize(n_ubatch);
2786
+ llama_ubatch ubatch = {
2787
+ /*equal_seqs =*/ true,
2788
+ /*n_tokens =*/ 0,
2789
+ /*n_seq_tokens =*/ 0,
2790
+ /*n_seqs =*/ 0,
2791
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
2792
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
2793
+ /*pos =*/ ubatch_pos.data(),
2794
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
2795
+ /*seq_id =*/ ubatch_seq_id.data(),
2796
+ /*output =*/ ubatch_output.data(),
2797
+ };
2798
+ return ubatch;
2799
+ }
2800
+
2801
+ void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
2802
+ LM_GGML_ASSERT(batch != nullptr);
2803
+ LM_GGML_ASSERT(length <= seq.length);
2804
+ // Can only add sequences of equal lengths to a batch,
2805
+ // otherwise it isn't clear to which sequence a token belongs
2806
+ LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
2807
+ LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
2808
+ // NOTE: loops are separated for cache-friendliness
2809
+ if (batch->token) {
2810
+ if (ubatch.equal_seqs) {
2811
+ for (size_t i = 0; i < length; ++i) {
2812
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
2813
+ }
2814
+ } else {
2815
+ // simple split
2816
+ ubatch.token = batch->token + seq.offset;
2817
+ }
2818
+ } else {
2819
+ ubatch.token = nullptr;
2820
+ }
2821
+ if (batch->embd) {
2822
+ if (ubatch.equal_seqs) {
2823
+ for (size_t i = 0; i < length; ++i) {
2824
+ memcpy(
2825
+ ubatch.embd + n_embd * (ubatch.n_tokens + i),
2826
+ batch->embd + n_embd * ids[seq.offset + i],
2827
+ n_embd * sizeof(float)
2828
+ );
2829
+ }
2830
+ } else {
2831
+ // simple split
2832
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
2833
+ }
2834
+ } else {
2835
+ ubatch.embd = nullptr;
2836
+ }
2837
+ // from here on, the else branches are deprecated;
2838
+ // they are helpers for smoother batch API transition
2839
+ if (batch->pos) {
2840
+ if (ubatch.equal_seqs) {
2841
+ for (size_t i = 0; i < length; ++i) {
2842
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
2843
+ }
2844
+ } else {
2845
+ // simple split
2846
+ ubatch.pos = batch->pos + seq.offset;
2847
+ }
2848
+ } else {
2849
+ for (size_t i = 0; i < length; ++i) {
2850
+ llama_pos bi = ids[seq.offset + i];
2851
+ ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
2852
+ }
2853
+ }
2854
+ if (ubatch.equal_seqs) {
2855
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
2856
+ if (seq.seq_id) {
2857
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
2858
+ } else {
2859
+ LM_GGML_ASSERT(seq.n_seq_id == 1);
2860
+ ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
2861
+ }
2862
+ } else {
2863
+ // simple split
2864
+ if (batch->n_seq_id) {
2865
+ for (size_t i = 0; i < length; ++i) {
2866
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
2867
+ }
2868
+ } else {
2869
+ for (size_t i = 0; i < length; ++i) {
2870
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
2871
+ }
2872
+ }
2873
+ if (batch->seq_id) {
2874
+ for (size_t i = 0; i < length; ++i) {
2875
+ ubatch.seq_id = batch->seq_id + seq.offset;
2876
+ }
2877
+ } else {
2878
+ for (size_t i = 0; i < length; ++i) {
2879
+ ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
2880
+ }
2881
+ }
2882
+ }
2883
+ if (logits_all) {
2884
+ for (size_t i = 0; i < length; ++i) {
2885
+ ubatch.output[ubatch.n_tokens + i] = 1;
2886
+ out_ids.push_back(ids[seq.offset + i]);
2887
+ }
2888
+ } else if (batch->logits) {
2889
+ if (ubatch.equal_seqs) {
2890
+ for (size_t i = 0; i < length; ++i) {
2891
+ size_t id = ids[seq.offset + i];
2892
+ int8_t is_output = batch->logits[id];
2893
+ ubatch.output[ubatch.n_tokens + i] = is_output;
2894
+ if (is_output) { out_ids.push_back(id); }
2895
+ }
2896
+ } else {
2897
+ // simple split
2898
+ ubatch.output = batch->logits + seq.offset;
2899
+ for (size_t i = 0; i < length; ++i) {
2900
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
2901
+ }
2902
+ }
2903
+ } else {
2904
+ // only get last output
2905
+ for (size_t i = 0; i < length; ++i) {
2906
+ size_t id = ids[seq.offset + i];
2907
+ int8_t is_last = id == ids.size() - 1;
2908
+ ubatch.output[ubatch.n_tokens + i] = is_last;
2909
+ if (is_last) { out_ids.push_back(id); }
2910
+ }
2911
+ }
2912
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
2913
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
2914
+ }
2915
+ ubatch.n_tokens += length;
2916
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
2917
+ seq.offset += length;
2918
+ seq.length -= length;
2919
+ n_tokens -= length;
2920
+ LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
2921
+ }
2922
+
2923
+ // simple split, unknown number of sequences of unequal lengths
2924
+ llama_ubatch split_simple(size_t n_ubatch) {
2925
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
2926
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
2927
+ ubatch.equal_seqs = false;
2928
+ if (!seq.empty()) {
2929
+ llama_sbatch_seq & s = seq[0];
2930
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
2931
+ LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
2932
+ add_seq_to_ubatch(ubatch, s, length);
2933
+ }
2934
+ return ubatch;
2935
+ }
2936
+
2937
+ // make batches of equal-length sequences
2938
+ llama_ubatch split_equal(size_t n_ubatch) {
2939
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
2940
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
2941
+ if (!seq.empty()) {
2942
+ size_t length = 0;
2943
+ size_t n_tokens_in_ubatch = 0;
2944
+ LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
2945
+ // smallest first, because it's easier to split this way;
2946
+ // starting from the end to pop in constant time.
2947
+ for (size_t i = seq.size(); i-- > 0;) {
2948
+ llama_sbatch_seq & s = seq[i];
2949
+ LM_GGML_ASSERT(s.length > 0);
2950
+ if (length == 0) {
2951
+ length = s.length < n_ubatch ? s.length : n_ubatch;
2952
+ }
2953
+ add_seq_to_ubatch(ubatch, s, length);
2954
+ n_tokens_in_ubatch += length;
2955
+ // shared prompts can't be mixed with any of their sequences,
2956
+ // so it's safer to compute them in their own ubatch
2957
+ if (s.n_seq_id > 1) { break; }
2958
+ // stop when there isn't enough space for another sequence
2959
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
2960
+ }
2961
+ }
2962
+ return ubatch;
2963
+ }
2964
+
2965
+ // sequence-wise split
2966
+ llama_ubatch split_seq(size_t n_ubatch) {
2967
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
2968
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
2969
+ if (!seq.empty()) {
2970
+ llama_sbatch_seq & s = seq[seq.size() - 1];
2971
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
2972
+ LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
2973
+ add_seq_to_ubatch(ubatch, s, length);
2974
+ }
2975
+ return ubatch;
2976
+ }
2977
+
2978
+ void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
2979
+ LM_GGML_ASSERT(batch.n_tokens >= 0);
2980
+ this->batch = &batch;
2981
+ this->n_embd = n_embd;
2982
+ this->logits_all = logits_all;
2983
+
2984
+ n_tokens = batch.n_tokens;
2985
+ ids.resize(n_tokens);
2986
+ out_ids.clear();
2987
+ // TODO: reserve out_ids and seq
2988
+
2989
+ for (size_t i = 0; i < n_tokens; ++i) {
2990
+ ids[i] = i;
2991
+ }
2992
+ if (simple_split) {
2993
+ seq.resize(1);
2994
+ llama_sbatch_seq & s = seq[0];
2995
+ s.n_seq_id = 0;
2996
+ s.seq_id = nullptr;
2997
+ s.offset = 0;
2998
+ s.length = n_tokens;
2999
+ s.all_seq_id = batch.all_seq_id;
3000
+ return;
3001
+ }
3002
+ std::sort(ids.begin(), ids.end(),
3003
+ [&batch](size_t a, size_t b) {
3004
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
3005
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
3006
+ // sort by seq_id, then by pos
3007
+ if (n_seq_a == n_seq_b) {
3008
+ if (batch.seq_id) {
3009
+ for (int32_t i = 0; i < n_seq_a; ++i) {
3010
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
3011
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
3012
+ // smaller seq_ids go first
3013
+ if (seq_id_a != seq_id_b) {
3014
+ return seq_id_a < seq_id_b;
3015
+ }
3016
+ }
3017
+ }
3018
+ // when all else is equal, sort by pos
3019
+ if (batch.pos) {
3020
+ return batch.pos[a] < batch.pos[b];
3021
+ }
3022
+ // no pos, sort by id (assuming batch.all_pos_1 is positive)
3023
+ return a < b;
3024
+ }
3025
+ // shared prompts go first
3026
+ return n_seq_a > n_seq_b;
3027
+ }
3028
+ );
3029
+ // init seq
3030
+ llama_sbatch_seq * last_seq = nullptr;
3031
+
3032
+ if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
3033
+ for (size_t i = 0; i < n_tokens; ++i) {
3034
+ const size_t bi = ids[i];
3035
+ const int32_t n_seqs = batch.n_seq_id[bi];
3036
+ llama_seq_id * seq_ids = batch.seq_id[bi];
3037
+ if (last_seq != nullptr) {
3038
+ bool same = n_seqs == last_seq->n_seq_id;
3039
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
3040
+ if (seq_ids[j] != last_seq->seq_id[j]) {
3041
+ same = false;
3042
+ }
3043
+ }
3044
+ if (same) {
3045
+ last_seq->length += 1;
3046
+ continue;
3047
+ }
3048
+ }
3049
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
3050
+ seq.push_back(new_seq);
3051
+ last_seq = &seq.back();
3052
+ }
3053
+ } else {
3054
+ llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
3055
+ seq.push_back(new_seq);
3056
+ }
3057
+ // keep shared prompts first at the end, then sort by length descending.
3058
+ std::sort(seq.begin(), seq.end(),
3059
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
3060
+ if (a.n_seq_id == b.n_seq_id) {
3061
+ return a.length > b.length;
3062
+ }
3063
+ return a.n_seq_id < b.n_seq_id;
3064
+ }
3065
+ );
3066
+ }
3067
+ };
3068
+
2717
3069
  struct llama_context {
2718
3070
  llama_context(const llama_model & model)
2719
3071
  : model(model)
@@ -2735,6 +3087,7 @@ struct llama_context {
2735
3087
 
2736
3088
  struct llama_cparams cparams;
2737
3089
  struct llama_sampling sampling;
3090
+ struct llama_sbatch sbatch;
2738
3091
  struct llama_kv_cache kv_self;
2739
3092
  struct llama_control_vector cvec;
2740
3093
 
@@ -2995,8 +3348,7 @@ static bool llama_kv_cache_init(
2995
3348
 
2996
3349
  cache.has_shift = false;
2997
3350
 
2998
- // TODO: find a nicer way to add other recurrent model architectures
2999
- cache.recurrent = model.arch == LLM_ARCH_MAMBA;
3351
+ cache.recurrent = llama_model_is_recurrent(&model);
3000
3352
  cache.v_trans = !cache.recurrent && !cparams.flash_attn;
3001
3353
 
3002
3354
  cache.head = 0;
@@ -3009,13 +3361,6 @@ static bool llama_kv_cache_init(
3009
3361
  cache.cells.clear();
3010
3362
  cache.cells.resize(kv_size);
3011
3363
 
3012
- if (cache.recurrent) {
3013
- // init state copy sources
3014
- for (uint32_t i = 0; i < cache.size; ++i) {
3015
- cache.cells[i].src = i;
3016
- }
3017
- }
3018
-
3019
3364
  // count used buffer types
3020
3365
  std::map<lm_ggml_backend_buffer_type_t, int> buft_layer_count;
3021
3366
  if (offload) {
@@ -3083,46 +3428,162 @@ static bool llama_kv_cache_init(
3083
3428
  // to the first cell of the slot.
3084
3429
  static bool llama_kv_cache_find_slot(
3085
3430
  struct llama_kv_cache & cache,
3086
- const struct llama_batch & batch) {
3431
+ const struct llama_ubatch & batch) {
3087
3432
  const uint32_t n_tokens = batch.n_tokens;
3433
+ const uint32_t n_seqs = batch.n_seqs;
3434
+ const uint32_t n_seq_tokens = batch.n_seq_tokens;
3088
3435
 
3089
3436
  if (cache.recurrent) {
3090
3437
  // For recurrent state architectures (like Mamba),
3091
- // each KV cache cell can store the state for a whole sequence.
3092
-
3093
- llama_seq_id min = cache.size - 1;
3094
- llama_seq_id max = 0;
3095
-
3096
- for (uint32_t i = 0; i < n_tokens; ++i) {
3097
- for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
3098
- llama_seq_id seq_id = batch.seq_id[i][j];
3099
- // make sure it's a valid seq_id
3100
- if ((uint32_t) seq_id < cache.size) {
3101
- if (seq_id > max) {
3102
- max = seq_id;
3103
- }
3104
- if (seq_id < min) {
3105
- min = seq_id;
3438
+ // each cache cell can store the state for a whole sequence.
3439
+ // A slot should be always be contiguous.
3440
+
3441
+ // can only process batches with an equal number of new tokens in each sequence
3442
+ LM_GGML_ASSERT(batch.equal_seqs);
3443
+
3444
+ int32_t min = cache.size - 1;
3445
+ int32_t max = 0;
3446
+
3447
+ // everything should fit if all seq_ids are smaller than the max
3448
+ for (uint32_t s = 0; s < n_seqs; ++s) {
3449
+ const uint32_t n_seq_id = batch.n_seq_id[s];
3450
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
3451
+ const llama_seq_id seq_id = batch.seq_id[s][j];
3452
+
3453
+ if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
3454
+ // too big seq_id
3455
+ // TODO: would it be possible to resize the cache instead?
3456
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3457
+ return false;
3458
+ }
3459
+ if (j > 0) {
3460
+ llama_kv_cell & seq = cache.cells[seq_id];
3461
+ if (seq.tail >= 0) {
3462
+ llama_kv_cell & cell = cache.cells[seq.tail];
3463
+ // clear cells from seq_ids that become shared
3464
+ // (should not normally happen, but let's handle it anyway)
3465
+ cell.seq_id.erase(seq_id);
3466
+ seq.tail = -1;
3467
+ if (cell.seq_id.empty()) {
3468
+ cell.pos = -1;
3469
+ cell.src = -1;
3470
+ cache.used -= 1;
3471
+ }
3106
3472
  }
3107
- // Assuming the tokens are in-order
3108
- if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
3109
- // What should happen when the pos backtracks or skips a value?
3110
- // Clearing the state mid-batch would require special-casing which isn't done.
3111
- LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
3112
- __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
3473
+ }
3474
+ }
3475
+ }
3476
+
3477
+ #ifndef NDEBUG
3478
+ {
3479
+ std::vector<int32_t> tails_verif;
3480
+ tails_verif.assign(cache.size, -1);
3481
+ for (uint32_t i = 0; i < cache.size; ++i) {
3482
+ llama_kv_cell & cell = cache.cells[i];
3483
+ for (llama_seq_id seq_id : cell.seq_id) {
3484
+ if (tails_verif[seq_id] != -1) {
3485
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
3113
3486
  }
3114
- if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
3115
- cache.used += 1;
3487
+ tails_verif[seq_id] = i;
3488
+ }
3489
+ }
3490
+ for (uint32_t i = 0; i < cache.size; ++i) {
3491
+ if (tails_verif[i] != cache.cells[i].tail) {
3492
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
3493
+ }
3494
+ }
3495
+ }
3496
+ #endif
3497
+
3498
+ // find next empty cell
3499
+ uint32_t next_empty_cell = cache.head;
3500
+
3501
+ for (uint32_t i = 0; i < cache.size; ++i) {
3502
+ if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
3503
+ llama_kv_cell & cell = cache.cells[next_empty_cell];
3504
+ if (cell.is_empty()) { break; }
3505
+ next_empty_cell += 1;
3506
+ }
3507
+
3508
+ // find usable cell range
3509
+ for (uint32_t s = 0; s < n_seqs; ++s) {
3510
+ const llama_seq_id seq_id = batch.seq_id[s][0];
3511
+ llama_kv_cell & seq_meta = cache.cells[seq_id];
3512
+ bool has_cell = false;
3513
+ if (seq_meta.tail >= 0) {
3514
+ llama_kv_cell & cell = cache.cells[seq_meta.tail];
3515
+ LM_GGML_ASSERT(cell.has_seq_id(seq_id));
3516
+ // does this seq_id "own" the cell?
3517
+ if (cell.seq_id.size() == 1) { has_cell = true; }
3518
+ }
3519
+ if (!has_cell) {
3520
+ llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
3521
+ LM_GGML_ASSERT(empty_cell.is_empty());
3522
+ // copy old tail into the empty cell
3523
+ if (seq_meta.tail >= 0) {
3524
+ llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
3525
+ empty_cell.pos = orig_cell.pos;
3526
+ empty_cell.src = orig_cell.src;
3527
+ orig_cell.seq_id.erase(seq_id);
3528
+ empty_cell.seq_id.insert(seq_id); // will be overwritten
3529
+ }
3530
+ seq_meta.tail = next_empty_cell;
3531
+ // find next empty cell
3532
+ if (s + 1 < n_seqs) {
3533
+ next_empty_cell += 1;
3534
+ for (uint32_t i = 0; i < cache.size; ++i) {
3535
+ if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
3536
+ llama_kv_cell & cell = cache.cells[next_empty_cell];
3537
+ if (cell.is_empty()) { break; }
3538
+ next_empty_cell += 1;
3116
3539
  }
3117
- cache.cells[seq_id].pos = batch.pos[i];
3118
- // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
3119
- } else {
3120
- // too big seq_id
3121
- // TODO: would it be possible to resize the KV cache size instead?
3122
- LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3123
- return false;
3124
3540
  }
3125
3541
  }
3542
+ if (min > seq_meta.tail) { min = seq_meta.tail; }
3543
+ if (max < seq_meta.tail) { max = seq_meta.tail; }
3544
+ }
3545
+
3546
+ // gather and re-order
3547
+ for (uint32_t s = 0; s < n_seqs; ++s) {
3548
+ int32_t dst_id = s + min;
3549
+ int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
3550
+ if (dst_id != src_id) {
3551
+ llama_kv_cell & dst_cell = cache.cells[dst_id];
3552
+ llama_kv_cell & src_cell = cache.cells[src_id];
3553
+
3554
+ std::swap(dst_cell.pos, src_cell.pos);
3555
+ std::swap(dst_cell.src, src_cell.src);
3556
+ std::swap(dst_cell.seq_id, src_cell.seq_id);
3557
+
3558
+ // swap tails (assuming they NEVER overlap)
3559
+ for (const llama_seq_id seq_id : src_cell.seq_id) {
3560
+ cache.cells[seq_id].tail = src_id;
3561
+ }
3562
+ for (const llama_seq_id seq_id : dst_cell.seq_id) {
3563
+ cache.cells[seq_id].tail = dst_id;
3564
+ }
3565
+ }
3566
+ }
3567
+
3568
+ // update the pos of the used seqs
3569
+ for (uint32_t s = 0; s < n_seqs; ++s) {
3570
+ const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
3571
+ int32_t cell_id = s + min;
3572
+ llama_kv_cell & cell = cache.cells[cell_id];
3573
+
3574
+ if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
3575
+ // What should happen when the pos backtracks or skips a value?
3576
+ // Clearing the state mid-batch would require special-casing which isn't done.
3577
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3578
+ __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
3579
+ }
3580
+ cell.pos = last_pos;
3581
+ cell.seq_id.clear();
3582
+ for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
3583
+ const llama_seq_id seq_id = batch.seq_id[s][j];
3584
+ cell.seq_id.insert(seq_id);
3585
+ cache.cells[seq_id].tail = cell_id;
3586
+ }
3126
3587
  }
3127
3588
 
3128
3589
  // allow getting the range of used cells, from head to head + n
@@ -3130,7 +3591,7 @@ static bool llama_kv_cache_find_slot(
3130
3591
  cache.n = max - min + 1;
3131
3592
 
3132
3593
  // sanity check
3133
- return max >= min;
3594
+ return cache.n >= n_seqs;
3134
3595
  }
3135
3596
  // otherwise, one cell per token.
3136
3597
 
@@ -3168,11 +3629,14 @@ static bool llama_kv_cache_find_slot(
3168
3629
  }
3169
3630
  }
3170
3631
 
3171
- for (uint32_t i = 0; i < n_tokens; i++) {
3172
- cache.cells[cache.head + i].pos = batch.pos[i];
3632
+ for (uint32_t s = 0; s < n_seqs; s++) {
3633
+ for (uint32_t i = 0; i < n_seq_tokens; ++i) {
3634
+ uint32_t k = s*n_seq_tokens + i;
3635
+ cache.cells[cache.head + k].pos = batch.pos[k];
3173
3636
 
3174
- for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
3175
- cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
3637
+ for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
3638
+ cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
3639
+ }
3176
3640
  }
3177
3641
  }
3178
3642
 
@@ -3198,6 +3662,8 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
3198
3662
  for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
3199
3663
  cache.cells[i].pos = -1;
3200
3664
  cache.cells[i].seq_id.clear();
3665
+ cache.cells[i].src = -1;
3666
+ cache.cells[i].tail = -1;
3201
3667
  }
3202
3668
  cache.head = 0;
3203
3669
  cache.used = 0;
@@ -3224,9 +3690,16 @@ static bool llama_kv_cache_seq_rm(
3224
3690
  return false;
3225
3691
  }
3226
3692
  if (0 <= seq_id) {
3227
- // partial intersection is invalid
3228
- if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
3229
- return false;
3693
+ int32_t & tail_id = cache.cells[seq_id].tail;
3694
+ if (tail_id >= 0) {
3695
+ const llama_kv_cell & cell = cache.cells[tail_id];
3696
+ // partial intersection is invalid
3697
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
3698
+ return false;
3699
+ }
3700
+ if (p0 <= cell.pos && p1 < cell.pos) {
3701
+ tail_id = -1;
3702
+ }
3230
3703
  }
3231
3704
  } else {
3232
3705
  // seq_id is negative, then the range should include everything or nothing
@@ -3250,6 +3723,7 @@ static bool llama_kv_cache_seq_rm(
3250
3723
  if (cache.cells[i].pos >= 0) cache.used--;
3251
3724
 
3252
3725
  cache.cells[i].pos = -1;
3726
+ cache.cells[i].src = -1;
3253
3727
  if (new_head == cache.size) new_head = i;
3254
3728
  }
3255
3729
  }
@@ -3272,23 +3746,29 @@ static void llama_kv_cache_seq_cp(
3272
3746
 
3273
3747
  if (cache.recurrent) {
3274
3748
  if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
3275
- seq_id_src = cache.cells[seq_id_src].src;
3276
- LM_GGML_ASSERT((uint32_t) seq_id_src < cache.size);
3277
- // intent to "copy from"
3278
- // supports copy chains thanks to taking the source of the source
3279
- cache.cells[seq_id_dst].src = seq_id_src;
3280
-
3281
- // preserve the "keep or clear" status of the copied sequence
3282
- if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
3283
- cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
3284
- } else {
3285
- cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
3749
+ llama_kv_cell & tail_src = cache.cells[seq_id_src];
3750
+ llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
3751
+ if (tail_dst.tail >= 0) {
3752
+ // clear destination seq_id if it wasn't empty
3753
+ llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
3754
+
3755
+ cell_dst.seq_id.erase(seq_id_dst);
3756
+ tail_dst.tail = -1;
3757
+ if (cell_dst.seq_id.empty()) {
3758
+ cell_dst.pos = -1;
3759
+ cell_dst.delta = -1;
3760
+ cell_dst.src = -1;
3761
+ cache.used -= 1;
3762
+ }
3286
3763
  }
3764
+ if (tail_src.tail >= 0) {
3765
+ llama_kv_cell & cell_src = cache.cells[tail_src.tail];
3287
3766
 
3288
- cache.do_copy = true;
3289
-
3290
- cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
3767
+ cell_src.seq_id.insert(seq_id_dst);
3768
+ tail_dst.tail = tail_src.tail;
3769
+ }
3291
3770
  }
3771
+
3292
3772
  return;
3293
3773
  }
3294
3774
  // otherwise, this is the KV cache of a Transformer-like model
@@ -3306,9 +3786,13 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
3306
3786
  uint32_t new_head = cache.size;
3307
3787
 
3308
3788
  for (uint32_t i = 0; i < cache.size; ++i) {
3789
+ if (cache.recurrent && (llama_seq_id) i != seq_id) {
3790
+ cache.cells[i].tail = -1;
3791
+ }
3309
3792
  if (!cache.cells[i].has_seq_id(seq_id)) {
3310
3793
  if (cache.cells[i].pos >= 0) cache.used--;
3311
3794
  cache.cells[i].pos = -1;
3795
+ cache.cells[i].src = -1;
3312
3796
  cache.cells[i].seq_id.clear();
3313
3797
  if (new_head == cache.size) new_head = i;
3314
3798
  } else {
@@ -3337,9 +3821,12 @@ static void llama_kv_cache_seq_add(
3337
3821
  if (cache.recurrent) {
3338
3822
  // for Mamba-like models, only the pos needs to be shifted
3339
3823
  if (0 <= seq_id && seq_id < (int64_t) cache.size) {
3340
- llama_kv_cell & cell = cache.cells[seq_id];
3341
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
3342
- cell.pos += delta;
3824
+ const int32_t tail_id = cache.cells[seq_id].tail;
3825
+ if (tail_id >= 0) {
3826
+ llama_kv_cell & cell = cache.cells[tail_id];
3827
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
3828
+ cell.pos += delta;
3829
+ }
3343
3830
  }
3344
3831
  }
3345
3832
  return;
@@ -3383,9 +3870,12 @@ static void llama_kv_cache_seq_div(
3383
3870
  if (cache.recurrent) {
3384
3871
  // for Mamba-like models, only the pos needs to be changed
3385
3872
  if (0 <= seq_id && seq_id < (int64_t) cache.size) {
3386
- llama_kv_cell & cell = cache.cells[seq_id];
3387
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
3388
- cell.pos /= d;
3873
+ const int32_t tail_id = cache.cells[seq_id].tail;
3874
+ if (tail_id >= 0) {
3875
+ llama_kv_cell & cell = cache.cells[tail_id];
3876
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
3877
+ cell.pos /= d;
3878
+ }
3389
3879
  }
3390
3880
  }
3391
3881
  return;
@@ -3417,7 +3907,9 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
3417
3907
  }
3418
3908
 
3419
3909
  static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
3420
- cache.do_defrag = true;
3910
+ if (!cache.recurrent) {
3911
+ cache.do_defrag = true;
3912
+ }
3421
3913
  }
3422
3914
 
3423
3915
  static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
@@ -6124,6 +6616,7 @@ static bool llm_load_tensors(
6124
6616
  const int64_t n_embd_gqa = n_embd_v_gqa;
6125
6617
  const int64_t n_vocab = hparams.n_vocab;
6126
6618
  const int64_t n_vocab_type = hparams.n_vocab_type;
6619
+ const int64_t n_rot = hparams.n_rot;
6127
6620
  const int64_t n_expert = hparams.n_expert;
6128
6621
  const int64_t n_expert_used = hparams.n_expert_used;
6129
6622
  const int64_t n_ctx_train = hparams.n_ctx_train;
@@ -6181,7 +6674,7 @@ static bool llm_load_tensors(
6181
6674
 
6182
6675
  layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
6183
6676
 
6184
- layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
6677
+ layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
6185
6678
 
6186
6679
  if (n_expert == 0) {
6187
6680
  layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
@@ -7634,8 +8127,8 @@ static bool llm_load_tensors(
7634
8127
 
7635
8128
  layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
7636
8129
 
7637
- layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
7638
- layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)});
8130
+ layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
8131
+ layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
7639
8132
 
7640
8133
  layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
7641
8134
 
@@ -7712,7 +8205,7 @@ static bool llm_load_tensors(
7712
8205
  layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
7713
8206
 
7714
8207
  layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
7715
- layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
8208
+ layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
7716
8209
  layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
7717
8210
  layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
7718
8211
  layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
@@ -7959,7 +8452,7 @@ static struct lm_ggml_tensor * llm_build_inp_embd(
7959
8452
  struct lm_ggml_context * ctx,
7960
8453
  struct llama_context & lctx,
7961
8454
  const llama_hparams & hparams,
7962
- const llama_batch & batch,
8455
+ const llama_ubatch & batch,
7963
8456
  struct lm_ggml_tensor * tok_embd,
7964
8457
  const llm_build_cb & cb) {
7965
8458
  const int64_t n_embd = hparams.n_embd;
@@ -8393,9 +8886,10 @@ static struct lm_ggml_tensor * llm_build_kqv(
8393
8886
  0);
8394
8887
  cb(v, "v", il);
8395
8888
 
8396
- cur = lm_ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
8889
+ cur = lm_ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
8890
+ hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
8397
8891
 
8398
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
8892
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
8399
8893
  lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
8400
8894
  }
8401
8895
 
@@ -8404,7 +8898,7 @@ static struct lm_ggml_tensor * llm_build_kqv(
8404
8898
  struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx, k, q);
8405
8899
  cb(kq, "kq", il);
8406
8900
 
8407
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON) {
8901
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
8408
8902
  // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
8409
8903
  // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
8410
8904
  lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32);
@@ -8508,12 +9002,180 @@ static struct lm_ggml_tensor * llm_build_kv(
8508
9002
  return cur;
8509
9003
  }
8510
9004
 
9005
+ static struct lm_ggml_tensor * llm_build_copy_mask_state(
9006
+ struct lm_ggml_context * ctx,
9007
+ struct lm_ggml_cgraph * graph,
9008
+ struct lm_ggml_tensor * s,
9009
+ struct lm_ggml_tensor * state_copy,
9010
+ struct lm_ggml_tensor * state_mask,
9011
+ int32_t n_state,
9012
+ int32_t kv_size,
9013
+ int32_t kv_head,
9014
+ int32_t n_kv,
9015
+ int32_t n_seqs) {
9016
+ struct lm_ggml_tensor * states = lm_ggml_reshape_2d(ctx, s, n_state, kv_size);
9017
+
9018
+ // copy states
9019
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
9020
+ // this shrinks the tensors's ne[1] to n_kv
9021
+ states = lm_ggml_get_rows(ctx, states, state_copy);
9022
+
9023
+ // clear states of sequences which are starting at the beginning of this batch
9024
+ // FIXME: zero-out NANs?
9025
+ states = lm_ggml_mul(ctx, states, state_mask);
9026
+
9027
+ // copy states which won't be changed further (between n_seqs and n_rs)
9028
+ lm_ggml_build_forward_expand(graph,
9029
+ lm_ggml_cpy(ctx,
9030
+ lm_ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*lm_ggml_element_size(states)),
9031
+ lm_ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*lm_ggml_element_size(s))));
9032
+
9033
+ // the part of the states that will be used and modified
9034
+ return lm_ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
9035
+ }
9036
+
9037
+ // TODO: split
9038
+ static struct lm_ggml_tensor * llm_build_mamba(
9039
+ struct lm_ggml_context * ctx,
9040
+ struct llama_context & lctx,
9041
+ const llama_ubatch & batch,
9042
+ struct lm_ggml_cgraph * graph,
9043
+ struct lm_ggml_tensor * cur,
9044
+ struct lm_ggml_tensor * state_copy,
9045
+ struct lm_ggml_tensor * state_mask,
9046
+ int32_t kv_head,
9047
+ int32_t n_kv,
9048
+ const llm_build_cb & cb,
9049
+ int il) {
9050
+ const llama_model & model = lctx.model;
9051
+ const llama_hparams & hparams = model.hparams;
9052
+ const llama_kv_cache & kv = lctx.kv_self;
9053
+ const int64_t d_conv = hparams.ssm_d_conv;
9054
+ const int64_t d_inner = hparams.ssm_d_inner;
9055
+ const int64_t d_state = hparams.ssm_d_state;
9056
+ const int64_t dt_rank = hparams.ssm_dt_rank;
9057
+ const int64_t n_seqs = batch.n_seqs;
9058
+ // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
9059
+ const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
9060
+ // Use the same RMS norm as the final layer norm
9061
+ const float norm_rms_eps = hparams.f_norm_rms_eps;
9062
+
9063
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
9064
+
9065
+ LM_GGML_ASSERT(n_seqs != 0);
9066
+ LM_GGML_ASSERT(batch.equal_seqs);
9067
+ LM_GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
9068
+
9069
+ struct lm_ggml_tensor * conv_states_all = kv.k_l[il];
9070
+ struct lm_ggml_tensor * ssm_states_all = kv.v_l[il];
9071
+
9072
+ // (ab)using the KV cache to store the states
9073
+ struct lm_ggml_tensor * conv = llm_build_copy_mask_state(ctx,
9074
+ graph, conv_states_all, state_copy, state_mask,
9075
+ hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
9076
+ conv = lm_ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
9077
+ struct lm_ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
9078
+ graph, ssm_states_all, state_copy, state_mask,
9079
+ hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
9080
+ ssm = lm_ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
9081
+
9082
+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
9083
+ cur = lm_ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
9084
+
9085
+ // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
9086
+ struct lm_ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
9087
+ // split the above in two
9088
+ // => {d_inner, n_seq_tokens, n_seqs}
9089
+ struct lm_ggml_tensor * x = lm_ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
9090
+ struct lm_ggml_tensor * z = lm_ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*lm_ggml_element_size(xz));
9091
+
9092
+ // conv
9093
+ {
9094
+ // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
9095
+ struct lm_ggml_tensor * conv_x = lm_ggml_concat(ctx, conv, lm_ggml_transpose(ctx, x), 0);
9096
+
9097
+ // copy last (d_conv - 1) columns back into the state cache
9098
+ struct lm_ggml_tensor * last_conv = lm_ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
9099
+
9100
+ lm_ggml_build_forward_expand(graph,
9101
+ lm_ggml_cpy(ctx, last_conv,
9102
+ lm_ggml_view_1d(ctx, conv_states_all,
9103
+ (d_conv - 1)*(d_inner)*(n_seqs),
9104
+ kv_head*(d_conv - 1)*(d_inner)*lm_ggml_element_size(conv_states_all))));
9105
+
9106
+ // 1D convolution
9107
+ // The equivalent is to make a self-overlapping view of conv_x
9108
+ // over d_conv columns at each stride in the 3rd dimension,
9109
+ // then element-wise multiply that with the conv1d weight,
9110
+ // then sum the elements of each row,
9111
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
9112
+ // then permute away the ne[0] dimension,
9113
+ // and then you're left with the resulting x tensor.
9114
+ // For simultaneous sequences, all sequences need to have the same length.
9115
+ x = lm_ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
9116
+
9117
+ // bias
9118
+ x = lm_ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
9119
+
9120
+ x = lm_ggml_silu(ctx, x);
9121
+ }
9122
+
9123
+ // ssm
9124
+ {
9125
+ // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
9126
+ struct lm_ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
9127
+ // split
9128
+ struct lm_ggml_tensor * dt = lm_ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
9129
+ struct lm_ggml_tensor * B = lm_ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], lm_ggml_element_size(x_db)*dt_rank);
9130
+ struct lm_ggml_tensor * C = lm_ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], lm_ggml_element_size(x_db)*(dt_rank+d_state));
9131
+
9132
+ // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
9133
+ if (ssm_dt_b_c_rms) {
9134
+ dt = lm_ggml_rms_norm(ctx, dt, norm_rms_eps);
9135
+ B = lm_ggml_rms_norm(ctx, B, norm_rms_eps);
9136
+ C = lm_ggml_rms_norm(ctx, C, norm_rms_eps);
9137
+ }
9138
+
9139
+ // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
9140
+ dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
9141
+ dt = lm_ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
9142
+
9143
+ // Custom operator to optimize the parallel associative scan
9144
+ // as described in the Annex D of the Mamba paper.
9145
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9146
+ struct lm_ggml_tensor * y_ssm = lm_ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
9147
+
9148
+ // store last states
9149
+ lm_ggml_build_forward_expand(graph,
9150
+ lm_ggml_cpy(ctx,
9151
+ lm_ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
9152
+ lm_ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*lm_ggml_element_size(ssm_states_all))));
9153
+
9154
+ struct lm_ggml_tensor * y = lm_ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
9155
+
9156
+ // TODO: skip computing output earlier for unused tokens
9157
+
9158
+ // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
9159
+ y = lm_ggml_add(ctx, y, lm_ggml_mul(ctx, x, model.layers[il].ssm_d));
9160
+ y = lm_ggml_mul(ctx, y, lm_ggml_silu(ctx, lm_ggml_cont(ctx, z)));
9161
+
9162
+ // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
9163
+ cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
9164
+ }
9165
+
9166
+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
9167
+ cur = lm_ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
9168
+ cb(cur, "mamba_out", il);
9169
+
9170
+ return cur;
9171
+ }
9172
+
8511
9173
  struct llm_build_context {
8512
9174
  const llama_model & model;
8513
9175
  llama_context & lctx;
8514
9176
  const llama_hparams & hparams;
8515
9177
  const llama_cparams & cparams;
8516
- const llama_batch & batch;
9178
+ const llama_ubatch & batch;
8517
9179
  const llama_kv_cache & kv_self;
8518
9180
 
8519
9181
  const int64_t n_embd;
@@ -8559,7 +9221,7 @@ struct llm_build_context {
8559
9221
  // TODO: consider making the entire interface noexcept
8560
9222
  llm_build_context(
8561
9223
  llama_context & lctx,
8562
- const llama_batch & batch,
9224
+ const llama_ubatch & batch,
8563
9225
  const llm_build_cb & cb,
8564
9226
  bool worst_case) :
8565
9227
  model (lctx.model),
@@ -8666,29 +9328,6 @@ struct llm_build_context {
8666
9328
  return gf;
8667
9329
  }
8668
9330
 
8669
- struct lm_ggml_cgraph * build_s_copy() {
8670
- struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
8671
-
8672
- LM_GGML_ASSERT(kv_self.recurrent);
8673
-
8674
- struct lm_ggml_tensor * state_copy = build_inp_s_copy();
8675
-
8676
- for (int il = 0; il < n_layer; ++il) {
8677
- struct lm_ggml_tensor * conv_states = lm_ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
8678
- struct lm_ggml_tensor * ssm_states = lm_ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
8679
-
8680
- conv_states = lm_ggml_get_rows(ctx0, conv_states, state_copy);
8681
- ssm_states = lm_ggml_get_rows(ctx0, ssm_states, state_copy);
8682
-
8683
- // TODO: name the intermediate tensors with cb()
8684
-
8685
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
8686
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
8687
- }
8688
-
8689
- return gf;
8690
- }
8691
-
8692
9331
  struct lm_ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
8693
9332
  struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
8694
9333
 
@@ -8823,7 +9462,7 @@ struct llm_build_context {
8823
9462
  }
8824
9463
 
8825
9464
  struct lm_ggml_tensor * build_inp_s_copy() {
8826
- lctx.inp_s_copy = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, kv_self.size);
9465
+ lctx.inp_s_copy = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_kv);
8827
9466
  cb(lctx.inp_s_copy, "inp_s_copy", -1);
8828
9467
  lm_ggml_set_input(lctx.inp_s_copy);
8829
9468
  return lctx.inp_s_copy;
@@ -8836,13 +9475,6 @@ struct llm_build_context {
8836
9475
  return lctx.inp_s_mask;
8837
9476
  }
8838
9477
 
8839
- struct lm_ggml_tensor * build_inp_s_seq() {
8840
- lctx.inp_s_seq = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_I32, n_kv, n_tokens);
8841
- cb(lctx.inp_s_seq, "inp_s_seq", -1);
8842
- lm_ggml_set_input(lctx.inp_s_seq);
8843
- return lctx.inp_s_seq;
8844
- }
8845
-
8846
9478
  struct lm_ggml_cgraph * append_pooling(struct lm_ggml_cgraph * gf) {
8847
9479
  // find result_norm tensor for input
8848
9480
  struct lm_ggml_tensor * inp = nullptr;
@@ -12172,136 +12804,31 @@ struct llm_build_context {
12172
12804
  struct lm_ggml_cgraph * build_mamba() {
12173
12805
  struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
12174
12806
 
12175
- const int64_t d_model = n_embd;
12176
- const int64_t d_conv = hparams.ssm_d_conv;
12177
- const int64_t d_inner = hparams.ssm_d_inner;
12178
- LM_GGML_ASSERT(2 * d_model == d_inner);
12179
- const int64_t d_state = hparams.ssm_d_state;
12180
- const int64_t dt_rank = hparams.ssm_dt_rank;
12181
- // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
12182
- const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
12183
- // Use the same RMS norm as the final layer norm
12184
- const float norm_rms_eps = hparams.f_norm_rms_eps;
12185
-
12186
12807
  struct lm_ggml_tensor * cur;
12187
12808
  struct lm_ggml_tensor * inpL;
12188
12809
 
12189
12810
  // {n_embd, n_tokens}
12190
12811
  inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
12191
12812
 
12813
+ struct lm_ggml_tensor * state_copy = build_inp_s_copy();
12192
12814
  struct lm_ggml_tensor * state_mask = build_inp_s_mask();
12193
- struct lm_ggml_tensor * state_seq = build_inp_s_seq();
12194
12815
 
12195
12816
  for (int il = 0; il < n_layer; ++il) {
12196
- // (ab)using the KV cache to store the states
12197
- struct lm_ggml_tensor * conv_states = lm_ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
12198
- struct lm_ggml_tensor * ssm_states = lm_ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
12199
-
12200
- // clear states of sequences which are starting at the beginning of this batch
12201
- {
12202
- conv_states = lm_ggml_mul(ctx0,
12203
- lm_ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
12204
- state_mask);
12205
- ssm_states = lm_ggml_mul(ctx0,
12206
- lm_ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
12207
- state_mask);
12208
- }
12209
-
12210
- conv_states = lm_ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
12211
- ssm_states = lm_ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
12212
-
12213
12817
  // norm
12214
12818
  cur = llm_build_norm(ctx0, inpL, hparams,
12215
12819
  model.layers[il].attn_norm, NULL,
12216
12820
  LLM_NORM_RMS, cb, il);
12217
12821
  cb(cur, "attn_norm", il);
12218
12822
 
12219
- // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
12220
- struct lm_ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur);
12221
- // split the above in two
12222
- // => {d_inner, n_tokens}
12223
- struct lm_ggml_tensor * x = lm_ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
12224
- struct lm_ggml_tensor * z = lm_ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], lm_ggml_element_size(xz)*d_inner);
12225
-
12226
- // conv
12227
- {
12228
- // Custom operator which is needed only to ease simultaneous sequence processing.
12229
- // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
12230
- // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
12231
- // then element-wise multiply that with the conv1d weigth,
12232
- // then sum the elements of each row,
12233
- // (the last two steps are a dot product over rows (also doable with mul_mat))
12234
- // then permute away the ne[0] dimension,
12235
- // and then you're left with the resulting x tensor.
12236
- // The new conv_states is the last (d_conv - 1) columns
12237
- // of the last 3rd dimensional "layer" of the self-overlapping view.
12238
- // For simultaneous sequences, it's more complicated.
12239
- struct lm_ggml_tensor * x_conv = lm_ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
12240
-
12241
- // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
12242
- lm_ggml_build_forward_expand(gf,
12243
- lm_ggml_cpy(ctx0,
12244
- lm_ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*lm_ggml_element_size(x_conv), (1+d_inner*n_tokens)*lm_ggml_element_size(x_conv)),
12245
- lm_ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*lm_ggml_element_size(x_conv))));
12246
-
12247
- // extract x from x_conv
12248
- x = lm_ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*lm_ggml_element_size(x_conv), 0);
12249
-
12250
- // bias
12251
- x = lm_ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
12252
-
12253
- x = lm_ggml_silu(ctx0, x);
12254
- }
12255
-
12256
- // ssm
12257
- {
12258
- // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
12259
- struct lm_ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
12260
- // split
12261
- struct lm_ggml_tensor * dt = lm_ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
12262
- struct lm_ggml_tensor * B = lm_ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], lm_ggml_element_size(x_db)*dt_rank);
12263
- struct lm_ggml_tensor * C = lm_ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], lm_ggml_element_size(x_db)*(dt_rank+d_state));
12264
-
12265
- // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
12266
- if (ssm_dt_b_c_rms) {
12267
- dt = lm_ggml_rms_norm(ctx0, dt, norm_rms_eps);
12268
- B = lm_ggml_rms_norm(ctx0, B, norm_rms_eps);
12269
- C = lm_ggml_rms_norm(ctx0, C, norm_rms_eps);
12270
- }
12271
-
12272
- // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
12273
- dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
12274
- dt = lm_ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
12275
-
12276
- // Custom operator to optimize the parallel associative scan
12277
- // as described in the Annex D of the Mamba paper.
12278
- // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
12279
- // because only a single tensor can be returned.
12280
- struct lm_ggml_tensor * y_ssm_states = lm_ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
12281
-
12282
- // store last states (the second part of y_ssm_states)
12283
- lm_ggml_build_forward_expand(gf,
12284
- lm_ggml_cpy(ctx0,
12285
- lm_ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*lm_ggml_element_size(y_ssm_states)),
12286
- lm_ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*lm_ggml_element_size(ssm_states))));
12287
-
12288
- struct lm_ggml_tensor * y = lm_ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*lm_ggml_element_size(y_ssm_states), 0);
12289
-
12290
- if (il == n_layer - 1) {
12291
- // skip computing output for unused tokens
12292
- struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12293
- x = lm_ggml_get_rows(ctx0, x, inp_out_ids);
12294
- y = lm_ggml_get_rows(ctx0, y, inp_out_ids);
12295
- z = lm_ggml_get_rows(ctx0, z, inp_out_ids);
12296
- inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
12297
- }
12298
-
12299
- // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
12300
- y = lm_ggml_add(ctx0, y, lm_ggml_mul(ctx0, x, model.layers[il].ssm_d));
12301
- y = lm_ggml_mul(ctx0, y, lm_ggml_silu(ctx0, z));
12823
+ cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
12824
+ state_copy, state_mask,
12825
+ kv_head, n_kv, cb, il);
12302
12826
 
12303
- // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
12304
- cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);
12827
+ if (il == n_layer - 1) {
12828
+ // skip computing output for unused tokens
12829
+ struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
12830
+ cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
12831
+ inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
12305
12832
  }
12306
12833
 
12307
12834
  // residual
@@ -14167,8 +14694,8 @@ struct llm_build_context {
14167
14694
  };
14168
14695
 
14169
14696
  static struct lm_ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
14170
- llama_batch dummy;
14171
- dummy.n_tokens = 0;
14697
+ llama_ubatch dummy = {};
14698
+ dummy.equal_seqs = true;
14172
14699
 
14173
14700
  llm_build_cb cb = [&](struct lm_ggml_tensor * , const char * , int ) { };
14174
14701
 
@@ -14184,8 +14711,8 @@ static struct lm_ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, co
14184
14711
  }
14185
14712
 
14186
14713
  static struct lm_ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
14187
- llama_batch dummy;
14188
- dummy.n_tokens = 0;
14714
+ llama_ubatch dummy = {};
14715
+ dummy.equal_seqs = true;
14189
14716
 
14190
14717
  llm_build_cb cb = [&](struct lm_ggml_tensor * , const char * , int ) { };
14191
14718
 
@@ -14200,26 +14727,9 @@ static struct lm_ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
14200
14727
  return result;
14201
14728
  }
14202
14729
 
14203
- static struct lm_ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
14204
- llama_batch dummy;
14205
- dummy.n_tokens = 0;
14206
-
14207
- llm_build_cb cb = [&](struct lm_ggml_tensor * , const char * , int ) { };
14208
-
14209
- struct llm_build_context llm(lctx, dummy, cb, false);
14210
-
14211
- llm.init();
14212
-
14213
- struct lm_ggml_cgraph * result = llm.build_s_copy();
14214
-
14215
- llm.free();
14216
-
14217
- return result;
14218
- }
14219
-
14220
14730
  static struct lm_ggml_cgraph * llama_build_graph(
14221
14731
  llama_context & lctx,
14222
- const llama_batch & batch,
14732
+ const llama_ubatch & batch,
14223
14733
  bool worst_case) {
14224
14734
  const auto & model = lctx.model;
14225
14735
 
@@ -14489,7 +14999,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
14489
14999
  return relative_bucket;
14490
15000
  }
14491
15001
 
14492
- static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
15002
+ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
14493
15003
  //
14494
15004
  // set input data
14495
15005
  //
@@ -14528,10 +15038,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14528
15038
  for (int i = 0; i < n_tokens; ++i) {
14529
15039
  data[i] = i;
14530
15040
  }
14531
- } else if (batch.logits) {
15041
+ } else if (batch.output) {
14532
15042
  int32_t n_outputs = 0;
14533
15043
  for (int i = 0; i < n_tokens; ++i) {
14534
- if (batch.logits[i]) {
15044
+ if (batch.output[i]) {
14535
15045
  data[n_outputs++] = i;
14536
15046
  }
14537
15047
  }
@@ -14555,8 +15065,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14555
15065
  if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
14556
15066
  // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
14557
15067
  if (cparams.causal_attn && !lctx.is_encoding) {
14558
- const int64_t n_kv = kv_self.n;
14559
- const int64_t n_tokens = batch.n_tokens;
15068
+ const int64_t n_kv = kv_self.n;
15069
+ const int64_t n_tokens = batch.n_tokens;
15070
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
15071
+ const int64_t n_seqs = batch.n_seqs;
14560
15072
 
14561
15073
 
14562
15074
  float * data = nullptr;
@@ -14576,32 +15088,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14576
15088
  // of the correct sequence for each token of the batch.
14577
15089
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
14578
15090
  for (int h = 0; h < 1; ++h) {
14579
- for (int j = 0; j < n_tokens; ++j) {
14580
- const llama_pos pos = batch.pos[j];
14581
- const llama_seq_id seq_id = batch.seq_id[j][0];
15091
+ for (int s = 0; s < n_seqs; ++s) {
15092
+ const llama_seq_id seq_id = batch.seq_id[s][0];
14582
15093
 
14583
- for (int i = 0; i < n_kv; ++i) {
14584
- float f;
14585
- if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
14586
- f = -INFINITY;
14587
- } else {
14588
- if (hparams.use_alibi) {
14589
- f = -std::abs(lctx.kv_self.cells[i].pos - pos);
15094
+ for (int j = 0; j < n_seq_tokens; ++j) {
15095
+ const llama_pos pos = batch.pos[s*n_seq_tokens + j];
15096
+
15097
+ for (int i = 0; i < n_kv; ++i) {
15098
+ float f;
15099
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
15100
+ f = -INFINITY;
14590
15101
  } else {
14591
- f = 0.0f;
15102
+ if (hparams.use_alibi) {
15103
+ f = -std::abs(kv_self.cells[i].pos - pos);
15104
+ } else {
15105
+ f = 0.0f;
15106
+ }
14592
15107
  }
14593
- }
14594
15108
 
14595
- if (data) {
14596
- data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
14597
- }
15109
+ if (data) {
15110
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
15111
+ }
14598
15112
 
14599
- // may need to cut off old tokens for sliding window
14600
- if (data_swa) {
14601
- if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
14602
- f = -INFINITY;
15113
+ // may need to cut off old tokens for sliding window
15114
+ if (data_swa) {
15115
+ if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
15116
+ f = -INFINITY;
15117
+ }
15118
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
14603
15119
  }
14604
- data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
14605
15120
  }
14606
15121
  }
14607
15122
  }
@@ -14623,8 +15138,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14623
15138
  }
14624
15139
  }
14625
15140
  } else {
15141
+ const int64_t n_tokens = batch.n_tokens;
15142
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
15143
+ const int64_t n_seqs = batch.n_seqs;
14626
15144
  // when using kv cache, the mask needs to match the kv cache size
14627
- const int64_t n_tokens = batch.n_tokens;
14628
15145
  const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
14629
15146
 
14630
15147
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
@@ -14632,27 +15149,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14632
15149
  float * data = (float *) lctx.inp_KQ_mask->data;
14633
15150
 
14634
15151
  for (int h = 0; h < 1; ++h) {
14635
- for (int j = 0; j < n_tokens; ++j) {
14636
- const llama_seq_id seq_id = batch.seq_id[j][0];
14637
-
14638
- for (int i = 0; i < n_tokens; ++i) {
14639
- float f = -INFINITY;
14640
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
14641
- if (batch.seq_id[i][s] == seq_id) {
14642
- if (hparams.use_alibi) {
14643
- f = -std::abs(batch.pos[i] - batch.pos[j]);
14644
- } else {
14645
- f = 0.0f;
15152
+ for (int s1 = 0; s1 < n_seqs; ++s1) {
15153
+ const llama_seq_id seq_id = batch.seq_id[s1][0];
15154
+
15155
+ for (int j = 0; j < n_seq_tokens; ++j) {
15156
+ const int32_t tj = s1*n_seq_tokens + j;
15157
+
15158
+ for (int s0 = 0; s0 < n_seqs; ++s0) {
15159
+ for (int i = 0; i < n_seq_tokens; ++i) {
15160
+ const int32_t ti = s0*n_seq_tokens + i;
15161
+ float f = -INFINITY;
15162
+
15163
+ for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
15164
+ if (batch.seq_id[s0][s] == seq_id) {
15165
+ if (hparams.use_alibi) {
15166
+ f = -std::abs(batch.pos[ti] - batch.pos[tj]);
15167
+ } else {
15168
+ f = 0.0f;
15169
+ }
15170
+ break;
15171
+ }
14646
15172
  }
14647
- break;
15173
+
15174
+ data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
14648
15175
  }
14649
15176
  }
14650
15177
 
14651
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
14652
- }
14653
-
14654
- for (int i = n_tokens; i < n_stride; ++i) {
14655
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
15178
+ for (int i = n_tokens; i < n_stride; ++i) {
15179
+ data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
15180
+ }
14656
15181
  }
14657
15182
  }
14658
15183
  }
@@ -14660,7 +15185,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14660
15185
  }
14661
15186
 
14662
15187
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
14663
- const int64_t n_tokens = batch.n_tokens;
15188
+ const int64_t n_tokens = batch.n_tokens;
15189
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
15190
+ const int64_t n_seqs = batch.n_seqs;
14664
15191
 
14665
15192
  LM_GGML_ASSERT(lctx.inp_mean);
14666
15193
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -14669,12 +15196,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14669
15196
  memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * lm_ggml_element_size(lctx.inp_mean));
14670
15197
 
14671
15198
  std::vector<uint64_t> sum(n_tokens, 0);
14672
- for (int i = 0; i < n_tokens; ++i) {
14673
- const llama_seq_id seq_id = batch.seq_id[i][0];
14674
15199
 
15200
+ for (int s = 0; s < n_seqs; ++s) {
15201
+ const llama_seq_id seq_id = batch.seq_id[s][0];
15202
+
15203
+ // TODO: adapt limits to n_seqs when batch.equal_seqs is true
14675
15204
  LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
14676
15205
 
14677
- sum[seq_id] += 1;
15206
+ sum[seq_id] += batch.n_seq_tokens;
14678
15207
  }
14679
15208
 
14680
15209
  std::vector<float> div(n_tokens, 0.0f);
@@ -14685,14 +15214,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14685
15214
  }
14686
15215
  }
14687
15216
 
14688
- for (int i = 0; i < n_tokens; ++i) {
14689
- const llama_seq_id seq_id = batch.seq_id[i][0];
14690
- data[seq_id*n_tokens + i] = div[seq_id];
15217
+ for (int s = 0; s < n_seqs; ++s) {
15218
+ const llama_seq_id seq_id = batch.seq_id[s][0];
15219
+
15220
+ for (int i = 0; i < n_seq_tokens; ++i) {
15221
+ data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
15222
+ }
14691
15223
  }
14692
15224
  }
14693
15225
 
14694
15226
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
14695
- const int64_t n_tokens = batch.n_tokens;
15227
+ const int64_t n_tokens = batch.n_tokens;
15228
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
15229
+ const int64_t n_seqs = batch.n_seqs;
14696
15230
 
14697
15231
  LM_GGML_ASSERT(lctx.inp_cls);
14698
15232
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -14700,20 +15234,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14700
15234
  uint32_t * data = (uint32_t *) lctx.inp_cls->data;
14701
15235
  memset(lctx.inp_cls->data, 0, n_tokens * lm_ggml_element_size(lctx.inp_cls));
14702
15236
 
14703
- for (int i = 0; i < n_tokens; ++i) {
14704
- const llama_seq_id seq_id = batch.seq_id[i][0];
14705
- const llama_pos pos = batch.pos[i];
15237
+ for (int s = 0; s < n_seqs; ++s) {
15238
+ const llama_seq_id seq_id = batch.seq_id[s][0];
14706
15239
 
15240
+ // TODO: adapt limits to n_seqs when batch.equal_seqs is true
14707
15241
  LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
14708
15242
 
14709
- if (pos == 0) {
14710
- data[seq_id] = i;
15243
+ for (int i = 0; i < n_seq_tokens; ++i) {
15244
+ const llama_pos pos = batch.pos[s*n_seq_tokens + i];
15245
+
15246
+ if (pos == 0) {
15247
+ data[seq_id] = s*n_seq_tokens + i;
15248
+ }
14711
15249
  }
14712
15250
  }
14713
15251
  }
14714
15252
 
14715
15253
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
14716
- const int64_t n_tokens = batch.n_tokens;
15254
+ const int64_t n_tokens = batch.n_tokens;
15255
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
15256
+ const int64_t n_seqs = batch.n_seqs;
14717
15257
 
14718
15258
  LM_GGML_ASSERT(lctx.inp_cls);
14719
15259
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -14724,15 +15264,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14724
15264
  std::vector<int> last_pos(n_tokens, -1);
14725
15265
  std::vector<int> last_row(n_tokens, -1);
14726
15266
 
14727
- for (int i = 0; i < n_tokens; ++i) {
14728
- const llama_seq_id seq_id = batch.seq_id[i][0];
14729
- const llama_pos pos = batch.pos[i];
15267
+ for (int s = 0; s < n_seqs; ++s) {
15268
+ const llama_seq_id seq_id = batch.seq_id[s][0];
14730
15269
 
15270
+ // TODO: adapt limits to n_seqs when batch.equal_seqs is true
14731
15271
  LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
14732
15272
 
14733
- if (pos >= last_pos[seq_id]) {
14734
- last_pos[seq_id] = pos;
14735
- last_row[seq_id] = i;
15273
+ for (int i = 0; i < n_seq_tokens; ++i) {
15274
+ const llama_pos pos = batch.pos[s*n_seq_tokens + i];
15275
+
15276
+ if (pos >= last_pos[seq_id]) {
15277
+ last_pos[seq_id] = pos;
15278
+ last_row[seq_id] = s*n_seq_tokens + i;
15279
+ }
14736
15280
  }
14737
15281
  }
14738
15282
 
@@ -14750,41 +15294,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14750
15294
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
14751
15295
  float * data = (float *) lctx.inp_s_mask->data;
14752
15296
 
14753
- // states which are not affected by the current batch are left untouched
15297
+ // clear unused states
14754
15298
  for (int i = 0; i < n_kv; ++i) {
14755
- llama_seq_id seq_id = i + lctx.kv_self.head;
14756
- llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
14757
- bool has_self_seq = kv_cell.has_seq_id(seq_id);
15299
+ uint32_t cell_id = i + kv_self.head;
15300
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
14758
15301
 
14759
- data[i] = (float) has_self_seq;
15302
+ data[i] = (float) (kv_cell.src >= 0);
14760
15303
 
14761
- // ensure current sequences will be kept
14762
- if (!has_self_seq && kv_cell.pos >= 0) {
14763
- kv_cell.seq_id.insert(seq_id);
15304
+ // only clear once
15305
+ if (kv_cell.src < 0) {
15306
+ kv_cell.src = cell_id;
14764
15307
  }
14765
15308
  }
14766
15309
  }
14767
- // For Mamba (and other recurrent architectures),
14768
- // update the correct state(s)/sequence(s) for each token of the batch.
14769
- // Like with the KQ_mask, if a token in the batch has multiple sequences,
14770
- // they are assumed to be equivalent (not here, but in lm_ggml_ssm_scan and lm_ggml_ssm_conv).
14771
- if (lctx.inp_s_seq) {
14772
- const int64_t n_tokens = batch.n_tokens;
14773
15310
 
14774
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
14775
- int32_t * data = (int32_t *) lctx.inp_s_seq->data;
15311
+ if (lctx.inp_s_copy) {
15312
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
15313
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
14776
15314
 
14777
- for (int j = 0; j < n_tokens; ++j) {
14778
- const int32_t n_seq = batch.n_seq_id[j];
14779
- LM_GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
15315
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
15316
+ for (uint32_t i = 0; i < n_kv; ++i) {
15317
+ const uint32_t cell_id = i + kv_self.head;
15318
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
14780
15319
 
14781
- for (int i = 0; i < n_kv; ++i) {
14782
- if (i < n_seq) {
14783
- // for this type of model, the head is the minimum seq_id of the batch
14784
- data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
14785
- } else {
14786
- data[j*n_kv + i] = -1;
14787
- }
15320
+ // prevent out-of-bound sources
15321
+ if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
15322
+ kv_cell.src = cell_id;
15323
+ }
15324
+
15325
+ data[i] = kv_cell.src;
15326
+
15327
+ // ensure copy only happens once
15328
+ if (kv_cell.src != (int32_t) cell_id) {
15329
+ kv_cell.src = cell_id;
14788
15330
  }
14789
15331
  }
14790
15332
  }
@@ -14794,6 +15336,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14794
15336
  const int64_t n_tokens = batch.n_tokens;
14795
15337
 
14796
15338
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
15339
+ LM_GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
14797
15340
 
14798
15341
  int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
14799
15342
 
@@ -14829,6 +15372,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
14829
15372
  const int64_t n_tokens = batch.n_tokens;
14830
15373
 
14831
15374
  LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
15375
+ LM_GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
14832
15376
 
14833
15377
  float * data = (float *) lctx.inp_KQ_mask_cross->data;
14834
15378
 
@@ -14922,6 +15466,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
14922
15466
  return n_outputs_max;
14923
15467
  }
14924
15468
 
15469
+ // make the outputs have the same order they had in the user-provided batch
15470
+ static void llama_output_reorder(struct llama_context * ctx) {
15471
+ std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
15472
+ if (!out_ids.empty()) {
15473
+ uint32_t n_vocab = ctx->model.hparams.n_vocab;
15474
+ uint32_t n_embd = ctx->model.hparams.n_embd;
15475
+ int32_t n_outputs = ctx->n_outputs;
15476
+ LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
15477
+ // TODO: is there something more efficient which also minimizes swaps?
15478
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
15479
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
15480
+ int32_t j_min = i;
15481
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
15482
+ if (out_ids[j] < out_ids[j_min]) {
15483
+ j_min = j;
15484
+ }
15485
+ }
15486
+ if (j_min == i) { continue; }
15487
+ std::swap(out_ids[i], out_ids[j_min]);
15488
+ if (ctx->logits_size > 0) {
15489
+ for (uint32_t k = 0; k < n_vocab; k++) {
15490
+ std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
15491
+ }
15492
+ }
15493
+ if (ctx->embd_size > 0) {
15494
+ for (uint32_t k = 0; k < n_embd; k++) {
15495
+ std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
15496
+ }
15497
+ }
15498
+ }
15499
+ std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
15500
+ for (int32_t i = 0; i < n_outputs; ++i) {
15501
+ ctx->output_ids[out_ids[i]] = i;
15502
+ }
15503
+ out_ids.clear();
15504
+ }
15505
+ }
14925
15506
 
14926
15507
  static void llama_graph_compute(
14927
15508
  llama_context & lctx,
@@ -14994,15 +15575,11 @@ static int llama_decode_internal(
14994
15575
 
14995
15576
  const auto n_ubatch = cparams.n_ubatch;
14996
15577
 
14997
- // TODO: simplify or deprecate
14998
- std::vector<llama_pos> pos;
14999
- std::vector<int32_t> n_seq_id;
15000
- std::vector<llama_seq_id *> seq_id_arr;
15001
- std::vector<std::vector<llama_seq_id>> seq_id;
15002
-
15003
15578
  // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
15004
15579
  const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
15005
15580
 
15581
+ lctx.embd_seq.clear();
15582
+
15006
15583
  // count outputs
15007
15584
  if (batch_all.logits && !embd_pooled) {
15008
15585
  for (uint32_t i = 0; i < n_tokens_all; ++i) {
@@ -15015,55 +15592,42 @@ static int llama_decode_internal(
15015
15592
  n_outputs = 1;
15016
15593
  }
15017
15594
 
15595
+ lctx.sbatch.from_batch(batch_all, n_embd,
15596
+ /* simple_split */ !kv_self.recurrent,
15597
+ /* logits_all */ n_outputs == n_tokens_all);
15598
+
15018
15599
  // reserve output buffer
15019
15600
  if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
15020
15601
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
15021
15602
  return -2;
15022
15603
  };
15023
15604
 
15024
- // set output mappings
15025
- if (batch_all.logits) {
15026
- int32_t i_logits = 0;
15027
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
15028
- if (batch_all.logits[i]) {
15029
- lctx.output_ids[i] = i_logits++;
15605
+ while (lctx.sbatch.n_tokens > 0) {
15606
+ llama_ubatch ubatch;
15607
+ if (kv_self.recurrent) {
15608
+ if (embd_pooled) {
15609
+ // Pooled embeddings cannot be split across ubatches (yet)
15610
+ ubatch = lctx.sbatch.split_seq(n_ubatch);
15611
+ } else {
15612
+ // recurrent model architectures are easier to implement
15613
+ // with equal-length sequences
15614
+ ubatch = lctx.sbatch.split_equal(n_ubatch);
15030
15615
  }
15616
+ } else {
15617
+ ubatch = lctx.sbatch.split_simple(n_ubatch);
15031
15618
  }
15032
- } else {
15033
- for (uint32_t i = 0; i < n_outputs; ++i) {
15034
- lctx.output_ids[i] = i;
15035
- }
15036
- }
15037
-
15038
- for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
15039
- const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
15040
- llama_batch u_batch = {
15041
- /* .n_tokens = */ (int32_t) n_tokens,
15042
- /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
15043
- /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
15044
- /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
15045
- /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
15046
- /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
15047
- /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
15048
- /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
15049
- /* .all_pos_1 = */ batch_all.all_pos_1,
15050
- /* .all_seq_id = */ batch_all.all_seq_id,
15051
- };
15619
+ const uint32_t n_tokens = ubatch.n_tokens;
15052
15620
 
15053
15621
  // count the outputs in this u_batch
15054
15622
  {
15055
15623
  int32_t n_outputs_new = 0;
15056
15624
 
15057
- if (u_batch.logits && !embd_pooled) {
15058
- for (uint32_t i = 0; i < n_tokens; i++) {
15059
- n_outputs_new += u_batch.logits[i] != 0;
15060
- }
15061
- } else if (n_outputs == n_tokens_all) {
15625
+ if (n_outputs == n_tokens_all) {
15062
15626
  n_outputs_new = n_tokens;
15063
15627
  } else {
15064
- // keep last output only
15065
- if (cur_token + n_tokens >= n_tokens_all) {
15066
- n_outputs_new = 1;
15628
+ LM_GGML_ASSERT(ubatch.output);
15629
+ for (uint32_t i = 0; i < n_tokens; i++) {
15630
+ n_outputs_new += (int32_t) (ubatch.output[i] != 0);
15067
15631
  }
15068
15632
  }
15069
15633
 
@@ -15074,32 +15638,6 @@ static int llama_decode_internal(
15074
15638
  int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15075
15639
  LM_GGML_ASSERT(n_threads > 0);
15076
15640
 
15077
- // helpers for smoother batch API transition
15078
- // after deprecating the llama_eval calls, these will be removed
15079
- if (u_batch.pos == nullptr) {
15080
- pos.resize(n_tokens);
15081
- for (uint32_t i = 0; i < n_tokens; i++) {
15082
- pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
15083
- }
15084
-
15085
- u_batch.pos = pos.data();
15086
- }
15087
-
15088
- if (u_batch.seq_id == nullptr) {
15089
- n_seq_id.resize(n_tokens);
15090
- seq_id.resize(n_tokens);
15091
- seq_id_arr.resize(n_tokens);
15092
- for (uint32_t i = 0; i < n_tokens; i++) {
15093
- n_seq_id[i] = 1;
15094
- seq_id[i].resize(1);
15095
- seq_id[i][0] = u_batch.all_seq_id;
15096
- seq_id_arr[i] = seq_id[i].data();
15097
- }
15098
-
15099
- u_batch.n_seq_id = n_seq_id.data();
15100
- u_batch.seq_id = seq_id_arr.data();
15101
- }
15102
-
15103
15641
  // non-causal masks do not use the KV cache
15104
15642
  if (hparams.causal_attn) {
15105
15643
  llama_kv_cache_update(&lctx);
@@ -15110,7 +15648,7 @@ static int llama_decode_internal(
15110
15648
  kv_self.head = 0;
15111
15649
  }
15112
15650
 
15113
- if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
15651
+ if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
15114
15652
  return 1;
15115
15653
  }
15116
15654
 
@@ -15129,7 +15667,7 @@ static int llama_decode_internal(
15129
15667
  lm_ggml_backend_sched_reset(lctx.sched);
15130
15668
  lm_ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
15131
15669
 
15132
- lm_ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
15670
+ lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
15133
15671
 
15134
15672
  // the output is always the last tensor in the graph
15135
15673
  struct lm_ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
@@ -15157,7 +15695,7 @@ static int llama_decode_internal(
15157
15695
 
15158
15696
  lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
15159
15697
 
15160
- llama_set_inputs(lctx, u_batch);
15698
+ llama_set_inputs(lctx, ubatch);
15161
15699
 
15162
15700
  llama_graph_compute(lctx, gf, n_threads);
15163
15701
 
@@ -15215,12 +15753,11 @@ static int llama_decode_internal(
15215
15753
  case LLAMA_POOLING_TYPE_CLS:
15216
15754
  case LLAMA_POOLING_TYPE_LAST:
15217
15755
  {
15218
- // extract sequence embeddings
15756
+ // extract sequence embeddings (cleared before processing each batch)
15219
15757
  auto & embd_seq_out = lctx.embd_seq;
15220
- embd_seq_out.clear();
15221
15758
 
15222
- for (uint32_t i = 0; i < n_tokens; i++) {
15223
- const llama_seq_id seq_id = u_batch.seq_id[i][0];
15759
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
15760
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
15224
15761
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
15225
15762
  continue;
15226
15763
  }
@@ -15237,6 +15774,25 @@ static int llama_decode_internal(
15237
15774
  n_outputs_prev += lctx.n_outputs;
15238
15775
  }
15239
15776
 
15777
+ // set output mappings
15778
+ {
15779
+ bool sorted_output = true;
15780
+
15781
+ LM_GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
15782
+
15783
+ for (size_t i = 0; i < n_outputs; ++i) {
15784
+ size_t out_id = lctx.sbatch.out_ids[i];
15785
+ lctx.output_ids[out_id] = i;
15786
+ if (out_id != i) {
15787
+ sorted_output = false;
15788
+ }
15789
+ }
15790
+
15791
+ if (sorted_output) {
15792
+ lctx.sbatch.out_ids.clear();
15793
+ }
15794
+ }
15795
+
15240
15796
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
15241
15797
  lctx.n_outputs = n_outputs;
15242
15798
 
@@ -15301,11 +15857,9 @@ static int llama_encode_internal(
15301
15857
 
15302
15858
  const int64_t n_embd = hparams.n_embd;
15303
15859
 
15304
- // TODO: simplify or deprecate
15305
- std::vector<llama_pos> pos;
15306
- std::vector<int32_t> n_seq_id;
15307
- std::vector<llama_seq_id *> seq_id_arr;
15308
- std::vector<std::vector<llama_seq_id>> seq_id;
15860
+ lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
15861
+
15862
+ const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
15309
15863
 
15310
15864
  // reserve output buffer
15311
15865
  if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
@@ -15323,36 +15877,10 @@ static int llama_encode_internal(
15323
15877
  const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15324
15878
  LM_GGML_ASSERT(n_threads > 0);
15325
15879
 
15326
- // helpers for smoother batch API transition
15327
- // after deprecating the llama_eval calls, these will be removed
15328
- if (batch.pos == nullptr) {
15329
- pos.resize(n_tokens);
15330
- for (uint32_t i = 0; i < n_tokens; i++) {
15331
- pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
15332
- }
15333
-
15334
- batch.pos = pos.data();
15335
- }
15336
-
15337
- if (batch.seq_id == nullptr) {
15338
- n_seq_id.resize(n_tokens);
15339
- seq_id.resize(n_tokens);
15340
- seq_id_arr.resize(n_tokens);
15341
- for (uint32_t i = 0; i < n_tokens; i++) {
15342
- n_seq_id[i] = 1;
15343
- seq_id[i].resize(1);
15344
- seq_id[i][0] = batch.all_seq_id;
15345
- seq_id_arr[i] = seq_id[i].data();
15346
- }
15347
-
15348
- batch.n_seq_id = n_seq_id.data();
15349
- batch.seq_id = seq_id_arr.data();
15350
- }
15351
-
15352
15880
  lm_ggml_backend_sched_reset(lctx.sched);
15353
15881
  lm_ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
15354
15882
 
15355
- lm_ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
15883
+ lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
15356
15884
 
15357
15885
  // the output embeddings after the final encoder normalization
15358
15886
  struct lm_ggml_tensor * embd = nullptr;
@@ -15376,7 +15904,7 @@ static int llama_encode_internal(
15376
15904
 
15377
15905
  lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
15378
15906
 
15379
- llama_set_inputs(lctx, batch);
15907
+ llama_set_inputs(lctx, ubatch);
15380
15908
 
15381
15909
  llama_graph_compute(lctx, gf, n_threads);
15382
15910
 
@@ -15390,12 +15918,13 @@ static int llama_encode_internal(
15390
15918
  float * embd_out = lctx.embd_enc.data();
15391
15919
 
15392
15920
  lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
15921
+ LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
15393
15922
 
15394
15923
  // remember the sequence ids used during the encoding - needed for cross attention later
15395
15924
  lctx.seq_ids_enc.resize(n_tokens);
15396
15925
  for (uint32_t i = 0; i < n_tokens; i++) {
15397
- for (int s = 0; s < batch.n_seq_id[i]; s++) {
15398
- llama_seq_id seq_id = batch.seq_id[i][s];
15926
+ for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
15927
+ llama_seq_id seq_id = ubatch.seq_id[i][s];
15399
15928
  lctx.seq_ids_enc[i].insert(seq_id);
15400
15929
  }
15401
15930
  }
@@ -15420,8 +15949,10 @@ static int llama_encode_internal(
15420
15949
  auto & embd_seq_out = lctx.embd_seq;
15421
15950
  embd_seq_out.clear();
15422
15951
 
15952
+ LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
15953
+
15423
15954
  for (uint32_t i = 0; i < n_tokens; i++) {
15424
- const llama_seq_id seq_id = batch.seq_id[i][0];
15955
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
15425
15956
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
15426
15957
  continue;
15427
15958
  }
@@ -15699,32 +16230,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
15699
16230
  }
15700
16231
  }
15701
16232
 
15702
- if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
15703
- {
15704
- lm_ggml_backend_sched_reset(lctx.sched);
15705
-
15706
- lm_ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
15707
-
15708
- lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
15709
-
15710
- llama_set_s_copy(lctx);
15711
-
15712
- llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
15713
-
15714
- need_reserve = true;
15715
- }
15716
-
15717
- {
15718
- auto & kv_self = lctx.kv_self;
15719
-
15720
- kv_self.do_copy = false;
15721
-
15722
- for (uint32_t i = 0; i < kv_self.size; ++i) {
15723
- kv_self.cells[i].src = i;
15724
- }
15725
- }
15726
- }
15727
-
15728
16233
  // defragment the KV cache if needed
15729
16234
  if (lctx.kv_self.do_defrag) {
15730
16235
  llama_kv_cache_defrag_internal(lctx);
@@ -15738,10 +16243,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
15738
16243
  if (need_reserve) {
15739
16244
  // TODO: extract to a function
15740
16245
  // build worst-case graph
15741
- int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
15742
- int n_past = lctx.cparams.n_ctx - n_tokens;
16246
+ uint32_t n_seqs = 1; // TODO: worst-case number of sequences
16247
+ uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
15743
16248
  llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
15744
- lm_ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
16249
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
16250
+ lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
15745
16251
 
15746
16252
  // initialize scheduler with the worst-case graph
15747
16253
  lm_ggml_backend_sched_reset(lctx.sched);
@@ -16337,12 +16843,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
16337
16843
  qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
16338
16844
 
16339
16845
  // sanity checks
16340
- //
16341
- // - qs.n_attention_wv == 0 for Mamba models
16342
- // - qs.n_attention_wv == model.hparams.n_layer for Transformer models
16343
- // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
16344
- //
16345
- LM_GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
16846
+ {
16847
+ const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
16848
+ // attention layers have a non-zero number of kv heads
16849
+ int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
16850
+ if (llama_model_has_encoder(&model)) {
16851
+ n_attn_layer *= 3;
16852
+ }
16853
+ LM_GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
16854
+ }
16346
16855
 
16347
16856
  size_t total_size_org = 0;
16348
16857
  size_t total_size_new = 0;
@@ -17037,12 +17546,6 @@ struct llama_context * llama_new_context_with_model(
17037
17546
  params.flash_attn = false;
17038
17547
  }
17039
17548
 
17040
- if (params.flash_attn && model->hparams.attn_soft_cap) {
17041
- LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
17042
- params.flash_attn = false;
17043
- }
17044
-
17045
-
17046
17549
  if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
17047
17550
  LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
17048
17551
  params.flash_attn = false;
@@ -17151,7 +17654,7 @@ struct llama_context * llama_new_context_with_model(
17151
17654
  lm_ggml_type type_v = params.type_v;
17152
17655
 
17153
17656
  // Mamba only needs a constant number of KV cache cells per sequence
17154
- if (model->arch == LLM_ARCH_MAMBA) {
17657
+ if (llama_model_is_recurrent(model)) {
17155
17658
  // Mamba needs at least as many KV cells as there are sequences kept at any time
17156
17659
  kv_size = std::max((uint32_t) 1, params.n_seq_max);
17157
17660
  // it's probably best to keep as much precision as possible for the states
@@ -17383,10 +17886,11 @@ struct llama_context * llama_new_context_with_model(
17383
17886
  }
17384
17887
 
17385
17888
  // build worst-case graph
17386
- int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
17387
- int n_past = cparams.n_ctx - n_tokens;
17889
+ uint32_t n_seqs = 1; // TODO: worst-case number of sequences
17890
+ uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
17388
17891
  llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
17389
- lm_ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
17892
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
17893
+ lm_ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
17390
17894
 
17391
17895
  // initialize scheduler with the worst-case graph
17392
17896
  if (!lm_ggml_backend_sched_reserve(ctx->sched, gf)) {
@@ -17626,6 +18130,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
17626
18130
  return model->hparams.dec_start_token_id;
17627
18131
  }
17628
18132
 
18133
+ bool llama_model_is_recurrent(const struct llama_model * model) {
18134
+ switch (model->arch) {
18135
+ case LLM_ARCH_MAMBA: return true;
18136
+ default: return false;
18137
+ }
18138
+ }
18139
+
17629
18140
  uint32_t llama_model_quantize(
17630
18141
  const char * fname_inp,
17631
18142
  const char * fname_out,
@@ -17947,7 +18458,9 @@ struct llama_data_write {
17947
18458
  write_string(rng_str);
17948
18459
  }
17949
18460
 
17950
- void write_output_ids(const struct llama_context * ctx) {
18461
+ void write_output_ids(struct llama_context * ctx) {
18462
+ llama_output_reorder(ctx);
18463
+
17951
18464
  const uint32_t n_outputs = ctx->n_outputs;
17952
18465
 
17953
18466
  std::vector<int32_t> output_pos;
@@ -18235,8 +18748,11 @@ struct llama_data_read {
18235
18748
 
18236
18749
  llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
18237
18750
 
18238
- llama_batch batch = llama_batch_init(cell_count, 0, 1);
18751
+ llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
18239
18752
  batch.n_tokens = cell_count;
18753
+ batch.n_seq_tokens = cell_count;
18754
+ batch.n_seqs = 1;
18755
+
18240
18756
  for (uint32_t i = 0; i < cell_count; ++i) {
18241
18757
  llama_pos pos;
18242
18758
  uint32_t n_seq_id;
@@ -18250,11 +18766,10 @@ struct llama_data_read {
18250
18766
  }
18251
18767
 
18252
18768
  batch.pos[i] = pos;
18253
- batch.n_seq_id[i] = 1;
18254
- batch.seq_id[i][0] = dest_seq_id;
18255
18769
  }
18770
+ batch.n_seq_id[0] = 1;
18771
+ batch.seq_id[0] = &dest_seq_id;
18256
18772
  if (!llama_kv_cache_find_slot(kv_self, batch)) {
18257
- llama_batch_free(batch);
18258
18773
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
18259
18774
  return false;
18260
18775
  }
@@ -18266,9 +18781,6 @@ struct llama_data_read {
18266
18781
  LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
18267
18782
  LM_GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
18268
18783
  LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
18269
-
18270
- // Cleanup
18271
- llama_batch_free(batch);
18272
18784
  } else {
18273
18785
  // whole KV cache restore
18274
18786
 
@@ -18300,6 +18812,15 @@ struct llama_data_read {
18300
18812
  }
18301
18813
 
18302
18814
  cell.seq_id.insert(seq_id);
18815
+
18816
+ if (kv_self.recurrent) {
18817
+ int32_t & tail = kv_self.cells[seq_id].tail;
18818
+ if (tail != -1) {
18819
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
18820
+ return false;
18821
+ }
18822
+ tail = i;
18823
+ }
18303
18824
  }
18304
18825
  }
18305
18826
 
@@ -18307,6 +18828,14 @@ struct llama_data_read {
18307
18828
  kv_self.used = cell_count;
18308
18829
  }
18309
18830
 
18831
+ if (kv_self.recurrent) {
18832
+ for (uint32_t i = 0; i < cell_count; ++i) {
18833
+ uint32_t cell_id = kv_self.head + i;
18834
+ // make sure the recurrent states will keep their restored state
18835
+ kv_self.cells[cell_id].src = cell_id;
18836
+ }
18837
+ }
18838
+
18310
18839
  return true;
18311
18840
  }
18312
18841
 
@@ -18894,7 +19423,18 @@ struct llama_batch llama_batch_get_one(
18894
19423
  }
18895
19424
 
18896
19425
  struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
18897
- llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
19426
+ llama_batch batch = {
19427
+ /*n_tokens =*/ 0,
19428
+ /*tokens =*/ nullptr,
19429
+ /*embd =*/ nullptr,
19430
+ /*pos =*/ nullptr,
19431
+ /*n_seq_id =*/ nullptr,
19432
+ /*seq_id =*/ nullptr,
19433
+ /*logits =*/ nullptr,
19434
+ /*all_pos_0 =*/ 0,
19435
+ /*all_pos_1 =*/ 0,
19436
+ /*all_seq_id =*/ 0,
19437
+ };
18898
19438
 
18899
19439
  if (embd) {
18900
19440
  batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
@@ -18980,6 +19520,10 @@ void llama_synchronize(struct llama_context * ctx) {
18980
19520
  float * llama_get_logits(struct llama_context * ctx) {
18981
19521
  llama_synchronize(ctx);
18982
19522
 
19523
+ // reorder logits for backward compatibility
19524
+ // TODO: maybe deprecate this
19525
+ llama_output_reorder(ctx);
19526
+
18983
19527
  return ctx->logits;
18984
19528
  }
18985
19529
 
@@ -19024,6 +19568,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
19024
19568
  float * llama_get_embeddings(struct llama_context * ctx) {
19025
19569
  llama_synchronize(ctx);
19026
19570
 
19571
+ // reorder embeddings for backward compatibility
19572
+ // TODO: maybe deprecate this
19573
+ llama_output_reorder(ctx);
19574
+
19027
19575
  return ctx->embd;
19028
19576
  }
19029
19577