@fugood/llama.node 1.3.0 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/package.json +14 -14
  2. package/scripts/llama.cpp.patch +8 -8
  3. package/src/llama.cpp/common/CMakeLists.txt +2 -0
  4. package/src/llama.cpp/common/arg.cpp +44 -999
  5. package/src/llama.cpp/common/arg.h +2 -2
  6. package/src/llama.cpp/common/chat.cpp +17 -2
  7. package/src/llama.cpp/common/common.cpp +33 -0
  8. package/src/llama.cpp/common/common.h +15 -1
  9. package/src/llama.cpp/common/download.cpp +1054 -0
  10. package/src/llama.cpp/common/download.h +55 -0
  11. package/src/llama.cpp/ggml/CMakeLists.txt +1 -1
  12. package/src/llama.cpp/ggml/include/ggml.h +2 -0
  13. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  14. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +29 -11
  15. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  16. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  17. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  18. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -1
  20. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +21 -21
  21. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +172 -75
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +0 -4
  23. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +82 -21
  24. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +25 -25
  25. package/src/llama.cpp/include/llama.h +7 -3
  26. package/src/llama.cpp/src/CMakeLists.txt +95 -0
  27. package/src/llama.cpp/src/llama-arch.cpp +108 -0
  28. package/src/llama.cpp/src/llama-arch.h +11 -0
  29. package/src/llama.cpp/src/llama-batch.cpp +63 -31
  30. package/src/llama.cpp/src/llama-batch.h +12 -1
  31. package/src/llama.cpp/src/llama-chat.cpp +32 -0
  32. package/src/llama.cpp/src/llama-chat.h +1 -0
  33. package/src/llama.cpp/src/llama-context.cpp +36 -13
  34. package/src/llama.cpp/src/llama-context.h +5 -5
  35. package/src/llama.cpp/src/llama-cparams.h +1 -0
  36. package/src/llama.cpp/src/llama-graph.cpp +3 -3
  37. package/src/llama.cpp/src/llama-hparams.cpp +11 -1
  38. package/src/llama.cpp/src/llama-hparams.h +6 -0
  39. package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +3 -1
  40. package/src/llama.cpp/src/llama-kv-cache.cpp +33 -1
  41. package/src/llama.cpp/src/llama-kv-cells.h +44 -2
  42. package/src/llama.cpp/src/llama-memory-recurrent.cpp +4 -3
  43. package/src/llama.cpp/src/llama-model.cpp +320 -13171
  44. package/src/llama.cpp/src/llama-model.h +8 -0
  45. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  46. package/src/llama.cpp/src/llama-vocab.cpp +5 -0
  47. package/src/llama.cpp/src/llama-vocab.h +1 -0
  48. package/src/llama.cpp/src/models/apertus.cpp +125 -0
  49. package/src/llama.cpp/src/models/arcee.cpp +135 -0
  50. package/src/llama.cpp/src/models/arctic.cpp +138 -0
  51. package/src/llama.cpp/src/models/arwkv7.cpp +86 -0
  52. package/src/llama.cpp/src/models/baichuan.cpp +122 -0
  53. package/src/llama.cpp/src/models/bailingmoe.cpp +144 -0
  54. package/src/llama.cpp/src/models/bailingmoe2.cpp +135 -0
  55. package/src/llama.cpp/src/models/bert.cpp +176 -0
  56. package/src/llama.cpp/src/models/bitnet.cpp +160 -0
  57. package/src/llama.cpp/src/models/bloom.cpp +101 -0
  58. package/src/llama.cpp/src/models/chameleon.cpp +178 -0
  59. package/src/llama.cpp/src/models/chatglm.cpp +132 -0
  60. package/src/llama.cpp/src/models/codeshell.cpp +111 -0
  61. package/src/llama.cpp/src/models/cogvlm.cpp +100 -0
  62. package/src/llama.cpp/src/models/cohere2-iswa.cpp +131 -0
  63. package/src/llama.cpp/src/models/command-r.cpp +122 -0
  64. package/src/llama.cpp/src/models/dbrx.cpp +123 -0
  65. package/src/llama.cpp/src/models/deci.cpp +135 -0
  66. package/src/llama.cpp/src/models/deepseek.cpp +144 -0
  67. package/src/llama.cpp/src/models/deepseek2.cpp +236 -0
  68. package/src/llama.cpp/src/models/dots1.cpp +134 -0
  69. package/src/llama.cpp/src/models/dream.cpp +105 -0
  70. package/src/llama.cpp/src/models/ernie4-5-moe.cpp +150 -0
  71. package/src/llama.cpp/src/models/ernie4-5.cpp +110 -0
  72. package/src/llama.cpp/src/models/exaone.cpp +114 -0
  73. package/src/llama.cpp/src/models/exaone4.cpp +123 -0
  74. package/src/llama.cpp/src/models/falcon-h1.cpp +113 -0
  75. package/src/llama.cpp/src/models/falcon.cpp +120 -0
  76. package/src/llama.cpp/src/models/gemma-embedding.cpp +120 -0
  77. package/src/llama.cpp/src/models/gemma.cpp +112 -0
  78. package/src/llama.cpp/src/models/gemma2-iswa.cpp +125 -0
  79. package/src/llama.cpp/src/models/gemma3-iswa.cpp +131 -0
  80. package/src/llama.cpp/src/models/gemma3n-iswa.cpp +377 -0
  81. package/src/llama.cpp/src/models/glm4-moe.cpp +153 -0
  82. package/src/llama.cpp/src/models/glm4.cpp +127 -0
  83. package/src/llama.cpp/src/models/gpt2.cpp +105 -0
  84. package/src/llama.cpp/src/models/gptneox.cpp +144 -0
  85. package/src/llama.cpp/src/models/granite-hybrid.cpp +196 -0
  86. package/src/llama.cpp/src/models/granite.cpp +211 -0
  87. package/src/llama.cpp/src/models/graph-context-mamba.cpp +283 -0
  88. package/src/llama.cpp/src/models/grok.cpp +159 -0
  89. package/src/llama.cpp/src/models/grovemoe.cpp +141 -0
  90. package/src/llama.cpp/src/models/hunyuan-dense.cpp +132 -0
  91. package/src/llama.cpp/src/models/hunyuan-moe.cpp +154 -0
  92. package/src/llama.cpp/src/models/internlm2.cpp +120 -0
  93. package/src/llama.cpp/src/models/jais.cpp +86 -0
  94. package/src/llama.cpp/src/models/jamba.cpp +106 -0
  95. package/src/llama.cpp/src/models/lfm2.cpp +173 -0
  96. package/src/llama.cpp/src/models/llada-moe.cpp +122 -0
  97. package/src/llama.cpp/src/models/llada.cpp +99 -0
  98. package/src/llama.cpp/src/models/llama-iswa.cpp +174 -0
  99. package/src/llama.cpp/src/models/llama.cpp +155 -0
  100. package/src/llama.cpp/src/models/mamba.cpp +55 -0
  101. package/src/llama.cpp/src/models/minicpm3.cpp +199 -0
  102. package/src/llama.cpp/src/models/minimax-m2.cpp +124 -0
  103. package/src/llama.cpp/src/models/models.h +481 -0
  104. package/src/llama.cpp/src/models/mpt.cpp +126 -0
  105. package/src/llama.cpp/src/models/nemotron-h.cpp +121 -0
  106. package/src/llama.cpp/src/models/nemotron.cpp +122 -0
  107. package/src/llama.cpp/src/models/neo-bert.cpp +104 -0
  108. package/src/llama.cpp/src/models/olmo.cpp +121 -0
  109. package/src/llama.cpp/src/models/olmo2.cpp +150 -0
  110. package/src/llama.cpp/src/models/olmoe.cpp +124 -0
  111. package/src/llama.cpp/src/models/openai-moe-iswa.cpp +124 -0
  112. package/src/llama.cpp/src/models/openelm.cpp +124 -0
  113. package/src/llama.cpp/src/models/orion.cpp +123 -0
  114. package/src/llama.cpp/src/models/pangu-embedded.cpp +121 -0
  115. package/src/llama.cpp/src/models/phi2.cpp +121 -0
  116. package/src/llama.cpp/src/models/phi3.cpp +152 -0
  117. package/src/llama.cpp/src/models/plamo.cpp +110 -0
  118. package/src/llama.cpp/src/models/plamo2.cpp +316 -0
  119. package/src/llama.cpp/src/models/plm.cpp +168 -0
  120. package/src/llama.cpp/src/models/qwen.cpp +108 -0
  121. package/src/llama.cpp/src/models/qwen2.cpp +117 -0
  122. package/src/llama.cpp/src/models/qwen2moe.cpp +151 -0
  123. package/src/llama.cpp/src/models/qwen2vl.cpp +117 -0
  124. package/src/llama.cpp/src/models/qwen3.cpp +117 -0
  125. package/src/llama.cpp/src/models/qwen3moe.cpp +124 -0
  126. package/src/llama.cpp/src/models/qwen3vl-moe.cpp +149 -0
  127. package/src/llama.cpp/src/models/qwen3vl.cpp +141 -0
  128. package/src/llama.cpp/src/models/refact.cpp +94 -0
  129. package/src/llama.cpp/src/models/rwkv6-base.cpp +162 -0
  130. package/src/llama.cpp/src/models/rwkv6.cpp +94 -0
  131. package/src/llama.cpp/src/models/rwkv6qwen2.cpp +86 -0
  132. package/src/llama.cpp/src/models/rwkv7-base.cpp +135 -0
  133. package/src/llama.cpp/src/models/rwkv7.cpp +90 -0
  134. package/src/llama.cpp/src/models/seed-oss.cpp +124 -0
  135. package/src/llama.cpp/src/models/smallthinker.cpp +120 -0
  136. package/src/llama.cpp/src/models/smollm3.cpp +128 -0
  137. package/src/llama.cpp/src/models/stablelm.cpp +146 -0
  138. package/src/llama.cpp/src/models/starcoder.cpp +100 -0
  139. package/src/llama.cpp/src/models/starcoder2.cpp +121 -0
  140. package/src/llama.cpp/src/models/t5-dec.cpp +166 -0
  141. package/src/llama.cpp/src/models/t5-enc.cpp +96 -0
  142. package/src/llama.cpp/src/models/wavtokenizer-dec.cpp +149 -0
  143. package/src/llama.cpp/src/models/xverse.cpp +108 -0
