@fugood/llama.node 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/lib/binding.ts +1 -0
  2. package/package.json +14 -14
  3. package/src/LlamaCompletionWorker.cpp +24 -4
  4. package/src/LlamaCompletionWorker.h +7 -1
  5. package/src/LlamaContext.cpp +2 -1
  6. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  7. package/src/llama.cpp/common/arg.cpp +37 -0
  8. package/src/llama.cpp/common/common.cpp +22 -6
  9. package/src/llama.cpp/common/common.h +14 -1
  10. package/src/llama.cpp/ggml/CMakeLists.txt +3 -0
  11. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  12. package/src/llama.cpp/ggml/include/ggml.h +13 -0
  13. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  14. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +23 -8
  16. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +39 -0
  18. package/src/llama.cpp/include/llama.h +13 -48
  19. package/src/llama.cpp/src/llama-arch.cpp +222 -15
  20. package/src/llama.cpp/src/llama-arch.h +16 -1
  21. package/src/llama.cpp/src/llama-batch.cpp +76 -70
  22. package/src/llama.cpp/src/llama-batch.h +24 -18
  23. package/src/llama.cpp/src/llama-chat.cpp +44 -1
  24. package/src/llama.cpp/src/llama-chat.h +2 -0
  25. package/src/llama.cpp/src/llama-context.cpp +134 -95
  26. package/src/llama.cpp/src/llama-context.h +13 -16
  27. package/src/llama.cpp/src/llama-cparams.h +3 -2
  28. package/src/llama.cpp/src/llama-graph.cpp +239 -154
  29. package/src/llama.cpp/src/llama-graph.h +162 -126
  30. package/src/llama.cpp/src/llama-hparams.cpp +45 -0
  31. package/src/llama.cpp/src/llama-hparams.h +11 -1
  32. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  34. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  35. package/src/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  36. package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  37. package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -9
  38. package/src/llama.cpp/src/llama-model.cpp +2309 -665
  39. package/src/llama.cpp/src/llama-model.h +18 -4
  40. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  41. package/src/llama.cpp/src/llama-vocab.cpp +368 -9
  42. package/src/llama.cpp/src/llama-vocab.h +43 -0
  43. package/src/llama.cpp/src/unicode.cpp +207 -0
  44. package/src/llama.cpp/src/unicode.h +2 -0
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
28
28
  }
29
29
  }
30
30
 
31
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
+ bool res = true;
33
+
34
+ res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
+
37
+ return res;
38
+ }
39
+
31
40
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
32
41
  if (ubatch->pos && pos) {
33
42
  const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
50
59
  }
51
60
  }
52
61
 
62
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
+ bool res = true;
64
+
65
+ res &= pos->ne[0] == params.ubatch.n_tokens;
66
+
67
+ return res;
68
+ }
69
+
53
70
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
54
71
  if (ubatch->pos && attn_scale) {
55
72
  const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
71
88
  const int64_t n_tokens = ubatch->n_tokens;
72
89
 
73
90
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
74
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
91
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
75
92
 
76
93
  int32_t * data = (int32_t *) pos_bucket->data;
77
94
 
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
118
135
  }
119
136
  }
120
137
 
138
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
139
+ bool res = true;
140
+
141
+ res &= n_outputs == params.n_outputs;
142
+
143
+ return res;
144
+ }
145
+
121
146
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
122
147
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
123
148
  const int64_t n_tokens = ubatch->n_tokens;
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
287
312
  mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
288
313
  }
289
314
 
315
+ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
316
+ const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
317
+
318
+ this->mctx = mctx;
319
+
320
+ bool res = true;
321
+
322
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
323
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
324
+
325
+ res &= self_kq_mask->ne[0] == mctx->get_n_kv();
326
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
327
+
328
+ res &= mctx->get_supports_set_rows(); // TODO: tmp
329
+
330
+ return res;
331
+ }
332
+
290
333
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
291
334
  mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292
335
  mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
299
342
  mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
300
343
  }
301
344
 
345
+ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
346
+ const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
347
+
348
+ this->mctx = mctx;
349
+
350
+ bool res = true;
351
+
352
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
353
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
354
+
355
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
356
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
357
+
358
+ res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
359
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
360
+
361
+ res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
362
+ res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
363
+
364
+ res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
365
+
366
+ return res;
367
+ }
368
+
302
369
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
303
370
  GGML_ASSERT(cross_kq_mask);
