@fugood/llama.node 1.3.0-rc.5 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -85,6 +85,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
85
85
  { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
86
86
  { LLM_ARCH_PLM, "plm" },
87
87
  { LLM_ARCH_BAILINGMOE, "bailingmoe" },
88
+ { LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
88
89
  { LLM_ARCH_DOTS1, "dots1" },
89
90
  { LLM_ARCH_ARCEE, "arcee" },
90
91
  { LLM_ARCH_ERNIE4_5, "ernie4_5" },
@@ -135,6 +136,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
135
136
  { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
136
137
  { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
137
138
  { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
139
+ { LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
140
+ { LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
138
141
  { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
139
142
  { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
140
143
  { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
@@ -1946,6 +1949,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1946
1949
  { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1947
1950
  },
1948
1951
  },
1952
+ {
1953
+ LLM_ARCH_BAILINGMOE2,
1954
+ {
1955
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1956
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1957
+ { LLM_TENSOR_OUTPUT, "output" },
1958
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1959
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1960
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1961
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
1962
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1963
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1964
+ { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1965
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1966
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1967
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1968
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1969
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1970
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1971
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1972
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1973
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1974
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1975
+ { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
1976
+ { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
1977
+ { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
1978
+ { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
1979
+ { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
1980
+ { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
1981
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
1982
+ },
1983
+ },
1949
1984
  {
1950
1985
  LLM_ARCH_DOTS1,
1951
1986
  {
@@ -89,6 +89,7 @@ enum llm_arch {
89
89
  LLM_ARCH_WAVTOKENIZER_DEC,
90
90
  LLM_ARCH_PLM,
91
91
  LLM_ARCH_BAILINGMOE,
92
+ LLM_ARCH_BAILINGMOE2,
92
93
  LLM_ARCH_DOTS1,
93
94
  LLM_ARCH_ARCEE,
94
95
  LLM_ARCH_ERNIE4_5,
@@ -139,6 +140,8 @@ enum llm_kv {
139
140
  LLM_KV_EXPERT_COUNT,
140
141
  LLM_KV_EXPERT_USED_COUNT,
141
142
  LLM_KV_EXPERT_SHARED_COUNT,
143
+ LLM_KV_EXPERT_GROUP_COUNT,
144
+ LLM_KV_EXPERT_GROUP_USED_COUNT,
142
145
  LLM_KV_EXPERT_WEIGHTS_SCALE,
143
146
  LLM_KV_EXPERT_WEIGHTS_NORM,
144
147
  LLM_KV_EXPERT_GATING_FUNC,
@@ -123,7 +123,7 @@ private:
123
123
  uint32_t n_seq_max;
124
124
  uint32_t n_outputs;
125
125
 
126
- std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
126
+ std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
127
127
 
128
128
  std::vector<llama_pos> pos;
129
129
  std::vector<int32_t> n_seq_id;
@@ -63,6 +63,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
63
63
  { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
64
64
  { "yandex", LLM_CHAT_TEMPLATE_YANDEX },
65
65
  { "bailing", LLM_CHAT_TEMPLATE_BAILING },
66
+ { "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK },
67
+ { "bailing2", LLM_CHAT_TEMPLATE_BAILING2 },
66
68
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
67
69
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
68
70
  { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
@@ -191,6 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
191
193
  return LLM_CHAT_TEMPLATE_YANDEX;
192
194
  } else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
193
195
  return LLM_CHAT_TEMPLATE_BAILING;
196
+ } else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("\"HUMAN\"") && tmpl_contains("<think>")) {
197
+ return LLM_CHAT_TEMPLATE_BAILING_THINK;
198
+ } else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("<role>HUMAN</role>") && tmpl_contains("<|role_end|>")) {
199
+ return LLM_CHAT_TEMPLATE_BAILING2;
194
200
  } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
195
201
  return LLM_CHAT_TEMPLATE_LLAMA4;
196
202
  } else if (tmpl_contains("<|endofuserprompt|>")) {
@@ -644,8 +650,8 @@ int32_t llm_chat_apply_template(
644
650
  if (add_ass) {
645
651
  ss << " Ассистент:[SEP]";
646
652
  }
647
- } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
648
- // Bailing (Ling) template
653
+ } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
654
+ // Bailing (Ling/Ring) template
649
655
  for (auto message : chat) {
650
656
  std::string role(message->role);
651
657
 
@@ -658,6 +664,33 @@ int32_t llm_chat_apply_template(
658
664
  ss << "<role>" << role << "</role>" << message->content;
659
665
  }
660
666
 
667
+ if (add_ass) {
668
+ ss << "<role>ASSISTANT</role>";
669
+
670
+ if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
671
+ ss << "<think>";
672
+ }
673
+ }
674
+ } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) {
675
+ // Bailing2 (Ling 2.0) template
676
+ bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
677
+
678
+ if (!has_system) {
679
+ ss << "<role>SYSTEM</role>detailed thinking off<|role_end|>";
680
+ }
681
+
682
+ for (auto message : chat) {
683
+ std::string role(message->role);
684
+
685
+ if (role == "user") {
686
+ role = "HUMAN";
687
+ } else {
688
+ std::transform(role.begin(), role.end(), role.begin(), ::toupper);
689
+ }
690
+
691
+ ss << "<role>" << role << "</role>" << message->content << "<|role_end|>";
692
+ }
693
+
661
694
  if (add_ass) {
662
695
  ss << "<role>ASSISTANT</role>";
663
696
  }
@@ -42,6 +42,8 @@ enum llm_chat_template {
42
42
  LLM_CHAT_TEMPLATE_MEGREZ,
43
43
  LLM_CHAT_TEMPLATE_YANDEX,
44
44
  LLM_CHAT_TEMPLATE_BAILING,
45
+ LLM_CHAT_TEMPLATE_BAILING_THINK,
46
+ LLM_CHAT_TEMPLATE_BAILING2,
45
47
  LLM_CHAT_TEMPLATE_LLAMA4,
46
48
  LLM_CHAT_TEMPLATE_SMOLVLM,
47
49
  LLM_CHAT_TEMPLATE_DOTS1,
@@ -268,9 +268,7 @@ llama_context::llama_context(
268
268
  if (pipeline_parallel) {
269
269
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
270
270
  }
271
- }
272
271
 
273
- if (!hparams.vocab_only) {
274
272
  llama_memory_context_ptr mctx;
275
273
  if (memory) {
276
274
  LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
@@ -343,7 +341,14 @@ llama_context::llama_context(
343
341
  {
344
342
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
345
343
  if (!gf) {
346
- throw std::runtime_error("failed to allocate compute pp buffers");
344
+ if (pipeline_parallel) {
345
+ LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
346
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
347
+ gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
348
+ }
349
+ if (!gf) {
350
+ throw std::runtime_error("failed to allocate compute pp buffers");
351
+ }
347
352
  }
348
353
 
349
354
  n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
@@ -2346,7 +2351,8 @@ llama_context * llama_init_from_model(
2346
2351
  return nullptr;
2347
2352
  }
2348
2353
 
2349
- if (params.pooling_type != model->hparams.pooling_type) {
2354
+ if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
2355
+ params.pooling_type != model->hparams.pooling_type) {
2350
2356
  //user-specified pooling-type is different from the model default
2351
2357
  LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
2352
2358
  model->hparams.pooling_type, params.pooling_type);
@@ -810,6 +810,9 @@ ggml_tensor * llm_graph_context::build_ffn(
810
810
  GGML_ABORT("fatal error");
811
811
  }
812
812
 
813
+ //expand here so that we can fuse ffn gate
814
+ ggml_build_forward_expand(gf, cur);
815
+
813
816
  if (gate && type_gate == LLM_FFN_PAR) {
814
817
  cur = ggml_mul(ctx0, cur, tmp);
815
818
  cb(cur, "ffn_gate_par", il);
@@ -950,6 +953,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
950
953
  cb(selection_probs, "ffn_moe_probs_biased", il);
951
954
  }
952
955
 
956
+ // select top n_group_used expert groups
957
+ // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
958
+ if (hparams.n_expert_groups > 1 && n_tokens > 0) {
959
+ const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
960
+
961
+ // organize experts into n_expert_groups
962
+ ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
963
+
964
+ ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
965
+ group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
966
+
967
+ // get top n_group_used expert groups
968
+ group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
969
+ group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
970
+
971
+ ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
972
+ cb(expert_groups, "ffn_moe_group_topk", il);
973
+
974
+ // mask out the other groups
975
+ selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
976
+ selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
977
+ selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
978
+ cb(selection_probs, "ffn_moe_probs_masked", il);
979
+ }
980
+
953
981
  // select experts
954
982
  ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
955
983
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
@@ -981,6 +1009,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
981
1009
  ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
982
1010
  cb(weights_sum, "ffn_moe_weights_sum", il);
983
1011
 
1012
+ // Avoid division by zero, clamp to smallest number representable by F16
1013
+ weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1014
+ cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1015
+
984
1016
  weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
985
1017
  cb(weights, "ffn_moe_weights_norm", il);
986
1018
 
@@ -1061,6 +1093,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1061
1093
  GGML_ABORT("fatal error");
1062
1094
  }
1063
1095
 
1096
+ //expand here so that we can fuse ffn gate
1097
+ ggml_build_forward_expand(gf, cur);
1098
+
1064
1099
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1065
1100
  cb(experts, "ffn_moe_down", il);
1066
1101
 
@@ -72,6 +72,8 @@ struct llama_hparams {
72
72
  uint32_t n_ff_chexp = 0;
73
73
  uint32_t n_expert_shared = 0;
74
74
  uint32_t n_norm_groups = 0;
75
+ uint32_t n_expert_groups = 0;
76
+ uint32_t n_group_used = 0;
75
77
  uint32_t n_group_experts = 0;
76
78
 
77
79
  float expert_group_scale = 0.05f;
@@ -8,6 +8,7 @@
8
8
  #include <algorithm>
9
9
  #include <cassert>
10
10
  #include <cmath>
11
+ #include <cstring>
11
12
  #include <limits>
12
13
  #include <map>
13
14
  #include <stdexcept>
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
37
38
 
38
39
  const uint32_t n_layer_kv = hparams.n_layer_kv();
39
40
 
41
+ // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
42
+ struct ggml_backend_buft_comparator {
43
+ bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
44
+ return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
45
+ }
46
+ };
47
+ std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
48
+
40
49
  // create a context for each buffer type
41
- std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
42
50
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
43
51
  auto it = ctx_map.find(buft);
44
52
  if (it == ctx_map.end()) {
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
53
61
  return nullptr;
54
62
  }
55
63
 
56
- ctx_map[buft] = ctx;
57
- ctxs.emplace_back(ctx);
64
+ ctx_map.emplace(buft, ctx);
58
65
 
59
66
  return ctx;
60
67
  }
61
68
 
62
- return it->second;
69
+ return it->second.get();
63
70
  };
64
71
 
65
72
  GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache(
167
174
  }
168
175
 
169
176
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
170
- for (auto it : ctx_map) {
171
- auto * buft = it.first;
172
- auto * ctx = it.second;
173
-
174
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
177
+ for (auto & [buft, ctx] : ctx_map) {
178
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
175
179
  if (!buf) {
176
180
  throw std::runtime_error("failed to allocate buffer for kv cache");
177
181
  }
@@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache(
179
183
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
180
184
 
181
185
  ggml_backend_buffer_clear(buf, 0);
182
- bufs.emplace_back(buf);
186
+ ctxs_bufs.emplace_back(std::move(ctx), buf);
183
187
  }
184
188
 
185
189
  {
@@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
203
207
  }
204
208
 
205
209
  if (data) {
206
- for (auto & buf : bufs) {
210
+ for (auto & [_, buf] : ctxs_bufs) {
207
211
  ggml_backend_buffer_clear(buf.get(), 0);
208
212
  }
209
213
  }
@@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472
476
 
473
477
  std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
474
478
  std::map<ggml_backend_buffer_type_t, size_t> ret;
475
- for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
476
- ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
479
+ for (const auto & [_, buf] : ctxs_bufs) {
480
+ ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
477
481
  }
478
482
  return ret;
479
483
  }
@@ -957,10 +961,14 @@ bool llama_kv_cache::get_has_shift() const {
957
961
  uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
958
962
  uint32_t result = 0;
959
963
 
964
+ // pad the n_kv value so that the graph remains constant across batches and can be reused
965
+ // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
966
+ const uint32_t n_pad_cur = std::max(n_pad, 256u);
967
+
960
968
  for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
961
969
  const auto & cells = v_cells[sinfo.strm[s]];
962
970
 
963
- result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
971
+ result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
964
972
  }
965
973
 
966
974
  return result;
@@ -1298,7 +1306,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
1298
1306
  size_t llama_kv_cache::total_size() const {
1299
1307
  size_t size = 0;
1300
1308
 
1301
- for (const auto & buf : bufs) {
1309
+ for (const auto & [_, buf] : ctxs_bufs) {
1302
1310
  size += ggml_backend_buffer_get_size(buf.get());
1303
1311
  }
1304
1312
 
@@ -2010,8 +2018,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
2010
2018
  void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2011
2019
  kv->set_input_pos_bucket(dst, ubatch);
2012
2020
  }
2013
-
2014
- uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
2015
- // the FA kernels require padding to avoid extra runtime boundary checks
2016
- return cparams.flash_attn ? 256u : 32u;
2017
- }
@@ -19,8 +19,6 @@ struct llama_context;
19
19
 
20
20
  class llama_kv_cache : public llama_memory_i {
21
21
  public:
22
- static uint32_t get_padding(const llama_cparams & cparams);
23
-
24
22
  struct stream_copy_info {
25
23
  bool empty() const {
26
24
  assert(ssrc.size() == sdst.size());
@@ -217,8 +215,8 @@ private:
217
215
  // this is the SWA type of the cache - not to be confused with the model SWA type
218
216
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
219
217
 
220
- std::vector<ggml_context_ptr> ctxs;
221
- std::vector<ggml_backend_buffer_ptr> bufs;
218
+ // ggml contexts for the KV cache along with the allocated backend buffers:
219
+ std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
222
220
 
223
221
  // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
224
222
  // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@@ -7,6 +7,7 @@
7
7
 
8
8
  #include <algorithm>
9
9
  #include <cassert>
10
+ #include <cstring>
10
11
  #include <limits>
11
12
  #include <map>
12
13
  #include <stdexcept>
@@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
32
33
  cells.clear();
33
34
  cells.resize(mem_size);
34
35
 
36
+ // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
37
+ struct ggml_backend_buft_comparator {
38
+ bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
39
+ return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
40
+ }
41
+ };
42
+ std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
43
+
35
44
  // create a context for each buffer type
36
- std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
37
45
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
38
46
  auto it = ctx_map.find(buft);
39
47
  if (it == ctx_map.end()) {
@@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
48
56
  return nullptr;
49
57
  }
50
58
 
51
- ctx_map[buft] = ctx;
52
- ctxs.emplace_back(ctx);
59
+ ctx_map.emplace(buft, ctx);
53
60
 
54
61
  return ctx;
55
62
  }
56
63
 
57
- return it->second;
64
+ return it->second.get();
58
65
  };
59
66
 
60
67
  r_l.resize(n_layer);
@@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
93
100
  }
94
101
 
95
102
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
96
- for (auto it : ctx_map) {
97
- auto * buft = it.first;
98
- auto * ctx = it.second;
99
-
100
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
103
+ for (auto & [buft, ctx] : ctx_map) {
104
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
101
105
  if (!buf) {
102
106
  throw std::runtime_error("failed to allocate buffer for rs cache");
103
107
  }
104
108
  ggml_backend_buffer_clear(buf, 0);
105
109
  LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
106
- bufs.emplace_back(buf);
110
+ ctxs_bufs.emplace_back(std::move(ctx), buf);
107
111
  }
108
112
 
109
113
  {
@@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
129
133
  used = 0;
130
134
 
131
135
  if (data) {
132
- for (auto & buf : bufs) {
136
+ for (auto & [_, buf] : ctxs_bufs) {
133
137
  ggml_backend_buffer_clear(buf.get(), 0);
134
138
  }
135
139
  }
@@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
364
368
 
365
369
  std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
366
370
  std::map<ggml_backend_buffer_type_t, size_t> ret;
367
- for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
368
- ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
371
+ for (const auto & [_, buf] : ctxs_bufs) {
372
+ ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
369
373
  }
370
374
  return ret;
371
375
  }
@@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
662
666
 
663
667
  size_t llama_memory_recurrent::total_size() const {
664
668
  size_t size = 0;
665
- for (const auto & buf : bufs) {
669
+ for (const auto & [_, buf] : ctxs_bufs) {
666
670
  size += ggml_backend_buffer_get_size(buf.get());
667
671
  }
668
672
 
@@ -109,8 +109,8 @@ private:
109
109
 
110
110
  const uint32_t n_seq_max = 1;
111
111
 
112
- std::vector<ggml_context_ptr> ctxs;
113
- std::vector<ggml_backend_buffer_ptr> bufs;
112
+ // ggml contexts for the KV cache along with the allocated backend buffers:
113
+ std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114
114
 
115
115
  size_t total_size() const;
116
116