@fugood/llama.node 1.0.2 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/arg.cpp +7 -0
  4. package/src/llama.cpp/common/common.h +1 -0
  5. package/src/llama.cpp/ggml/CMakeLists.txt +7 -2
  6. package/src/llama.cpp/ggml/include/ggml.h +91 -10
  7. package/src/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  8. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  9. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  10. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +726 -155
  11. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  12. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +9 -9
  14. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +49 -9
  15. package/src/llama.cpp/include/llama.h +1 -0
  16. package/src/llama.cpp/src/llama-arch.cpp +90 -2
  17. package/src/llama.cpp/src/llama-arch.h +6 -0
  18. package/src/llama.cpp/src/llama-batch.cpp +27 -1
  19. package/src/llama.cpp/src/llama-batch.h +8 -1
  20. package/src/llama.cpp/src/llama-chat.cpp +15 -0
  21. package/src/llama.cpp/src/llama-chat.h +1 -0
  22. package/src/llama.cpp/src/llama-graph.cpp +64 -50
  23. package/src/llama.cpp/src/llama-graph.h +41 -16
  24. package/src/llama.cpp/src/llama-hparams.cpp +2 -1
  25. package/src/llama.cpp/src/llama-hparams.h +1 -0
  26. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  27. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  28. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  29. package/src/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  30. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  31. package/src/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  32. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  33. package/src/llama.cpp/src/llama-memory-recurrent.cpp +15 -2
  34. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  35. package/src/llama.cpp/src/llama-memory.h +3 -0
  36. package/src/llama.cpp/src/llama-model.cpp +1234 -248
  37. package/src/llama.cpp/src/llama-model.h +2 -0
  38. package/src/llama.cpp/src/llama-vocab.cpp +8 -1
  39. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
281
  }
282
282
 
283
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
- if (self_kq_mask) {
285
- mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286
- }
284
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
285
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
286
+
287
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
287
288
  }
288
289
 
289
290
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290
- if (self_kq_mask) {
291
- mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292
- }
291
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
293
293
 
294
- if (self_kq_mask_swa) {
295
- mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296
- }
294
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
295
+
296
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
297
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
298
+
299
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
297
300
  }
298
301
 
299
302
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333
336
  }
334
337
 
335
338
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
- if (self_kq_mask) {
337
- mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
- }
339
+ mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
340
+ mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
341
+
342
+ mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
339
343
 
340
344
  const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
345
 
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350
354
  }
351
355
  }
352
356
 
