@fugood/llama.node 1.3.0-rc.6 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. package/CMakeLists.txt +12 -2
  2. package/package.json +14 -14
  3. package/scripts/llama.cpp.patch +8 -9
  4. package/src/llama.cpp/common/CMakeLists.txt +2 -0
  5. package/src/llama.cpp/common/arg.cpp +39 -1001
  6. package/src/llama.cpp/common/arg.h +2 -2
  7. package/src/llama.cpp/common/chat.cpp +216 -2
  8. package/src/llama.cpp/common/chat.h +1 -0
  9. package/src/llama.cpp/common/common.cpp +33 -0
  10. package/src/llama.cpp/common/common.h +13 -0
  11. package/src/llama.cpp/common/download.cpp +1054 -0
  12. package/src/llama.cpp/common/download.h +55 -0
  13. package/src/llama.cpp/common/json-schema-to-grammar.cpp +19 -3
  14. package/src/llama.cpp/ggml/CMakeLists.txt +3 -1
  15. package/src/llama.cpp/ggml/include/ggml-hexagon.h +19 -0
  16. package/src/llama.cpp/ggml/include/ggml.h +2 -0
  17. package/src/llama.cpp/ggml/src/CMakeLists.txt +7 -3
  18. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +10 -3
  19. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  20. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  21. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -1
  23. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +0 -5
  24. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +172 -35
  25. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +82 -21
  26. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +25 -25
  27. package/src/llama.cpp/include/llama.h +7 -3
  28. package/src/llama.cpp/src/CMakeLists.txt +95 -0
  29. package/src/llama.cpp/src/llama-arch.cpp +108 -0
  30. package/src/llama.cpp/src/llama-arch.h +11 -0
  31. package/src/llama.cpp/src/llama-batch.cpp +63 -31
  32. package/src/llama.cpp/src/llama-batch.h +12 -1
  33. package/src/llama.cpp/src/llama-chat.cpp +32 -0
  34. package/src/llama.cpp/src/llama-chat.h +1 -0
  35. package/src/llama.cpp/src/llama-context.cpp +44 -16
  36. package/src/llama.cpp/src/llama-context.h +5 -5
  37. package/src/llama.cpp/src/llama-cparams.h +1 -0
  38. package/src/llama.cpp/src/llama-graph.cpp +12 -7
  39. package/src/llama.cpp/src/llama-hparams.cpp +11 -1
  40. package/src/llama.cpp/src/llama-hparams.h +6 -0
  41. package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +3 -1
  42. package/src/llama.cpp/src/llama-kv-cache.cpp +56 -21
  43. package/src/llama.cpp/src/llama-kv-cache.h +2 -4
  44. package/src/llama.cpp/src/llama-kv-cells.h +44 -2
  45. package/src/llama.cpp/src/llama-memory-recurrent.cpp +18 -14
  46. package/src/llama.cpp/src/llama-memory-recurrent.h +2 -2
  47. package/src/llama.cpp/src/llama-model.cpp +350 -13194
  48. package/src/llama.cpp/src/llama-model.h +9 -2
  49. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  50. package/src/llama.cpp/src/llama-vocab.cpp +5 -0
  51. package/src/llama.cpp/src/llama-vocab.h +1 -0
  52. package/src/llama.cpp/src/models/apertus.cpp +125 -0
  53. package/src/llama.cpp/src/models/arcee.cpp +135 -0
  54. package/src/llama.cpp/src/models/arctic.cpp +138 -0
  55. package/src/llama.cpp/src/models/arwkv7.cpp +86 -0
  56. package/src/llama.cpp/src/models/baichuan.cpp +122 -0
  57. package/src/llama.cpp/src/models/bailingmoe.cpp +144 -0
  58. package/src/llama.cpp/src/models/bailingmoe2.cpp +135 -0
  59. package/src/llama.cpp/src/models/bert.cpp +176 -0
  60. package/src/llama.cpp/src/models/bitnet.cpp +160 -0
  61. package/src/llama.cpp/src/models/bloom.cpp +101 -0
  62. package/src/llama.cpp/src/models/chameleon.cpp +178 -0
  63. package/src/llama.cpp/src/models/chatglm.cpp +132 -0
  64. package/src/llama.cpp/src/models/codeshell.cpp +111 -0
  65. package/src/llama.cpp/src/models/cogvlm.cpp +100 -0
  66. package/src/llama.cpp/src/models/cohere2-iswa.cpp +131 -0
  67. package/src/llama.cpp/src/models/command-r.cpp +122 -0
  68. package/src/llama.cpp/src/models/dbrx.cpp +123 -0
  69. package/src/llama.cpp/src/models/deci.cpp +135 -0
  70. package/src/llama.cpp/src/models/deepseek.cpp +144 -0
  71. package/src/llama.cpp/src/models/deepseek2.cpp +236 -0
  72. package/src/llama.cpp/src/models/dots1.cpp +134 -0
  73. package/src/llama.cpp/src/models/dream.cpp +105 -0
  74. package/src/llama.cpp/src/models/ernie4-5-moe.cpp +150 -0
  75. package/src/llama.cpp/src/models/ernie4-5.cpp +111 -0
  76. package/src/llama.cpp/src/models/exaone.cpp +114 -0
  77. package/src/llama.cpp/src/models/exaone4.cpp +123 -0
  78. package/src/llama.cpp/src/models/falcon-h1.cpp +113 -0
  79. package/src/llama.cpp/src/models/falcon.cpp +120 -0
  80. package/src/llama.cpp/src/models/gemma-embedding.cpp +120 -0
  81. package/src/llama.cpp/src/models/gemma.cpp +112 -0
  82. package/src/llama.cpp/src/models/gemma2-iswa.cpp +125 -0
  83. package/src/llama.cpp/src/models/gemma3-iswa.cpp +131 -0
  84. package/src/llama.cpp/src/models/gemma3n-iswa.cpp +377 -0
  85. package/src/llama.cpp/src/models/glm4-moe.cpp +153 -0
  86. package/src/llama.cpp/src/models/glm4.cpp +127 -0
  87. package/src/llama.cpp/src/models/gpt2.cpp +105 -0
  88. package/src/llama.cpp/src/models/gptneox.cpp +144 -0
  89. package/src/llama.cpp/src/models/granite-hybrid.cpp +196 -0
  90. package/src/llama.cpp/src/models/granite.cpp +211 -0
  91. package/src/llama.cpp/src/models/graph-context-mamba.cpp +283 -0
  92. package/src/llama.cpp/src/models/grok.cpp +159 -0
  93. package/src/llama.cpp/src/models/grovemoe.cpp +141 -0
  94. package/src/llama.cpp/src/models/hunyuan-dense.cpp +132 -0
  95. package/src/llama.cpp/src/models/hunyuan-moe.cpp +154 -0
  96. package/src/llama.cpp/src/models/internlm2.cpp +120 -0
  97. package/src/llama.cpp/src/models/jais.cpp +86 -0
  98. package/src/llama.cpp/src/models/jamba.cpp +106 -0
  99. package/src/llama.cpp/src/models/lfm2.cpp +173 -0
  100. package/src/llama.cpp/src/models/llada-moe.cpp +122 -0
  101. package/src/llama.cpp/src/models/llada.cpp +99 -0
  102. package/src/llama.cpp/src/models/llama-iswa.cpp +174 -0
  103. package/src/llama.cpp/src/models/llama.cpp +155 -0
  104. package/src/llama.cpp/src/models/mamba.cpp +55 -0
  105. package/src/llama.cpp/src/models/minicpm3.cpp +199 -0
  106. package/src/llama.cpp/src/models/minimax-m2.cpp +124 -0
  107. package/src/llama.cpp/src/models/models.h +481 -0
  108. package/src/llama.cpp/src/models/mpt.cpp +126 -0
  109. package/src/llama.cpp/src/models/nemotron-h.cpp +121 -0
  110. package/src/llama.cpp/src/models/nemotron.cpp +122 -0
  111. package/src/llama.cpp/src/models/neo-bert.cpp +104 -0
  112. package/src/llama.cpp/src/models/olmo.cpp +121 -0
  113. package/src/llama.cpp/src/models/olmo2.cpp +150 -0
  114. package/src/llama.cpp/src/models/olmoe.cpp +124 -0
  115. package/src/llama.cpp/src/models/openai-moe-iswa.cpp +123 -0
  116. package/src/llama.cpp/src/models/openelm.cpp +124 -0
  117. package/src/llama.cpp/src/models/orion.cpp +123 -0
  118. package/src/llama.cpp/src/models/pangu-embedded.cpp +121 -0
  119. package/src/llama.cpp/src/models/phi2.cpp +121 -0
  120. package/src/llama.cpp/src/models/phi3.cpp +152 -0
  121. package/src/llama.cpp/src/models/plamo.cpp +110 -0
  122. package/src/llama.cpp/src/models/plamo2.cpp +316 -0
  123. package/src/llama.cpp/src/models/plm.cpp +168 -0
  124. package/src/llama.cpp/src/models/qwen.cpp +108 -0
  125. package/src/llama.cpp/src/models/qwen2.cpp +117 -0
  126. package/src/llama.cpp/src/models/qwen2moe.cpp +151 -0
  127. package/src/llama.cpp/src/models/qwen2vl.cpp +117 -0
  128. package/src/llama.cpp/src/models/qwen3.cpp +117 -0
  129. package/src/llama.cpp/src/models/qwen3moe.cpp +124 -0
  130. package/src/llama.cpp/src/models/qwen3vl-moe.cpp +149 -0
  131. package/src/llama.cpp/src/models/qwen3vl.cpp +141 -0
  132. package/src/llama.cpp/src/models/refact.cpp +94 -0
  133. package/src/llama.cpp/src/models/rwkv6-base.cpp +162 -0
  134. package/src/llama.cpp/src/models/rwkv6.cpp +94 -0
  135. package/src/llama.cpp/src/models/rwkv6qwen2.cpp +86 -0
  136. package/src/llama.cpp/src/models/rwkv7-base.cpp +135 -0
  137. package/src/llama.cpp/src/models/rwkv7.cpp +90 -0
  138. package/src/llama.cpp/src/models/seed-oss.cpp +124 -0
  139. package/src/llama.cpp/src/models/smallthinker.cpp +120 -0
  140. package/src/llama.cpp/src/models/smollm3.cpp +128 -0
  141. package/src/llama.cpp/src/models/stablelm.cpp +146 -0
  142. package/src/llama.cpp/src/models/starcoder.cpp +100 -0
  143. package/src/llama.cpp/src/models/starcoder2.cpp +121 -0
  144. package/src/llama.cpp/src/models/t5-dec.cpp +166 -0
  145. package/src/llama.cpp/src/models/t5-enc.cpp +96 -0
  146. package/src/llama.cpp/src/models/wavtokenizer-dec.cpp +149 -0
  147. package/src/llama.cpp/src/models/xverse.cpp +108 -0
