@fugood/llama.node 1.3.0 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +8 -8
- package/src/llama.cpp/common/CMakeLists.txt +2 -0
- package/src/llama.cpp/common/arg.cpp +44 -999
- package/src/llama.cpp/common/arg.h +2 -2
- package/src/llama.cpp/common/chat.cpp +17 -2
- package/src/llama.cpp/common/common.cpp +33 -0
- package/src/llama.cpp/common/common.h +15 -1
- package/src/llama.cpp/common/download.cpp +1054 -0
- package/src/llama.cpp/common/download.h +55 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/include/ggml.h +2 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +29 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +21 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +172 -75
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +0 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +82 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +25 -25
- package/src/llama.cpp/include/llama.h +7 -3
- package/src/llama.cpp/src/CMakeLists.txt +95 -0
- package/src/llama.cpp/src/llama-arch.cpp +108 -0
- package/src/llama.cpp/src/llama-arch.h +11 -0
- package/src/llama.cpp/src/llama-batch.cpp +63 -31
- package/src/llama.cpp/src/llama-batch.h +12 -1
- package/src/llama.cpp/src/llama-chat.cpp +32 -0
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +36 -13
- package/src/llama.cpp/src/llama-context.h +5 -5
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +3 -3
- package/src/llama.cpp/src/llama-hparams.cpp +11 -1
- package/src/llama.cpp/src/llama-hparams.h +6 -0
- package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +3 -1
- package/src/llama.cpp/src/llama-kv-cache.cpp +33 -1
- package/src/llama.cpp/src/llama-kv-cells.h +44 -2
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +4 -3
- package/src/llama.cpp/src/llama-model.cpp +320 -13171
- package/src/llama.cpp/src/llama-model.h +8 -0
- package/src/llama.cpp/src/llama-quant.cpp +1 -1
- package/src/llama.cpp/src/llama-vocab.cpp +5 -0
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/models/apertus.cpp +125 -0
- package/src/llama.cpp/src/models/arcee.cpp +135 -0
- package/src/llama.cpp/src/models/arctic.cpp +138 -0
- package/src/llama.cpp/src/models/arwkv7.cpp +86 -0
- package/src/llama.cpp/src/models/baichuan.cpp +122 -0
- package/src/llama.cpp/src/models/bailingmoe.cpp +144 -0
- package/src/llama.cpp/src/models/bailingmoe2.cpp +135 -0
- package/src/llama.cpp/src/models/bert.cpp +176 -0
- package/src/llama.cpp/src/models/bitnet.cpp +160 -0
- package/src/llama.cpp/src/models/bloom.cpp +101 -0
- package/src/llama.cpp/src/models/chameleon.cpp +178 -0
- package/src/llama.cpp/src/models/chatglm.cpp +132 -0
- package/src/llama.cpp/src/models/codeshell.cpp +111 -0
- package/src/llama.cpp/src/models/cogvlm.cpp +100 -0
- package/src/llama.cpp/src/models/cohere2-iswa.cpp +131 -0
- package/src/llama.cpp/src/models/command-r.cpp +122 -0
- package/src/llama.cpp/src/models/dbrx.cpp +123 -0
- package/src/llama.cpp/src/models/deci.cpp +135 -0
- package/src/llama.cpp/src/models/deepseek.cpp +144 -0
- package/src/llama.cpp/src/models/deepseek2.cpp +236 -0
- package/src/llama.cpp/src/models/dots1.cpp +134 -0
- package/src/llama.cpp/src/models/dream.cpp +105 -0
- package/src/llama.cpp/src/models/ernie4-5-moe.cpp +150 -0
- package/src/llama.cpp/src/models/ernie4-5.cpp +110 -0
- package/src/llama.cpp/src/models/exaone.cpp +114 -0
- package/src/llama.cpp/src/models/exaone4.cpp +123 -0
- package/src/llama.cpp/src/models/falcon-h1.cpp +113 -0
- package/src/llama.cpp/src/models/falcon.cpp +120 -0
- package/src/llama.cpp/src/models/gemma-embedding.cpp +120 -0
- package/src/llama.cpp/src/models/gemma.cpp +112 -0
- package/src/llama.cpp/src/models/gemma2-iswa.cpp +125 -0
- package/src/llama.cpp/src/models/gemma3-iswa.cpp +131 -0
- package/src/llama.cpp/src/models/gemma3n-iswa.cpp +377 -0
- package/src/llama.cpp/src/models/glm4-moe.cpp +153 -0
- package/src/llama.cpp/src/models/glm4.cpp +127 -0
- package/src/llama.cpp/src/models/gpt2.cpp +105 -0
- package/src/llama.cpp/src/models/gptneox.cpp +144 -0
- package/src/llama.cpp/src/models/granite-hybrid.cpp +196 -0
- package/src/llama.cpp/src/models/granite.cpp +211 -0
- package/src/llama.cpp/src/models/graph-context-mamba.cpp +283 -0
- package/src/llama.cpp/src/models/grok.cpp +159 -0
- package/src/llama.cpp/src/models/grovemoe.cpp +141 -0
- package/src/llama.cpp/src/models/hunyuan-dense.cpp +132 -0
- package/src/llama.cpp/src/models/hunyuan-moe.cpp +154 -0
- package/src/llama.cpp/src/models/internlm2.cpp +120 -0
- package/src/llama.cpp/src/models/jais.cpp +86 -0
- package/src/llama.cpp/src/models/jamba.cpp +106 -0
- package/src/llama.cpp/src/models/lfm2.cpp +173 -0
- package/src/llama.cpp/src/models/llada-moe.cpp +122 -0
- package/src/llama.cpp/src/models/llada.cpp +99 -0
- package/src/llama.cpp/src/models/llama-iswa.cpp +174 -0
- package/src/llama.cpp/src/models/llama.cpp +155 -0
- package/src/llama.cpp/src/models/mamba.cpp +55 -0
- package/src/llama.cpp/src/models/minicpm3.cpp +199 -0
- package/src/llama.cpp/src/models/minimax-m2.cpp +124 -0
- package/src/llama.cpp/src/models/models.h +481 -0
- package/src/llama.cpp/src/models/mpt.cpp +126 -0
- package/src/llama.cpp/src/models/nemotron-h.cpp +121 -0
- package/src/llama.cpp/src/models/nemotron.cpp +122 -0
- package/src/llama.cpp/src/models/neo-bert.cpp +104 -0
- package/src/llama.cpp/src/models/olmo.cpp +121 -0
- package/src/llama.cpp/src/models/olmo2.cpp +150 -0
- package/src/llama.cpp/src/models/olmoe.cpp +124 -0
- package/src/llama.cpp/src/models/openai-moe-iswa.cpp +124 -0
- package/src/llama.cpp/src/models/openelm.cpp +124 -0
- package/src/llama.cpp/src/models/orion.cpp +123 -0
- package/src/llama.cpp/src/models/pangu-embedded.cpp +121 -0
- package/src/llama.cpp/src/models/phi2.cpp +121 -0
- package/src/llama.cpp/src/models/phi3.cpp +152 -0
- package/src/llama.cpp/src/models/plamo.cpp +110 -0
- package/src/llama.cpp/src/models/plamo2.cpp +316 -0
- package/src/llama.cpp/src/models/plm.cpp +168 -0
- package/src/llama.cpp/src/models/qwen.cpp +108 -0
- package/src/llama.cpp/src/models/qwen2.cpp +117 -0
- package/src/llama.cpp/src/models/qwen2moe.cpp +151 -0
- package/src/llama.cpp/src/models/qwen2vl.cpp +117 -0
- package/src/llama.cpp/src/models/qwen3.cpp +117 -0
- package/src/llama.cpp/src/models/qwen3moe.cpp +124 -0
- package/src/llama.cpp/src/models/qwen3vl-moe.cpp +149 -0
- package/src/llama.cpp/src/models/qwen3vl.cpp +141 -0
- package/src/llama.cpp/src/models/refact.cpp +94 -0
- package/src/llama.cpp/src/models/rwkv6-base.cpp +162 -0
- package/src/llama.cpp/src/models/rwkv6.cpp +94 -0
- package/src/llama.cpp/src/models/rwkv6qwen2.cpp +86 -0
- package/src/llama.cpp/src/models/rwkv7-base.cpp +135 -0
- package/src/llama.cpp/src/models/rwkv7.cpp +90 -0
- package/src/llama.cpp/src/models/seed-oss.cpp +124 -0
- package/src/llama.cpp/src/models/smallthinker.cpp +120 -0
- package/src/llama.cpp/src/models/smollm3.cpp +128 -0
- package/src/llama.cpp/src/models/stablelm.cpp +146 -0
- package/src/llama.cpp/src/models/starcoder.cpp +100 -0
- package/src/llama.cpp/src/models/starcoder2.cpp +121 -0
- package/src/llama.cpp/src/models/t5-dec.cpp +166 -0
- package/src/llama.cpp/src/models/t5-enc.cpp +96 -0
- package/src/llama.cpp/src/models/wavtokenizer-dec.cpp +149 -0
- package/src/llama.cpp/src/models/xverse.cpp +108 -0
|
@@ -1142,7 +1142,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1142
1142
|
|
|
1143
1143
|
// input embeddings with optional lora
|
|
1144
1144
|
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
1145
|
-
const int64_t n_embd = hparams.
|
|
1145
|
+
const int64_t n_embd = hparams.n_embd_inp();
|
|
1146
1146
|
|
|
1147
1147
|
auto inp = std::make_unique<llm_graph_input_embd>();
|
|
1148
1148
|
|
|
@@ -1279,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
|
|
1279
1279
|
// return cur;
|
|
1280
1280
|
//}
|
|
1281
1281
|
|
|
1282
|
-
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.
|
|
1282
|
+
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
|
1283
1283
|
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1284
1284
|
|
|
1285
1285
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
|
@@ -2035,7 +2035,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
|
|
|
2035
2035
|
|
|
2036
2036
|
if (bidirectional) {
|
|
2037
2037
|
relative_bucket += (relative_position > 0) * n_buckets;
|
|
2038
|
-
relative_position = abs(relative_position);
|
|
2038
|
+
relative_position = std::abs(relative_position);
|
|
2039
2039
|
} else {
|
|
2040
2040
|
relative_position = -std::min<int32_t>(relative_position, 0);
|
|
2041
2041
|
}
|
|
@@ -60,6 +60,16 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
|
|
|
60
60
|
return n_head/n_head_kv;
|
|
61
61
|
}
|
|
62
62
|
|
|
63
|
+
uint32_t llama_hparams::n_embd_inp() const {
|
|
64
|
+
uint32_t n_embd_inp = n_embd;
|
|
65
|
+
|
|
66
|
+
if (n_deepstack_layers > 0) {
|
|
67
|
+
n_embd_inp += n_embd * n_deepstack_layers;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
return n_embd_inp;
|
|
71
|
+
}
|
|
72
|
+
|
|
63
73
|
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
|
|
64
74
|
const uint32_t n_head_kv = this->n_head_kv(il);
|
|
65
75
|
|
|
@@ -148,7 +158,7 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
|
|
|
148
158
|
}
|
|
149
159
|
|
|
150
160
|
uint32_t llama_hparams::n_pos_per_embd() const {
|
|
151
|
-
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
|
161
|
+
return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
|
|
152
162
|
}
|
|
153
163
|
|
|
154
164
|
bool llama_hparams::is_swa(uint32_t il) const {
|
|
@@ -183,6 +183,9 @@ struct llama_hparams {
|
|
|
183
183
|
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
|
|
184
184
|
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
|
|
185
185
|
|
|
186
|
+
// qwen3vl deepstack
|
|
187
|
+
uint32_t n_deepstack_layers = 0;
|
|
188
|
+
|
|
186
189
|
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
|
187
190
|
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
|
188
191
|
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
|
@@ -224,6 +227,9 @@ struct llama_hparams {
|
|
|
224
227
|
|
|
225
228
|
uint32_t n_gqa(uint32_t il = 0) const;
|
|
226
229
|
|
|
230
|
+
// dimension of main + auxiliary input embeddings
|
|
231
|
+
uint32_t n_embd_inp() const;
|
|
232
|
+
|
|
227
233
|
// dimension of key embeddings across all k-v heads
|
|
228
234
|
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
|
|
229
235
|
|
|
@@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|
|
45
45
|
|
|
46
46
|
const uint32_t size_base = kv_size;
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
// note: the SWA cache is always padded to 256 for performance
|
|
49
|
+
// https://github.com/ggml-org/llama.cpp/issues/17037
|
|
50
|
+
uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
|
|
49
51
|
|
|
50
52
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
51
53
|
if (swa_full) {
|
|
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
338
338
|
llama_pos pos = v_cells[s0].pos_get(i);
|
|
339
339
|
llama_pos shift = v_cells[s0].get_shift(i);
|
|
340
340
|
|
|
341
|
+
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
|
|
342
|
+
|
|
341
343
|
if (shift != 0) {
|
|
342
344
|
pos -= shift;
|
|
343
345
|
assert(pos >= 0);
|
|
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
349
351
|
if (shift != 0) {
|
|
350
352
|
v_cells[s1].pos_add(i, shift);
|
|
351
353
|
}
|
|
354
|
+
|
|
355
|
+
v_cells[s1].ext_set(i, ext);
|
|
352
356
|
}
|
|
353
357
|
}
|
|
354
358
|
|
|
@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
|
|
383
387
|
|
|
384
388
|
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
385
389
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
390
|
+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
|
386
391
|
|
|
387
392
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
388
393
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
|
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
|
|
427
432
|
|
|
428
433
|
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
429
434
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
435
|
+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
|
430
436
|
|
|
431
437
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
432
438
|
|
|
@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
|
|
900
906
|
|
|
901
907
|
cells.pos_set(idx, ubatch.pos[i]);
|
|
902
908
|
|
|
909
|
+
if (ubatch.is_pos_2d()) {
|
|
910
|
+
llama_kv_cell_ext ext {
|
|
911
|
+
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
|
|
912
|
+
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
|
|
913
|
+
};
|
|
914
|
+
cells.ext_set(idx, ext);
|
|
915
|
+
}
|
|
916
|
+
|
|
903
917
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
904
918
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
|
905
919
|
}
|
|
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|
|
1247
1261
|
|
|
1248
1262
|
const llama_pos p1 = ubatch->pos[i];
|
|
1249
1263
|
|
|
1264
|
+
// for M-RoPE
|
|
1265
|
+
const bool is_2d = ubatch->is_pos_2d();
|
|
1266
|
+
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
|
1267
|
+
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
|
1268
|
+
|
|
1250
1269
|
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
|
1251
1270
|
|
|
1252
1271
|
for (uint32_t j = 0; j < n_kv; ++j) {
|
|
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|
|
1266
1285
|
continue;
|
|
1267
1286
|
}
|
|
1268
1287
|
|
|
1288
|
+
// M-RoPE causal mask
|
|
1289
|
+
if (causal_attn && is_2d && p0 == p1) {
|
|
1290
|
+
const auto & p0_ext = cells.ext_get(j);
|
|
1291
|
+
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
|
1292
|
+
continue;
|
|
1293
|
+
}
|
|
1294
|
+
}
|
|
1295
|
+
|
|
1269
1296
|
// apply SWA if any
|
|
1270
1297
|
if (is_masked_swa(p0, p1)) {
|
|
1271
1298
|
continue;
|
|
@@ -1348,7 +1375,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|
|
1348
1375
|
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
|
1349
1376
|
|
|
1350
1377
|
const auto & n_rot = hparams.n_rot;
|
|
1351
|
-
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
|
|
1378
|
+
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
|
1352
1379
|
// @ngxson : this is a workaround
|
|
1353
1380
|
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
|
1354
1381
|
// a normal RoPE should work, we just need to use the correct ordering
|
|
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
|
|
|
1559
1586
|
io.write(&pos, sizeof(pos));
|
|
1560
1587
|
io.write(&n_seq_id, sizeof(n_seq_id));
|
|
1561
1588
|
|
|
1589
|
+
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
|
|
1590
|
+
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
1591
|
+
|
|
1562
1592
|
for (const auto & seq_id : seq_ids) {
|
|
1563
1593
|
io.write(&seq_id, sizeof(seq_id));
|
|
1564
1594
|
}
|
|
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1704
1734
|
return false;
|
|
1705
1735
|
}
|
|
1706
1736
|
|
|
1737
|
+
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
|
|
1738
|
+
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
1707
1739
|
apply_ubatch(sinfo, ubatch);
|
|
1708
1740
|
|
|
1709
1741
|
const auto head_cur = sinfo.head();
|
|
@@ -5,9 +5,27 @@
|
|
|
5
5
|
|
|
6
6
|
#include <bitset>
|
|
7
7
|
#include <cassert>
|
|
8
|
-
#include <
|
|
9
|
-
#include <set>
|
|
8
|
+
#include <cstring>
|
|
10
9
|
#include <map>
|
|
10
|
+
#include <set>
|
|
11
|
+
#include <vector>
|
|
12
|
+
|
|
13
|
+
struct llama_kv_cell_ext {
|
|
14
|
+
// 2D spatial positions, typically used for M-RoPE
|
|
15
|
+
llama_pos x = 0;
|
|
16
|
+
llama_pos y = 0;
|
|
17
|
+
|
|
18
|
+
// return true if the current 2D spatial position is greater than other
|
|
19
|
+
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
|
|
20
|
+
return (y > oy) || (y == oy && x > ox);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
void reset() {
|
|
24
|
+
static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
|
|
25
|
+
|
|
26
|
+
memset(this, 0, sizeof(*this));
|
|
27
|
+
}
|
|
28
|
+
};
|
|
11
29
|
|
|
12
30
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
|
13
31
|
// TODO: add unit tests
|
|
@@ -16,6 +34,7 @@ public:
|
|
|
16
34
|
void reset() {
|
|
17
35
|
for (uint32_t i = 0; i < pos.size(); ++i) {
|
|
18
36
|
pos[i] = -1;
|
|
37
|
+
ext[i].reset();
|
|
19
38
|
shift[i] = 0;
|
|
20
39
|
seq[i].reset();
|
|
21
40
|
}
|
|
@@ -43,6 +62,7 @@ public:
|
|
|
43
62
|
|
|
44
63
|
void resize(uint32_t n) {
|
|
45
64
|
pos.resize(n);
|
|
65
|
+
ext.resize(n);
|
|
46
66
|
shift.resize(n);
|
|
47
67
|
seq.resize(n);
|
|
48
68
|
|
|
@@ -108,6 +128,7 @@ public:
|
|
|
108
128
|
const auto idx = i + j;
|
|
109
129
|
|
|
110
130
|
res.pos[j] = pos[idx];
|
|
131
|
+
res.ext[j] = ext[idx];
|
|
111
132
|
res.seq[j] = seq[idx];
|
|
112
133
|
|
|
113
134
|
assert(shift[idx] == 0);
|
|
@@ -126,6 +147,7 @@ public:
|
|
|
126
147
|
const auto idx = idxs[j];
|
|
127
148
|
|
|
128
149
|
res.pos[j] = pos[idx];
|
|
150
|
+
res.ext[j] = ext[idx];
|
|
129
151
|
res.seq[j] = seq[idx];
|
|
130
152
|
|
|
131
153
|
assert(shift[idx] == 0);
|
|
@@ -154,6 +176,7 @@ public:
|
|
|
154
176
|
}
|
|
155
177
|
|
|
156
178
|
pos[idx] = other.pos[j];
|
|
179
|
+
ext[idx] = other.ext[j];
|
|
157
180
|
seq[idx] = other.seq[j];
|
|
158
181
|
|
|
159
182
|
if (pos[idx] != -1) {
|
|
@@ -184,6 +207,7 @@ public:
|
|
|
184
207
|
}
|
|
185
208
|
|
|
186
209
|
pos[idx] = other.pos[j];
|
|
210
|
+
ext[idx] = other.ext[j];
|
|
187
211
|
seq[idx] = other.seq[j];
|
|
188
212
|
|
|
189
213
|
if (pos[idx] != -1) {
|
|
@@ -203,6 +227,7 @@ public:
|
|
|
203
227
|
seq[i].reset();
|
|
204
228
|
|
|
205
229
|
pos[i] = -1;
|
|
230
|
+
ext[i].reset();
|
|
206
231
|
shift[i] = 0;
|
|
207
232
|
|
|
208
233
|
used.erase(i);
|
|
@@ -221,6 +246,7 @@ public:
|
|
|
221
246
|
|
|
222
247
|
if (seq[i].none()) {
|
|
223
248
|
pos[i] = -1;
|
|
249
|
+
ext[i].reset();
|
|
224
250
|
shift[i] = 0;
|
|
225
251
|
|
|
226
252
|
used.erase(i);
|
|
@@ -250,6 +276,7 @@ public:
|
|
|
250
276
|
seq[i].reset();
|
|
251
277
|
|
|
252
278
|
pos[i] = -1;
|
|
279
|
+
ext[i].reset();
|
|
253
280
|
shift[i] = 0;
|
|
254
281
|
|
|
255
282
|
used.erase(i);
|
|
@@ -340,6 +367,13 @@ public:
|
|
|
340
367
|
return pos[i];
|
|
341
368
|
}
|
|
342
369
|
|
|
370
|
+
const llama_kv_cell_ext & ext_get(uint32_t i) const {
|
|
371
|
+
assert(i < pos.size());
|
|
372
|
+
assert(pos[i] != -1);
|
|
373
|
+
|
|
374
|
+
return ext[i];
|
|
375
|
+
}
|
|
376
|
+
|
|
343
377
|
// note: call only if the cell is not empty
|
|
344
378
|
llama_pos get_shift(uint32_t i) const {
|
|
345
379
|
assert(i < pos.size());
|
|
@@ -368,6 +402,11 @@ public:
|
|
|
368
402
|
used.insert(i);
|
|
369
403
|
}
|
|
370
404
|
|
|
405
|
+
void ext_set(uint32_t i, llama_kv_cell_ext p) {
|
|
406
|
+
assert(i < ext.size());
|
|
407
|
+
ext[i] = p;
|
|
408
|
+
}
|
|
409
|
+
|
|
371
410
|
// pos[i] = pos[i] + d
|
|
372
411
|
// sets "has_shift" to true
|
|
373
412
|
// note: call only if the cell is not empty
|
|
@@ -424,6 +463,9 @@ private:
|
|
|
424
463
|
|
|
425
464
|
std::vector<llama_pos> pos;
|
|
426
465
|
|
|
466
|
+
// stores extra info per cell
|
|
467
|
+
std::vector<llama_kv_cell_ext> ext;
|
|
468
|
+
|
|
427
469
|
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
|
428
470
|
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
|
429
471
|
//
|
|
@@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
|
151
151
|
p1 = std::numeric_limits<llama_pos>::max();
|
|
152
152
|
}
|
|
153
153
|
|
|
154
|
-
// models like Mamba or RWKV can't have a state partially erased
|
|
154
|
+
// models like Mamba or RWKV can't have a state partially erased at the end
|
|
155
|
+
// of the sequence because their state isn't preserved for previous tokens
|
|
155
156
|
if (seq_id >= (int64_t) size) {
|
|
156
157
|
// could be fatal
|
|
157
158
|
return false;
|
|
@@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
|
160
161
|
int32_t & tail_id = cells[seq_id].tail;
|
|
161
162
|
if (tail_id >= 0) {
|
|
162
163
|
const auto & cell = cells[tail_id];
|
|
163
|
-
// partial intersection is invalid
|
|
164
|
-
if (
|
|
164
|
+
// partial intersection is invalid if it includes the final pos
|
|
165
|
+
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
|
165
166
|
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
|
|
166
167
|
return false;
|
|
167
168
|
}
|