cui-llama.rn 1.1.0 → 1.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/common.cpp +19 -6
- package/cpp/ggml-aarch64.c +6 -21
- package/cpp/ggml-metal.m +154 -26
- package/cpp/ggml.c +115 -195
- package/cpp/ggml.h +5 -7
- package/cpp/llama-impl.h +10 -4
- package/cpp/llama-sampling.cpp +16 -14
- package/cpp/llama.cpp +1048 -500
- package/cpp/llama.h +3 -0
- package/package.json +1 -1
package/cpp/llama.cpp
CHANGED
@@ -2527,10 +2527,29 @@ struct llama_layer {
|
|
2527
2527
|
struct lm_ggml_tensor * ffn_down_scale;
|
2528
2528
|
};
|
2529
2529
|
|
2530
|
+
// very similar to llama_batch,
|
2531
|
+
// but has more metadata about sequences
|
2532
|
+
struct llama_ubatch {
|
2533
|
+
bool equal_seqs;
|
2534
|
+
// TODO: whole_seqs for embeddings?
|
2535
|
+
|
2536
|
+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
2537
|
+
uint32_t n_seq_tokens; // tokens per sequence
|
2538
|
+
uint32_t n_seqs;
|
2539
|
+
|
2540
|
+
llama_token * token; // [n_tokens]
|
2541
|
+
float * embd; // [n_embd, n_tokens]
|
2542
|
+
llama_pos * pos; // [n_tokens]
|
2543
|
+
int32_t * n_seq_id; // [n_seqs]
|
2544
|
+
llama_seq_id ** seq_id; // [n_seqs]
|
2545
|
+
int8_t * output; // [n_tokens]
|
2546
|
+
};
|
2547
|
+
|
2530
2548
|
struct llama_kv_cell {
|
2531
2549
|
llama_pos pos = -1;
|
2532
2550
|
llama_pos delta = 0;
|
2533
|
-
int32_t src =
|
2551
|
+
int32_t src = -1; // used by recurrent state models to copy states
|
2552
|
+
int32_t tail = -1;
|
2534
2553
|
|
2535
2554
|
std::set<llama_seq_id> seq_id;
|
2536
2555
|
|
@@ -2551,7 +2570,6 @@ struct llama_kv_cell {
|
|
2551
2570
|
struct llama_kv_cache {
|
2552
2571
|
bool has_shift = false;
|
2553
2572
|
bool do_defrag = false;
|
2554
|
-
bool do_copy = false;
|
2555
2573
|
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
2556
2574
|
bool v_trans = true; // the value tensor is transposed
|
2557
2575
|
|
@@ -2714,6 +2732,340 @@ struct llama_model {
|
|
2714
2732
|
}
|
2715
2733
|
};
|
2716
2734
|
|
2735
|
+
struct llama_sbatch_seq {
|
2736
|
+
int32_t n_seq_id;
|
2737
|
+
llama_seq_id * seq_id;
|
2738
|
+
size_t offset;
|
2739
|
+
size_t length;
|
2740
|
+
|
2741
|
+
// helper for smoother batch API transition -- can be deprecated in the future
|
2742
|
+
llama_seq_id all_seq_id; // used if seq_id == NULL
|
2743
|
+
};
|
2744
|
+
|
2745
|
+
// sequence-length-aware batch splitting
|
2746
|
+
struct llama_sbatch {
|
2747
|
+
// tokens left in this batch
|
2748
|
+
size_t n_tokens;
|
2749
|
+
|
2750
|
+
size_t n_embd;
|
2751
|
+
|
2752
|
+
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
2753
|
+
|
2754
|
+
// sorted indices into the batch
|
2755
|
+
std::vector<size_t> ids;
|
2756
|
+
// batch indices of the output
|
2757
|
+
std::vector<size_t> out_ids;
|
2758
|
+
std::vector<llama_sbatch_seq> seq;
|
2759
|
+
const llama_batch * batch = nullptr;
|
2760
|
+
|
2761
|
+
// buffers for the ubatch
|
2762
|
+
std::vector<llama_token> ubatch_token;
|
2763
|
+
std::vector<float> ubatch_embd;
|
2764
|
+
std::vector<llama_pos> ubatch_pos;
|
2765
|
+
std::vector<int32_t> ubatch_n_seq_id;
|
2766
|
+
std::vector<llama_seq_id *> ubatch_seq_id;
|
2767
|
+
std::vector<int8_t> ubatch_output;
|
2768
|
+
|
2769
|
+
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
|
2770
|
+
// clear empty sequences
|
2771
|
+
// the previous ubatch is assumed to be gone,
|
2772
|
+
// so nothing should refer to values in these sequences anymore.
|
2773
|
+
for (size_t i = seq.size(); i-- > 0;) {
|
2774
|
+
if (seq[i].length == 0) {
|
2775
|
+
seq.pop_back();
|
2776
|
+
} else {
|
2777
|
+
break;
|
2778
|
+
}
|
2779
|
+
}
|
2780
|
+
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
2781
|
+
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
2782
|
+
ubatch_pos.resize(n_ubatch);
|
2783
|
+
ubatch_n_seq_id.resize(n_ubatch);
|
2784
|
+
ubatch_seq_id.resize(n_ubatch);
|
2785
|
+
ubatch_output.resize(n_ubatch);
|
2786
|
+
llama_ubatch ubatch = {
|
2787
|
+
/*equal_seqs =*/ true,
|
2788
|
+
/*n_tokens =*/ 0,
|
2789
|
+
/*n_seq_tokens =*/ 0,
|
2790
|
+
/*n_seqs =*/ 0,
|
2791
|
+
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
2792
|
+
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
2793
|
+
/*pos =*/ ubatch_pos.data(),
|
2794
|
+
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
2795
|
+
/*seq_id =*/ ubatch_seq_id.data(),
|
2796
|
+
/*output =*/ ubatch_output.data(),
|
2797
|
+
};
|
2798
|
+
return ubatch;
|
2799
|
+
}
|
2800
|
+
|
2801
|
+
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
2802
|
+
LM_GGML_ASSERT(batch != nullptr);
|
2803
|
+
LM_GGML_ASSERT(length <= seq.length);
|
2804
|
+
// Can only add sequences of equal lengths to a batch,
|
2805
|
+
// otherwise it isn't clear to which sequence a token belongs
|
2806
|
+
LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
|
2807
|
+
LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
2808
|
+
// NOTE: loops are separated for cache-friendliness
|
2809
|
+
if (batch->token) {
|
2810
|
+
if (ubatch.equal_seqs) {
|
2811
|
+
for (size_t i = 0; i < length; ++i) {
|
2812
|
+
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
2813
|
+
}
|
2814
|
+
} else {
|
2815
|
+
// simple split
|
2816
|
+
ubatch.token = batch->token + seq.offset;
|
2817
|
+
}
|
2818
|
+
} else {
|
2819
|
+
ubatch.token = nullptr;
|
2820
|
+
}
|
2821
|
+
if (batch->embd) {
|
2822
|
+
if (ubatch.equal_seqs) {
|
2823
|
+
for (size_t i = 0; i < length; ++i) {
|
2824
|
+
memcpy(
|
2825
|
+
ubatch.embd + n_embd * (ubatch.n_tokens + i),
|
2826
|
+
batch->embd + n_embd * ids[seq.offset + i],
|
2827
|
+
n_embd * sizeof(float)
|
2828
|
+
);
|
2829
|
+
}
|
2830
|
+
} else {
|
2831
|
+
// simple split
|
2832
|
+
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
2833
|
+
}
|
2834
|
+
} else {
|
2835
|
+
ubatch.embd = nullptr;
|
2836
|
+
}
|
2837
|
+
// from here on, the else branches are deprecated;
|
2838
|
+
// they are helpers for smoother batch API transition
|
2839
|
+
if (batch->pos) {
|
2840
|
+
if (ubatch.equal_seqs) {
|
2841
|
+
for (size_t i = 0; i < length; ++i) {
|
2842
|
+
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
2843
|
+
}
|
2844
|
+
} else {
|
2845
|
+
// simple split
|
2846
|
+
ubatch.pos = batch->pos + seq.offset;
|
2847
|
+
}
|
2848
|
+
} else {
|
2849
|
+
for (size_t i = 0; i < length; ++i) {
|
2850
|
+
llama_pos bi = ids[seq.offset + i];
|
2851
|
+
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
2852
|
+
}
|
2853
|
+
}
|
2854
|
+
if (ubatch.equal_seqs) {
|
2855
|
+
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
2856
|
+
if (seq.seq_id) {
|
2857
|
+
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
2858
|
+
} else {
|
2859
|
+
LM_GGML_ASSERT(seq.n_seq_id == 1);
|
2860
|
+
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
|
2861
|
+
}
|
2862
|
+
} else {
|
2863
|
+
// simple split
|
2864
|
+
if (batch->n_seq_id) {
|
2865
|
+
for (size_t i = 0; i < length; ++i) {
|
2866
|
+
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
2867
|
+
}
|
2868
|
+
} else {
|
2869
|
+
for (size_t i = 0; i < length; ++i) {
|
2870
|
+
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
2871
|
+
}
|
2872
|
+
}
|
2873
|
+
if (batch->seq_id) {
|
2874
|
+
for (size_t i = 0; i < length; ++i) {
|
2875
|
+
ubatch.seq_id = batch->seq_id + seq.offset;
|
2876
|
+
}
|
2877
|
+
} else {
|
2878
|
+
for (size_t i = 0; i < length; ++i) {
|
2879
|
+
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
|
2880
|
+
}
|
2881
|
+
}
|
2882
|
+
}
|
2883
|
+
if (logits_all) {
|
2884
|
+
for (size_t i = 0; i < length; ++i) {
|
2885
|
+
ubatch.output[ubatch.n_tokens + i] = 1;
|
2886
|
+
out_ids.push_back(ids[seq.offset + i]);
|
2887
|
+
}
|
2888
|
+
} else if (batch->logits) {
|
2889
|
+
if (ubatch.equal_seqs) {
|
2890
|
+
for (size_t i = 0; i < length; ++i) {
|
2891
|
+
size_t id = ids[seq.offset + i];
|
2892
|
+
int8_t is_output = batch->logits[id];
|
2893
|
+
ubatch.output[ubatch.n_tokens + i] = is_output;
|
2894
|
+
if (is_output) { out_ids.push_back(id); }
|
2895
|
+
}
|
2896
|
+
} else {
|
2897
|
+
// simple split
|
2898
|
+
ubatch.output = batch->logits + seq.offset;
|
2899
|
+
for (size_t i = 0; i < length; ++i) {
|
2900
|
+
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
2901
|
+
}
|
2902
|
+
}
|
2903
|
+
} else {
|
2904
|
+
// only get last output
|
2905
|
+
for (size_t i = 0; i < length; ++i) {
|
2906
|
+
size_t id = ids[seq.offset + i];
|
2907
|
+
int8_t is_last = id == ids.size() - 1;
|
2908
|
+
ubatch.output[ubatch.n_tokens + i] = is_last;
|
2909
|
+
if (is_last) { out_ids.push_back(id); }
|
2910
|
+
}
|
2911
|
+
}
|
2912
|
+
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
2913
|
+
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
2914
|
+
}
|
2915
|
+
ubatch.n_tokens += length;
|
2916
|
+
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
2917
|
+
seq.offset += length;
|
2918
|
+
seq.length -= length;
|
2919
|
+
n_tokens -= length;
|
2920
|
+
LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
2921
|
+
}
|
2922
|
+
|
2923
|
+
// simple split, unknown number of sequences of unequal lengths
|
2924
|
+
llama_ubatch split_simple(size_t n_ubatch) {
|
2925
|
+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
2926
|
+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
2927
|
+
ubatch.equal_seqs = false;
|
2928
|
+
if (!seq.empty()) {
|
2929
|
+
llama_sbatch_seq & s = seq[0];
|
2930
|
+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
2931
|
+
LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
2932
|
+
add_seq_to_ubatch(ubatch, s, length);
|
2933
|
+
}
|
2934
|
+
return ubatch;
|
2935
|
+
}
|
2936
|
+
|
2937
|
+
// make batches of equal-length sequences
|
2938
|
+
llama_ubatch split_equal(size_t n_ubatch) {
|
2939
|
+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
2940
|
+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
2941
|
+
if (!seq.empty()) {
|
2942
|
+
size_t length = 0;
|
2943
|
+
size_t n_tokens_in_ubatch = 0;
|
2944
|
+
LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
2945
|
+
// smallest first, because it's easier to split this way;
|
2946
|
+
// starting from the end to pop in constant time.
|
2947
|
+
for (size_t i = seq.size(); i-- > 0;) {
|
2948
|
+
llama_sbatch_seq & s = seq[i];
|
2949
|
+
LM_GGML_ASSERT(s.length > 0);
|
2950
|
+
if (length == 0) {
|
2951
|
+
length = s.length < n_ubatch ? s.length : n_ubatch;
|
2952
|
+
}
|
2953
|
+
add_seq_to_ubatch(ubatch, s, length);
|
2954
|
+
n_tokens_in_ubatch += length;
|
2955
|
+
// shared prompts can't be mixed with any of their sequences,
|
2956
|
+
// so it's safer to compute them in their own ubatch
|
2957
|
+
if (s.n_seq_id > 1) { break; }
|
2958
|
+
// stop when there isn't enough space for another sequence
|
2959
|
+
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
2960
|
+
}
|
2961
|
+
}
|
2962
|
+
return ubatch;
|
2963
|
+
}
|
2964
|
+
|
2965
|
+
// sequence-wise split
|
2966
|
+
llama_ubatch split_seq(size_t n_ubatch) {
|
2967
|
+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
2968
|
+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
2969
|
+
if (!seq.empty()) {
|
2970
|
+
llama_sbatch_seq & s = seq[seq.size() - 1];
|
2971
|
+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
2972
|
+
LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
2973
|
+
add_seq_to_ubatch(ubatch, s, length);
|
2974
|
+
}
|
2975
|
+
return ubatch;
|
2976
|
+
}
|
2977
|
+
|
2978
|
+
void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
|
2979
|
+
LM_GGML_ASSERT(batch.n_tokens >= 0);
|
2980
|
+
this->batch = &batch;
|
2981
|
+
this->n_embd = n_embd;
|
2982
|
+
this->logits_all = logits_all;
|
2983
|
+
|
2984
|
+
n_tokens = batch.n_tokens;
|
2985
|
+
ids.resize(n_tokens);
|
2986
|
+
out_ids.clear();
|
2987
|
+
// TODO: reserve out_ids and seq
|
2988
|
+
|
2989
|
+
for (size_t i = 0; i < n_tokens; ++i) {
|
2990
|
+
ids[i] = i;
|
2991
|
+
}
|
2992
|
+
if (simple_split) {
|
2993
|
+
seq.resize(1);
|
2994
|
+
llama_sbatch_seq & s = seq[0];
|
2995
|
+
s.n_seq_id = 0;
|
2996
|
+
s.seq_id = nullptr;
|
2997
|
+
s.offset = 0;
|
2998
|
+
s.length = n_tokens;
|
2999
|
+
s.all_seq_id = batch.all_seq_id;
|
3000
|
+
return;
|
3001
|
+
}
|
3002
|
+
std::sort(ids.begin(), ids.end(),
|
3003
|
+
[&batch](size_t a, size_t b) {
|
3004
|
+
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
3005
|
+
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
3006
|
+
// sort by seq_id, then by pos
|
3007
|
+
if (n_seq_a == n_seq_b) {
|
3008
|
+
if (batch.seq_id) {
|
3009
|
+
for (int32_t i = 0; i < n_seq_a; ++i) {
|
3010
|
+
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
3011
|
+
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
3012
|
+
// smaller seq_ids go first
|
3013
|
+
if (seq_id_a != seq_id_b) {
|
3014
|
+
return seq_id_a < seq_id_b;
|
3015
|
+
}
|
3016
|
+
}
|
3017
|
+
}
|
3018
|
+
// when all else is equal, sort by pos
|
3019
|
+
if (batch.pos) {
|
3020
|
+
return batch.pos[a] < batch.pos[b];
|
3021
|
+
}
|
3022
|
+
// no pos, sort by id (assuming batch.all_pos_1 is positive)
|
3023
|
+
return a < b;
|
3024
|
+
}
|
3025
|
+
// shared prompts go first
|
3026
|
+
return n_seq_a > n_seq_b;
|
3027
|
+
}
|
3028
|
+
);
|
3029
|
+
// init seq
|
3030
|
+
llama_sbatch_seq * last_seq = nullptr;
|
3031
|
+
|
3032
|
+
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
|
3033
|
+
for (size_t i = 0; i < n_tokens; ++i) {
|
3034
|
+
const size_t bi = ids[i];
|
3035
|
+
const int32_t n_seqs = batch.n_seq_id[bi];
|
3036
|
+
llama_seq_id * seq_ids = batch.seq_id[bi];
|
3037
|
+
if (last_seq != nullptr) {
|
3038
|
+
bool same = n_seqs == last_seq->n_seq_id;
|
3039
|
+
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
3040
|
+
if (seq_ids[j] != last_seq->seq_id[j]) {
|
3041
|
+
same = false;
|
3042
|
+
}
|
3043
|
+
}
|
3044
|
+
if (same) {
|
3045
|
+
last_seq->length += 1;
|
3046
|
+
continue;
|
3047
|
+
}
|
3048
|
+
}
|
3049
|
+
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
|
3050
|
+
seq.push_back(new_seq);
|
3051
|
+
last_seq = &seq.back();
|
3052
|
+
}
|
3053
|
+
} else {
|
3054
|
+
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
|
3055
|
+
seq.push_back(new_seq);
|
3056
|
+
}
|
3057
|
+
// keep shared prompts first at the end, then sort by length descending.
|
3058
|
+
std::sort(seq.begin(), seq.end(),
|
3059
|
+
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
3060
|
+
if (a.n_seq_id == b.n_seq_id) {
|
3061
|
+
return a.length > b.length;
|
3062
|
+
}
|
3063
|
+
return a.n_seq_id < b.n_seq_id;
|
3064
|
+
}
|
3065
|
+
);
|
3066
|
+
}
|
3067
|
+
};
|
3068
|
+
|
2717
3069
|
struct llama_context {
|
2718
3070
|
llama_context(const llama_model & model)
|
2719
3071
|
: model(model)
|
@@ -2735,6 +3087,7 @@ struct llama_context {
|
|
2735
3087
|
|
2736
3088
|
struct llama_cparams cparams;
|
2737
3089
|
struct llama_sampling sampling;
|
3090
|
+
struct llama_sbatch sbatch;
|
2738
3091
|
struct llama_kv_cache kv_self;
|
2739
3092
|
struct llama_control_vector cvec;
|
2740
3093
|
|
@@ -2995,8 +3348,7 @@ static bool llama_kv_cache_init(
|
|
2995
3348
|
|
2996
3349
|
cache.has_shift = false;
|
2997
3350
|
|
2998
|
-
|
2999
|
-
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
|
3351
|
+
cache.recurrent = llama_model_is_recurrent(&model);
|
3000
3352
|
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
3001
3353
|
|
3002
3354
|
cache.head = 0;
|
@@ -3009,13 +3361,6 @@ static bool llama_kv_cache_init(
|
|
3009
3361
|
cache.cells.clear();
|
3010
3362
|
cache.cells.resize(kv_size);
|
3011
3363
|
|
3012
|
-
if (cache.recurrent) {
|
3013
|
-
// init state copy sources
|
3014
|
-
for (uint32_t i = 0; i < cache.size; ++i) {
|
3015
|
-
cache.cells[i].src = i;
|
3016
|
-
}
|
3017
|
-
}
|
3018
|
-
|
3019
3364
|
// count used buffer types
|
3020
3365
|
std::map<lm_ggml_backend_buffer_type_t, int> buft_layer_count;
|
3021
3366
|
if (offload) {
|
@@ -3083,46 +3428,162 @@ static bool llama_kv_cache_init(
|
|
3083
3428
|
// to the first cell of the slot.
|
3084
3429
|
static bool llama_kv_cache_find_slot(
|
3085
3430
|
struct llama_kv_cache & cache,
|
3086
|
-
|
3431
|
+
const struct llama_ubatch & batch) {
|
3087
3432
|
const uint32_t n_tokens = batch.n_tokens;
|
3433
|
+
const uint32_t n_seqs = batch.n_seqs;
|
3434
|
+
const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
3088
3435
|
|
3089
3436
|
if (cache.recurrent) {
|
3090
3437
|
// For recurrent state architectures (like Mamba),
|
3091
|
-
// each
|
3092
|
-
|
3093
|
-
|
3094
|
-
|
3095
|
-
|
3096
|
-
|
3097
|
-
|
3098
|
-
|
3099
|
-
|
3100
|
-
|
3101
|
-
|
3102
|
-
|
3103
|
-
|
3104
|
-
|
3105
|
-
|
3438
|
+
// each cache cell can store the state for a whole sequence.
|
3439
|
+
// A slot should be always be contiguous.
|
3440
|
+
|
3441
|
+
// can only process batches with an equal number of new tokens in each sequence
|
3442
|
+
LM_GGML_ASSERT(batch.equal_seqs);
|
3443
|
+
|
3444
|
+
int32_t min = cache.size - 1;
|
3445
|
+
int32_t max = 0;
|
3446
|
+
|
3447
|
+
// everything should fit if all seq_ids are smaller than the max
|
3448
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
3449
|
+
const uint32_t n_seq_id = batch.n_seq_id[s];
|
3450
|
+
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
3451
|
+
const llama_seq_id seq_id = batch.seq_id[s][j];
|
3452
|
+
|
3453
|
+
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
|
3454
|
+
// too big seq_id
|
3455
|
+
// TODO: would it be possible to resize the cache instead?
|
3456
|
+
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
3457
|
+
return false;
|
3458
|
+
}
|
3459
|
+
if (j > 0) {
|
3460
|
+
llama_kv_cell & seq = cache.cells[seq_id];
|
3461
|
+
if (seq.tail >= 0) {
|
3462
|
+
llama_kv_cell & cell = cache.cells[seq.tail];
|
3463
|
+
// clear cells from seq_ids that become shared
|
3464
|
+
// (should not normally happen, but let's handle it anyway)
|
3465
|
+
cell.seq_id.erase(seq_id);
|
3466
|
+
seq.tail = -1;
|
3467
|
+
if (cell.seq_id.empty()) {
|
3468
|
+
cell.pos = -1;
|
3469
|
+
cell.src = -1;
|
3470
|
+
cache.used -= 1;
|
3471
|
+
}
|
3106
3472
|
}
|
3107
|
-
|
3108
|
-
|
3109
|
-
|
3110
|
-
|
3111
|
-
|
3112
|
-
|
3473
|
+
}
|
3474
|
+
}
|
3475
|
+
}
|
3476
|
+
|
3477
|
+
#ifndef NDEBUG
|
3478
|
+
{
|
3479
|
+
std::vector<int32_t> tails_verif;
|
3480
|
+
tails_verif.assign(cache.size, -1);
|
3481
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
3482
|
+
llama_kv_cell & cell = cache.cells[i];
|
3483
|
+
for (llama_seq_id seq_id : cell.seq_id) {
|
3484
|
+
if (tails_verif[seq_id] != -1) {
|
3485
|
+
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
3113
3486
|
}
|
3114
|
-
|
3115
|
-
|
3487
|
+
tails_verif[seq_id] = i;
|
3488
|
+
}
|
3489
|
+
}
|
3490
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
3491
|
+
if (tails_verif[i] != cache.cells[i].tail) {
|
3492
|
+
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
|
3493
|
+
}
|
3494
|
+
}
|
3495
|
+
}
|
3496
|
+
#endif
|
3497
|
+
|
3498
|
+
// find next empty cell
|
3499
|
+
uint32_t next_empty_cell = cache.head;
|
3500
|
+
|
3501
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
3502
|
+
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
3503
|
+
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
3504
|
+
if (cell.is_empty()) { break; }
|
3505
|
+
next_empty_cell += 1;
|
3506
|
+
}
|
3507
|
+
|
3508
|
+
// find usable cell range
|
3509
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
3510
|
+
const llama_seq_id seq_id = batch.seq_id[s][0];
|
3511
|
+
llama_kv_cell & seq_meta = cache.cells[seq_id];
|
3512
|
+
bool has_cell = false;
|
3513
|
+
if (seq_meta.tail >= 0) {
|
3514
|
+
llama_kv_cell & cell = cache.cells[seq_meta.tail];
|
3515
|
+
LM_GGML_ASSERT(cell.has_seq_id(seq_id));
|
3516
|
+
// does this seq_id "own" the cell?
|
3517
|
+
if (cell.seq_id.size() == 1) { has_cell = true; }
|
3518
|
+
}
|
3519
|
+
if (!has_cell) {
|
3520
|
+
llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
|
3521
|
+
LM_GGML_ASSERT(empty_cell.is_empty());
|
3522
|
+
// copy old tail into the empty cell
|
3523
|
+
if (seq_meta.tail >= 0) {
|
3524
|
+
llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
|
3525
|
+
empty_cell.pos = orig_cell.pos;
|
3526
|
+
empty_cell.src = orig_cell.src;
|
3527
|
+
orig_cell.seq_id.erase(seq_id);
|
3528
|
+
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
3529
|
+
}
|
3530
|
+
seq_meta.tail = next_empty_cell;
|
3531
|
+
// find next empty cell
|
3532
|
+
if (s + 1 < n_seqs) {
|
3533
|
+
next_empty_cell += 1;
|
3534
|
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
3535
|
+
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
3536
|
+
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
3537
|
+
if (cell.is_empty()) { break; }
|
3538
|
+
next_empty_cell += 1;
|
3116
3539
|
}
|
3117
|
-
cache.cells[seq_id].pos = batch.pos[i];
|
3118
|
-
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
|
3119
|
-
} else {
|
3120
|
-
// too big seq_id
|
3121
|
-
// TODO: would it be possible to resize the KV cache size instead?
|
3122
|
-
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
3123
|
-
return false;
|
3124
3540
|
}
|
3125
3541
|
}
|
3542
|
+
if (min > seq_meta.tail) { min = seq_meta.tail; }
|
3543
|
+
if (max < seq_meta.tail) { max = seq_meta.tail; }
|
3544
|
+
}
|
3545
|
+
|
3546
|
+
// gather and re-order
|
3547
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
3548
|
+
int32_t dst_id = s + min;
|
3549
|
+
int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
|
3550
|
+
if (dst_id != src_id) {
|
3551
|
+
llama_kv_cell & dst_cell = cache.cells[dst_id];
|
3552
|
+
llama_kv_cell & src_cell = cache.cells[src_id];
|
3553
|
+
|
3554
|
+
std::swap(dst_cell.pos, src_cell.pos);
|
3555
|
+
std::swap(dst_cell.src, src_cell.src);
|
3556
|
+
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
3557
|
+
|
3558
|
+
// swap tails (assuming they NEVER overlap)
|
3559
|
+
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
3560
|
+
cache.cells[seq_id].tail = src_id;
|
3561
|
+
}
|
3562
|
+
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
3563
|
+
cache.cells[seq_id].tail = dst_id;
|
3564
|
+
}
|
3565
|
+
}
|
3566
|
+
}
|
3567
|
+
|
3568
|
+
// update the pos of the used seqs
|
3569
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
3570
|
+
const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
3571
|
+
int32_t cell_id = s + min;
|
3572
|
+
llama_kv_cell & cell = cache.cells[cell_id];
|
3573
|
+
|
3574
|
+
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
3575
|
+
// What should happen when the pos backtracks or skips a value?
|
3576
|
+
// Clearing the state mid-batch would require special-casing which isn't done.
|
3577
|
+
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
3578
|
+
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
|
3579
|
+
}
|
3580
|
+
cell.pos = last_pos;
|
3581
|
+
cell.seq_id.clear();
|
3582
|
+
for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
|
3583
|
+
const llama_seq_id seq_id = batch.seq_id[s][j];
|
3584
|
+
cell.seq_id.insert(seq_id);
|
3585
|
+
cache.cells[seq_id].tail = cell_id;
|
3586
|
+
}
|
3126
3587
|
}
|
3127
3588
|
|
3128
3589
|
// allow getting the range of used cells, from head to head + n
|
@@ -3130,7 +3591,7 @@ static bool llama_kv_cache_find_slot(
|
|
3130
3591
|
cache.n = max - min + 1;
|
3131
3592
|
|
3132
3593
|
// sanity check
|
3133
|
-
return
|
3594
|
+
return cache.n >= n_seqs;
|
3134
3595
|
}
|
3135
3596
|
// otherwise, one cell per token.
|
3136
3597
|
|
@@ -3168,11 +3629,14 @@ static bool llama_kv_cache_find_slot(
|
|
3168
3629
|
}
|
3169
3630
|
}
|
3170
3631
|
|
3171
|
-
for (uint32_t
|
3172
|
-
|
3632
|
+
for (uint32_t s = 0; s < n_seqs; s++) {
|
3633
|
+
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
3634
|
+
uint32_t k = s*n_seq_tokens + i;
|
3635
|
+
cache.cells[cache.head + k].pos = batch.pos[k];
|
3173
3636
|
|
3174
|
-
|
3175
|
-
|
3637
|
+
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
|
3638
|
+
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
|
3639
|
+
}
|
3176
3640
|
}
|
3177
3641
|
}
|
3178
3642
|
|
@@ -3198,6 +3662,8 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
|
3198
3662
|
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
|
3199
3663
|
cache.cells[i].pos = -1;
|
3200
3664
|
cache.cells[i].seq_id.clear();
|
3665
|
+
cache.cells[i].src = -1;
|
3666
|
+
cache.cells[i].tail = -1;
|
3201
3667
|
}
|
3202
3668
|
cache.head = 0;
|
3203
3669
|
cache.used = 0;
|
@@ -3224,9 +3690,16 @@ static bool llama_kv_cache_seq_rm(
|
|
3224
3690
|
return false;
|
3225
3691
|
}
|
3226
3692
|
if (0 <= seq_id) {
|
3227
|
-
|
3228
|
-
if (
|
3229
|
-
|
3693
|
+
int32_t & tail_id = cache.cells[seq_id].tail;
|
3694
|
+
if (tail_id >= 0) {
|
3695
|
+
const llama_kv_cell & cell = cache.cells[tail_id];
|
3696
|
+
// partial intersection is invalid
|
3697
|
+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
3698
|
+
return false;
|
3699
|
+
}
|
3700
|
+
if (p0 <= cell.pos && p1 < cell.pos) {
|
3701
|
+
tail_id = -1;
|
3702
|
+
}
|
3230
3703
|
}
|
3231
3704
|
} else {
|
3232
3705
|
// seq_id is negative, then the range should include everything or nothing
|
@@ -3250,6 +3723,7 @@ static bool llama_kv_cache_seq_rm(
|
|
3250
3723
|
if (cache.cells[i].pos >= 0) cache.used--;
|
3251
3724
|
|
3252
3725
|
cache.cells[i].pos = -1;
|
3726
|
+
cache.cells[i].src = -1;
|
3253
3727
|
if (new_head == cache.size) new_head = i;
|
3254
3728
|
}
|
3255
3729
|
}
|
@@ -3272,23 +3746,29 @@ static void llama_kv_cache_seq_cp(
|
|
3272
3746
|
|
3273
3747
|
if (cache.recurrent) {
|
3274
3748
|
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
3275
|
-
|
3276
|
-
|
3277
|
-
|
3278
|
-
|
3279
|
-
|
3280
|
-
|
3281
|
-
|
3282
|
-
|
3283
|
-
|
3284
|
-
|
3285
|
-
|
3749
|
+
llama_kv_cell & tail_src = cache.cells[seq_id_src];
|
3750
|
+
llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
|
3751
|
+
if (tail_dst.tail >= 0) {
|
3752
|
+
// clear destination seq_id if it wasn't empty
|
3753
|
+
llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
|
3754
|
+
|
3755
|
+
cell_dst.seq_id.erase(seq_id_dst);
|
3756
|
+
tail_dst.tail = -1;
|
3757
|
+
if (cell_dst.seq_id.empty()) {
|
3758
|
+
cell_dst.pos = -1;
|
3759
|
+
cell_dst.delta = -1;
|
3760
|
+
cell_dst.src = -1;
|
3761
|
+
cache.used -= 1;
|
3762
|
+
}
|
3286
3763
|
}
|
3764
|
+
if (tail_src.tail >= 0) {
|
3765
|
+
llama_kv_cell & cell_src = cache.cells[tail_src.tail];
|
3287
3766
|
|
3288
|
-
|
3289
|
-
|
3290
|
-
|
3767
|
+
cell_src.seq_id.insert(seq_id_dst);
|
3768
|
+
tail_dst.tail = tail_src.tail;
|
3769
|
+
}
|
3291
3770
|
}
|
3771
|
+
|
3292
3772
|
return;
|
3293
3773
|
}
|
3294
3774
|
// otherwise, this is the KV cache of a Transformer-like model
|
@@ -3306,9 +3786,13 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
|
|
3306
3786
|
uint32_t new_head = cache.size;
|
3307
3787
|
|
3308
3788
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
3789
|
+
if (cache.recurrent && (llama_seq_id) i != seq_id) {
|
3790
|
+
cache.cells[i].tail = -1;
|
3791
|
+
}
|
3309
3792
|
if (!cache.cells[i].has_seq_id(seq_id)) {
|
3310
3793
|
if (cache.cells[i].pos >= 0) cache.used--;
|
3311
3794
|
cache.cells[i].pos = -1;
|
3795
|
+
cache.cells[i].src = -1;
|
3312
3796
|
cache.cells[i].seq_id.clear();
|
3313
3797
|
if (new_head == cache.size) new_head = i;
|
3314
3798
|
} else {
|
@@ -3337,9 +3821,12 @@ static void llama_kv_cache_seq_add(
|
|
3337
3821
|
if (cache.recurrent) {
|
3338
3822
|
// for Mamba-like models, only the pos needs to be shifted
|
3339
3823
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
3340
|
-
|
3341
|
-
if (
|
3342
|
-
cell
|
3824
|
+
const int32_t tail_id = cache.cells[seq_id].tail;
|
3825
|
+
if (tail_id >= 0) {
|
3826
|
+
llama_kv_cell & cell = cache.cells[tail_id];
|
3827
|
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
3828
|
+
cell.pos += delta;
|
3829
|
+
}
|
3343
3830
|
}
|
3344
3831
|
}
|
3345
3832
|
return;
|
@@ -3383,9 +3870,12 @@ static void llama_kv_cache_seq_div(
|
|
3383
3870
|
if (cache.recurrent) {
|
3384
3871
|
// for Mamba-like models, only the pos needs to be changed
|
3385
3872
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
3386
|
-
|
3387
|
-
if (
|
3388
|
-
cell
|
3873
|
+
const int32_t tail_id = cache.cells[seq_id].tail;
|
3874
|
+
if (tail_id >= 0) {
|
3875
|
+
llama_kv_cell & cell = cache.cells[tail_id];
|
3876
|
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
3877
|
+
cell.pos /= d;
|
3878
|
+
}
|
3389
3879
|
}
|
3390
3880
|
}
|
3391
3881
|
return;
|
@@ -3417,7 +3907,9 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
|
|
3417
3907
|
}
|
3418
3908
|
|
3419
3909
|
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
3420
|
-
cache.
|
3910
|
+
if (!cache.recurrent) {
|
3911
|
+
cache.do_defrag = true;
|
3912
|
+
}
|
3421
3913
|
}
|
3422
3914
|
|
3423
3915
|
static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
|
@@ -6124,6 +6616,7 @@ static bool llm_load_tensors(
|
|
6124
6616
|
const int64_t n_embd_gqa = n_embd_v_gqa;
|
6125
6617
|
const int64_t n_vocab = hparams.n_vocab;
|
6126
6618
|
const int64_t n_vocab_type = hparams.n_vocab_type;
|
6619
|
+
const int64_t n_rot = hparams.n_rot;
|
6127
6620
|
const int64_t n_expert = hparams.n_expert;
|
6128
6621
|
const int64_t n_expert_used = hparams.n_expert_used;
|
6129
6622
|
const int64_t n_ctx_train = hparams.n_ctx_train;
|
@@ -6181,7 +6674,7 @@ static bool llm_load_tensors(
|
|
6181
6674
|
|
6182
6675
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
6183
6676
|
|
6184
|
-
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {
|
6677
|
+
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
6185
6678
|
|
6186
6679
|
if (n_expert == 0) {
|
6187
6680
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
@@ -7634,8 +8127,8 @@ static bool llm_load_tensors(
|
|
7634
8127
|
|
7635
8128
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
7636
8129
|
|
7637
|
-
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd +
|
7638
|
-
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd +
|
8130
|
+
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
8131
|
+
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
|
7639
8132
|
|
7640
8133
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
7641
8134
|
|
@@ -7712,7 +8205,7 @@ static bool llm_load_tensors(
|
|
7712
8205
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
7713
8206
|
|
7714
8207
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
7715
|
-
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {
|
8208
|
+
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
7716
8209
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
7717
8210
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
7718
8211
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
@@ -7959,7 +8452,7 @@ static struct lm_ggml_tensor * llm_build_inp_embd(
|
|
7959
8452
|
struct lm_ggml_context * ctx,
|
7960
8453
|
struct llama_context & lctx,
|
7961
8454
|
const llama_hparams & hparams,
|
7962
|
-
|
8455
|
+
const llama_ubatch & batch,
|
7963
8456
|
struct lm_ggml_tensor * tok_embd,
|
7964
8457
|
const llm_build_cb & cb) {
|
7965
8458
|
const int64_t n_embd = hparams.n_embd;
|
@@ -8393,9 +8886,10 @@ static struct lm_ggml_tensor * llm_build_kqv(
|
|
8393
8886
|
0);
|
8394
8887
|
cb(v, "v", il);
|
8395
8888
|
|
8396
|
-
cur = lm_ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias
|
8889
|
+
cur = lm_ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
8890
|
+
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
8397
8891
|
|
8398
|
-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
8892
|
+
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
|
8399
8893
|
lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
|
8400
8894
|
}
|
8401
8895
|
|
@@ -8404,7 +8898,7 @@ static struct lm_ggml_tensor * llm_build_kqv(
|
|
8404
8898
|
struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx, k, q);
|
8405
8899
|
cb(kq, "kq", il);
|
8406
8900
|
|
8407
|
-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON) {
|
8901
|
+
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
|
8408
8902
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
8409
8903
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
8410
8904
|
lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32);
|
@@ -8508,12 +9002,180 @@ static struct lm_ggml_tensor * llm_build_kv(
|
|
8508
9002
|
return cur;
|
8509
9003
|
}
|
8510
9004
|
|
9005
|
+
static struct lm_ggml_tensor * llm_build_copy_mask_state(
|
9006
|
+
struct lm_ggml_context * ctx,
|
9007
|
+
struct lm_ggml_cgraph * graph,
|
9008
|
+
struct lm_ggml_tensor * s,
|
9009
|
+
struct lm_ggml_tensor * state_copy,
|
9010
|
+
struct lm_ggml_tensor * state_mask,
|
9011
|
+
int32_t n_state,
|
9012
|
+
int32_t kv_size,
|
9013
|
+
int32_t kv_head,
|
9014
|
+
int32_t n_kv,
|
9015
|
+
int32_t n_seqs) {
|
9016
|
+
struct lm_ggml_tensor * states = lm_ggml_reshape_2d(ctx, s, n_state, kv_size);
|
9017
|
+
|
9018
|
+
// copy states
|
9019
|
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
9020
|
+
// this shrinks the tensors's ne[1] to n_kv
|
9021
|
+
states = lm_ggml_get_rows(ctx, states, state_copy);
|
9022
|
+
|
9023
|
+
// clear states of sequences which are starting at the beginning of this batch
|
9024
|
+
// FIXME: zero-out NANs?
|
9025
|
+
states = lm_ggml_mul(ctx, states, state_mask);
|
9026
|
+
|
9027
|
+
// copy states which won't be changed further (between n_seqs and n_rs)
|
9028
|
+
lm_ggml_build_forward_expand(graph,
|
9029
|
+
lm_ggml_cpy(ctx,
|
9030
|
+
lm_ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*lm_ggml_element_size(states)),
|
9031
|
+
lm_ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*lm_ggml_element_size(s))));
|
9032
|
+
|
9033
|
+
// the part of the states that will be used and modified
|
9034
|
+
return lm_ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
|
9035
|
+
}
|
9036
|
+
|
9037
|
+
// TODO: split
|
9038
|
+
static struct lm_ggml_tensor * llm_build_mamba(
|
9039
|
+
struct lm_ggml_context * ctx,
|
9040
|
+
struct llama_context & lctx,
|
9041
|
+
const llama_ubatch & batch,
|
9042
|
+
struct lm_ggml_cgraph * graph,
|
9043
|
+
struct lm_ggml_tensor * cur,
|
9044
|
+
struct lm_ggml_tensor * state_copy,
|
9045
|
+
struct lm_ggml_tensor * state_mask,
|
9046
|
+
int32_t kv_head,
|
9047
|
+
int32_t n_kv,
|
9048
|
+
const llm_build_cb & cb,
|
9049
|
+
int il) {
|
9050
|
+
const llama_model & model = lctx.model;
|
9051
|
+
const llama_hparams & hparams = model.hparams;
|
9052
|
+
const llama_kv_cache & kv = lctx.kv_self;
|
9053
|
+
const int64_t d_conv = hparams.ssm_d_conv;
|
9054
|
+
const int64_t d_inner = hparams.ssm_d_inner;
|
9055
|
+
const int64_t d_state = hparams.ssm_d_state;
|
9056
|
+
const int64_t dt_rank = hparams.ssm_dt_rank;
|
9057
|
+
const int64_t n_seqs = batch.n_seqs;
|
9058
|
+
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
9059
|
+
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
9060
|
+
// Use the same RMS norm as the final layer norm
|
9061
|
+
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
9062
|
+
|
9063
|
+
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
9064
|
+
|
9065
|
+
LM_GGML_ASSERT(n_seqs != 0);
|
9066
|
+
LM_GGML_ASSERT(batch.equal_seqs);
|
9067
|
+
LM_GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
|
9068
|
+
|
9069
|
+
struct lm_ggml_tensor * conv_states_all = kv.k_l[il];
|
9070
|
+
struct lm_ggml_tensor * ssm_states_all = kv.v_l[il];
|
9071
|
+
|
9072
|
+
// (ab)using the KV cache to store the states
|
9073
|
+
struct lm_ggml_tensor * conv = llm_build_copy_mask_state(ctx,
|
9074
|
+
graph, conv_states_all, state_copy, state_mask,
|
9075
|
+
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
|
9076
|
+
conv = lm_ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
|
9077
|
+
struct lm_ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
|
9078
|
+
graph, ssm_states_all, state_copy, state_mask,
|
9079
|
+
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
|
9080
|
+
ssm = lm_ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
|
9081
|
+
|
9082
|
+
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
9083
|
+
cur = lm_ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
9084
|
+
|
9085
|
+
// {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
|
9086
|
+
struct lm_ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
|
9087
|
+
// split the above in two
|
9088
|
+
// => {d_inner, n_seq_tokens, n_seqs}
|
9089
|
+
struct lm_ggml_tensor * x = lm_ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
|
9090
|
+
struct lm_ggml_tensor * z = lm_ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*lm_ggml_element_size(xz));
|
9091
|
+
|
9092
|
+
// conv
|
9093
|
+
{
|
9094
|
+
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
|
9095
|
+
struct lm_ggml_tensor * conv_x = lm_ggml_concat(ctx, conv, lm_ggml_transpose(ctx, x), 0);
|
9096
|
+
|
9097
|
+
// copy last (d_conv - 1) columns back into the state cache
|
9098
|
+
struct lm_ggml_tensor * last_conv = lm_ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
|
9099
|
+
|
9100
|
+
lm_ggml_build_forward_expand(graph,
|
9101
|
+
lm_ggml_cpy(ctx, last_conv,
|
9102
|
+
lm_ggml_view_1d(ctx, conv_states_all,
|
9103
|
+
(d_conv - 1)*(d_inner)*(n_seqs),
|
9104
|
+
kv_head*(d_conv - 1)*(d_inner)*lm_ggml_element_size(conv_states_all))));
|
9105
|
+
|
9106
|
+
// 1D convolution
|
9107
|
+
// The equivalent is to make a self-overlapping view of conv_x
|
9108
|
+
// over d_conv columns at each stride in the 3rd dimension,
|
9109
|
+
// then element-wise multiply that with the conv1d weight,
|
9110
|
+
// then sum the elements of each row,
|
9111
|
+
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
9112
|
+
// then permute away the ne[0] dimension,
|
9113
|
+
// and then you're left with the resulting x tensor.
|
9114
|
+
// For simultaneous sequences, all sequences need to have the same length.
|
9115
|
+
x = lm_ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
|
9116
|
+
|
9117
|
+
// bias
|
9118
|
+
x = lm_ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
|
9119
|
+
|
9120
|
+
x = lm_ggml_silu(ctx, x);
|
9121
|
+
}
|
9122
|
+
|
9123
|
+
// ssm
|
9124
|
+
{
|
9125
|
+
// {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
|
9126
|
+
struct lm_ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
|
9127
|
+
// split
|
9128
|
+
struct lm_ggml_tensor * dt = lm_ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
|
9129
|
+
struct lm_ggml_tensor * B = lm_ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], lm_ggml_element_size(x_db)*dt_rank);
|
9130
|
+
struct lm_ggml_tensor * C = lm_ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], lm_ggml_element_size(x_db)*(dt_rank+d_state));
|
9131
|
+
|
9132
|
+
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
9133
|
+
if (ssm_dt_b_c_rms) {
|
9134
|
+
dt = lm_ggml_rms_norm(ctx, dt, norm_rms_eps);
|
9135
|
+
B = lm_ggml_rms_norm(ctx, B, norm_rms_eps);
|
9136
|
+
C = lm_ggml_rms_norm(ctx, C, norm_rms_eps);
|
9137
|
+
}
|
9138
|
+
|
9139
|
+
// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
|
9140
|
+
dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
|
9141
|
+
dt = lm_ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
|
9142
|
+
|
9143
|
+
// Custom operator to optimize the parallel associative scan
|
9144
|
+
// as described in the Annex D of the Mamba paper.
|
9145
|
+
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
9146
|
+
struct lm_ggml_tensor * y_ssm = lm_ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
|
9147
|
+
|
9148
|
+
// store last states
|
9149
|
+
lm_ggml_build_forward_expand(graph,
|
9150
|
+
lm_ggml_cpy(ctx,
|
9151
|
+
lm_ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
|
9152
|
+
lm_ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*lm_ggml_element_size(ssm_states_all))));
|
9153
|
+
|
9154
|
+
struct lm_ggml_tensor * y = lm_ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
|
9155
|
+
|
9156
|
+
// TODO: skip computing output earlier for unused tokens
|
9157
|
+
|
9158
|
+
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
|
9159
|
+
y = lm_ggml_add(ctx, y, lm_ggml_mul(ctx, x, model.layers[il].ssm_d));
|
9160
|
+
y = lm_ggml_mul(ctx, y, lm_ggml_silu(ctx, lm_ggml_cont(ctx, z)));
|
9161
|
+
|
9162
|
+
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
|
9163
|
+
cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
|
9164
|
+
}
|
9165
|
+
|
9166
|
+
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
9167
|
+
cur = lm_ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
|
9168
|
+
cb(cur, "mamba_out", il);
|
9169
|
+
|
9170
|
+
return cur;
|
9171
|
+
}
|
9172
|
+
|
8511
9173
|
struct llm_build_context {
|
8512
9174
|
const llama_model & model;
|
8513
9175
|
llama_context & lctx;
|
8514
9176
|
const llama_hparams & hparams;
|
8515
9177
|
const llama_cparams & cparams;
|
8516
|
-
const
|
9178
|
+
const llama_ubatch & batch;
|
8517
9179
|
const llama_kv_cache & kv_self;
|
8518
9180
|
|
8519
9181
|
const int64_t n_embd;
|
@@ -8559,7 +9221,7 @@ struct llm_build_context {
|
|
8559
9221
|
// TODO: consider making the entire interface noexcept
|
8560
9222
|
llm_build_context(
|
8561
9223
|
llama_context & lctx,
|
8562
|
-
const
|
9224
|
+
const llama_ubatch & batch,
|
8563
9225
|
const llm_build_cb & cb,
|
8564
9226
|
bool worst_case) :
|
8565
9227
|
model (lctx.model),
|
@@ -8666,29 +9328,6 @@ struct llm_build_context {
|
|
8666
9328
|
return gf;
|
8667
9329
|
}
|
8668
9330
|
|
8669
|
-
struct lm_ggml_cgraph * build_s_copy() {
|
8670
|
-
struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
8671
|
-
|
8672
|
-
LM_GGML_ASSERT(kv_self.recurrent);
|
8673
|
-
|
8674
|
-
struct lm_ggml_tensor * state_copy = build_inp_s_copy();
|
8675
|
-
|
8676
|
-
for (int il = 0; il < n_layer; ++il) {
|
8677
|
-
struct lm_ggml_tensor * conv_states = lm_ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
8678
|
-
struct lm_ggml_tensor * ssm_states = lm_ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
8679
|
-
|
8680
|
-
conv_states = lm_ggml_get_rows(ctx0, conv_states, state_copy);
|
8681
|
-
ssm_states = lm_ggml_get_rows(ctx0, ssm_states, state_copy);
|
8682
|
-
|
8683
|
-
// TODO: name the intermediate tensors with cb()
|
8684
|
-
|
8685
|
-
lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
|
8686
|
-
lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
|
8687
|
-
}
|
8688
|
-
|
8689
|
-
return gf;
|
8690
|
-
}
|
8691
|
-
|
8692
9331
|
struct lm_ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
8693
9332
|
struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
8694
9333
|
|
@@ -8823,7 +9462,7 @@ struct llm_build_context {
|
|
8823
9462
|
}
|
8824
9463
|
|
8825
9464
|
struct lm_ggml_tensor * build_inp_s_copy() {
|
8826
|
-
lctx.inp_s_copy = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32,
|
9465
|
+
lctx.inp_s_copy = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_kv);
|
8827
9466
|
cb(lctx.inp_s_copy, "inp_s_copy", -1);
|
8828
9467
|
lm_ggml_set_input(lctx.inp_s_copy);
|
8829
9468
|
return lctx.inp_s_copy;
|
@@ -8836,13 +9475,6 @@ struct llm_build_context {
|
|
8836
9475
|
return lctx.inp_s_mask;
|
8837
9476
|
}
|
8838
9477
|
|
8839
|
-
struct lm_ggml_tensor * build_inp_s_seq() {
|
8840
|
-
lctx.inp_s_seq = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_I32, n_kv, n_tokens);
|
8841
|
-
cb(lctx.inp_s_seq, "inp_s_seq", -1);
|
8842
|
-
lm_ggml_set_input(lctx.inp_s_seq);
|
8843
|
-
return lctx.inp_s_seq;
|
8844
|
-
}
|
8845
|
-
|
8846
9478
|
struct lm_ggml_cgraph * append_pooling(struct lm_ggml_cgraph * gf) {
|
8847
9479
|
// find result_norm tensor for input
|
8848
9480
|
struct lm_ggml_tensor * inp = nullptr;
|
@@ -12172,136 +12804,31 @@ struct llm_build_context {
|
|
12172
12804
|
struct lm_ggml_cgraph * build_mamba() {
|
12173
12805
|
struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
12174
12806
|
|
12175
|
-
const int64_t d_model = n_embd;
|
12176
|
-
const int64_t d_conv = hparams.ssm_d_conv;
|
12177
|
-
const int64_t d_inner = hparams.ssm_d_inner;
|
12178
|
-
LM_GGML_ASSERT(2 * d_model == d_inner);
|
12179
|
-
const int64_t d_state = hparams.ssm_d_state;
|
12180
|
-
const int64_t dt_rank = hparams.ssm_dt_rank;
|
12181
|
-
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
12182
|
-
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
12183
|
-
// Use the same RMS norm as the final layer norm
|
12184
|
-
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
12185
|
-
|
12186
12807
|
struct lm_ggml_tensor * cur;
|
12187
12808
|
struct lm_ggml_tensor * inpL;
|
12188
12809
|
|
12189
12810
|
// {n_embd, n_tokens}
|
12190
12811
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
12191
12812
|
|
12813
|
+
struct lm_ggml_tensor * state_copy = build_inp_s_copy();
|
12192
12814
|
struct lm_ggml_tensor * state_mask = build_inp_s_mask();
|
12193
|
-
struct lm_ggml_tensor * state_seq = build_inp_s_seq();
|
12194
12815
|
|
12195
12816
|
for (int il = 0; il < n_layer; ++il) {
|
12196
|
-
// (ab)using the KV cache to store the states
|
12197
|
-
struct lm_ggml_tensor * conv_states = lm_ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
12198
|
-
struct lm_ggml_tensor * ssm_states = lm_ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
12199
|
-
|
12200
|
-
// clear states of sequences which are starting at the beginning of this batch
|
12201
|
-
{
|
12202
|
-
conv_states = lm_ggml_mul(ctx0,
|
12203
|
-
lm_ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
|
12204
|
-
state_mask);
|
12205
|
-
ssm_states = lm_ggml_mul(ctx0,
|
12206
|
-
lm_ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
|
12207
|
-
state_mask);
|
12208
|
-
}
|
12209
|
-
|
12210
|
-
conv_states = lm_ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
|
12211
|
-
ssm_states = lm_ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
|
12212
|
-
|
12213
12817
|
// norm
|
12214
12818
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
12215
12819
|
model.layers[il].attn_norm, NULL,
|
12216
12820
|
LLM_NORM_RMS, cb, il);
|
12217
12821
|
cb(cur, "attn_norm", il);
|
12218
12822
|
|
12219
|
-
|
12220
|
-
|
12221
|
-
|
12222
|
-
// => {d_inner, n_tokens}
|
12223
|
-
struct lm_ggml_tensor * x = lm_ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
|
12224
|
-
struct lm_ggml_tensor * z = lm_ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], lm_ggml_element_size(xz)*d_inner);
|
12225
|
-
|
12226
|
-
// conv
|
12227
|
-
{
|
12228
|
-
// Custom operator which is needed only to ease simultaneous sequence processing.
|
12229
|
-
// For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
|
12230
|
-
// then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
|
12231
|
-
// then element-wise multiply that with the conv1d weigth,
|
12232
|
-
// then sum the elements of each row,
|
12233
|
-
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
12234
|
-
// then permute away the ne[0] dimension,
|
12235
|
-
// and then you're left with the resulting x tensor.
|
12236
|
-
// The new conv_states is the last (d_conv - 1) columns
|
12237
|
-
// of the last 3rd dimensional "layer" of the self-overlapping view.
|
12238
|
-
// For simultaneous sequences, it's more complicated.
|
12239
|
-
struct lm_ggml_tensor * x_conv = lm_ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
|
12240
|
-
|
12241
|
-
// store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
|
12242
|
-
lm_ggml_build_forward_expand(gf,
|
12243
|
-
lm_ggml_cpy(ctx0,
|
12244
|
-
lm_ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*lm_ggml_element_size(x_conv), (1+d_inner*n_tokens)*lm_ggml_element_size(x_conv)),
|
12245
|
-
lm_ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*lm_ggml_element_size(x_conv))));
|
12246
|
-
|
12247
|
-
// extract x from x_conv
|
12248
|
-
x = lm_ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*lm_ggml_element_size(x_conv), 0);
|
12249
|
-
|
12250
|
-
// bias
|
12251
|
-
x = lm_ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
|
12252
|
-
|
12253
|
-
x = lm_ggml_silu(ctx0, x);
|
12254
|
-
}
|
12255
|
-
|
12256
|
-
// ssm
|
12257
|
-
{
|
12258
|
-
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
|
12259
|
-
struct lm_ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
|
12260
|
-
// split
|
12261
|
-
struct lm_ggml_tensor * dt = lm_ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
|
12262
|
-
struct lm_ggml_tensor * B = lm_ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], lm_ggml_element_size(x_db)*dt_rank);
|
12263
|
-
struct lm_ggml_tensor * C = lm_ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], lm_ggml_element_size(x_db)*(dt_rank+d_state));
|
12264
|
-
|
12265
|
-
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
12266
|
-
if (ssm_dt_b_c_rms) {
|
12267
|
-
dt = lm_ggml_rms_norm(ctx0, dt, norm_rms_eps);
|
12268
|
-
B = lm_ggml_rms_norm(ctx0, B, norm_rms_eps);
|
12269
|
-
C = lm_ggml_rms_norm(ctx0, C, norm_rms_eps);
|
12270
|
-
}
|
12271
|
-
|
12272
|
-
// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
|
12273
|
-
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
|
12274
|
-
dt = lm_ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
12275
|
-
|
12276
|
-
// Custom operator to optimize the parallel associative scan
|
12277
|
-
// as described in the Annex D of the Mamba paper.
|
12278
|
-
// => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
|
12279
|
-
// because only a single tensor can be returned.
|
12280
|
-
struct lm_ggml_tensor * y_ssm_states = lm_ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
|
12281
|
-
|
12282
|
-
// store last states (the second part of y_ssm_states)
|
12283
|
-
lm_ggml_build_forward_expand(gf,
|
12284
|
-
lm_ggml_cpy(ctx0,
|
12285
|
-
lm_ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*lm_ggml_element_size(y_ssm_states)),
|
12286
|
-
lm_ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*lm_ggml_element_size(ssm_states))));
|
12287
|
-
|
12288
|
-
struct lm_ggml_tensor * y = lm_ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*lm_ggml_element_size(y_ssm_states), 0);
|
12289
|
-
|
12290
|
-
if (il == n_layer - 1) {
|
12291
|
-
// skip computing output for unused tokens
|
12292
|
-
struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
|
12293
|
-
x = lm_ggml_get_rows(ctx0, x, inp_out_ids);
|
12294
|
-
y = lm_ggml_get_rows(ctx0, y, inp_out_ids);
|
12295
|
-
z = lm_ggml_get_rows(ctx0, z, inp_out_ids);
|
12296
|
-
inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
|
12297
|
-
}
|
12298
|
-
|
12299
|
-
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
|
12300
|
-
y = lm_ggml_add(ctx0, y, lm_ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
12301
|
-
y = lm_ggml_mul(ctx0, y, lm_ggml_silu(ctx0, z));
|
12823
|
+
cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
|
12824
|
+
state_copy, state_mask,
|
12825
|
+
kv_head, n_kv, cb, il);
|
12302
12826
|
|
12303
|
-
|
12304
|
-
|
12827
|
+
if (il == n_layer - 1) {
|
12828
|
+
// skip computing output for unused tokens
|
12829
|
+
struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids();
|
12830
|
+
cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids);
|
12831
|
+
inpL = lm_ggml_get_rows(ctx0, inpL, inp_out_ids);
|
12305
12832
|
}
|
12306
12833
|
|
12307
12834
|
// residual
|
@@ -14167,8 +14694,8 @@ struct llm_build_context {
|
|
14167
14694
|
};
|
14168
14695
|
|
14169
14696
|
static struct lm_ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
14170
|
-
|
14171
|
-
dummy.
|
14697
|
+
llama_ubatch dummy = {};
|
14698
|
+
dummy.equal_seqs = true;
|
14172
14699
|
|
14173
14700
|
llm_build_cb cb = [&](struct lm_ggml_tensor * , const char * , int ) { };
|
14174
14701
|
|
@@ -14184,8 +14711,8 @@ static struct lm_ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, co
|
|
14184
14711
|
}
|
14185
14712
|
|
14186
14713
|
static struct lm_ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
|
14187
|
-
|
14188
|
-
dummy.
|
14714
|
+
llama_ubatch dummy = {};
|
14715
|
+
dummy.equal_seqs = true;
|
14189
14716
|
|
14190
14717
|
llm_build_cb cb = [&](struct lm_ggml_tensor * , const char * , int ) { };
|
14191
14718
|
|
@@ -14200,26 +14727,9 @@ static struct lm_ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
|
|
14200
14727
|
return result;
|
14201
14728
|
}
|
14202
14729
|
|
14203
|
-
static struct lm_ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
|
14204
|
-
llama_batch dummy;
|
14205
|
-
dummy.n_tokens = 0;
|
14206
|
-
|
14207
|
-
llm_build_cb cb = [&](struct lm_ggml_tensor * , const char * , int ) { };
|
14208
|
-
|
14209
|
-
struct llm_build_context llm(lctx, dummy, cb, false);
|
14210
|
-
|
14211
|
-
llm.init();
|
14212
|
-
|
14213
|
-
struct lm_ggml_cgraph * result = llm.build_s_copy();
|
14214
|
-
|
14215
|
-
llm.free();
|
14216
|
-
|
14217
|
-
return result;
|
14218
|
-
}
|
14219
|
-
|
14220
14730
|
static struct lm_ggml_cgraph * llama_build_graph(
|
14221
14731
|
llama_context & lctx,
|
14222
|
-
|
14732
|
+
const llama_ubatch & batch,
|
14223
14733
|
bool worst_case) {
|
14224
14734
|
const auto & model = lctx.model;
|
14225
14735
|
|
@@ -14489,7 +14999,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
|
14489
14999
|
return relative_bucket;
|
14490
15000
|
}
|
14491
15001
|
|
14492
|
-
static void llama_set_inputs(llama_context & lctx, const
|
15002
|
+
static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
14493
15003
|
//
|
14494
15004
|
// set input data
|
14495
15005
|
//
|
@@ -14528,10 +15038,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14528
15038
|
for (int i = 0; i < n_tokens; ++i) {
|
14529
15039
|
data[i] = i;
|
14530
15040
|
}
|
14531
|
-
} else if (batch.
|
15041
|
+
} else if (batch.output) {
|
14532
15042
|
int32_t n_outputs = 0;
|
14533
15043
|
for (int i = 0; i < n_tokens; ++i) {
|
14534
|
-
if (batch.
|
15044
|
+
if (batch.output[i]) {
|
14535
15045
|
data[n_outputs++] = i;
|
14536
15046
|
}
|
14537
15047
|
}
|
@@ -14555,8 +15065,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14555
15065
|
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
|
14556
15066
|
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
14557
15067
|
if (cparams.causal_attn && !lctx.is_encoding) {
|
14558
|
-
const int64_t n_kv
|
14559
|
-
const int64_t n_tokens
|
15068
|
+
const int64_t n_kv = kv_self.n;
|
15069
|
+
const int64_t n_tokens = batch.n_tokens;
|
15070
|
+
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
15071
|
+
const int64_t n_seqs = batch.n_seqs;
|
14560
15072
|
|
14561
15073
|
|
14562
15074
|
float * data = nullptr;
|
@@ -14576,32 +15088,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14576
15088
|
// of the correct sequence for each token of the batch.
|
14577
15089
|
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
14578
15090
|
for (int h = 0; h < 1; ++h) {
|
14579
|
-
for (int
|
14580
|
-
const
|
14581
|
-
const llama_seq_id seq_id = batch.seq_id[j][0];
|
15091
|
+
for (int s = 0; s < n_seqs; ++s) {
|
15092
|
+
const llama_seq_id seq_id = batch.seq_id[s][0];
|
14582
15093
|
|
14583
|
-
for (int
|
14584
|
-
|
14585
|
-
|
14586
|
-
|
14587
|
-
|
14588
|
-
if (
|
14589
|
-
f = -
|
15094
|
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
15095
|
+
const llama_pos pos = batch.pos[s*n_seq_tokens + j];
|
15096
|
+
|
15097
|
+
for (int i = 0; i < n_kv; ++i) {
|
15098
|
+
float f;
|
15099
|
+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
15100
|
+
f = -INFINITY;
|
14590
15101
|
} else {
|
14591
|
-
|
15102
|
+
if (hparams.use_alibi) {
|
15103
|
+
f = -std::abs(kv_self.cells[i].pos - pos);
|
15104
|
+
} else {
|
15105
|
+
f = 0.0f;
|
15106
|
+
}
|
14592
15107
|
}
|
14593
|
-
}
|
14594
15108
|
|
14595
|
-
|
14596
|
-
|
14597
|
-
|
15109
|
+
if (data) {
|
15110
|
+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
15111
|
+
}
|
14598
15112
|
|
14599
|
-
|
14600
|
-
|
14601
|
-
|
14602
|
-
|
15113
|
+
// may need to cut off old tokens for sliding window
|
15114
|
+
if (data_swa) {
|
15115
|
+
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
15116
|
+
f = -INFINITY;
|
15117
|
+
}
|
15118
|
+
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
14603
15119
|
}
|
14604
|
-
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
14605
15120
|
}
|
14606
15121
|
}
|
14607
15122
|
}
|
@@ -14623,8 +15138,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14623
15138
|
}
|
14624
15139
|
}
|
14625
15140
|
} else {
|
15141
|
+
const int64_t n_tokens = batch.n_tokens;
|
15142
|
+
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
15143
|
+
const int64_t n_seqs = batch.n_seqs;
|
14626
15144
|
// when using kv cache, the mask needs to match the kv cache size
|
14627
|
-
const int64_t n_tokens = batch.n_tokens;
|
14628
15145
|
const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
|
14629
15146
|
|
14630
15147
|
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
@@ -14632,27 +15149,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14632
15149
|
float * data = (float *) lctx.inp_KQ_mask->data;
|
14633
15150
|
|
14634
15151
|
for (int h = 0; h < 1; ++h) {
|
14635
|
-
for (int
|
14636
|
-
const llama_seq_id seq_id = batch.seq_id[
|
14637
|
-
|
14638
|
-
for (int
|
14639
|
-
|
14640
|
-
|
14641
|
-
|
14642
|
-
|
14643
|
-
|
14644
|
-
|
14645
|
-
|
15152
|
+
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
15153
|
+
const llama_seq_id seq_id = batch.seq_id[s1][0];
|
15154
|
+
|
15155
|
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
15156
|
+
const int32_t tj = s1*n_seq_tokens + j;
|
15157
|
+
|
15158
|
+
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
15159
|
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
15160
|
+
const int32_t ti = s0*n_seq_tokens + i;
|
15161
|
+
float f = -INFINITY;
|
15162
|
+
|
15163
|
+
for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
|
15164
|
+
if (batch.seq_id[s0][s] == seq_id) {
|
15165
|
+
if (hparams.use_alibi) {
|
15166
|
+
f = -std::abs(batch.pos[ti] - batch.pos[tj]);
|
15167
|
+
} else {
|
15168
|
+
f = 0.0f;
|
15169
|
+
}
|
15170
|
+
break;
|
15171
|
+
}
|
14646
15172
|
}
|
14647
|
-
|
15173
|
+
|
15174
|
+
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
14648
15175
|
}
|
14649
15176
|
}
|
14650
15177
|
|
14651
|
-
|
14652
|
-
|
14653
|
-
|
14654
|
-
for (int i = n_tokens; i < n_stride; ++i) {
|
14655
|
-
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
|
15178
|
+
for (int i = n_tokens; i < n_stride; ++i) {
|
15179
|
+
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
15180
|
+
}
|
14656
15181
|
}
|
14657
15182
|
}
|
14658
15183
|
}
|
@@ -14660,7 +15185,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14660
15185
|
}
|
14661
15186
|
|
14662
15187
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
14663
|
-
const int64_t n_tokens
|
15188
|
+
const int64_t n_tokens = batch.n_tokens;
|
15189
|
+
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
15190
|
+
const int64_t n_seqs = batch.n_seqs;
|
14664
15191
|
|
14665
15192
|
LM_GGML_ASSERT(lctx.inp_mean);
|
14666
15193
|
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
@@ -14669,12 +15196,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14669
15196
|
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * lm_ggml_element_size(lctx.inp_mean));
|
14670
15197
|
|
14671
15198
|
std::vector<uint64_t> sum(n_tokens, 0);
|
14672
|
-
for (int i = 0; i < n_tokens; ++i) {
|
14673
|
-
const llama_seq_id seq_id = batch.seq_id[i][0];
|
14674
15199
|
|
15200
|
+
for (int s = 0; s < n_seqs; ++s) {
|
15201
|
+
const llama_seq_id seq_id = batch.seq_id[s][0];
|
15202
|
+
|
15203
|
+
// TODO: adapt limits to n_seqs when batch.equal_seqs is true
|
14675
15204
|
LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
14676
15205
|
|
14677
|
-
sum[seq_id] +=
|
15206
|
+
sum[seq_id] += batch.n_seq_tokens;
|
14678
15207
|
}
|
14679
15208
|
|
14680
15209
|
std::vector<float> div(n_tokens, 0.0f);
|
@@ -14685,14 +15214,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14685
15214
|
}
|
14686
15215
|
}
|
14687
15216
|
|
14688
|
-
for (int
|
14689
|
-
const llama_seq_id seq_id = batch.seq_id[
|
14690
|
-
|
15217
|
+
for (int s = 0; s < n_seqs; ++s) {
|
15218
|
+
const llama_seq_id seq_id = batch.seq_id[s][0];
|
15219
|
+
|
15220
|
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
15221
|
+
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
|
15222
|
+
}
|
14691
15223
|
}
|
14692
15224
|
}
|
14693
15225
|
|
14694
15226
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
14695
|
-
const int64_t n_tokens
|
15227
|
+
const int64_t n_tokens = batch.n_tokens;
|
15228
|
+
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
15229
|
+
const int64_t n_seqs = batch.n_seqs;
|
14696
15230
|
|
14697
15231
|
LM_GGML_ASSERT(lctx.inp_cls);
|
14698
15232
|
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
@@ -14700,20 +15234,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14700
15234
|
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
14701
15235
|
memset(lctx.inp_cls->data, 0, n_tokens * lm_ggml_element_size(lctx.inp_cls));
|
14702
15236
|
|
14703
|
-
for (int
|
14704
|
-
const llama_seq_id seq_id = batch.seq_id[
|
14705
|
-
const llama_pos pos = batch.pos[i];
|
15237
|
+
for (int s = 0; s < n_seqs; ++s) {
|
15238
|
+
const llama_seq_id seq_id = batch.seq_id[s][0];
|
14706
15239
|
|
15240
|
+
// TODO: adapt limits to n_seqs when batch.equal_seqs is true
|
14707
15241
|
LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
|
14708
15242
|
|
14709
|
-
|
14710
|
-
|
15243
|
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
15244
|
+
const llama_pos pos = batch.pos[s*n_seq_tokens + i];
|
15245
|
+
|
15246
|
+
if (pos == 0) {
|
15247
|
+
data[seq_id] = s*n_seq_tokens + i;
|
15248
|
+
}
|
14711
15249
|
}
|
14712
15250
|
}
|
14713
15251
|
}
|
14714
15252
|
|
14715
15253
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
14716
|
-
const int64_t n_tokens
|
15254
|
+
const int64_t n_tokens = batch.n_tokens;
|
15255
|
+
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
15256
|
+
const int64_t n_seqs = batch.n_seqs;
|
14717
15257
|
|
14718
15258
|
LM_GGML_ASSERT(lctx.inp_cls);
|
14719
15259
|
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
@@ -14724,15 +15264,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14724
15264
|
std::vector<int> last_pos(n_tokens, -1);
|
14725
15265
|
std::vector<int> last_row(n_tokens, -1);
|
14726
15266
|
|
14727
|
-
for (int
|
14728
|
-
const llama_seq_id seq_id = batch.seq_id[
|
14729
|
-
const llama_pos pos = batch.pos[i];
|
15267
|
+
for (int s = 0; s < n_seqs; ++s) {
|
15268
|
+
const llama_seq_id seq_id = batch.seq_id[s][0];
|
14730
15269
|
|
15270
|
+
// TODO: adapt limits to n_seqs when batch.equal_seqs is true
|
14731
15271
|
LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
|
14732
15272
|
|
14733
|
-
|
14734
|
-
|
14735
|
-
|
15273
|
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
15274
|
+
const llama_pos pos = batch.pos[s*n_seq_tokens + i];
|
15275
|
+
|
15276
|
+
if (pos >= last_pos[seq_id]) {
|
15277
|
+
last_pos[seq_id] = pos;
|
15278
|
+
last_row[seq_id] = s*n_seq_tokens + i;
|
15279
|
+
}
|
14736
15280
|
}
|
14737
15281
|
}
|
14738
15282
|
|
@@ -14750,41 +15294,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14750
15294
|
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
14751
15295
|
float * data = (float *) lctx.inp_s_mask->data;
|
14752
15296
|
|
14753
|
-
//
|
15297
|
+
// clear unused states
|
14754
15298
|
for (int i = 0; i < n_kv; ++i) {
|
14755
|
-
|
14756
|
-
llama_kv_cell & kv_cell
|
14757
|
-
bool has_self_seq = kv_cell.has_seq_id(seq_id);
|
15299
|
+
uint32_t cell_id = i + kv_self.head;
|
15300
|
+
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
14758
15301
|
|
14759
|
-
data[i] = (float)
|
15302
|
+
data[i] = (float) (kv_cell.src >= 0);
|
14760
15303
|
|
14761
|
-
//
|
14762
|
-
if (
|
14763
|
-
kv_cell.
|
15304
|
+
// only clear once
|
15305
|
+
if (kv_cell.src < 0) {
|
15306
|
+
kv_cell.src = cell_id;
|
14764
15307
|
}
|
14765
15308
|
}
|
14766
15309
|
}
|
14767
|
-
// For Mamba (and other recurrent architectures),
|
14768
|
-
// update the correct state(s)/sequence(s) for each token of the batch.
|
14769
|
-
// Like with the KQ_mask, if a token in the batch has multiple sequences,
|
14770
|
-
// they are assumed to be equivalent (not here, but in lm_ggml_ssm_scan and lm_ggml_ssm_conv).
|
14771
|
-
if (lctx.inp_s_seq) {
|
14772
|
-
const int64_t n_tokens = batch.n_tokens;
|
14773
15310
|
|
14774
|
-
|
14775
|
-
|
15311
|
+
if (lctx.inp_s_copy) {
|
15312
|
+
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
15313
|
+
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
14776
15314
|
|
14777
|
-
|
14778
|
-
|
14779
|
-
|
15315
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
15316
|
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
15317
|
+
const uint32_t cell_id = i + kv_self.head;
|
15318
|
+
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
14780
15319
|
|
14781
|
-
|
14782
|
-
|
14783
|
-
|
14784
|
-
|
14785
|
-
|
14786
|
-
|
14787
|
-
|
15320
|
+
// prevent out-of-bound sources
|
15321
|
+
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
|
15322
|
+
kv_cell.src = cell_id;
|
15323
|
+
}
|
15324
|
+
|
15325
|
+
data[i] = kv_cell.src;
|
15326
|
+
|
15327
|
+
// ensure copy only happens once
|
15328
|
+
if (kv_cell.src != (int32_t) cell_id) {
|
15329
|
+
kv_cell.src = cell_id;
|
14788
15330
|
}
|
14789
15331
|
}
|
14790
15332
|
}
|
@@ -14794,6 +15336,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14794
15336
|
const int64_t n_tokens = batch.n_tokens;
|
14795
15337
|
|
14796
15338
|
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
|
15339
|
+
LM_GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
|
14797
15340
|
|
14798
15341
|
int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
|
14799
15342
|
|
@@ -14829,6 +15372,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|
14829
15372
|
const int64_t n_tokens = batch.n_tokens;
|
14830
15373
|
|
14831
15374
|
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
|
15375
|
+
LM_GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
|
14832
15376
|
|
14833
15377
|
float * data = (float *) lctx.inp_KQ_mask_cross->data;
|
14834
15378
|
|
@@ -14922,6 +15466,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
|
14922
15466
|
return n_outputs_max;
|
14923
15467
|
}
|
14924
15468
|
|
15469
|
+
// make the outputs have the same order they had in the user-provided batch
|
15470
|
+
static void llama_output_reorder(struct llama_context * ctx) {
|
15471
|
+
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
|
15472
|
+
if (!out_ids.empty()) {
|
15473
|
+
uint32_t n_vocab = ctx->model.hparams.n_vocab;
|
15474
|
+
uint32_t n_embd = ctx->model.hparams.n_embd;
|
15475
|
+
int32_t n_outputs = ctx->n_outputs;
|
15476
|
+
LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
15477
|
+
// TODO: is there something more efficient which also minimizes swaps?
|
15478
|
+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
15479
|
+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
15480
|
+
int32_t j_min = i;
|
15481
|
+
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
15482
|
+
if (out_ids[j] < out_ids[j_min]) {
|
15483
|
+
j_min = j;
|
15484
|
+
}
|
15485
|
+
}
|
15486
|
+
if (j_min == i) { continue; }
|
15487
|
+
std::swap(out_ids[i], out_ids[j_min]);
|
15488
|
+
if (ctx->logits_size > 0) {
|
15489
|
+
for (uint32_t k = 0; k < n_vocab; k++) {
|
15490
|
+
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
|
15491
|
+
}
|
15492
|
+
}
|
15493
|
+
if (ctx->embd_size > 0) {
|
15494
|
+
for (uint32_t k = 0; k < n_embd; k++) {
|
15495
|
+
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
|
15496
|
+
}
|
15497
|
+
}
|
15498
|
+
}
|
15499
|
+
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
|
15500
|
+
for (int32_t i = 0; i < n_outputs; ++i) {
|
15501
|
+
ctx->output_ids[out_ids[i]] = i;
|
15502
|
+
}
|
15503
|
+
out_ids.clear();
|
15504
|
+
}
|
15505
|
+
}
|
14925
15506
|
|
14926
15507
|
static void llama_graph_compute(
|
14927
15508
|
llama_context & lctx,
|
@@ -14994,15 +15575,11 @@ static int llama_decode_internal(
|
|
14994
15575
|
|
14995
15576
|
const auto n_ubatch = cparams.n_ubatch;
|
14996
15577
|
|
14997
|
-
// TODO: simplify or deprecate
|
14998
|
-
std::vector<llama_pos> pos;
|
14999
|
-
std::vector<int32_t> n_seq_id;
|
15000
|
-
std::vector<llama_seq_id *> seq_id_arr;
|
15001
|
-
std::vector<std::vector<llama_seq_id>> seq_id;
|
15002
|
-
|
15003
15578
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
15004
15579
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
15005
15580
|
|
15581
|
+
lctx.embd_seq.clear();
|
15582
|
+
|
15006
15583
|
// count outputs
|
15007
15584
|
if (batch_all.logits && !embd_pooled) {
|
15008
15585
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
@@ -15015,55 +15592,42 @@ static int llama_decode_internal(
|
|
15015
15592
|
n_outputs = 1;
|
15016
15593
|
}
|
15017
15594
|
|
15595
|
+
lctx.sbatch.from_batch(batch_all, n_embd,
|
15596
|
+
/* simple_split */ !kv_self.recurrent,
|
15597
|
+
/* logits_all */ n_outputs == n_tokens_all);
|
15598
|
+
|
15018
15599
|
// reserve output buffer
|
15019
15600
|
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
15020
15601
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
15021
15602
|
return -2;
|
15022
15603
|
};
|
15023
15604
|
|
15024
|
-
|
15025
|
-
|
15026
|
-
|
15027
|
-
|
15028
|
-
|
15029
|
-
|
15605
|
+
while (lctx.sbatch.n_tokens > 0) {
|
15606
|
+
llama_ubatch ubatch;
|
15607
|
+
if (kv_self.recurrent) {
|
15608
|
+
if (embd_pooled) {
|
15609
|
+
// Pooled embeddings cannot be split across ubatches (yet)
|
15610
|
+
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
15611
|
+
} else {
|
15612
|
+
// recurrent model architectures are easier to implement
|
15613
|
+
// with equal-length sequences
|
15614
|
+
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
15030
15615
|
}
|
15616
|
+
} else {
|
15617
|
+
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
15031
15618
|
}
|
15032
|
-
|
15033
|
-
for (uint32_t i = 0; i < n_outputs; ++i) {
|
15034
|
-
lctx.output_ids[i] = i;
|
15035
|
-
}
|
15036
|
-
}
|
15037
|
-
|
15038
|
-
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
15039
|
-
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
15040
|
-
llama_batch u_batch = {
|
15041
|
-
/* .n_tokens = */ (int32_t) n_tokens,
|
15042
|
-
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
|
15043
|
-
/* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
|
15044
|
-
/* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
|
15045
|
-
/* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
|
15046
|
-
/* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
|
15047
|
-
/* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
|
15048
|
-
/* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
|
15049
|
-
/* .all_pos_1 = */ batch_all.all_pos_1,
|
15050
|
-
/* .all_seq_id = */ batch_all.all_seq_id,
|
15051
|
-
};
|
15619
|
+
const uint32_t n_tokens = ubatch.n_tokens;
|
15052
15620
|
|
15053
15621
|
// count the outputs in this u_batch
|
15054
15622
|
{
|
15055
15623
|
int32_t n_outputs_new = 0;
|
15056
15624
|
|
15057
|
-
if (
|
15058
|
-
for (uint32_t i = 0; i < n_tokens; i++) {
|
15059
|
-
n_outputs_new += u_batch.logits[i] != 0;
|
15060
|
-
}
|
15061
|
-
} else if (n_outputs == n_tokens_all) {
|
15625
|
+
if (n_outputs == n_tokens_all) {
|
15062
15626
|
n_outputs_new = n_tokens;
|
15063
15627
|
} else {
|
15064
|
-
|
15065
|
-
|
15066
|
-
n_outputs_new
|
15628
|
+
LM_GGML_ASSERT(ubatch.output);
|
15629
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
15630
|
+
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
15067
15631
|
}
|
15068
15632
|
}
|
15069
15633
|
|
@@ -15074,32 +15638,6 @@ static int llama_decode_internal(
|
|
15074
15638
|
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
15075
15639
|
LM_GGML_ASSERT(n_threads > 0);
|
15076
15640
|
|
15077
|
-
// helpers for smoother batch API transition
|
15078
|
-
// after deprecating the llama_eval calls, these will be removed
|
15079
|
-
if (u_batch.pos == nullptr) {
|
15080
|
-
pos.resize(n_tokens);
|
15081
|
-
for (uint32_t i = 0; i < n_tokens; i++) {
|
15082
|
-
pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
|
15083
|
-
}
|
15084
|
-
|
15085
|
-
u_batch.pos = pos.data();
|
15086
|
-
}
|
15087
|
-
|
15088
|
-
if (u_batch.seq_id == nullptr) {
|
15089
|
-
n_seq_id.resize(n_tokens);
|
15090
|
-
seq_id.resize(n_tokens);
|
15091
|
-
seq_id_arr.resize(n_tokens);
|
15092
|
-
for (uint32_t i = 0; i < n_tokens; i++) {
|
15093
|
-
n_seq_id[i] = 1;
|
15094
|
-
seq_id[i].resize(1);
|
15095
|
-
seq_id[i][0] = u_batch.all_seq_id;
|
15096
|
-
seq_id_arr[i] = seq_id[i].data();
|
15097
|
-
}
|
15098
|
-
|
15099
|
-
u_batch.n_seq_id = n_seq_id.data();
|
15100
|
-
u_batch.seq_id = seq_id_arr.data();
|
15101
|
-
}
|
15102
|
-
|
15103
15641
|
// non-causal masks do not use the KV cache
|
15104
15642
|
if (hparams.causal_attn) {
|
15105
15643
|
llama_kv_cache_update(&lctx);
|
@@ -15110,7 +15648,7 @@ static int llama_decode_internal(
|
|
15110
15648
|
kv_self.head = 0;
|
15111
15649
|
}
|
15112
15650
|
|
15113
|
-
if (!llama_kv_cache_find_slot(kv_self,
|
15651
|
+
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
|
15114
15652
|
return 1;
|
15115
15653
|
}
|
15116
15654
|
|
@@ -15129,7 +15667,7 @@ static int llama_decode_internal(
|
|
15129
15667
|
lm_ggml_backend_sched_reset(lctx.sched);
|
15130
15668
|
lm_ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
15131
15669
|
|
15132
|
-
lm_ggml_cgraph * gf = llama_build_graph(lctx,
|
15670
|
+
lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
|
15133
15671
|
|
15134
15672
|
// the output is always the last tensor in the graph
|
15135
15673
|
struct lm_ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
@@ -15157,7 +15695,7 @@ static int llama_decode_internal(
|
|
15157
15695
|
|
15158
15696
|
lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
15159
15697
|
|
15160
|
-
llama_set_inputs(lctx,
|
15698
|
+
llama_set_inputs(lctx, ubatch);
|
15161
15699
|
|
15162
15700
|
llama_graph_compute(lctx, gf, n_threads);
|
15163
15701
|
|
@@ -15215,12 +15753,11 @@ static int llama_decode_internal(
|
|
15215
15753
|
case LLAMA_POOLING_TYPE_CLS:
|
15216
15754
|
case LLAMA_POOLING_TYPE_LAST:
|
15217
15755
|
{
|
15218
|
-
// extract sequence embeddings
|
15756
|
+
// extract sequence embeddings (cleared before processing each batch)
|
15219
15757
|
auto & embd_seq_out = lctx.embd_seq;
|
15220
|
-
embd_seq_out.clear();
|
15221
15758
|
|
15222
|
-
for (uint32_t
|
15223
|
-
const llama_seq_id seq_id =
|
15759
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
15760
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
15224
15761
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
15225
15762
|
continue;
|
15226
15763
|
}
|
@@ -15237,6 +15774,25 @@ static int llama_decode_internal(
|
|
15237
15774
|
n_outputs_prev += lctx.n_outputs;
|
15238
15775
|
}
|
15239
15776
|
|
15777
|
+
// set output mappings
|
15778
|
+
{
|
15779
|
+
bool sorted_output = true;
|
15780
|
+
|
15781
|
+
LM_GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
|
15782
|
+
|
15783
|
+
for (size_t i = 0; i < n_outputs; ++i) {
|
15784
|
+
size_t out_id = lctx.sbatch.out_ids[i];
|
15785
|
+
lctx.output_ids[out_id] = i;
|
15786
|
+
if (out_id != i) {
|
15787
|
+
sorted_output = false;
|
15788
|
+
}
|
15789
|
+
}
|
15790
|
+
|
15791
|
+
if (sorted_output) {
|
15792
|
+
lctx.sbatch.out_ids.clear();
|
15793
|
+
}
|
15794
|
+
}
|
15795
|
+
|
15240
15796
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
15241
15797
|
lctx.n_outputs = n_outputs;
|
15242
15798
|
|
@@ -15301,11 +15857,9 @@ static int llama_encode_internal(
|
|
15301
15857
|
|
15302
15858
|
const int64_t n_embd = hparams.n_embd;
|
15303
15859
|
|
15304
|
-
|
15305
|
-
|
15306
|
-
|
15307
|
-
std::vector<llama_seq_id *> seq_id_arr;
|
15308
|
-
std::vector<std::vector<llama_seq_id>> seq_id;
|
15860
|
+
lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
15861
|
+
|
15862
|
+
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
|
15309
15863
|
|
15310
15864
|
// reserve output buffer
|
15311
15865
|
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
|
@@ -15323,36 +15877,10 @@ static int llama_encode_internal(
|
|
15323
15877
|
const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
15324
15878
|
LM_GGML_ASSERT(n_threads > 0);
|
15325
15879
|
|
15326
|
-
// helpers for smoother batch API transition
|
15327
|
-
// after deprecating the llama_eval calls, these will be removed
|
15328
|
-
if (batch.pos == nullptr) {
|
15329
|
-
pos.resize(n_tokens);
|
15330
|
-
for (uint32_t i = 0; i < n_tokens; i++) {
|
15331
|
-
pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
|
15332
|
-
}
|
15333
|
-
|
15334
|
-
batch.pos = pos.data();
|
15335
|
-
}
|
15336
|
-
|
15337
|
-
if (batch.seq_id == nullptr) {
|
15338
|
-
n_seq_id.resize(n_tokens);
|
15339
|
-
seq_id.resize(n_tokens);
|
15340
|
-
seq_id_arr.resize(n_tokens);
|
15341
|
-
for (uint32_t i = 0; i < n_tokens; i++) {
|
15342
|
-
n_seq_id[i] = 1;
|
15343
|
-
seq_id[i].resize(1);
|
15344
|
-
seq_id[i][0] = batch.all_seq_id;
|
15345
|
-
seq_id_arr[i] = seq_id[i].data();
|
15346
|
-
}
|
15347
|
-
|
15348
|
-
batch.n_seq_id = n_seq_id.data();
|
15349
|
-
batch.seq_id = seq_id_arr.data();
|
15350
|
-
}
|
15351
|
-
|
15352
15880
|
lm_ggml_backend_sched_reset(lctx.sched);
|
15353
15881
|
lm_ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
15354
15882
|
|
15355
|
-
lm_ggml_cgraph * gf = llama_build_graph(lctx,
|
15883
|
+
lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
|
15356
15884
|
|
15357
15885
|
// the output embeddings after the final encoder normalization
|
15358
15886
|
struct lm_ggml_tensor * embd = nullptr;
|
@@ -15376,7 +15904,7 @@ static int llama_encode_internal(
|
|
15376
15904
|
|
15377
15905
|
lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
15378
15906
|
|
15379
|
-
llama_set_inputs(lctx,
|
15907
|
+
llama_set_inputs(lctx, ubatch);
|
15380
15908
|
|
15381
15909
|
llama_graph_compute(lctx, gf, n_threads);
|
15382
15910
|
|
@@ -15390,12 +15918,13 @@ static int llama_encode_internal(
|
|
15390
15918
|
float * embd_out = lctx.embd_enc.data();
|
15391
15919
|
|
15392
15920
|
lm_ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
15921
|
+
LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
15393
15922
|
|
15394
15923
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
15395
15924
|
lctx.seq_ids_enc.resize(n_tokens);
|
15396
15925
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
15397
|
-
for (int s = 0; s <
|
15398
|
-
llama_seq_id seq_id =
|
15926
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
15927
|
+
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
15399
15928
|
lctx.seq_ids_enc[i].insert(seq_id);
|
15400
15929
|
}
|
15401
15930
|
}
|
@@ -15420,8 +15949,10 @@ static int llama_encode_internal(
|
|
15420
15949
|
auto & embd_seq_out = lctx.embd_seq;
|
15421
15950
|
embd_seq_out.clear();
|
15422
15951
|
|
15952
|
+
LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
15953
|
+
|
15423
15954
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
15424
|
-
const llama_seq_id seq_id =
|
15955
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
15425
15956
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
15426
15957
|
continue;
|
15427
15958
|
}
|
@@ -15699,32 +16230,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|
15699
16230
|
}
|
15700
16231
|
}
|
15701
16232
|
|
15702
|
-
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
|
15703
|
-
{
|
15704
|
-
lm_ggml_backend_sched_reset(lctx.sched);
|
15705
|
-
|
15706
|
-
lm_ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
|
15707
|
-
|
15708
|
-
lm_ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
15709
|
-
|
15710
|
-
llama_set_s_copy(lctx);
|
15711
|
-
|
15712
|
-
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
|
15713
|
-
|
15714
|
-
need_reserve = true;
|
15715
|
-
}
|
15716
|
-
|
15717
|
-
{
|
15718
|
-
auto & kv_self = lctx.kv_self;
|
15719
|
-
|
15720
|
-
kv_self.do_copy = false;
|
15721
|
-
|
15722
|
-
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
15723
|
-
kv_self.cells[i].src = i;
|
15724
|
-
}
|
15725
|
-
}
|
15726
|
-
}
|
15727
|
-
|
15728
16233
|
// defragment the KV cache if needed
|
15729
16234
|
if (lctx.kv_self.do_defrag) {
|
15730
16235
|
llama_kv_cache_defrag_internal(lctx);
|
@@ -15738,10 +16243,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|
15738
16243
|
if (need_reserve) {
|
15739
16244
|
// TODO: extract to a function
|
15740
16245
|
// build worst-case graph
|
15741
|
-
|
15742
|
-
|
16246
|
+
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
16247
|
+
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
|
15743
16248
|
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
15744
|
-
|
16249
|
+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
16250
|
+
lm_ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
15745
16251
|
|
15746
16252
|
// initialize scheduler with the worst-case graph
|
15747
16253
|
lm_ggml_backend_sched_reset(lctx.sched);
|
@@ -16337,12 +16843,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|
16337
16843
|
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
16338
16844
|
|
16339
16845
|
// sanity checks
|
16340
|
-
|
16341
|
-
|
16342
|
-
|
16343
|
-
|
16344
|
-
|
16345
|
-
|
16846
|
+
{
|
16847
|
+
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
16848
|
+
// attention layers have a non-zero number of kv heads
|
16849
|
+
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
16850
|
+
if (llama_model_has_encoder(&model)) {
|
16851
|
+
n_attn_layer *= 3;
|
16852
|
+
}
|
16853
|
+
LM_GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
|
16854
|
+
}
|
16346
16855
|
|
16347
16856
|
size_t total_size_org = 0;
|
16348
16857
|
size_t total_size_new = 0;
|
@@ -17037,12 +17546,6 @@ struct llama_context * llama_new_context_with_model(
|
|
17037
17546
|
params.flash_attn = false;
|
17038
17547
|
}
|
17039
17548
|
|
17040
|
-
if (params.flash_attn && model->hparams.attn_soft_cap) {
|
17041
|
-
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
|
17042
|
-
params.flash_attn = false;
|
17043
|
-
}
|
17044
|
-
|
17045
|
-
|
17046
17549
|
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
17047
17550
|
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
17048
17551
|
params.flash_attn = false;
|
@@ -17151,7 +17654,7 @@ struct llama_context * llama_new_context_with_model(
|
|
17151
17654
|
lm_ggml_type type_v = params.type_v;
|
17152
17655
|
|
17153
17656
|
// Mamba only needs a constant number of KV cache cells per sequence
|
17154
|
-
if (model
|
17657
|
+
if (llama_model_is_recurrent(model)) {
|
17155
17658
|
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
17156
17659
|
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
17157
17660
|
// it's probably best to keep as much precision as possible for the states
|
@@ -17383,10 +17886,11 @@ struct llama_context * llama_new_context_with_model(
|
|
17383
17886
|
}
|
17384
17887
|
|
17385
17888
|
// build worst-case graph
|
17386
|
-
|
17387
|
-
|
17889
|
+
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
17890
|
+
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
17388
17891
|
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
17389
|
-
|
17892
|
+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
17893
|
+
lm_ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
|
17390
17894
|
|
17391
17895
|
// initialize scheduler with the worst-case graph
|
17392
17896
|
if (!lm_ggml_backend_sched_reserve(ctx->sched, gf)) {
|
@@ -17626,6 +18130,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
|
|
17626
18130
|
return model->hparams.dec_start_token_id;
|
17627
18131
|
}
|
17628
18132
|
|
18133
|
+
bool llama_model_is_recurrent(const struct llama_model * model) {
|
18134
|
+
switch (model->arch) {
|
18135
|
+
case LLM_ARCH_MAMBA: return true;
|
18136
|
+
default: return false;
|
18137
|
+
}
|
18138
|
+
}
|
18139
|
+
|
17629
18140
|
uint32_t llama_model_quantize(
|
17630
18141
|
const char * fname_inp,
|
17631
18142
|
const char * fname_out,
|
@@ -17947,7 +18458,9 @@ struct llama_data_write {
|
|
17947
18458
|
write_string(rng_str);
|
17948
18459
|
}
|
17949
18460
|
|
17950
|
-
void write_output_ids(
|
18461
|
+
void write_output_ids(struct llama_context * ctx) {
|
18462
|
+
llama_output_reorder(ctx);
|
18463
|
+
|
17951
18464
|
const uint32_t n_outputs = ctx->n_outputs;
|
17952
18465
|
|
17953
18466
|
std::vector<int32_t> output_pos;
|
@@ -18235,8 +18748,11 @@ struct llama_data_read {
|
|
18235
18748
|
|
18236
18749
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
18237
18750
|
|
18238
|
-
|
18751
|
+
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
18239
18752
|
batch.n_tokens = cell_count;
|
18753
|
+
batch.n_seq_tokens = cell_count;
|
18754
|
+
batch.n_seqs = 1;
|
18755
|
+
|
18240
18756
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
18241
18757
|
llama_pos pos;
|
18242
18758
|
uint32_t n_seq_id;
|
@@ -18250,11 +18766,10 @@ struct llama_data_read {
|
|
18250
18766
|
}
|
18251
18767
|
|
18252
18768
|
batch.pos[i] = pos;
|
18253
|
-
batch.n_seq_id[i] = 1;
|
18254
|
-
batch.seq_id[i][0] = dest_seq_id;
|
18255
18769
|
}
|
18770
|
+
batch.n_seq_id[0] = 1;
|
18771
|
+
batch.seq_id[0] = &dest_seq_id;
|
18256
18772
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
18257
|
-
llama_batch_free(batch);
|
18258
18773
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
18259
18774
|
return false;
|
18260
18775
|
}
|
@@ -18266,9 +18781,6 @@ struct llama_data_read {
|
|
18266
18781
|
LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
18267
18782
|
LM_GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
18268
18783
|
LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
18269
|
-
|
18270
|
-
// Cleanup
|
18271
|
-
llama_batch_free(batch);
|
18272
18784
|
} else {
|
18273
18785
|
// whole KV cache restore
|
18274
18786
|
|
@@ -18300,6 +18812,15 @@ struct llama_data_read {
|
|
18300
18812
|
}
|
18301
18813
|
|
18302
18814
|
cell.seq_id.insert(seq_id);
|
18815
|
+
|
18816
|
+
if (kv_self.recurrent) {
|
18817
|
+
int32_t & tail = kv_self.cells[seq_id].tail;
|
18818
|
+
if (tail != -1) {
|
18819
|
+
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
|
18820
|
+
return false;
|
18821
|
+
}
|
18822
|
+
tail = i;
|
18823
|
+
}
|
18303
18824
|
}
|
18304
18825
|
}
|
18305
18826
|
|
@@ -18307,6 +18828,14 @@ struct llama_data_read {
|
|
18307
18828
|
kv_self.used = cell_count;
|
18308
18829
|
}
|
18309
18830
|
|
18831
|
+
if (kv_self.recurrent) {
|
18832
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
18833
|
+
uint32_t cell_id = kv_self.head + i;
|
18834
|
+
// make sure the recurrent states will keep their restored state
|
18835
|
+
kv_self.cells[cell_id].src = cell_id;
|
18836
|
+
}
|
18837
|
+
}
|
18838
|
+
|
18310
18839
|
return true;
|
18311
18840
|
}
|
18312
18841
|
|
@@ -18894,7 +19423,18 @@ struct llama_batch llama_batch_get_one(
|
|
18894
19423
|
}
|
18895
19424
|
|
18896
19425
|
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
18897
|
-
llama_batch batch = {
|
19426
|
+
llama_batch batch = {
|
19427
|
+
/*n_tokens =*/ 0,
|
19428
|
+
/*tokens =*/ nullptr,
|
19429
|
+
/*embd =*/ nullptr,
|
19430
|
+
/*pos =*/ nullptr,
|
19431
|
+
/*n_seq_id =*/ nullptr,
|
19432
|
+
/*seq_id =*/ nullptr,
|
19433
|
+
/*logits =*/ nullptr,
|
19434
|
+
/*all_pos_0 =*/ 0,
|
19435
|
+
/*all_pos_1 =*/ 0,
|
19436
|
+
/*all_seq_id =*/ 0,
|
19437
|
+
};
|
18898
19438
|
|
18899
19439
|
if (embd) {
|
18900
19440
|
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
@@ -18980,6 +19520,10 @@ void llama_synchronize(struct llama_context * ctx) {
|
|
18980
19520
|
float * llama_get_logits(struct llama_context * ctx) {
|
18981
19521
|
llama_synchronize(ctx);
|
18982
19522
|
|
19523
|
+
// reorder logits for backward compatibility
|
19524
|
+
// TODO: maybe deprecate this
|
19525
|
+
llama_output_reorder(ctx);
|
19526
|
+
|
18983
19527
|
return ctx->logits;
|
18984
19528
|
}
|
18985
19529
|
|
@@ -19024,6 +19568,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
|
19024
19568
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
19025
19569
|
llama_synchronize(ctx);
|
19026
19570
|
|
19571
|
+
// reorder embeddings for backward compatibility
|
19572
|
+
// TODO: maybe deprecate this
|
19573
|
+
llama_output_reorder(ctx);
|
19574
|
+
|
19027
19575
|
return ctx->embd;
|
19028
19576
|
}
|
19029
19577
|
|