@fugood/llama.node 1.1.4 → 1.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/lib/binding.ts +8 -0
  2. package/package.json +14 -14
  3. package/scripts/llama.cpp.patch +17 -13
  4. package/src/LlamaCompletionWorker.cpp +2 -0
  5. package/src/LlamaContext.cpp +3 -0
  6. package/src/llama.cpp/common/arg.cpp +80 -10
  7. package/src/llama.cpp/common/chat.cpp +52 -8
  8. package/src/llama.cpp/common/chat.h +7 -2
  9. package/src/llama.cpp/common/common.cpp +1 -0
  10. package/src/llama.cpp/common/common.h +16 -6
  11. package/src/llama.cpp/common/speculative.cpp +135 -54
  12. package/src/llama.cpp/common/speculative.h +8 -1
  13. package/src/llama.cpp/ggml/CMakeLists.txt +4 -2
  14. package/src/llama.cpp/ggml/include/ggml.h +37 -1
  15. package/src/llama.cpp/ggml/src/CMakeLists.txt +12 -1
  16. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +61 -0
  17. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +96 -8
  18. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +3196 -0
  19. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +20 -0
  20. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14 -1
  21. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +207 -9
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -7
  23. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  24. package/src/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  25. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +263 -0
  26. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +11 -0
  27. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +19 -4
  28. package/src/llama.cpp/include/llama.h +9 -4
  29. package/src/llama.cpp/src/llama-arch.cpp +105 -0
  30. package/src/llama.cpp/src/llama-arch.h +12 -0
  31. package/src/llama.cpp/src/llama-batch.cpp +1 -1
  32. package/src/llama.cpp/src/llama-chat.cpp +33 -1
  33. package/src/llama.cpp/src/llama-chat.h +2 -0
  34. package/src/llama.cpp/src/llama-context.cpp +19 -10
  35. package/src/llama.cpp/src/llama-context.h +4 -1
  36. package/src/llama.cpp/src/llama-graph.cpp +175 -148
  37. package/src/llama.cpp/src/llama-graph.h +60 -23
  38. package/src/llama.cpp/src/llama-hparams.h +5 -3
  39. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +6 -2
  40. package/src/llama.cpp/src/llama-kv-cache-unified.h +1 -1
  41. package/src/llama.cpp/src/llama-memory-hybrid.cpp +2 -1
  42. package/src/llama.cpp/src/llama-memory-hybrid.h +1 -0
  43. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  44. package/src/llama.cpp/src/llama-model-loader.h +3 -2
  45. package/src/llama.cpp/src/llama-model.cpp +949 -75
  46. package/src/llama.cpp/src/llama-model.h +24 -4
  47. package/src/llama.cpp/src/llama-quant.cpp +40 -4
  48. package/src/llama.cpp/src/llama-vocab.cpp +49 -1
  49. package/src/llama.cpp/src/llama-vocab.h +1 -0
@@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
188
188
 
189
189
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
190
190
  const int64_t n_tokens = ubatch->n_tokens;
191
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
192
191
  const int64_t n_seqs_unq = ubatch->n_seqs_unq;
193
192
 
194
193
  if (cparams.embeddings && (
195
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
196
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
197
- )) {
194
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197
+ )) {
198
198
  GGML_ASSERT(cls);
199
199
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
200
200
 
201
201
  uint32_t * data = (uint32_t *) cls->data;
202
202
  memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
203
203
 
204
- for (int i = 0; i < n_tokens; i += n_seq_tokens) {
205
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
206
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
207
- const int32_t seq_idx = ubatch->seq_idx[seq_id];
208
-
209
- data[seq_idx] = i;
210
- }
211
- }
212
- }
204
+ std::vector<int> target_pos(n_seqs_unq, -1);
205
+ std::vector<int> target_row(n_seqs_unq, -1);
213
206
 
214
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
215
- GGML_ASSERT(cls);
216
- GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
217
-
218
- uint32_t * data = (uint32_t *) cls->data;
219
- memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
220
-
221
- std::vector<int> last_pos(n_seqs_unq, -1);
222
- std::vector<int> last_row(n_seqs_unq, -1);
207
+ bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
223
208
 