@@ -60,6 +60,16 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
60
60
  return n_head/n_head_kv;
61
61
  }
62
62
 
63
+ uint32_t llama_hparams::n_embd_inp() const {
64
+ uint32_t n_embd_inp = n_embd;
65
+
66
+ if (n_deepstack_layers > 0) {
67
+ n_embd_inp += n_embd * n_deepstack_layers;
68
+ }
69
+
70
+ return n_embd_inp;
71
+ }
72
+
63
73
  uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
64
74
  const uint32_t n_head_kv = this->n_head_kv(il);
65
75
 
@@ -148,7 +158,7 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
148
158
  }
149
159
 
150
160
  uint32_t llama_hparams::n_pos_per_embd() const {
151
- return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
161
+ return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
152
162
  }
153
163
 
154
164
  bool llama_hparams::is_swa(uint32_t il) const {
@@ -183,6 +183,9 @@ struct llama_hparams {
183
183
  std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
184
184
  std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
185
185
 
186
+ // qwen3vl deepstack
187
+ uint32_t n_deepstack_layers = 0;
188
+
186
189
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
187
190
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
188
191
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@@ -224,6 +227,9 @@ struct llama_hparams {
224
227
 
225
228
  uint32_t n_gqa(uint32_t il = 0) const;
226
229
 
230
+ // dimension of main + auxiliary input embeddings
231
+ uint32_t n_embd_inp() const;
232
+
227
233
  // dimension of key embeddings across all k-v heads
228
234
  uint32_t n_embd_k_gqa(uint32_t il = 0) const;
229
235
 
@@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
45
45
 
46
46
  const uint32_t size_base = kv_size;
47
47
 
48
- uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
48
+ // note: the SWA cache is always padded to 256 for performance
49
+ // https://github.com/ggml-org/llama.cpp/issues/17037
50
+ uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
49
51
 
50
52
  // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
51
53
  if (swa_full) {
@@ -8,6 +8,7 @@
8
8
  #include <algorithm>
9
9
  #include <cassert>
10
10
  #include <cmath>
11
+ #include <cstring>
11
12
  #include <limits>
12
13
  #include <map>
13
14
  #include <stdexcept>
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
37
38
 
38
39
  const uint32_t n_layer_kv = hparams.n_layer_kv();
39
40
 
41
+ // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
42
+ struct ggml_backend_buft_comparator {
43
+ bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
44
+ return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
45
+ }
46
+ };
47
+ std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
48
+
40
49
  // create a context for each buffer type
41
- std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
42
50
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
43
51
  auto it = ctx_map.find(buft);
44
52
  if (it == ctx_map.end()) {
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
53
61
  return nullptr;
54
62
  }
55
63
 
56
- ctx_map[buft] = ctx;
57
- ctxs.emplace_back(ctx);
64
+ ctx_map.emplace(buft, ctx);
58
65
 
59
66
  return ctx;
60
67
  }
61
68
 
62
- return it->second;
69
+ return it->second.get();
63
70
  };
64
71
 
65
72
  GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache(
167
174
  }
168
175
 
169
176
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
170
- for (auto it : ctx_map) {
171
- auto * buft = it.first;
172
- auto * ctx = it.second;
173
-
174
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
177
+ for (auto & [buft, ctx] : ctx_map) {
178
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
175
179
  if (!buf) {
176
180
  throw std::runtime_error("failed to allocate buffer for kv cache");
177
181
  }
@@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache(
179
183
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
180
184
 
181
185
  ggml_backend_buffer_clear(buf, 0);
182
- bufs.emplace_back(buf);
186
+ ctxs_bufs.emplace_back(std::move(ctx), buf);
183
187
  }
184
188
 
185
189
  {
@@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
203
207
  }
204
208
 
205
209
  if (data) {
206
- for (auto & buf : bufs) {
210
+ for (auto & [_, buf] : ctxs_bufs) {
207
211
  ggml_backend_buffer_clear(buf.get(), 0);
208
212
  }
209
213
  }
@@ -334,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
334
338
  llama_pos pos = v_cells[s0].pos_get(i);
335
339
  llama_pos shift = v_cells[s0].get_shift(i);
336
340
 
341
+ llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
342
+
337
343
  if (shift != 0) {
338
344
  pos -= shift;
339
345
  assert(pos >= 0);
@@ -345,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
345
351
  if (shift != 0) {
346
352
  v_cells[s1].pos_add(i, shift);
347
353
  }
354
+
355
+ v_cells[s1].ext_set(i, ext);
348
356
  }
349
357
  }
350
358
 
@@ -379,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
379
387
 
380
388
  void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
381
389
  GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
390
+ GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
382
391
 
383
392
  auto & cells = v_cells[seq_to_stream[seq_id]];
384
393
  auto & head = v_heads[seq_to_stream[seq_id]];
@@ -423,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
423
432
 
424
433
  void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
425
434
  GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
435
+ GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
426
436
 
427
437
  auto & cells = v_cells[seq_to_stream[seq_id]];
428
438
 
@@ -472,8 +482,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472
482
 
473
483
  std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
474
484
  std::map<ggml_backend_buffer_type_t, size_t> ret;
475
- for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
476
- ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
485
+ for (const auto & [_, buf] : ctxs_bufs) {
486
+ ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
477
487
  }
478
488
  return ret;
479
489
  }
@@ -896,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
896
906
 
897
907
  cells.pos_set(idx, ubatch.pos[i]);
898
908
 
909
+ if (ubatch.is_pos_2d()) {
910
+ llama_kv_cell_ext ext {
911
+ /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
912
+ /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
913
+ };
914
+ cells.ext_set(idx, ext);
915
+ }
916
+
899
917
  for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
900
918
  cells.seq_add(idx, ubatch.seq_id[i][s]);
901
919
  }
@@ -957,10 +975,14 @@ bool llama_kv_cache::get_has_shift() const {
957
975
  uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
958
976
  uint32_t result = 0;
959
977
 
978
+ // pad the n_kv value so that the graph remains constant across batches and can be reused
979
+ // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
980
+ const uint32_t n_pad_cur = std::max(n_pad, 256u);
981
+
960
982
  for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
961
983
  const auto & cells = v_cells[sinfo.strm[s]];
962
984
 
963
- result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
985
+ result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
964
986
  }
965
987
 
966
988
  return result;
@@ -1239,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
1239
1261
 
1240
1262
  const llama_pos p1 = ubatch->pos[i];
1241
1263
 
1264
+ // for M-RoPE
1265
+ const bool is_2d = ubatch->is_pos_2d();
1266
+ const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1267
+ const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
1268
+
1242
1269
  const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1243
1270
 
1244
1271
  for (uint32_t j = 0; j < n_kv; ++j) {
@@ -1258,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
1258
1285
  continue;
1259
1286
  }
1260
1287
 
1288
+ // M-RoPE causal mask
1289
+ if (causal_attn && is_2d && p0 == p1) {
1290
+ const auto & p0_ext = cells.ext_get(j);
1291
+ if (p0_ext.is_2d_gt(p1_x, p1_y)) {
1292
+ continue;
1293
+ }
1294
+ }
1295
+
1261
1296
  // apply SWA if any
1262
1297
  if (is_masked_swa(p0, p1)) {
1263
1298
  continue;
@@ -1298,7 +1333,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
1298
1333
  size_t llama_kv_cache::total_size() const {
1299
1334
  size_t size = 0;
1300
1335
 
1301
- for (const auto & buf : bufs) {
1336
+ for (const auto & [_, buf] : ctxs_bufs) {
1302
1337
  size += ggml_backend_buffer_get_size(buf.get());
1303
1338
  }
1304
1339
 
@@ -1340,7 +1375,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
1340
1375
  const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1341
1376
 
1342
1377
  const auto & n_rot = hparams.n_rot;
1343
- const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
1378
+ const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
1344
1379
  // @ngxson : this is a workaround
1345
1380
  // for M-RoPE, we want to rotate the whole vector when doing KV shift
1346
1381
  // a normal RoPE should work, we just need to use the correct ordering
@@ -1551,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
1551
1586
  io.write(&pos, sizeof(pos));
1552
1587
  io.write(&n_seq_id, sizeof(n_seq_id));
1553
1588
 
1589
+ // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
1590
+ // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1591
+
1554
1592
  for (const auto & seq_id : seq_ids) {
1555
1593
  io.write(&seq_id, sizeof(seq_id));
1556
1594
  }
@@ -1696,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
1696
1734
  return false;
1697
1735
  }
1698
1736
 
1737
+ // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
1738
+ // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1699
1739
  apply_ubatch(sinfo, ubatch);
1700
1740
 
1701
1741
  const auto head_cur = sinfo.head();
@@ -2010,8 +2050,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
2010
2050
  void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2011
2051
  kv->set_input_pos_bucket(dst, ubatch);
2012
2052
  }
2013
-
2014
- uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
2015
- // the FA kernels require padding to avoid extra runtime boundary checks
2016
- return cparams.flash_attn ? 256u : 32u;
2017
- }
@@ -19,8 +19,6 @@ struct llama_context;
19
19
 
20
20
  class llama_kv_cache : public llama_memory_i {
21
21
  public:
22
- static uint32_t get_padding(const llama_cparams & cparams);
23
-
24
22
  struct stream_copy_info {
25
23
  bool empty() const {
26
24
  assert(ssrc.size() == sdst.size());
@@ -217,8 +215,8 @@ private:
217
215
  // this is the SWA type of the cache - not to be confused with the model SWA type
218
216
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
219
217
 
220
- std::vector<ggml_context_ptr> ctxs;
221
- std::vector<ggml_backend_buffer_ptr> bufs;
218
+ // ggml contexts for the KV cache along with the allocated backend buffers:
219
+ std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
222
220
 
223
221
  // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
224
222
  // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@@ -5,9 +5,27 @@
5
5
 
6
6
  #include <bitset>
7
7
  #include <cassert>
8
- #include <vector>
9
- #include <set>
8
+ #include <cstring>
10
9
  #include <map>
10
+ #include <set>
11
+ #include <vector>
12
+
13
+ struct llama_kv_cell_ext {
14
+ // 2D spatial positions, typically used for M-RoPE
15
+ llama_pos x = 0;
16
+ llama_pos y = 0;
17
+
18
+ // return true if the current 2D spatial position is greater than other
19
+ bool is_2d_gt(llama_pos ox, llama_pos oy) const {
20
+ return (y > oy) || (y == oy && x > ox);
21
+ }
22
+
23
+ void reset() {
24
+ static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
25
+
26
+ memset(this, 0, sizeof(*this));
27
+ }
28
+ };
11
29
 
12
30
  // meta information about KV cells that can be part of multiple sequences at the same time
13
31
  // TODO: add unit tests
@@ -16,6 +34,7 @@ public:
16
34
  void reset() {
17
35
  for (uint32_t i = 0; i < pos.size(); ++i) {
18
36
  pos[i] = -1;
37
+ ext[i].reset();
19
38
  shift[i] = 0;
20
39
  seq[i].reset();
21
40
  }
@@ -43,6 +62,7 @@ public:
43
62
 
44
63
  void resize(uint32_t n) {
45
64
  pos.resize(n);
65
+ ext.resize(n);
46
66
  shift.resize(n);
47
67
  seq.resize(n);
48
68
 
@@ -108,6 +128,7 @@ public:
108
128
  const auto idx = i + j;
109
129
 
110
130
  res.pos[j] = pos[idx];
131
+ res.ext[j] = ext[idx];
111
132
  res.seq[j] = seq[idx];
112
133
 
113
134
  assert(shift[idx] == 0);
@@ -126,6 +147,7 @@ public:
126
147
  const auto idx = idxs[j];
127
148
 
128
149
  res.pos[j] = pos[idx];
150
+ res.ext[j] = ext[idx];
129
151
  res.seq[j] = seq[idx];
130
152
 
131
153
  assert(shift[idx] == 0);
@@ -154,6 +176,7 @@ public:
154
176
  }
155
177
 
156
178
  pos[idx] = other.pos[j];
179
+ ext[idx] = other.ext[j];
157
180
  seq[idx] = other.seq[j];
158
181
 
159
182
  if (pos[idx] != -1) {
@@ -184,6 +207,7 @@ public:
184
207
  }
185
208
 
186
209
  pos[idx] = other.pos[j];
210
+ ext[idx] = other.ext[j];
187
211
  seq[idx] = other.seq[j];
188
212
 
189
213
  if (pos[idx] != -1) {
@@ -203,6 +227,7 @@ public:
203
227
  seq[i].reset();
204
228
 
205
229
  pos[i] = -1;
230
+ ext[i].reset();
206
231
  shift[i] = 0;
207
232
 
208
233
  used.erase(i);
@@ -221,6 +246,7 @@ public:
221
246
 
222
247
  if (seq[i].none()) {
223
248
  pos[i] = -1;
249
+ ext[i].reset();
224
250
  shift[i] = 0;
225
251
 
226
252
  used.erase(i);
@@ -250,6 +276,7 @@ public:
250
276
  seq[i].reset();
251
277
 
252
278
  pos[i] = -1;
279
+ ext[i].reset();
253
280
  shift[i] = 0;
254
281
 
255
282
  used.erase(i);
@@ -340,6 +367,13 @@ public:
340
367
  return pos[i];
341
368
  }
342
369
 
370
+ const llama_kv_cell_ext & ext_get(uint32_t i) const {
371
+ assert(i < pos.size());
372
+ assert(pos[i] != -1);
373
+
374
+ return ext[i];
375
+ }
376
+
343
377
  // note: call only if the cell is not empty
344
378
  llama_pos get_shift(uint32_t i) const {
345
379
  assert(i < pos.size());
@@ -368,6 +402,11 @@ public:
368
402
  used.insert(i);
369
403
  }
370
404
 
405
+ void ext_set(uint32_t i, llama_kv_cell_ext p) {
406
+ assert(i < ext.size());
407
+ ext[i] = p;
408
+ }
409
+
371
410
  // pos[i] = pos[i] + d
372
411
  // sets "has_shift" to true
373
412
  // note: call only if the cell is not empty
@@ -424,6 +463,9 @@ private:
424
463
 
425
464
  std::vector<llama_pos> pos;
426
465
 
466
+ // stores extra info per cell
467
+ std::vector<llama_kv_cell_ext> ext;
468
+
427
469
  // this array accumulates any applied shifts to the pos array since the last reset_shift() call
428
470
  // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
429
471
  //
@@ -7,6 +7,7 @@
7
7
 
8
8
  #include <algorithm>
9
9
  #include <cassert>
10
+ #include <cstring>
10
11
  #include <limits>
11
12
  #include <map>
12
13
  #include <stdexcept>
@@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
32
33
  cells.clear();
33
34
  cells.resize(mem_size);
34
35
 
36
+ // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
37
+ struct ggml_backend_buft_comparator {
38
+ bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
39
+ return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
40
+ }
41
+ };
42
+ std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
43
+
35
44
  // create a context for each buffer type
36
- std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
37
45
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
38
46
  auto it = ctx_map.find(buft);
39
47
  if (it == ctx_map.end()) {
@@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
48
56
  return nullptr;
49
57
  }
50
58
 
51
- ctx_map[buft] = ctx;
52
- ctxs.emplace_back(ctx);
59
+ ctx_map.emplace(buft, ctx);
53
60
 
54
61
  return ctx;
55
62
  }
56
63
 
57
- return it->second;
64
+ return it->second.get();
58
65
  };
59
66
 
60
67
  r_l.resize(n_layer);
@@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
93
100
  }
94
101
 
95
102
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
96
- for (auto it : ctx_map) {
97
- auto * buft = it.first;
98
- auto * ctx = it.second;
99
-
100
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
103
+ for (auto & [buft, ctx] : ctx_map) {
104
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
101
105
  if (!buf) {
102
106
  throw std::runtime_error("failed to allocate buffer for rs cache");
103
107
  }
104
108
  ggml_backend_buffer_clear(buf, 0);
105
109
  LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
106
- bufs.emplace_back(buf);
110
+ ctxs_bufs.emplace_back(std::move(ctx), buf);
107
111
  }
108
112
 
109
113
  {
@@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
129
133
  used = 0;
130
134
 
131
135
  if (data) {
132
- for (auto & buf : bufs) {
136
+ for (auto & [_, buf] : ctxs_bufs) {
133
137
  ggml_backend_buffer_clear(buf.get(), 0);
134
138
  }
135
139
  }
@@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
364
368
 
365
369
  std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
366
370
  std::map<ggml_backend_buffer_type_t, size_t> ret;
367
- for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
368
- ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
371
+ for (const auto & [_, buf] : ctxs_bufs) {
372
+ ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
369
373
  }
370
374
  return ret;
371
375
  }
@@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
662
666
 
663
667
  size_t llama_memory_recurrent::total_size() const {
664
668
  size_t size = 0;
665
- for (const auto & buf : bufs) {
669
+ for (const auto & [_, buf] : ctxs_bufs) {
666
670
  size += ggml_backend_buffer_get_size(buf.get());
667
671
  }
668
672
 
@@ -109,8 +109,8 @@ private:
109
109
 
110
110
  const uint32_t n_seq_max = 1;
111
111
 
112
- std::vector<ggml_context_ptr> ctxs;
113
- std::vector<ggml_backend_buffer_ptr> bufs;
112
+ // ggml contexts for the KV cache along with the allocated backend buffers:
113
+ std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114
114
 
115
115
  size_t total_size() const;
116
116