@@ -1142,7 +1142,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1142
1142
 
1143
1143
  // input embeddings with optional lora
1144
1144
  ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1145
- const int64_t n_embd = hparams.n_embd;
1145
+ const int64_t n_embd = hparams.n_embd_inp();
1146
1146
 
1147
1147
  auto inp = std::make_unique<llm_graph_input_embd>();
1148
1148
 
@@ -1279,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1279
1279
  // return cur;
1280
1280
  //}
1281
1281
 
1282
- const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
1282
+ const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1283
1283
  const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1284
1284
 
1285
1285
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
@@ -2035,7 +2035,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
2035
2035
 
2036
2036
  if (bidirectional) {
2037
2037
  relative_bucket += (relative_position > 0) * n_buckets;
2038
- relative_position = abs(relative_position);
2038
+ relative_position = std::abs(relative_position);
2039
2039
  } else {
2040
2040
  relative_position = -std::min<int32_t>(relative_position, 0);
2041
2041
  }
@@ -60,6 +60,16 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
60
60
  return n_head/n_head_kv;
61
61
  }
62
62
 
63
+ uint32_t llama_hparams::n_embd_inp() const {
64
+ uint32_t n_embd_inp = n_embd;
65
+
66
+ if (n_deepstack_layers > 0) {
67
+ n_embd_inp += n_embd * n_deepstack_layers;
68
+ }
69
+
70
+ return n_embd_inp;
71
+ }
72
+
63
73
  uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