224
209
  for (int i = 0; i < n_tokens; ++i) {
225
210
  const llama_pos pos = ubatch->pos[i];
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
228
213
  const llama_seq_id seq_id = ubatch->seq_id[i][s];
229
214
  const int32_t seq_idx = ubatch->seq_idx[seq_id];
230
215
 
231
- if (pos >= last_pos[seq_idx]) {
232
- last_pos[seq_idx] = pos;
233
- last_row[seq_idx] = i;
216
+ if (
217
+ (target_pos[seq_idx] == -1) ||
218
+ ( last && pos >= target_pos[seq_idx]) ||
219
+ (!last && pos < target_pos[seq_idx])
220
+ ) {
221
+ target_pos[seq_idx] = pos;
222
+ target_row[seq_idx] = i;
234
223
  }
235
224
  }
236
225
  }
237
226
 
238
227
  for (int s = 0; s < n_seqs_unq; ++s) {
239
- if (last_row[s] >= 0) {
240
- data[s] = last_row[s];
228
+ if (target_row[s] >= 0) {
229
+ data[s] = target_row[s];
241
230
  }
242
231
  }
243
232
  }
@@ -751,6 +740,8 @@ ggml_tensor * llm_graph_context::build_ffn(
751
740
  cur = ggml_reglu(ctx0, cur);
752
741
  cb(cur, "ffn_reglu", il);
753
742
  } break;
743
+ default:
744
+ GGML_ABORT("fatal error");
754
745
  }
755
746
 