304
371
 
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
306
373
  const int64_t n_tokens = ubatch->n_tokens;
307
374
 
308
375
  GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
309
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
376
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
310
377
 
311
378
  float * data = (float *) cross_kq_mask->data;
312
379
 
@@ -336,29 +403,93 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
336
403
  }
337
404
 
338
405
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
339
- mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
340
- mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
406
+ inp_attn->set_input(ubatch);
407
+ inp_rs->set_input(ubatch);
408
+ }
341
409
 
342
- mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
410
+ //
411
+ // llm_graph_result
412
+ //
343
413
 
344
- const int64_t n_rs = mctx->get_recr()->get_n_rs();
414
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
415
+ reset();
345
416
 
346
- if (s_copy) {
347
- GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
348
- int32_t * data = (int32_t *) s_copy->data;
417
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
418
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
419
+ }
349
420
 
350
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
351
- for (uint32_t i = 0; i < n_rs; ++i) {
352
- data[i] = mctx->get_recr()->s_copy(i);
421
+ int64_t llm_graph_result::get_max_nodes() const {
422
+ return max_nodes;
423
+ }
424
+
425
+ void llm_graph_result::reset() {
426
+ t_tokens = nullptr;
427
+ t_logits = nullptr;
428
+ t_embd = nullptr;
429
+ t_embd_pooled = nullptr;
430
+
431
+ params = {};
432
+
433
+ inputs.clear();
434
+
435
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
436
+
437
+ ggml_init_params params = {
438
+ /*.mem_size =*/ buf_compute_meta.size(),
439
+ /*.mem_buffer =*/ buf_compute_meta.data(),
440
+ /*.no_alloc =*/ true,
441
+ };
442
+
443
+ ctx_compute.reset(ggml_init(params));
444
+
445
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
446
+ }
447
+
448
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
449
+ for (auto & input : inputs) {
450
+ input->set_input(ubatch);
451
+ }
452
+ }
453
+
454
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
455
+ if (!this->params.allow_reuse(params)) {
456
+ if (debug > 1) {
457
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
353
458
  }
459
+
460
+ return false;
354
461
  }
462
+
463
+ if (debug > 1) {
464
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
465
+ }
466
+
467
+ bool res = true;
468
+
469
+ for (auto & input : inputs) {
470
+ const bool cur = input->can_reuse(params);
471
+
472
+ if (debug > 1) {
473
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
474
+ }
475
+
476
+ res = res && cur;
477
+ }
478
+
479
+ if (debug > 0) {
480
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
481
+ }
482
+
483
+ return res;
355
484
  }
356
485
 
357
- void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
358
- GGML_UNUSED(ubatch);
359
- GGML_ASSERT(one && ggml_nelements(one) == 1);
360
- float f_one = 1.0f;
361
- ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
486
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
487
+ inputs.emplace_back(std::move(input));
488
+ return inputs.back().get();
489
+ }
490
+
491
+ void llm_graph_result::set_params(const llm_graph_params & params) {
492
+ this->params = params;
362
493
  }
363
494
 
364
495
  //
@@ -395,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
395
526
  n_ctx_orig (cparams.n_ctx_orig_yarn),
396
527
  pooling_type (cparams.pooling_type),
397
528
  rope_type (hparams.rope_type),
398
- ctx0 (params.ctx),
399
529
  sched (params.sched),
400
530
  backend_cpu (params.backend_cpu),
401
531
  cvec (params.cvec),
@@ -403,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
403
533
  mctx (params.mctx),
404
534
  cross (params.cross),
405
535
  cb_func (params.cb),
406
- res (std::make_unique<llm_graph_result>()) {
536
+ res (params.res),
537
+ ctx0 (res->get_ctx()),
538
+ gf (res->get_gf()) {
539
+ res->set_params(params);
407
540
  }
408
541
 
409
542
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -774,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
774
907
  cb(cur, "ffn_moe_weighted", il);
775
908
  }
776
909
 
910
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
911
+
912
+ assert(n_expert_used > 0);
913
+
914
+ // order the views before the adds
915
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
916
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
917
+
918
+ ggml_build_forward_expand(gf, cur_experts[i]);
919
+ }
920
+
777
921
  // aggregate experts
778
- ggml_tensor * moe_out = nullptr;
779
- for (int i = 0; i < n_expert_used; ++i) {
780
- ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
781
- experts->nb[2], i*experts->nb[1]);
922
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
923
+ // to avoid potentially a large number of add nodes during warmup
924
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
925
+ ggml_tensor * moe_out = cur_experts[0];
782
926
 
