@fugood/llama.node 1.0.2 → 1.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/src/llama.cpp/CMakeLists.txt +0 -1
- package/src/llama.cpp/common/arg.cpp +7 -0
- package/src/llama.cpp/common/common.h +1 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/src/llama.cpp/ggml/include/ggml.h +91 -10
- package/src/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +726 -155
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +9 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +49 -9
- package/src/llama.cpp/include/llama.h +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +90 -2
- package/src/llama.cpp/src/llama-arch.h +6 -0
- package/src/llama.cpp/src/llama-batch.cpp +27 -1
- package/src/llama.cpp/src/llama-batch.h +8 -1
- package/src/llama.cpp/src/llama-chat.cpp +15 -0
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +64 -50
- package/src/llama.cpp/src/llama-graph.h +41 -16
- package/src/llama.cpp/src/llama-hparams.cpp +2 -1
- package/src/llama.cpp/src/llama-hparams.h +1 -0
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/src/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/src/llama.cpp/src/llama-kv-cells.h +62 -10
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +15 -2
- package/src/llama.cpp/src/llama-memory.cpp +17 -0
- package/src/llama.cpp/src/llama-memory.h +3 -0
- package/src/llama.cpp/src/llama-model.cpp +1234 -248
- package/src/llama.cpp/src/llama-model.h +2 -0
- package/src/llama.cpp/src/llama-vocab.cpp +8 -1
- package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
|
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
281
281
|
}
|
|
282
282
|
|
|
283
283
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
284
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
|
285
|
+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
|
286
|
+
|
|
287
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
287
288
|
}
|
|
288
289
|
|
|
289
290
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
}
|
|
291
|
+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
292
|
+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
293
293
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
294
|
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
295
|
+
|
|
296
|
+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
|
297
|
+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
|
298
|
+
|
|
299
|
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
297
300
|
}
|
|
298
301
|
|
|
299
302
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
333
336
|
}
|
|
334
337
|
|
|
335
338
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
+
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
340
|
+
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
341
|
+
|
|
342
|
+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
339
343
|
|
|
340
344
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
341
345
|
|
|
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
|
350
354
|
}
|
|
351
355
|
}
|
|
352
356
|
|
|
353
|
-
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
|
357
|
+
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
|
358
|
+
GGML_UNUSED(ubatch);
|
|
354
359
|
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
|
355
360
|
float f_one = 1.0f;
|
|
356
361
|
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
|
@@ -997,8 +1002,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
|
997
1002
|
|
|
998
1003
|
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
|
999
1004
|
|
|
1000
|
-
inp->
|
|
1001
|
-
|
|
1005
|
+
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
|
1006
|
+
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
|
1007
|
+
|
|
1008
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1002
1009
|
ggml_set_input(inp->self_kq_mask);
|
|
1003
1010
|
|
|
1004
1011
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1135,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
|
1135
1142
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
|
1136
1143
|
|
|
1137
1144
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
|
1138
|
-
inp->kq_mask =
|
|
1139
|
-
//cb(inp_kq_mask, "KQ_mask", -1);
|
|
1145
|
+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1140
1146
|
ggml_set_input(inp->kq_mask);
|
|
1141
1147
|
|
|
1142
1148
|
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
|
@@ -1198,8 +1204,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
|
1198
1204
|
|
|
1199
1205
|
const auto n_kv = mctx_cur->get_n_kv();
|
|
1200
1206
|
|
|
1201
|
-
inp->
|
|
1202
|
-
|
|
1207
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
1208
|
+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
|
1209
|
+
|
|
1210
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1203
1211
|
ggml_set_input(inp->self_kq_mask);
|
|
1204
1212
|
|
|
1205
1213
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1230,8 +1238,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1230
1238
|
|
|
1231
1239
|
// store to KV cache
|
|
1232
1240
|
{
|
|
1233
|
-
|
|
1234
|
-
|
|
1241
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
1242
|
+
const auto & v_idxs = inp->get_v_idxs();
|
|
1243
|
+
|
|
1244
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1245
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1235
1246
|
}
|
|
1236
1247
|
|
|
1237
1248
|
const auto & kq_mask = inp->get_kq_mask();
|
|
@@ -1290,11 +1301,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1290
1301
|
|
|
1291
1302
|
// optionally store to KV cache
|
|
1292
1303
|
if (k_cur) {
|
|
1293
|
-
|
|
1304
|
+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
|
1305
|
+
|
|
1306
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1294
1307
|
}
|
|
1295
1308
|
|
|
1296
1309
|
if (v_cur) {
|
|
1297
|
-
|
|
1310
|
+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
|
1311
|
+
|
|
1312
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1298
1313
|
}
|
|
1299
1314
|
|
|
1300
1315
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
@@ -1326,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
|
1326
1341
|
|
|
1327
1342
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1328
1343
|
|
|
1329
|
-
inp->cross_kq_mask =
|
|
1344
|
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1330
1345
|
ggml_set_input(inp->cross_kq_mask);
|
|
1331
1346
|
|
|
1332
1347
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
|
@@ -1398,8 +1413,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1398
1413
|
|
|
1399
1414
|
// store to KV cache
|
|
1400
1415
|
{
|
|
1401
|
-
|
|
1402
|
-
|
|
1416
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
1417
|
+
const auto & v_idxs = inp->get_v_idxs();
|
|
1418
|
+
|
|
1419
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1420
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1403
1421
|
}
|
|
1404
1422
|
|
|
1405
1423
|
const auto & kq_mask = inp->get_kq_mask();
|
|
@@ -1434,8 +1452,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1434
1452
|
{
|
|
1435
1453
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1436
1454
|
|
|
1437
|
-
inp->
|
|
1438
|
-
|
|
1455
|
+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
1456
|
+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
1457
|
+
|
|
1458
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1439
1459
|
ggml_set_input(inp->self_kq_mask);
|
|
1440
1460
|
|
|
1441
1461
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1446,8 +1466,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1446
1466
|
|
|
1447
1467
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1448
1468
|
|
|
1449
|
-
inp->
|
|
1450
|
-
|
|
1469
|
+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
1470
|
+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
1471
|
+
|
|
1472
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1451
1473
|
ggml_set_input(inp->self_kq_mask_swa);
|
|
1452
1474
|
|
|
1453
1475
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
@@ -1466,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1466
1488
|
uint32_t kv_head,
|
|
1467
1489
|
uint32_t kv_size,
|
|
1468
1490
|
int32_t rs_zero,
|
|
1469
|
-
|
|
1491
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1470
1492
|
|
|
1471
1493
|
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
|
1472
1494
|
|
|
@@ -1475,19 +1497,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1475
1497
|
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
|
1476
1498
|
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
|
1477
1499
|
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1484
|
-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1485
|
-
ggml_build_forward_expand(gf, output_states);
|
|
1486
|
-
} else {
|
|
1487
|
-
// FIXME: make the gathering operation happen before the copy below
|
|
1488
|
-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
|
1489
|
-
output_states = states;
|
|
1490
|
-
}
|
|
1500
|
+
// copy states
|
|
1501
|
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
1502
|
+
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1503
|
+
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1504
|
+
ggml_build_forward_expand(gf, output_states);
|
|
1491
1505
|
|
|
1492
1506
|
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
|
1493
1507
|
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
@@ -1518,10 +1532,10 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1518
1532
|
ggml_tensor * s,
|
|
1519
1533
|
int32_t state_size,
|
|
1520
1534
|
int32_t n_seqs,
|
|
1521
|
-
|
|
1522
|
-
const auto *
|
|
1535
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1536
|
+
const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1523
1537
|
|
|
1524
|
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs,
|
|
1538
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
|
1525
1539
|
}
|
|
1526
1540
|
|
|
1527
1541
|
ggml_tensor * llm_graph_context::build_rs(
|
|
@@ -1530,10 +1544,10 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1530
1544
|
ggml_tensor * s,
|
|
1531
1545
|
int32_t state_size,
|
|
1532
1546
|
int32_t n_seqs,
|
|
1533
|
-
|
|
1534
|
-
const auto *
|
|
1547
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1548
|
+
const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
|
1535
1549
|
|
|
1536
|
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs,
|
|
1550
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
|
1537
1551
|
}
|
|
1538
1552
|
|
|
1539
1553
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
@@ -228,8 +228,8 @@ public:
|
|
|
228
228
|
|
|
229
229
|
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
|
230
230
|
|
|
231
|
-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
|
232
|
-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
|
231
|
+
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
|
232
|
+
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
|
233
233
|
|
|
234
234
|
const llama_hparams & hparams;
|
|
235
235
|
const llama_cparams & cparams;
|
|
@@ -249,10 +249,16 @@ public:
|
|
|
249
249
|
|
|
250
250
|
void set_input(const llama_ubatch * ubatch) override;
|
|
251
251
|
|
|
252
|
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
253
|
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
254
|
+
|
|
252
255
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
253
256
|
|
|
254
|
-
ggml_tensor *
|
|
255
|
-
ggml_tensor *
|
|
257
|
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
258
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
259
|
+
|
|
260
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
261
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
256
262
|
|
|
257
263
|
const llama_hparams & hparams;
|
|
258
264
|
const llama_cparams & cparams;
|
|
@@ -274,13 +280,23 @@ public:
|
|
|
274
280
|
|
|
275
281
|
void set_input(const llama_ubatch * ubatch) override;
|
|
276
282
|
|
|
283
|
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
284
|
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
285
|
+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
|
286
|
+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
|
287
|
+
|
|
277
288
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
278
289
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
|
279
290
|
|
|
280
|
-
ggml_tensor *
|
|
281
|
-
ggml_tensor *
|
|
282
|
-
ggml_tensor *
|
|
283
|
-
ggml_tensor *
|
|
291
|
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
292
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
293
|
+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
|
294
|
+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
|
295
|
+
|
|
296
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
297
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
298
|
+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
299
|
+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
284
300
|
|
|
285
301
|
const llama_hparams & hparams;
|
|
286
302
|
const llama_cparams & cparams;
|
|
@@ -297,8 +313,8 @@ public:
|
|
|
297
313
|
|
|
298
314
|
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
|
299
315
|
|
|
300
|
-
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
|
301
|
-
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
|
316
|
+
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
|
317
|
+
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
|
302
318
|
|
|
303
319
|
const llama_cross * cross = nullptr;
|
|
304
320
|
};
|
|
@@ -319,10 +335,16 @@ public:
|
|
|
319
335
|
|
|
320
336
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
321
337
|
|
|
338
|
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
339
|
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
340
|
+
|
|
322
341
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
323
342
|
|
|
324
|
-
ggml_tensor *
|
|
325
|
-
ggml_tensor *
|
|
343
|
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
344
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
345
|
+
|
|
346
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
|
347
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
|
326
348
|
|
|
327
349
|
const llama_hparams & hparams;
|
|
328
350
|
const llama_cparams & cparams;
|
|
@@ -336,7 +358,7 @@ public:
|
|
|
336
358
|
llm_graph_input_one() {}
|
|
337
359
|
virtual ~llm_graph_input_one() = default;
|
|
338
360
|
|
|
339
|
-
void set_input(const llama_ubatch *) override;
|
|
361
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
340
362
|
|
|
341
363
|
ggml_tensor * one = nullptr; // F32
|
|
342
364
|
};
|
|
@@ -424,6 +446,9 @@ struct llm_graph_params {
|
|
|
424
446
|
const llm_graph_cb & cb;
|
|
425
447
|
};
|
|
426
448
|
|
|
449
|
+
// used in build_rs to properly order writes and avoid unnecessary copies
|
|
450
|
+
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
|
451
|
+
|
|
427
452
|
struct llm_graph_context {
|
|
428
453
|
const llm_arch arch;
|
|
429
454
|
|
|
@@ -663,7 +688,7 @@ struct llm_graph_context {
|
|
|
663
688
|
uint32_t kv_head,
|
|
664
689
|
uint32_t kv_size,
|
|
665
690
|
int32_t rs_zero,
|
|
666
|
-
|
|
691
|
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
|
667
692
|
|
|
668
693
|
llm_graph_input_rs * build_rs_inp() const;
|
|
669
694
|
|
|
@@ -673,7 +698,7 @@ struct llm_graph_context {
|
|
|
673
698
|
ggml_tensor * s,
|
|
674
699
|
int32_t state_size,
|
|
675
700
|
int32_t n_seqs,
|
|
676
|
-
|
|
701
|
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
|
677
702
|
|
|
678
703
|
ggml_tensor * build_rs(
|
|
679
704
|
llm_graph_input_mem_hybrid * inp,
|
|
@@ -681,7 +706,7 @@ struct llm_graph_context {
|
|
|
681
706
|
ggml_tensor * s,
|
|
682
707
|
int32_t state_size,
|
|
683
708
|
int32_t n_seqs,
|
|
684
|
-
|
|
709
|
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
|
685
710
|
|
|
686
711
|
ggml_tensor * build_rwkv_token_shift_load(
|
|
687
712
|
llm_graph_input_rs * inp,
|
|
@@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const {
|
|
|
73
73
|
|
|
74
74
|
// TODO: maybe support other convolution strides than 1
|
|
75
75
|
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
|
76
|
-
|
|
76
|
+
// Corresponds to Mamba's conv_states size
|
|
77
|
+
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
|
|
77
78
|
}
|
|
78
79
|
|
|
79
80
|
uint32_t llama_hparams::n_embd_s() const {
|
|
@@ -113,20 +113,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
113
113
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
114
114
|
}
|
|
115
115
|
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
117
|
+
// failed to find a suitable split
|
|
118
118
|
break;
|
|
119
119
|
}
|
|
120
120
|
|
|
121
|
-
auto
|
|
122
|
-
if (
|
|
121
|
+
auto sinfos_base = kv_base->prepare(ubatches);
|
|
122
|
+
if (sinfos_base.empty()) {
|
|
123
123
|
break;
|
|
124
124
|
}
|
|
125
125
|
|
|
126
|
-
|
|
126
|
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
127
|
+
if (sinfos_swa.empty()) {
|
|
128
|
+
break;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
assert(sinfos_base.size() == sinfos_swa.size());
|
|
127
132
|
|
|
128
133
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
|
129
|
-
this, std::move(
|
|
134
|
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
130
135
|
} while (false);
|
|
131
136
|
|
|
132
137
|
// if it fails, try equal split
|
|
@@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
135
140
|
|
|
136
141
|
std::vector<llama_ubatch> ubatches;
|
|
137
142
|
while (true) {
|
|
138
|
-
auto ubatch = balloc.split_equal(n_ubatch);
|
|
143
|
+
auto ubatch = balloc.split_equal(n_ubatch, false);
|
|
139
144
|
|
|
140
145
|
if (ubatch.n_tokens == 0) {
|
|
141
146
|
break;
|
|
@@ -144,20 +149,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
144
149
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
145
150
|
}
|
|
146
151
|
|
|
147
|
-
|
|
148
|
-
|
|
152
|
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
153
|
+
// failed to find a suitable split
|
|
149
154
|
break;
|
|
150
155
|
}
|
|
151
156
|
|
|
152
|
-
auto
|
|
153
|
-
if (
|
|
157
|
+
auto sinfos_base = kv_base->prepare(ubatches);
|
|
158
|
+
if (sinfos_base.empty()) {
|
|
154
159
|
break;
|
|
155
160
|
}
|
|
156
161
|
|
|
157
|
-
|
|
162
|
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
163
|
+
if (sinfos_swa.empty()) {
|
|
164
|
+
break;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
assert(sinfos_base.size() == sinfos_swa.size());
|
|
158
168
|
|
|
159
169
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
|
160
|
-
this, std::move(
|
|
170
|
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
161
171
|
} while (false);
|
|
162
172
|
|
|
163
173
|
// TODO: if we fail again, we should attempt different splitting strategies
|
|
@@ -220,13 +230,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
|
220
230
|
|
|
221
231
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
222
232
|
llama_kv_cache_unified_iswa * kv,
|
|
223
|
-
|
|
224
|
-
|
|
233
|
+
slot_info_vec_t sinfos_base,
|
|
234
|
+
slot_info_vec_t sinfos_swa,
|
|
225
235
|
std::vector<llama_ubatch> ubatches) :
|
|
226
236
|
ubatches(std::move(ubatches)),
|
|
227
237
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
228
|
-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(
|
|
229
|
-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(
|
|
238
|
+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
|
239
|
+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
|
230
240
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
231
241
|
}
|
|
232
242
|
|
|
@@ -246,7 +256,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
|
|
|
246
256
|
}
|
|
247
257
|
|
|
248
258
|
bool llama_kv_cache_unified_iswa_context::apply() {
|
|
249
|
-
assert(status
|
|
259
|
+
assert(!llama_memory_status_is_fail(status));
|
|
250
260
|
|
|
251
261
|
bool res = true;
|
|
252
262
|
|
|
@@ -74,6 +74,8 @@ private:
|
|
|
74
74
|
|
|
75
75
|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
|
76
76
|
public:
|
|
77
|
+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
|
78
|
+
|
|
77
79
|
// used for errors
|
|
78
80
|
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
|
79
81
|
|
|
@@ -90,8 +92,8 @@ public:
|
|
|
90
92
|
// used to create a batch processing context from a batch
|
|
91
93
|
llama_kv_cache_unified_iswa_context(
|
|
92
94
|
llama_kv_cache_unified_iswa * kv,
|
|
93
|
-
|
|
94
|
-
|
|
95
|
+
slot_info_vec_t sinfos_base,
|
|
96
|
+
slot_info_vec_t sinfos_swa,
|
|
95
97
|
std::vector<llama_ubatch> ubatches);
|
|
96
98
|
|
|
97
99
|
virtual ~llama_kv_cache_unified_iswa_context();
|