64
74
  const uint32_t n_head_kv = this->n_head_kv(il);
65
75
 
@@ -148,7 +158,7 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
148
158
  }
149
159
 
150
160
  uint32_t llama_hparams::n_pos_per_embd() const {
151
- return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
161
+ return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
152
162
  }
153
163
 
154
164
  bool llama_hparams::is_swa(uint32_t il) const {
@@ -183,6 +183,9 @@ struct llama_hparams {
183
183
  std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
184
184
  std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
185
185
 
186
+ // qwen3vl deepstack
187
+ uint32_t n_deepstack_layers = 0;
188
+
186
189
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
187
190
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
188
191
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@@ -224,6 +227,9 @@ struct llama_hparams {
224
227
 
225
228
  uint32_t n_gqa(uint32_t il = 0) const;
226
229
 
230
+ // dimension of main + auxiliary input embeddings
231
+ uint32_t n_embd_inp() const;
232
+
227
233
  // dimension of key embeddings across all k-v heads
228
234
  uint32_t n_embd_k_gqa(uint32_t il = 0) const;
229
235
 
@@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
45
45
 
46
46
  const uint32_t size_base = kv_size;
47
47
 
48
- uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
48
+ // note: the SWA cache is always padded to 256 for performance
49
+ // https://github.com/ggml-org/llama.cpp/issues/17037
50
+ uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
49
51
 
50
52
  // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
51
53
  if (swa_full) {
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
338
338
  llama_pos pos = v_cells[s0].pos_get(i);
339
339
  llama_pos shift = v_cells[s0].get_shift(i);
340
340
 
341
+ llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
342
+
341
343
  if (shift != 0) {
342
344
  pos -= shift;
343
345
  assert(pos >= 0);
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
349
351
  if (shift != 0) {
350
352
  v_cells[s1].pos_add(i, shift);
351
353
  }
354
+
355
+ v_cells[s1].ext_set(i, ext);
352
356
  }
353
357
  }
354
358
 
@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
383
387
 
384
388
  void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
385
389
  GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
390
+ GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
386
391
 
387
392
  auto & cells = v_cells[seq_to_stream[seq_id]];
388
393
  auto & head = v_heads[seq_to_stream[seq_id]];
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
427
432
 
428
433
  void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
429
434
  GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
435
+ GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
430
436
 
431
437
  auto & cells = v_cells[seq_to_stream[seq_id]];
432
438
 
@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
900
906
 
901
907
  cells.pos_set(idx, ubatch.pos[i]);
902
908
 
909
+ if (ubatch.is_pos_2d()) {
910
+ llama_kv_cell_ext ext {
911
+ /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
912
+ /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
913
+ };
914
+ cells.ext_set(idx, ext);
915
+ }
916
+
903
917
  for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
904
918
  cells.seq_add(idx, ubatch.seq_id[i][s]);
905
919
  }
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
1247
1261
 
1248
1262
  const llama_pos p1 = ubatch->pos[i];
1249
1263
 
1264
+ // for M-RoPE
1265
+ const bool is_2d = ubatch->is_pos_2d();
1266
+ const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1267
+ const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
1268
+
1250
1269
  const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1251
1270
 
1252
1271
  for (uint32_t j = 0; j < n_kv; ++j) {
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
1266
1285
  continue;
1267
1286
  }
1268
1287
 
1288
+ // M-RoPE causal mask
1289
+ if (causal_attn && is_2d && p0 == p1) {
1290
+ const auto & p0_ext = cells.ext_get(j);
1291
+ if (p0_ext.is_2d_gt(p1_x, p1_y)) {
1292
+ continue;
1293
+ }
1294
+ }
1295
+
1269
1296
  // apply SWA if any
1270
1297
  if (is_masked_swa(p0, p1)) {
1271
1298
  continue;
@@ -1348,7 +1375,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
1348
1375
  const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1349
1376
 
1350
1377
  const auto & n_rot = hparams.n_rot;
1351
- const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
1378
+ const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
1352
1379
  // @ngxson : this is a workaround
1353
1380
  // for M-RoPE, we want to rotate the whole vector when doing KV shift
1354
1381
  // a normal RoPE should work, we just need to use the correct ordering
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
1559
1586
  io.write(&pos, sizeof(pos));
1560
1587
  io.write(&n_seq_id, sizeof(n_seq_id));
1561
1588
 
1589
+ // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
1590
+ // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1591
+
1562
1592
  for (const auto & seq_id : seq_ids) {
1563
1593
  io.write(&seq_id, sizeof(seq_id));
1564
1594
  }
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
1704
1734
  return false;
1705
1735
  }
1706
1736
 
1737
+ // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
1738
+ // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1707
1739
  apply_ubatch(sinfo, ubatch);
1708
1740
 
1709
1741
  const auto head_cur = sinfo.head();
@@ -5,9 +5,27 @@
5
5
 
6
6
  #include <bitset>
7
7
  #include <cassert>
8
- #include <vector>
9
- #include <set>
8
+ #include <cstring>
10
9
  #include <map>
10
+ #include <set>
11
+ #include <vector>
12
+
13
+ struct llama_kv_cell_ext {
14
+ // 2D spatial positions, typically used for M-RoPE
15
+ llama_pos x = 0;
16
+ llama_pos y = 0;
17
+
18
+ // return true if the current 2D spatial position is greater than other
19
+ bool is_2d_gt(llama_pos ox, llama_pos oy) const {
20
+ return (y > oy) || (y == oy && x > ox);
21
+ }
22
+
23
+ void reset() {
24
+ static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
25
+
26
+ memset(this, 0, sizeof(*this));
27
+ }
28
+ };
11
29
 
12
30
  // meta information about KV cells that can be part of multiple sequences at the same time
13
31
  // TODO: add unit tests
@@ -16,6 +34,7 @@ public:
16
34
  void reset() {
17
35
  for (uint32_t i = 0; i < pos.size(); ++i) {
18
36
  pos[i] = -1;
37
+ ext[i].reset();
19
38
  shift[i] = 0;
20
39
  seq[i].reset();
21
40
  }
@@ -43,6 +62,7 @@ public:
43
62
 
44
63
  void resize(uint32_t n) {
45
64
  pos.resize(n);
65
+ ext.resize(n);
46
66
  shift.resize(n);
47
67
  seq.resize(n);
48
68
 
@@ -108,6 +128,7 @@ public:
108
128
  const auto idx = i + j;
109
129
 
110
130
  res.pos[j] = pos[idx];
131
+ res.ext[j] = ext[idx];
111
132
  res.seq[j] = seq[idx];
112
133
 
113
134
  assert(shift[idx] == 0);
@@ -126,6 +147,7 @@ public:
126
147
  const auto idx = idxs[j];
127
148
 
128
149
  res.pos[j] = pos[idx];
150
+ res.ext[j] = ext[idx];
129
151
  res.seq[j] = seq[idx];
130
152
 
131
153
  assert(shift[idx] == 0);
@@ -154,6 +176,7 @@ public:
154
176
  }
155
177
 
156
178
  pos[idx] = other.pos[j];
179
+ ext[idx] = other.ext[j];
157
180
  seq[idx] = other.seq[j];
158
181
 
159
182
  if (pos[idx] != -1) {
@@ -184,6 +207,7 @@ public:
184
207
  }
185
208
 
186
209
  pos[idx] = other.pos[j];
210
+ ext[idx] = other.ext[j];
187
211
  seq[idx] = other.seq[j];
188
212
 
189
213
  if (pos[idx] != -1) {
@@ -203,6 +227,7 @@ public:
203
227
  seq[i].reset();
204
228
 
205
229
  pos[i] = -1;
230
+ ext[i].reset();
206
231
  shift[i] = 0;
207
232
 
208
233
  used.erase(i);
@@ -221,6 +246,7 @@ public:
221
246
 
222
247
  if (seq[i].none()) {
223
248
  pos[i] = -1;
249
+ ext[i].reset();
224
250
  shift[i] = 0;
225
251
 
226
252
  used.erase(i);
@@ -250,6 +276,7 @@ public:
250
276
  seq[i].reset();
251
277
 
252
278
  pos[i] = -1;
279
+ ext[i].reset();
253
280
  shift[i] = 0;
254
281
 
255
282
  used.erase(i);
@@ -340,6 +367,13 @@ public:
340
367
  return pos[i];
341
368
  }
342
369
 
370
+ const llama_kv_cell_ext & ext_get(uint32_t i) const {
371
+ assert(i < pos.size());
372
+ assert(pos[i] != -1);
373
+
374
+ return ext[i];
375
+ }
376
+
343
377
  // note: call only if the cell is not empty
344
378
  llama_pos get_shift(uint32_t i) const {
345
379
  assert(i < pos.size());
@@ -368,6 +402,11 @@ public:
368
402
  used.insert(i);
369
403
  }
370
404
 
405
+ void ext_set(uint32_t i, llama_kv_cell_ext p) {
406
+ assert(i < ext.size());
407
+ ext[i] = p;
408
+ }
409
+
371
410
  // pos[i] = pos[i] + d
372
411
  // sets "has_shift" to true
373
412
  // note: call only if the cell is not empty
@@ -424,6 +463,9 @@ private:
424
463
 
425
464
  std::vector<llama_pos> pos;
426
465
 
466
+ // stores extra info per cell
467
+ std::vector<llama_kv_cell_ext> ext;
468
+
427
469
  // this array accumulates any applied shifts to the pos array since the last reset_shift() call
428
470
  // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
429
471
  //
@@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
151
151
  p1 = std::numeric_limits<llama_pos>::max();
152
152
  }
153
153
 
154
- // models like Mamba or RWKV can't have a state partially erased
154
+ // models like Mamba or RWKV can't have a state partially erased at the end
155
+ // of the sequence because their state isn't preserved for previous tokens
155
156
  if (seq_id >= (int64_t) size) {
156
157
  // could be fatal
157
158
  return false;
@@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
160
161
  int32_t & tail_id = cells[seq_id].tail;
161
162
  if (tail_id >= 0) {
162
163
  const auto & cell = cells[tail_id];
163
- // partial intersection is invalid
164
- if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
164
+ // partial intersection is invalid if it includes the final pos
165
+ if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
165
166
  //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
166
167
  return false;
167
168
  }