783
- if (i == 0) {
784
- moe_out = cur_expert;
785
- } else {
786
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
787
- }
927
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
928
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
788
929
  }
789
930
 
790
- if (n_expert_used == 1) {
931
+ if (hparams.n_expert_used == 1) {
791
932
  // avoid returning a non-contiguous tensor
792
933
  moe_out = ggml_cont(ctx0, moe_out);
793
934
  }
@@ -992,37 +1133,7 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
992
1133
  return pos_bias;
993
1134
  }
994
1135
 
995
- llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
996
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
997
-
998
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
999
-
1000
- {
1001
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
1002
-
1003
- const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1004
-
1005
- inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
1006
- inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
1007
-
1008
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1009
- ggml_set_input(inp->self_kq_mask);
1010
-
1011
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1012
- }
1013
-
1014
- {
1015
- const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1016
-
1017
- inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1018
- ggml_set_input(inp->s_copy);
1019
- }
1020
-
1021
- return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1022
- }
1023
-
1024
1136
  ggml_tensor * llm_graph_context::build_attn_mha(
1025
- ggml_cgraph * gf,
1026
1137
  ggml_tensor * q,
1027
1138
  ggml_tensor * k,
1028
1139
  ggml_tensor * v,
@@ -1032,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1032
1143
  float kq_scale) const {
1033
1144
  const bool v_trans = v->nb[1] > v->nb[2];
1034
1145
 
1146
+ // split the batch into streams if needed
1147
+ const auto n_stream = k->ne[3];
1148
+
1149
+ q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
1150
+
1035
1151
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1036
1152
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1037
1153
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1038
1154
 
1039
- const auto n_tokens = q->ne[1];
1040
- const auto n_head = q->ne[2];
1041
- const auto n_kv = k->ne[1];
1155
+ const auto n_kv = k->ne[1];
1042
1156
 
1043
1157
  ggml_tensor * cur;
1044
1158
 
@@ -1080,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1080
1194
  #endif
1081
1195
  }
1082
1196
 
