@fugood/llama.node 1.0.1 → 1.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +12 -12
- package/src/llama.cpp/CMakeLists.txt +0 -1
- package/src/llama.cpp/common/arg.cpp +17 -0
- package/src/llama.cpp/common/chat.cpp +37 -20
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.h +4 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/src/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +181 -10
- package/src/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +38 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1297 -211
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +103 -9
- package/src/llama.cpp/include/llama.h +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +108 -2
- package/src/llama.cpp/src/llama-arch.h +7 -0
- package/src/llama.cpp/src/llama-batch.cpp +27 -1
- package/src/llama.cpp/src/llama-batch.h +8 -1
- package/src/llama.cpp/src/llama-chat.cpp +15 -0
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +95 -81
- package/src/llama.cpp/src/llama-graph.h +43 -16
- package/src/llama.cpp/src/llama-hparams.cpp +2 -1
- package/src/llama.cpp/src/llama-hparams.h +1 -0
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/src/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/src/llama.cpp/src/llama-kv-cells.h +62 -10
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +34 -16
- package/src/llama.cpp/src/llama-memory.cpp +17 -0
- package/src/llama.cpp/src/llama-memory.h +3 -0
- package/src/llama.cpp/src/llama-model.cpp +1374 -210
- package/src/llama.cpp/src/llama-model.h +3 -0
- package/src/llama.cpp/src/llama-vocab.cpp +8 -1
- package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
|
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
281
281
|
}
|
|
282
282
|
|
|
283
283
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
284
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
|
285
|
+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
|
286
|
+
|
|
287
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
287
288
|
}
|
|
288
289
|
|
|
289
290
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
}
|
|
291
|
+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
292
|
+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
293
293
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
294
|
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
295
|
+
|
|
296
|
+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
|
297
|
+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
|
298
|
+
|
|
299
|
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
297
300
|
}
|
|
298
301
|
|
|
299
302
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
333
336
|
}
|
|
334
337
|
|
|
335
338
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
+
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
340
|
+
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
341
|
+
|
|
342
|
+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
339
343
|
|
|
340
344
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
341
345
|
|
|
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
|
350
354
|
}
|
|
351
355
|
}
|
|
352
356
|
|
|
353
|
-
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
|
357
|
+
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
|
358
|
+
GGML_UNUSED(ubatch);
|
|
354
359
|
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
|
355
360
|
float f_one = 1.0f;
|
|
356
361
|
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
|
@@ -560,12 +565,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
560
565
|
|
|
561
566
|
switch (type_op) {
|
|
562
567
|
case LLM_FFN_SILU:
|
|
563
|
-
{
|
|
568
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
569
|
+
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
|
570
|
+
cb(cur, "ffn_swiglu", il);
|
|
571
|
+
type_gate = LLM_FFN_SEQ;
|
|
572
|
+
} else {
|
|
564
573
|
cur = ggml_silu(ctx0, cur);
|
|
565
574
|
cb(cur, "ffn_silu", il);
|
|
566
575
|
} break;
|
|
567
576
|
case LLM_FFN_GELU:
|
|
568
|
-
{
|
|
577
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
578
|
+
cur = ggml_geglu_split(ctx0, cur, tmp);
|
|
579
|
+
cb(cur, "ffn_geglu", il);
|
|
580
|
+
type_gate = LLM_FFN_SEQ;
|
|
581
|
+
} else {
|
|
569
582
|
cur = ggml_gelu(ctx0, cur);
|
|
570
583
|
cb(cur, "ffn_gelu", il);
|
|
571
584
|
if (act_scales != NULL) {
|
|
@@ -574,7 +587,11 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
574
587
|
}
|
|
575
588
|
} break;
|
|
576
589
|
case LLM_FFN_RELU:
|
|
577
|
-
{
|
|
590
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
591
|
+
cur = ggml_reglu_split(ctx0, cur, tmp);
|
|
592
|
+
cb(cur, "ffn_reglu", il);
|
|
593
|
+
type_gate = LLM_FFN_SEQ;
|
|
594
|
+
} else {
|
|
578
595
|
cur = ggml_relu(ctx0, cur);
|
|
579
596
|
cb(cur, "ffn_relu", il);
|
|
580
597
|
} break;
|
|
@@ -588,32 +605,19 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
588
605
|
} break;
|
|
589
606
|
case LLM_FFN_SWIGLU:
|
|
590
607
|
{
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
594
|
-
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
595
|
-
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
596
|
-
|
|
597
|
-
x0 = ggml_silu(ctx0, x0);
|
|
598
|
-
cb(cur, "ffn_silu", il);
|
|
599
|
-
|
|
600
|
-
cur = ggml_mul(ctx0, x0, x1);
|
|
601
|
-
cb(cur, "ffn_mul", il);
|
|
608
|
+
cur = ggml_swiglu(ctx0, cur);
|
|
609
|
+
cb(cur, "ffn_swiglu", il);
|
|
602
610
|
} break;
|
|
603
611
|
case LLM_FFN_GEGLU:
|
|
604
612
|
{
|
|
605
|
-
|
|
606
|
-
int64_t split_point = cur->ne[0] / 2;
|
|
607
|
-
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
608
|
-
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
609
|
-
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
610
|
-
|
|
611
|
-
x0 = ggml_gelu(ctx0, x0);
|
|
612
|
-
cb(x0, "ffn_gelu", il);
|
|
613
|
-
|
|
614
|
-
cur = ggml_mul(ctx0, x0, x1);
|
|
613
|
+
cur = ggml_geglu(ctx0, cur);
|
|
615
614
|
cb(cur, "ffn_geglu", il);
|
|
616
615
|
} break;
|
|
616
|
+
case LLM_FFN_REGLU:
|
|
617
|
+
{
|
|
618
|
+
cur = ggml_reglu(ctx0, cur);
|
|
619
|
+
cb(cur, "ffn_reglu", il);
|
|
620
|
+
} break;
|
|
617
621
|
}
|
|
618
622
|
|
|
619
623
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -743,12 +747,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
743
747
|
|
|
744
748
|
switch (type_op) {
|
|
745
749
|
case LLM_FFN_SILU:
|
|
746
|
-
{
|
|
750
|
+
if (gate_exps) {
|
|
751
|
+
cur = ggml_swiglu_split(ctx0, cur, up);
|
|
752
|
+
cb(cur, "ffn_moe_swiglu", il);
|
|
753
|
+
} else {
|
|
747
754
|
cur = ggml_silu(ctx0, cur);
|
|
748
755
|
cb(cur, "ffn_moe_silu", il);
|
|
749
756
|
} break;
|
|
750
757
|
case LLM_FFN_GELU:
|
|
751
|
-
{
|
|
758
|
+
if (gate_exps) {
|
|
759
|
+
cur = ggml_geglu_split(ctx0, cur, up);
|
|
760
|
+
cb(cur, "ffn_moe_geglu", il);
|
|
761
|
+
} else {
|
|
752
762
|
cur = ggml_gelu(ctx0, cur);
|
|
753
763
|
cb(cur, "ffn_moe_gelu", il);
|
|
754
764
|
} break;
|
|
@@ -756,11 +766,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
756
766
|
GGML_ABORT("fatal error");
|
|
757
767
|
}
|
|
758
768
|
|
|
759
|
-
if (gate_exps) {
|
|
760
|
-
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
|
761
|
-
cb(cur, "ffn_moe_gate_par", il);
|
|
762
|
-
}
|
|
763
|
-
|
|
764
769
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
|
765
770
|
cb(experts, "ffn_moe_down", il);
|
|
766
771
|
|
|
@@ -997,8 +1002,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
|
997
1002
|
|
|
998
1003
|
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
|
999
1004
|
|
|
1000
|
-
inp->
|
|
1001
|
-
|
|
1005
|
+
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
|
1006
|
+
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
|
1007
|
+
|
|
1008
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1002
1009
|
ggml_set_input(inp->self_kq_mask);
|
|
1003
1010
|
|
|
1004
1011
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1135,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
|
1135
1142
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
|
1136
1143
|
|
|
1137
1144
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
|
1138
|
-
inp->kq_mask =
|
|
1139
|
-
//cb(inp_kq_mask, "KQ_mask", -1);
|
|
1145
|
+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1140
1146
|
ggml_set_input(inp->kq_mask);
|
|
1141
1147
|
|
|
1142
1148
|
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
|
@@ -1198,8 +1204,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
|
1198
1204
|
|
|
1199
1205
|
const auto n_kv = mctx_cur->get_n_kv();
|
|
1200
1206
|
|
|
1201
|
-
inp->
|
|
1202
|
-
|
|
1207
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
1208
|
+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
|
1209
|
+
|
|
1210
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1203
1211
|
ggml_set_input(inp->self_kq_mask);
|
|
1204
1212
|
|
|
1205
1213
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1230,8 +1238,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1230
1238
|
|
|
1231
1239
|
// store to KV cache
|
|
1232
1240
|
{
|
|
1233
|
-
|
|
1234
|
-
|
|
1241
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
1242
|
+
const auto & v_idxs = inp->get_v_idxs();
|
|
1243
|
+
|
|
1244
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1245
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1235
1246
|
}
|
|
1236
1247
|
|
|
1237
1248
|
const auto & kq_mask = inp->get_kq_mask();
|
|
@@ -1290,11 +1301,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1290
1301
|
|
|
1291
1302
|
// optionally store to KV cache
|
|
1292
1303
|
if (k_cur) {
|
|
1293
|
-
|
|
1304
|
+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
|
1305
|
+
|
|
1306
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1294
1307
|
}
|
|
1295
1308
|
|
|
1296
1309
|
if (v_cur) {
|
|
1297
|
-
|
|
1310
|
+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
|
1311
|
+
|
|
1312
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1298
1313
|
}
|
|
1299
1314
|
|
|
1300
1315
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
@@ -1326,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
|
1326
1341
|
|
|
1327
1342
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1328
1343
|
|
|
1329
|
-
inp->cross_kq_mask =
|
|
1344
|
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1330
1345
|
ggml_set_input(inp->cross_kq_mask);
|
|
1331
1346
|
|
|
1332
1347
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
|
@@ -1398,8 +1413,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1398
1413
|
|
|
1399
1414
|
// store to KV cache
|
|
1400
1415
|
{
|
|
1401
|
-
|
|
1402
|
-
|
|
1416
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
1417
|
+
const auto & v_idxs = inp->get_v_idxs();
|
|
1418
|
+
|
|
1419
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1420
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1403
1421
|
}
|
|
1404
1422
|
|
|
1405
1423
|
const auto & kq_mask = inp->get_kq_mask();
|
|
@@ -1434,8 +1452,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1434
1452
|
{
|
|
1435
1453
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1436
1454
|
|
|
1437
|
-
inp->
|
|
1438
|
-
|
|
1455
|
+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
1456
|
+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
1457
|
+
|
|
1458
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1439
1459
|
ggml_set_input(inp->self_kq_mask);
|
|
1440
1460
|
|
|
1441
1461
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1446,8 +1466,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1446
1466
|
|
|
1447
1467
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1448
1468
|
|
|
1449
|
-
inp->
|
|
1450
|
-
|
|
1469
|
+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
1470
|
+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
1471
|
+
|
|
1472
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1451
1473
|
ggml_set_input(inp->self_kq_mask_swa);
|
|
1452
1474
|
|
|
1453
1475
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
@@ -1466,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1466
1488
|
uint32_t kv_head,
|
|
1467
1489
|
uint32_t kv_size,
|
|
1468
1490
|
int32_t rs_zero,
|
|
1469
|
-
|
|
1491
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1470
1492
|
|
|
1471
1493
|
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
|
1472
1494
|
|
|
@@ -1475,19 +1497,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1475
1497
|
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
|
1476
1498
|
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
|
1477
1499
|
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1484
|
-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1485
|
-
ggml_build_forward_expand(gf, output_states);
|
|
1486
|
-
} else {
|
|
1487
|
-
// FIXME: make the gathering operation happen before the copy below
|
|
1488
|
-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
|
1489
|
-
output_states = states;
|
|
1490
|
-
}
|
|
1500
|
+
// copy states
|
|
1501
|
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
1502
|
+
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1503
|
+
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1504
|
+
ggml_build_forward_expand(gf, output_states);
|
|
1491
1505
|
|
|
1492
1506
|
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
|
1493
1507
|
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
@@ -1518,10 +1532,10 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1518
1532
|
ggml_tensor * s,
|
|
1519
1533
|
int32_t state_size,
|
|
1520
1534
|
int32_t n_seqs,
|
|
1521
|
-
|
|
1522
|
-
const auto *
|
|
1535
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1536
|
+
const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1523
1537
|
|
|
1524
|
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs,
|
|
1538
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
|
1525
1539
|
}
|
|
1526
1540
|
|
|
1527
1541
|
ggml_tensor * llm_graph_context::build_rs(
|
|
@@ -1530,10 +1544,10 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1530
1544
|
ggml_tensor * s,
|
|
1531
1545
|
int32_t state_size,
|
|
1532
1546
|
int32_t n_seqs,
|
|
1533
|
-
|
|
1534
|
-
const auto *
|
|
1547
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1548
|
+
const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
|
1535
1549
|
|
|
1536
|
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs,
|
|
1550
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
|
1537
1551
|
}
|
|
1538
1552
|
|
|
1539
1553
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
@@ -38,6 +38,7 @@ enum llm_ffn_op_type {
|
|
|
38
38
|
LLM_FFN_RELU_SQR,
|
|
39
39
|
LLM_FFN_SWIGLU,
|
|
40
40
|
LLM_FFN_GEGLU,
|
|
41
|
+
LLM_FFN_REGLU,
|
|
41
42
|
};
|
|
42
43
|
|
|
43
44
|
enum llm_ffn_gate_type {
|
|
@@ -227,8 +228,8 @@ public:
|
|
|
227
228
|
|
|
228
229
|
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
|
229
230
|
|
|
230
|
-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
|
231
|
-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
|
231
|
+
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
|
232
|
+
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
|
232
233
|
|
|
233
234
|
const llama_hparams & hparams;
|
|
234
235
|
const llama_cparams & cparams;
|
|
@@ -248,10 +249,16 @@ public:
|
|
|
248
249
|
|
|
249
250
|
void set_input(const llama_ubatch * ubatch) override;
|
|
250
251
|
|
|
252
|
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
253
|
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
254
|
+
|
|
251
255
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
252
256
|
|
|
253
|
-
ggml_tensor *
|
|
254
|
-
ggml_tensor *
|
|
257
|
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
258
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
259
|
+
|
|
260
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
261
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
255
262
|
|
|
256
263
|
const llama_hparams & hparams;
|
|
257
264
|
const llama_cparams & cparams;
|
|
@@ -273,13 +280,23 @@ public:
|
|
|
273
280
|
|
|
274
281
|
void set_input(const llama_ubatch * ubatch) override;
|
|
275
282
|
|
|
283
|
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
284
|
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
285
|
+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
|
286
|
+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
|
287
|
+
|
|
276
288
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
277
289
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
|
278
290
|
|
|
279
|
-
ggml_tensor *
|
|
280
|
-
ggml_tensor *
|
|
281
|
-
ggml_tensor *
|
|
282
|
-
ggml_tensor *
|
|
291
|
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
292
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
293
|
+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
|
294
|
+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
|
295
|
+
|
|
296
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
297
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
298
|
+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
299
|
+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
283
300
|
|
|
284
301
|
const llama_hparams & hparams;
|
|
285
302
|
const llama_cparams & cparams;
|
|
@@ -296,8 +313,8 @@ public:
|
|
|
296
313
|
|
|
297
314
|
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
|
298
315
|
|
|
299
|
-
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
|
300
|
-
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
|
316
|
+
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
|
317
|
+
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
|
301
318
|
|
|
302
319
|
const llama_cross * cross = nullptr;
|
|
303
320
|
};
|
|
@@ -318,10 +335,16 @@ public:
|
|
|
318
335
|
|
|
319
336
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
320
337
|
|
|
338
|
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
339
|
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
340
|
+
|
|
321
341
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
322
342
|
|
|
323
|
-
ggml_tensor *
|
|
324
|
-
ggml_tensor *
|
|
343
|
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
344
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
345
|
+
|
|
346
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
347
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
325
348
|
|
|
326
349
|
const llama_hparams & hparams;
|
|
327
350
|
const llama_cparams & cparams;
|
|
@@ -335,7 +358,7 @@ public:
|
|
|
335
358
|
llm_graph_input_one() {}
|
|
336
359
|
virtual ~llm_graph_input_one() = default;
|
|
337
360
|
|
|
338
|
-
void set_input(const llama_ubatch *) override;
|
|
361
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
339
362
|
|
|
340
363
|
ggml_tensor * one = nullptr; // F32
|
|
341
364
|
};
|
|
@@ -423,6 +446,9 @@ struct llm_graph_params {
|
|
|
423
446
|
const llm_graph_cb & cb;
|
|
424
447
|
};
|
|
425
448
|
|
|
449
|
+
// used in build_rs to properly order writes and avoid unnecessary copies
|
|
450
|
+
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
|
451
|
+
|
|
426
452
|
struct llm_graph_context {
|
|
427
453
|
const llm_arch arch;
|
|
428
454
|
|
|
@@ -475,6 +501,7 @@ struct llm_graph_context {
|
|
|
475
501
|
std::unique_ptr<llm_graph_result> res;
|
|
476
502
|
|
|
477
503
|
llm_graph_context(const llm_graph_params & params);
|
|
504
|
+
virtual ~llm_graph_context() = default;
|
|
478
505
|
|
|
479
506
|
void cb(ggml_tensor * cur, const char * name, int il) const;
|
|
480
507
|
|
|
@@ -661,7 +688,7 @@ struct llm_graph_context {
|
|
|
661
688
|
uint32_t kv_head,
|
|
662
689
|
uint32_t kv_size,
|
|
663
690
|
int32_t rs_zero,
|
|
664
|
-
|
|
691
|
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
|
665
692
|
|
|
666
693
|
llm_graph_input_rs * build_rs_inp() const;
|
|
667
694
|
|
|
@@ -671,7 +698,7 @@ struct llm_graph_context {
|
|
|
671
698
|
ggml_tensor * s,
|
|
672
699
|
int32_t state_size,
|
|
673
700
|
int32_t n_seqs,
|
|
674
|
-
|
|
701
|
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
|
675
702
|
|
|
676
703
|
ggml_tensor * build_rs(
|
|
677
704
|
llm_graph_input_mem_hybrid * inp,
|
|
@@ -679,7 +706,7 @@ struct llm_graph_context {
|
|
|
679
706
|
ggml_tensor * s,
|
|
680
707
|
int32_t state_size,
|
|
681
708
|
int32_t n_seqs,
|
|
682
|
-
|
|
709
|
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
|
683
710
|
|
|
684
711
|
ggml_tensor * build_rwkv_token_shift_load(
|
|
685
712
|
llm_graph_input_rs * inp,
|
|
@@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const {
|
|
|
73
73
|
|
|
74
74
|
// TODO: maybe support other convolution strides than 1
|
|
75
75
|
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
|
76
|
-
|
|
76
|
+
// Corresponds to Mamba's conv_states size
|
|
77
|
+
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
|
|
77
78
|
}
|
|
78
79
|
|
|
79
80
|
uint32_t llama_hparams::n_embd_s() const {
|
|
@@ -113,20 +113,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
113
113
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
114
114
|
}
|
|
115
115
|
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
117
|
+
// failed to find a suitable split
|
|
118
118
|
break;
|
|
119
119
|
}
|
|
120
120
|
|
|
121
|
-
auto
|
|
122
|
-
if (
|
|
121
|
+
auto sinfos_base = kv_base->prepare(ubatches);
|
|
122
|
+
if (sinfos_base.empty()) {
|
|
123
123
|
break;
|
|
124
124
|
}
|
|
125
125
|
|
|
126
|
-
|
|
126
|
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
127
|
+
if (sinfos_swa.empty()) {
|
|
128
|
+
break;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
assert(sinfos_base.size() == sinfos_swa.size());
|
|
127
132
|
|
|
128
133
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
|
129
|
-
this, std::move(
|
|
134
|
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
130
135
|
} while (false);
|
|
131
136
|
|
|
132
137
|
// if it fails, try equal split
|
|
@@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
135
140
|
|
|
136
141
|
std::vector<llama_ubatch> ubatches;
|
|
137
142
|
while (true) {
|
|
138
|
-
auto ubatch = balloc.split_equal(n_ubatch);
|
|
143
|
+
auto ubatch = balloc.split_equal(n_ubatch, false);
|
|
139
144
|
|
|
140
145
|
if (ubatch.n_tokens == 0) {
|
|
141
146
|
break;
|
|
@@ -144,20 +149,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
144
149
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
145
150
|
}
|
|
146
151
|
|
|
147
|
-
|
|
148
|
-
|
|
152
|
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
153
|
+
// failed to find a suitable split
|
|
149
154
|
break;
|
|
150
155
|
}
|
|
151
156
|
|
|
152
|
-
auto
|
|
153
|
-
if (
|
|
157
|
+
auto sinfos_base = kv_base->prepare(ubatches);
|
|
158
|
+
if (sinfos_base.empty()) {
|
|
154
159
|
break;
|
|
155
160
|
}
|
|
156
161
|
|
|
157
|
-
|
|
162
|
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
163
|
+
if (sinfos_swa.empty()) {
|
|
164
|
+
break;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
assert(sinfos_base.size() == sinfos_swa.size());
|
|
158
168
|
|
|
159
169
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
|
160
|
-
this, std::move(
|
|
170
|
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
161
171
|
} while (false);
|
|
162
172
|
|
|
163
173
|
// TODO: if we fail again, we should attempt different splitting strategies
|
|
@@ -220,13 +230,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
|
220
230
|
|
|
221
231
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
222
232
|
llama_kv_cache_unified_iswa * kv,
|
|
223
|
-
|
|
224
|
-
|
|
233
|
+
slot_info_vec_t sinfos_base,
|
|
234
|
+
slot_info_vec_t sinfos_swa,
|
|
225
235
|
std::vector<llama_ubatch> ubatches) :
|
|
226
236
|
ubatches(std::move(ubatches)),
|
|
227
237
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
228
|
-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(
|
|
229
|
-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(
|
|
238
|
+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
|
239
|
+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
|
230
240
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
231
241
|
}
|
|
232
242
|
|
|
@@ -246,7 +256,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
|
|
|
246
256
|
}
|
|
247
257
|
|
|
248
258
|
bool llama_kv_cache_unified_iswa_context::apply() {
|
|
249
|
-
assert(status
|
|
259
|
+
assert(!llama_memory_status_is_fail(status));
|
|
250
260
|
|
|
251
261
|
bool res = true;
|
|
252
262
|
|
|
@@ -74,6 +74,8 @@ private:
|
|
|
74
74
|
|
|
75
75
|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
|
76
76
|
public:
|
|
77
|
+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
|
78
|
+
|
|
77
79
|
// used for errors
|
|
78
80
|
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
|
79
81
|
|
|
@@ -90,8 +92,8 @@ public:
|
|
|
90
92
|
// used to create a batch processing context from a batch
|
|
91
93
|
llama_kv_cache_unified_iswa_context(
|
|
92
94
|
llama_kv_cache_unified_iswa * kv,
|
|
93
|
-
|
|
94
|
-
|
|
95
|
+
slot_info_vec_t sinfos_base,
|
|
96
|
+
slot_info_vec_t sinfos_swa,
|
|
95
97
|
std::vector<llama_ubatch> ubatches);
|
|
96
98
|
|
|
97
99
|
virtual ~llama_kv_cache_unified_iswa_context();
|