353
- void llm_graph_input_one::set_input(const llama_ubatch *) {
357
+ void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
358
+ GGML_UNUSED(ubatch);
354
359
  GGML_ASSERT(one && ggml_nelements(one) == 1);
355
360
  float f_one = 1.0f;
356
361
  ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@@ -997,8 +1002,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
997
1002
 
998
1003
  const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
1004
 
1000
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1005
+ inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
1006
+ inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
1007
+
1008
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1002
1009
  ggml_set_input(inp->self_kq_mask);
1003
1010
 
1004
1011
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1135,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1135
1142
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1136
1143
 
1137
1144
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1138
- inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1139
- //cb(inp_kq_mask, "KQ_mask", -1);
1145
+ inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1140
1146
  ggml_set_input(inp->kq_mask);
1141
1147
 
1142
1148
  inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1198,8 +1204,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1198
1204
 
1199
1205
  const auto n_kv = mctx_cur->get_n_kv();
1200
1206
 
1201
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1202
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1207
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1208
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1209
+
1210
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1203
1211
  ggml_set_input(inp->self_kq_mask);
1204
1212
 
1205
1213
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1230,8 +1238,11 @@ ggml_tensor * llm_graph_context::build_attn(
1230
1238
 
1231
1239
  // store to KV cache
1232
1240
  {
1233
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1241
+ const auto & k_idxs = inp->get_k_idxs();
1242
+ const auto & v_idxs = inp->get_v_idxs();
1243
+
1244
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1245
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1235
1246
  }
1236
1247
 
1237
1248
  const auto & kq_mask = inp->get_kq_mask();
@@ -1290,11 +1301,15 @@ ggml_tensor * llm_graph_context::build_attn(
1290
1301
 
1291
1302
  // optionally store to KV cache
1292
1303
  if (k_cur) {
1293
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1304
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1305
+
1306
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1294
1307
  }
1295
1308
 
1296
1309
  if (v_cur) {
1297
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1310
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1311
+
1312
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1298
1313
  }
1299
1314
 
1300
1315
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1326,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1326
1341
 
1327
1342
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1328
1343
 
1329
- inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1344
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1330
1345
  ggml_set_input(inp->cross_kq_mask);
1331
1346
 
1332
1347
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1398,8 +1413,11 @@ ggml_tensor * llm_graph_context::build_attn(
1398
1413
 
1399
1414
  // store to KV cache
1400
1415
  {
1401
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1416
+ const auto & k_idxs = inp->get_k_idxs();
1417
+ const auto & v_idxs = inp->get_v_idxs();
1418
+
1419
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1420
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1403
1421
  }
1404
1422
 
1405
1423
  const auto & kq_mask = inp->get_kq_mask();
@@ -1434,8 +1452,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1434
1452
  {
1435
1453
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
1454
 
1437
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1455
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1456
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1457
+
1458
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1439
1459
  ggml_set_input(inp->self_kq_mask);
1440
1460
 
1441
1461
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1446,8 +1466,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1446
1466
 
1447
1467
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
1468
 
1449
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1469
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1470
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1471
+
1472
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1451
1473
  ggml_set_input(inp->self_kq_mask_swa);
1452
1474
 
1453
1475
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1466,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_rs(
1466
1488
  uint32_t kv_head,
1467
1489
  uint32_t kv_size,
1468
1490
  int32_t rs_zero,
1469
- bool avoid_copies) const {
1491
+ const llm_graph_get_rows_fn & get_state_rows) const {
1470
1492
 
1471
1493
  ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1472
1494
 
@@ -1475,19 +1497,11 @@ ggml_tensor * llm_graph_context::build_rs(
1475
1497
  ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
1498
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1477
1499
 
1478
- ggml_tensor * output_states;
1479
-
1480
- if (!avoid_copies) {
1481
- // copy states
1482
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
- // {state_size, kv_size} -> {state_size, n_seqs}
1484
- output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
- ggml_build_forward_expand(gf, output_states);
1486
- } else {
1487
- // FIXME: make the gathering operation happen before the copy below
1488
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
- output_states = states;
1490
- }
1500
+ // copy states
1501
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1502
+ // {state_size, kv_size} -> {state_size, n_seqs}
1503
+ ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1504
+ ggml_build_forward_expand(gf, output_states);
1491
1505
 
1492
1506
  // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
1507
  ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
@@ -1518,10 +1532,10 @@ ggml_tensor * llm_graph_context::build_rs(
1518
1532
  ggml_tensor * s,
1519
1533
  int32_t state_size,
1520
1534
  int32_t n_seqs,
1521
- bool avoid_copies) const {
1522
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1535
+ const llm_graph_get_rows_fn & get_state_rows) const {
1536
+ const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
1537
 
1524
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1538
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1525
1539
  }
1526
1540
 
1527
1541
  ggml_tensor * llm_graph_context::build_rs(
@@ -1530,10 +1544,10 @@ ggml_tensor * llm_graph_context::build_rs(
1530
1544
  ggml_tensor * s,
1531
1545
  int32_t state_size,
1532
1546
  int32_t n_seqs,
1533
- bool avoid_copies) const {
1534
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1547
+ const llm_graph_get_rows_fn & get_state_rows) const {
1548
+ const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1535
1549
 
1536
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1550
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1537
1551
  }
1538
1552
 
1539
1553
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -228,8 +228,8 @@ public:
228
228
 
229
229
  ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
230
230
 
231
- ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
232
- ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
231
+ ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
232
+ ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
233
233
 
234
234
  const llama_hparams & hparams;
235
235
  const llama_cparams & cparams;
@@ -249,10 +249,16 @@ public:
249
249
 
250
250
  void set_input(const llama_ubatch * ubatch) override;
251
251
 
252
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
253
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
254
+
252
255
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
253
256
 
254
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
255
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
257
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
259
+
260
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
261
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
256
262
 
257
263
  const llama_hparams & hparams;
258
264
  const llama_cparams & cparams;
@@ -274,13 +280,23 @@ public:
274
280
 
275
281
  void set_input(const llama_ubatch * ubatch) override;
276
282
 
283
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
284
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
285
+ ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
286
+ ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
287
+
277
288
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
278
289
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
279
290
 
280
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
281
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
282
- ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
283
- ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
291
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
292
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
293
+ ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294
+ ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
295
+
296
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
297
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
298
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
299
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
284
300
 
285
301
  const llama_hparams & hparams;
286
302
  const llama_cparams & cparams;
@@ -297,8 +313,8 @@ public:
297
313
 
298
314
  ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
299
315
 
300
- ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
301
- ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
316
+ ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
317
+ ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
302
318
 
303
319
  const llama_cross * cross = nullptr;
304
320
  };
@@ -319,10 +335,16 @@ public:
319
335
 
320
336
  ggml_tensor * s_copy; // I32 [kv_size]
321
337
 
338
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
339
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
340
+
322
341
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
323
342
 
324
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
325
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
343
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
344
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
345
+
346
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
347
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
326
348
 
327
349
  const llama_hparams & hparams;
328
350
  const llama_cparams & cparams;
@@ -336,7 +358,7 @@ public:
336
358
  llm_graph_input_one() {}
337
359
  virtual ~llm_graph_input_one() = default;
338
360
 
339
- void set_input(const llama_ubatch *) override;
361
+ void set_input(const llama_ubatch * ubatch) override;
340
362
 
341
363
  ggml_tensor * one = nullptr; // F32
342
364
  };
@@ -424,6 +446,9 @@ struct llm_graph_params {
424
446
  const llm_graph_cb & cb;
425
447
  };
426
448
 
449
+ // used in build_rs to properly order writes and avoid unnecessary copies
450
+ using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
451
+
427
452
  struct llm_graph_context {
428
453
  const llm_arch arch;
429
454
 
@@ -663,7 +688,7 @@ struct llm_graph_context {
663
688
  uint32_t kv_head,
664
689
  uint32_t kv_size,
665
690
  int32_t rs_zero,
666
- bool avoid_copies = false) const;
691
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
667
692
 
668
693
  llm_graph_input_rs * build_rs_inp() const;
669
694
 
@@ -673,7 +698,7 @@ struct llm_graph_context {
673
698
  ggml_tensor * s,
674
699
  int32_t state_size,
675
700
  int32_t n_seqs,
676
- bool avoid_copies = false) const;
701
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
677
702
 
678
703
  ggml_tensor * build_rs(
679
704
  llm_graph_input_mem_hybrid * inp,
@@ -681,7 +706,7 @@ struct llm_graph_context {
681
706
  ggml_tensor * s,
682
707
  int32_t state_size,
683
708
  int32_t n_seqs,
684
- bool avoid_copies = false) const;
709
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
685
710
 
686
711
  ggml_tensor * build_rwkv_token_shift_load(
687
712
  llm_graph_input_rs * inp,
@@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const {
73
73
 
74
74
  // TODO: maybe support other convolution strides than 1
75
75
  // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
76
- return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
76
+ // Corresponds to Mamba's conv_states size
77
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
77
78
  }
78
79
 
79
80
  uint32_t llama_hparams::n_embd_s() const {
@@ -114,6 +114,7 @@ struct llama_hparams {
114
114
  uint32_t ssm_d_inner = 0;
115
115
  uint32_t ssm_d_state = 0;
116
116
  uint32_t ssm_dt_rank = 0;
117
+ uint32_t ssm_n_group = 0;
117
118
 
118
119
  // for hybrid state space models
119
120
  std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
@@ -113,20 +113,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113
113
  ubatches.push_back(std::move(ubatch)); // NOLINT
114
114
  }
115
115
 
116
- auto heads_base = kv_base->prepare(ubatches);
117
- if (heads_base.empty()) {
116
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
117
+ // failed to find a suitable split
118
118
  break;
119
119
  }
120
120
 
121
- auto heads_swa = kv_swa->prepare(ubatches);
122
- if (heads_swa.empty()) {
121
+ auto sinfos_base = kv_base->prepare(ubatches);
122
+ if (sinfos_base.empty()) {
123
123
  break;
124
124
  }
125
125
 
126
- assert(heads_base.size() == heads_swa.size());
126
+ auto sinfos_swa = kv_swa->prepare(ubatches);
127
+ if (sinfos_swa.empty()) {
128
+ break;
129
+ }
130
+
131
+ assert(sinfos_base.size() == sinfos_swa.size());
127
132
 
128
133
  return std::make_unique<llama_kv_cache_unified_iswa_context>(
129
- this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
134
+ this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
130
135
  } while (false);
131
136
 
132
137
  // if it fails, try equal split
@@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
135
140
 
136
141
  std::vector<llama_ubatch> ubatches;
137
142
  while (true) {
138
- auto ubatch = balloc.split_equal(n_ubatch);
143
+ auto ubatch = balloc.split_equal(n_ubatch, false);
139
144
 
140
145
  if (ubatch.n_tokens == 0) {
141
146
  break;
@@ -144,20 +149,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144
149
  ubatches.push_back(std::move(ubatch)); // NOLINT
145
150
  }
146
151
 
147
- auto heads_base = kv_base->prepare(ubatches);
148
- if (heads_base.empty()) {
152
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
153
+ // failed to find a suitable split
149
154
  break;
150
155
  }
151
156
 
152
- auto heads_swa = kv_swa->prepare(ubatches);
153
- if (heads_swa.empty()) {
157
+ auto sinfos_base = kv_base->prepare(ubatches);
158
+ if (sinfos_base.empty()) {
154
159
  break;
155
160
  }
156
161
 
157
- assert(heads_base.size() == heads_swa.size());
162
+ auto sinfos_swa = kv_swa->prepare(ubatches);
163
+ if (sinfos_swa.empty()) {
164
+ break;
165
+ }
166
+
167
+ assert(sinfos_base.size() == sinfos_swa.size());
158
168
 
159
169
  return std::make_unique<llama_kv_cache_unified_iswa_context>(
160
- this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
170
+ this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
161
171
  } while (false);
162
172
 
163
173
  // TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +230,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
220
230
 
221
231
  llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222
232
  llama_kv_cache_unified_iswa * kv,
223
- std::vector<uint32_t> heads_base,
224
- std::vector<uint32_t> heads_swa,
233
+ slot_info_vec_t sinfos_base,
234
+ slot_info_vec_t sinfos_swa,
225
235
  std::vector<llama_ubatch> ubatches) :
226
236
  ubatches(std::move(ubatches)),
227
237
  // note: here we copy the ubatches. not sure if this is ideal
228
- ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229
- ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
238
+ ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
239
+ ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
230
240
  status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231
241
  }
232
242
 
@@ -246,7 +256,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
246
256
  }
247
257
 
248
258
  bool llama_kv_cache_unified_iswa_context::apply() {
249
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
259
+ assert(!llama_memory_status_is_fail(status));
250
260
 
251
261
  bool res = true;
252
262
 
@@ -74,6 +74,8 @@ private:
74
74
 
75
75
  class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
76
76
  public:
77
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
78
+
77
79
  // used for errors
78
80
  llama_kv_cache_unified_iswa_context(llama_memory_status status);
79
81
 
@@ -90,8 +92,8 @@ public:
90
92
  // used to create a batch processing context from a batch
91
93
  llama_kv_cache_unified_iswa_context(
92
94
  llama_kv_cache_unified_iswa * kv,
93
- std::vector<uint32_t> heads_base,
94
- std::vector<uint32_t> heads_swa,
95
+ slot_info_vec_t sinfos_base,
96
+ slot_info_vec_t sinfos_swa,
95
97
  std::vector<llama_ubatch> ubatches);
96
98
 
97
99
  virtual ~llama_kv_cache_unified_iswa_context();