1083
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1197
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1084
1198
  } else {
1085
1199
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1086
1200
 
@@ -1125,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1125
1239
 
1126
1240
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1127
1241
 
1128
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1242
+ // recombine streams
1243
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1129
1244
 
1130
1245
  if (!cparams.offload_kqv) {
1131
1246
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1152,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1152
1267
 
1153
1268
  ggml_tensor * llm_graph_context::build_attn(
1154
1269
  llm_graph_input_attn_no_cache * inp,
1155
- ggml_cgraph * gf,
1156
1270
  ggml_tensor * wo,
1157
1271
  ggml_tensor * wo_b,
1158
1272
  ggml_tensor * q_cur,
@@ -1172,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn(
1172
1286
 
1173
1287
  const auto & kq_mask = inp->get_kq_mask();
1174
1288
 
1289
+ // [TAG_NO_CACHE_PAD]
1290
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1291
+ assert(!ubatch.equal_seqs());
1292
+
1175
1293
  ggml_tensor * q = q_cur;
1176
1294
  ggml_tensor * k = k_cur;
1177
1295
  ggml_tensor * v = v_cur;
1178
1296
 
1179
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1297
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1180
1298
  cb(cur, "kqv_out", il);
1181
1299
 
1182
1300
  if (wo) {
@@ -1194,31 +1312,44 @@ ggml_tensor * llm_graph_context::build_attn(
1194
1312
  return cur;
1195
1313
  }
1196
1314
 
1197
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1198
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1315
+ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
1316
+ ggml_context * ctx0,
1317
+ const llama_ubatch & ubatch,
1318
+ const llama_hparams & hparams,
1319
+ const llama_cparams & cparams,
1320
+ const llama_kv_cache_unified_context * mctx_cur) {
1199
1321
 
1200
1322
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1201
1323
 
1202
1324
  {
1203
1325
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1204
1326
 
1205
- const auto n_kv = mctx_cur->get_n_kv();
1327
+ const auto n_kv = mctx_cur->get_n_kv();
1328
+ const auto n_tokens = ubatch.n_tokens;
1329
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1206
1330
 
1207
1331
  inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1208
1332
  inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1209
1333
 
1210
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1334
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1211
1335
  ggml_set_input(inp->self_kq_mask);
1212
1336
 
1213
1337
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1214
1338
  }
1215
1339
 
1340
+ return inp;
1341
+ }
1342
+
1343
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1344
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1345
+
1346
+ auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1347
+
1216
1348
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1217
1349
  }
1218
1350
 
1219
1351
  ggml_tensor * llm_graph_context::build_attn(
1220
1352
  llm_graph_input_attn_kv_unified * inp,
1221
- ggml_cgraph * gf,
1222
1353
  ggml_tensor * wo,
1223
1354
  ggml_tensor * wo_b,
1224
1355
  ggml_tensor * q_cur,
@@ -1234,7 +1365,7 @@ ggml_tensor * llm_graph_context::build_attn(
1234
1365
  ggml_build_forward_expand(gf, k_cur);
1235
1366
  ggml_build_forward_expand(gf, v_cur);
1236
1367
 
1237
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1368
+ const auto * mctx_cur = inp->mctx;
1238
1369
 
1239
1370
  // store to KV cache
1240
1371
  {
@@ -1251,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
1251
1382
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1252
1383
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1253
1384
 
1254
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1385
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1255
1386
  cb(cur, "kqv_out", il);
1256
1387
 
1257
1388
  if (wo) {
@@ -1271,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn(
1271
1402
 
1272
1403
  ggml_tensor * llm_graph_context::build_attn(
1273
1404
  llm_graph_input_attn_kv_unified_iswa * inp,
1274
- ggml_cgraph * gf,
1275
1405
  ggml_tensor * wo,
1276
1406
  ggml_tensor * wo_b,
1277
1407
  ggml_tensor * q_cur,
@@ -1293,7 +1423,7 @@ ggml_tensor * llm_graph_context::build_attn(
1293
1423
  ggml_build_forward_expand(gf, v_cur);
1294
1424
  }
1295
1425
 
1296
- const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1426
+ const auto * mctx_iswa = inp->mctx;
1297
1427
 
1298
1428
  const bool is_swa = hparams.is_swa(il);
1299
1429
 
@@ -1318,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn(
1318
1448
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1319
1449
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1320
1450
 
1321
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1451
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1322
1452
  cb(cur, "kqv_out", il);
1323
1453
 
1324
1454
  if (wo) {
@@ -1351,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1351
1481
 
1352
1482
  ggml_tensor * llm_graph_context::build_attn(
1353
1483
  llm_graph_input_attn_cross * inp,
1354
- ggml_cgraph * gf,
1355
1484
  ggml_tensor * wo,
1356
1485
  ggml_tensor * wo_b,
1357
1486
  ggml_tensor * q_cur,
@@ -1373,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn(
1373
1502
  ggml_tensor * k = k_cur;
1374
1503
  ggml_tensor * v = v_cur;
1375
1504
 
1376
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1505
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1377
1506
  cb(cur, "kqv_out", il);
1378
1507
 
1379
1508
  if (wo) {
@@ -1391,71 +1520,23 @@ ggml_tensor * llm_graph_context::build_attn(
1391
1520
  return cur;
1392
1521
  }
1393
1522
 
1394
- ggml_tensor * llm_graph_context::build_attn(
1395
- llm_graph_input_mem_hybrid * inp,
1396
- ggml_cgraph * gf,
1397
- ggml_tensor * wo,
1398
- ggml_tensor * wo_b,
1399
- ggml_tensor * q_cur,
1400
- ggml_tensor * k_cur,
1401
- ggml_tensor * v_cur,
1402
- ggml_tensor * kq_b,
1403
- ggml_tensor * v_mla,
1404
- float kq_scale,
1405
- int il) const {
1406
- // these nodes are added to the graph together so that they are not reordered
1407
- // by doing so, the number of splits in the graph is reduced
1408
- ggml_build_forward_expand(gf, q_cur);
1409
- ggml_build_forward_expand(gf, k_cur);
1410
- ggml_build_forward_expand(gf, v_cur);
1411
-
1412
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1413
-
1414
- // store to KV cache
1415
- {
1416
- const auto & k_idxs = inp->get_k_idxs();
1417
- const auto & v_idxs = inp->get_v_idxs();
1418
-
1419
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1420
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1421
- }
1422
-
1423
- const auto & kq_mask = inp->get_kq_mask();
1424
-
1425
- ggml_tensor * q = q_cur;
1426
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1427
- ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1428
-
1429
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1430
- cb(cur, "kqv_out", il);
1431
-
1432
- if (wo) {
1433
- cur = build_lora_mm(wo, cur);
1434
- if (arch == LLM_ARCH_GLM4) {
1435
- // GLM4 seems to have numerical issues with half-precision accumulators
1436
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1437
- }
1438
- }
1439
-
1440
- if (wo_b) {
1441
- cur = ggml_add(ctx0, cur, wo_b);
1442
- }
1443
-
1444
- return cur;
1445
- }
1446
-
1523
+ // TODO: maybe separate the inner implementation into a separate function
1524
+ // like with the non-sliding window equivalent
1525
+ // once sliding-window hybrid caches are a thing.
1447
1526
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1448
1527
  const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1449
1528
 
1450
1529
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1451
1530
 
1531
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1532
+
1452
1533
  {
1453
1534
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1454
1535
 
1455
1536
  inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1456
1537
  inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1457
1538
 
1458
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1539
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1459
1540
  ggml_set_input(inp->self_kq_mask);
1460
1541
 
1461
1542
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1469,7 +1550,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1469
1550
  inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1470
1551
  inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1471
1552
 
1472
- inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1553
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1473
1554
  ggml_set_input(inp->self_kq_mask_swa);
1474
1555
 
1475
1556
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1479,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1479
1560
  }
1480
1561
 
1481
1562
  ggml_tensor * llm_graph_context::build_rs(
1482
- ggml_cgraph * gf,
1483
1563
  ggml_tensor * s,
1484
1564
  ggml_tensor * state_copy,
1485
1565
  int32_t state_size,
@@ -1513,8 +1593,9 @@ ggml_tensor * llm_graph_context::build_rs(
1513
1593
  return output_states;
1514
1594
  }
1515
1595
 
1516
- llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1517
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1596
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1597
+ ggml_context * ctx0,
1598
+ const llama_memory_recurrent_context * mctx_cur) {
1518
1599
 
1519
1600
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1520
1601
 
@@ -1523,38 +1604,32 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1523
1604
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1524
1605
  ggml_set_input(inp->s_copy);
1525
1606
 
1526
- return (llm_graph_input_rs *) res->add_input(std::move(inp));
1607
+ return inp;
1527
1608
  }
1528
1609
 
1529
- ggml_tensor * llm_graph_context::build_rs(
1530
- llm_graph_input_rs * inp,
1531
- ggml_cgraph * gf,
1532
- ggml_tensor * s,
1533
- int32_t state_size,
1534
- int32_t n_seqs,
1535
- const llm_graph_get_rows_fn & get_state_rows) const {
1536
- const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
1610
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1611
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1612
+
1613
+ auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1537
1614
 
1538
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1615
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1539
1616
  }
1540
1617
 
1541
1618
  ggml_tensor * llm_graph_context::build_rs(
1542
- llm_graph_input_mem_hybrid * inp,
1543
- ggml_cgraph * gf,
1619
+ llm_graph_input_rs * inp,
1544
1620
  ggml_tensor * s,
1545
1621
  int32_t state_size,
1546
1622
  int32_t n_seqs,
1547
1623
  const llm_graph_get_rows_fn & get_state_rows) const {
1548
- const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1624
+ const auto * kv_state = inp->mctx;
1549
1625
 
1550
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1626
+ return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1551
1627
  }
1552
1628
 
1553
1629
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1554
1630
  llm_graph_input_rs * inp,
1555
- ggml_cgraph * gf,
1556
1631
  const llama_ubatch & ubatch,
1557
- int il) const {
1632
+ int il) const {
1558
1633
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1559
1634
 
1560
1635
  const auto token_shift_count = hparams.token_shift_count;
@@ -1564,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1564
1639
  ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1565
1640
 
1566
1641
  ggml_tensor * token_shift = build_rs(
1567
- inp, gf, token_shift_all,
1642
+ inp, token_shift_all,
1568
1643
  hparams.n_embd_r(), n_seqs);
1569
1644
 
1570
1645
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1592,8 +1667,18 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1592
1667
  );
1593
1668
  }
1594
1669
 
1670
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1671
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1672
+
1673
+ auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1674
+ auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1675
+
1676
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1677
+
1678
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1679
+ }
1680
+
1595
1681
  void llm_graph_context::build_pooling(
1596
- ggml_cgraph * gf,
1597
1682
  ggml_tensor * cls,
1598
1683
  ggml_tensor * cls_b,
1599
1684
  ggml_tensor * cls_out,