756
747
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -760,8 +751,8 @@ ggml_tensor * llm_graph_context::build_ffn(
760
751
 
761
752
  if (down) {
762
753
  cur = build_lora_mm(down, cur);
763
- if (arch == LLM_ARCH_GLM4) {
764
- // GLM4 seems to have numerical issues with half-precision accumulators
754
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
755
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
765
756
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
766
757
  }
767
758
  }
@@ -796,13 +787,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
796
787
  bool scale_w,
797
788
  float w_scale,
798
789
  llama_expert_gating_func_type gating_op,
799
- int il) const {
790
+ int il,
791
+ ggml_tensor * probs_in) const {
792
+ return build_moe_ffn(
793
+ cur,
794
+ gate_inp, /* gate_inp_b */ nullptr,
795
+ up_exps, /* up_exps_b */ nullptr,
796
+ gate_exps, /* gate_exps_b */ nullptr,
797
+ down_exps, /* down_exps_b */ nullptr,
798
+ exp_probs_b,
799
+ n_expert,
800
+ n_expert_used,
801
+ type_op,
802
+ norm_w,
803
+ scale_w,
804
+ w_scale,
805
+ gating_op,
806
+ il,
807
+ probs_in
808
+ );
809
+ }
810
+
811
+ ggml_tensor * llm_graph_context::build_moe_ffn(
812
+ ggml_tensor * cur,
813
+ ggml_tensor * gate_inp,
814
+ ggml_tensor * gate_inp_b,
815
+ ggml_tensor * up_exps,
816
+ ggml_tensor * up_exps_b,
817
+ ggml_tensor * gate_exps,
818
+ ggml_tensor * gate_exps_b,
819
+ ggml_tensor * down_exps,
820
+ ggml_tensor * down_exps_b,
821
+ ggml_tensor * exp_probs_b,
822
+ int64_t n_expert,
823
+ int64_t n_expert_used,
824
+ llm_ffn_op_type type_op,
825
+ bool norm_w,
826
+ bool scale_w,
827
+ float w_scale,
828
+ llama_expert_gating_func_type gating_op,
829
+ int il,
830
+ ggml_tensor * probs_in) const {
800
831
  const int64_t n_embd = cur->ne[0];
801
832
  const int64_t n_tokens = cur->ne[1];
802
833
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
803
834
 
804
- ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
805
- cb(logits, "ffn_moe_logits", il);
835
+ ggml_tensor * logits = nullptr;
836
+
837
+ if (probs_in == nullptr) {
838
+ logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
839
+ cb(logits, "ffn_moe_logits", il);
840
+ } else {
841
+ logits = probs_in;
842
+ }
843
+
844
+ if (gate_inp_b) {
845
+ logits = ggml_add(ctx0, logits, gate_inp_b);
846
+ cb(logits, "ffn_moe_logits_biased", il);
847
+ }
806
848
 
807
849
  ggml_tensor * probs = nullptr;
808
850
  switch (gating_op) {
@@ -814,6 +856,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
814
856
  {
815
857
  probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
816
858
  } break;
859
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
860
+ {
861
+ probs = logits; // [n_expert, n_tokens]
862
+ } break;
817
863
  default:
818
864
  GGML_ABORT("fatal error");
819
865
  }
@@ -842,6 +888,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
842
888
  ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
843
889
  cb(weights, "ffn_moe_weights", il);
844
890
 
891
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
892
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
893
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
894
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
895
+ cb(weights, "ffn_moe_weights_softmax", il);
896
+ }
897
+
845
898
  if (norm_w) {
846
899
  weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
847
900
 
@@ -870,6 +923,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
870
923
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
871
924
  cb(up, "ffn_moe_up", il);
872
925
 
926
+ if (up_exps_b) {
927
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
928
+ cb(up, "ffn_moe_up_biased", il);
929
+ }
930
+
873
931
  ggml_tensor * experts = nullptr;
874
932
  if (gate_exps) {
875
933
  cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
@@ -878,6 +936,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
878
936
  cur = up;
879
937
  }
880
938
 
939
+ if (gate_exps_b) {
940
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
941
+ cb(cur, "ffn_moe_gate_biased", il);
942
+ }
943
+
881
944
  switch (type_op) {
882
945
  case LLM_FFN_SILU:
883
946
  if (gate_exps) {
@@ -895,6 +958,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
895
958
  cur = ggml_gelu(ctx0, cur);
896
959
  cb(cur, "ffn_moe_gelu", il);
897
960
  } break;
961
+ case LLM_FFN_SWIGLU_OAI_MOE:
962
+ {
963
+ // TODO: move to hparams?
964
+ constexpr float alpha = 1.702f;
965
+ constexpr float limit = 7.0f;
966
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
967
+ cb(cur, "ffn_moe_swiglu_oai", il);
968
+ } break;
969
+ case LLM_FFN_RELU:
970
+ if (gate_exps) {
971
+ cur = ggml_reglu_split(ctx0, cur, up);
972
+ cb(cur, "ffn_moe_reglu", il);
973
+ } else {
974
+ cur = ggml_relu(ctx0, cur);
975
+ cb(cur, "ffn_moe_relu", il);
976
+ } break;
898
977
  default:
899
978
  GGML_ABORT("fatal error");
900
979
  }
@@ -902,6 +981,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
902
981
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
903
982
  cb(experts, "ffn_moe_down", il);
904
983
 
984
+ if (down_exps_b) {
985
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
986
+ cb(experts, "ffn_moe_down_biased", il);
987
+ }
988
+
905
989
  if (!weight_before_ffn) {
906
990
  experts = ggml_mul(ctx0, experts, weights);
907
991
  cb(cur, "ffn_moe_weighted", il);
@@ -938,100 +1022,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
938
1022
  return moe_out;
939
1023
  }
940
1024
 
941
- ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
942
- ggml_tensor * cur,
943
- ggml_tensor * probs,
944
- ggml_tensor * up_exps,
945
- ggml_tensor * gate_exps,
946
- ggml_tensor * down_exps,
947
- ggml_tensor * exp_probs_b,
948
- int64_t n_expert,
949
- int64_t n_expert_used,
950
- llama_expert_gating_func_type gating_op,
951
- int il) const {
952
- const int64_t n_embd = cur->ne[0];
953
- const int64_t n_tokens = cur->ne[1];
954
-
955
- // add experts selection bias - introduced in DeepSeek V3
956
- // leave probs unbiased as it's later used to get expert weights
957
- ggml_tensor * selection_probs = probs;
958
- if (exp_probs_b != nullptr) {
959
- selection_probs = ggml_add(ctx0, probs, exp_probs_b);
960
- cb(selection_probs, "ffn_moe_probs_biased", il);
961
- }
962
-
963
- // select experts
964
- ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
965
- cb(selected_experts->src[0], "ffn_moe_argsort", il);
966
- cb(selected_experts, "ffn_moe_topk", il);
967
-
968
- ggml_tensor * weights = ggml_get_rows(ctx0,
969
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
970
- cb(weights, "ffn_moe_weights", il);
971
-
972
- weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
973
- if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
974
- weights = ggml_soft_max(ctx0, weights);
975
- } else {
976
- weights = ggml_sigmoid(ctx0, weights);
977
- ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
978
- cb(weights_sum, "ffn_moe_weights_sum", il);
979
-
980
- weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
981
- cb(weights, "ffn_moe_weights_norm", il);
982
- }
983
-
984
- weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
985
-
986
- cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
987
-
988
- ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
989
- cb(up, "ffn_moe_up", il);
990
-
991
- ggml_tensor * experts = nullptr;
992
- cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
993
- cb(cur, "ffn_moe_gate", il);
994
-
995
- cur = ggml_reglu_split(ctx0, cur, up);
996
- cb(cur, "ffn_moe_reglu", il);
997
-
998
- experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
999
- cb(experts, "ffn_moe_down", il);
1000
-
1001
- experts = ggml_mul(ctx0, experts, weights);
1002
- cb(cur, "ffn_moe_weighted", il);
1003
-
1004
- ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1005
-
1006
- assert(n_expert_used > 0);
1007
-
1008
- // order the views before the adds
1009
- for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1010
- cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1011
-
1012
- ggml_build_forward_expand(gf, cur_experts[i]);
1013
- }
1014
-
1015
- // aggregate experts
1016
- // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1017
- // to avoid potentially a large number of add nodes during warmup
1018
- // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1019
- ggml_tensor * moe_out = cur_experts[0];
1020
-
1021
- for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1022
- moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1023
- }
1024
-
1025
- if (n_expert_used == 1) {
1026
- // avoid returning a non-contiguous tensor
1027
- moe_out = ggml_cont(ctx0, moe_out);
1028
- }
1029
-
1030
- cb(moe_out, "ffn_moe_out", il);
1031
-
1032
- return moe_out;
1033
- }
1034
-
1035
1025
  // input embeddings with optional lora
1036
1026
  ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1037
1027
  const int64_t n_embd = hparams.n_embd;
@@ -1234,6 +1224,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1234
1224
  ggml_tensor * kq_b,
1235
1225
  ggml_tensor * kq_mask,
1236
1226
  ggml_tensor * v_mla,
1227
+ ggml_tensor * sinks,
1237
1228
  float kq_scale) const {
1238
1229
  const bool v_trans = v->nb[1] > v->nb[2];
1239
1230
 
@@ -1270,7 +1261,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1270
1261
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1271
1262
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1272
1263
 
1273
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1264
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
1265
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1274
1266
 
1275
1267
  if (v_mla) {
1276
1268
  #if 0
@@ -1318,6 +1310,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1318
1310
  }
1319
1311
 
1320
1312
  kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1313
+ ggml_soft_max_add_sinks(kq, sinks);
1321
1314
 
1322
1315
  if (!v_trans) {
1323
1316
  // note: avoid this branch
@@ -1388,7 +1381,7 @@ ggml_tensor * llm_graph_context::build_attn(
1388
1381
  ggml_tensor * k = k_cur;
1389
1382
  ggml_tensor * v = v_cur;
1390
1383
 
1391
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1384
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1392
1385
  cb(cur, "kqv_out", il);
1393
1386
 
1394
1387
  if (wo) {
@@ -1476,13 +1469,13 @@ ggml_tensor * llm_graph_context::build_attn(
1476
1469
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1477
1470
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1478
1471
 
1479
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1472
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1480
1473
  cb(cur, "kqv_out", il);
1481
1474
 
1482
1475
  if (wo) {
1483
1476
  cur = build_lora_mm(wo, cur);
1484
- if (arch == LLM_ARCH_GLM4) {
1485
- // GLM4 seems to have numerical issues with half-precision accumulators
1477
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1478
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1486
1479
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1487
1480
  }
1488
1481
  }
@@ -1505,6 +1498,32 @@ ggml_tensor * llm_graph_context::build_attn(
1505
1498
  ggml_tensor * v_mla,
1506
1499
  float kq_scale,
1507
1500
  int il) const {
1501
+ return build_attn_with_sinks(
1502
+ inp,
1503
+ wo,
1504
+ wo_b,
1505
+ q_cur,
1506
+ k_cur,
1507
+ v_cur,
1508
+ kq_b,
1509
+ v_mla,
1510
+ nullptr,
1511
+ kq_scale,
1512
+ il);
1513
+ }
1514
+
1515
+ ggml_tensor * llm_graph_context::build_attn_with_sinks(
1516
+ llm_graph_input_attn_kv_unified_iswa * inp,
1517
+ ggml_tensor * wo,
1518
+ ggml_tensor * wo_b,
1519
+ ggml_tensor * q_cur,
1520
+ ggml_tensor * k_cur,
1521
+ ggml_tensor * v_cur,
1522
+ ggml_tensor * kq_b,
1523
+ ggml_tensor * v_mla,
1524
+ ggml_tensor * sinks,
1525
+ float kq_scale,
1526
+ int il) const {
1508
1527
  // these nodes are added to the graph together so that they are not reordered
1509
1528
  // by doing so, the number of splits in the graph is reduced
1510
1529
  ggml_build_forward_expand(gf, q_cur);
@@ -1542,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn(
1542
1561
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1543
1562
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1544
1563
 
1545
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1564
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
1546
1565
  cb(cur, "kqv_out", il);
1547
1566
 
1548
1567
  if (wo) {
@@ -1596,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_attn(
1596
1615
  ggml_tensor * k = k_cur;
1597
1616
  ggml_tensor * v = v_cur;
1598
1617
 
1599
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1618
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1600
1619
  cb(cur, "kqv_out", il);
1601
1620
 
1602
1621
  if (wo) {
@@ -1655,16 +1674,17 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1655
1674
 
1656
1675
  ggml_tensor * llm_graph_context::build_rs(
1657
1676
  ggml_tensor * s,
1658
- ggml_tensor * state_copy,
1677
+ ggml_tensor * state_copy_main,
1678
+ ggml_tensor * state_copy_extra,
1659
1679
  int32_t state_size,
1660
1680
  int32_t n_seqs,
1661
- uint32_t n_kv,
1662
- uint32_t kv_head,
1663
- uint32_t kv_size,
1681
+ uint32_t n_rs,
1682
+ uint32_t rs_head,
1683
+ uint32_t rs_size,
1664
1684
  int32_t rs_zero,
1665
1685
  const llm_graph_get_rows_fn & get_state_rows) const {
1666
1686
 
1667
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1687
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
1668
1688
 
1669
1689
  // Clear a single state which will then be copied to the other cleared states.
1670
1690
  // Note that this is a no-op when the view is zero-sized.
@@ -1672,39 +1692,44 @@ ggml_tensor * llm_graph_context::build_rs(
1672
1692
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1673
1693
 
1674
1694
  // copy states
1675
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1676
- // {state_size, kv_size} -> {state_size, n_seqs}
1677
- ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1695
+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1696
+ // {state_size, rs_size} -> {state_size, n_seqs}
1697
+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
1678
1698
  ggml_build_forward_expand(gf, output_states);
1679
1699
 
1680
- // copy extra states which won't be changed further (between n_seqs and n_kv)
1681
- ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1700
+ // copy extra states which won't be changed further (between n_seqs and n_rs)
1701
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
1682
1702
  ggml_build_forward_expand(gf,
1683
1703
  ggml_cpy(ctx0,
1684
1704
  states_extra,
1685
- ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1705
+ ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
1686
1706
 
1687
1707
  return output_states;
1688
1708
  }
1689
1709
 
1690
1710
  static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1691
1711
  ggml_context * ctx0,
1712
+ const llama_ubatch & ubatch,
1692
1713
  const llama_memory_recurrent_context * mctx_cur) {
1693
1714
 
1694
1715
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1695
1716
 
1696
- const auto n_rs = mctx_cur->get_n_rs();
1717
+ const int64_t n_rs = mctx_cur->get_n_rs();
1718
+ const int64_t n_seqs = ubatch.n_seqs;
1697
1719
 
1698
1720
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1699
1721
  ggml_set_input(inp->s_copy);
1700
1722
 
1723
+ inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1724
+ inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1725
+
1701
1726
  return inp;
1702
1727
  }
1703
1728
 
1704
1729
  llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1705
1730
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1706
1731
 
1707
- auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1732
+ auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
1708
1733
 
1709
1734
  return (llm_graph_input_rs *) res->add_input(std::move(inp));
1710
1735
  }
@@ -1717,7 +1742,9 @@ ggml_tensor * llm_graph_context::build_rs(
1717
1742
  const llm_graph_get_rows_fn & get_state_rows) const {
1718
1743
  const auto * kv_state = inp->mctx;
1719
1744
 
1720
- return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1745
+ return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
1746
+ kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
1747
+ get_state_rows);
1721
1748
  }
1722
1749
 
1723
1750
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1764,7 +1791,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1764
1791
  llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1765
1792
  const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1766
1793
 
1767
- auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1794
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
1768
1795
  auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1769
1796
 
1770
1797
  auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);