@fugood/llama.node 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  4. package/src/llama.cpp/common/arg.cpp +44 -0
  5. package/src/llama.cpp/common/common.cpp +22 -6
  6. package/src/llama.cpp/common/common.h +15 -1
  7. package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
  8. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  9. package/src/llama.cpp/ggml/include/ggml.h +104 -10
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  12. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
  19. package/src/llama.cpp/include/llama.h +13 -47
  20. package/src/llama.cpp/src/llama-arch.cpp +298 -3
  21. package/src/llama.cpp/src/llama-arch.h +22 -1
  22. package/src/llama.cpp/src/llama-batch.cpp +103 -71
  23. package/src/llama.cpp/src/llama-batch.h +31 -18
  24. package/src/llama.cpp/src/llama-chat.cpp +59 -1
  25. package/src/llama.cpp/src/llama-chat.h +3 -0
  26. package/src/llama.cpp/src/llama-context.cpp +134 -95
  27. package/src/llama.cpp/src/llama-context.h +13 -16
  28. package/src/llama.cpp/src/llama-cparams.h +3 -2
  29. package/src/llama.cpp/src/llama-graph.cpp +279 -180
  30. package/src/llama.cpp/src/llama-graph.h +183 -122
  31. package/src/llama.cpp/src/llama-hparams.cpp +47 -1
  32. package/src/llama.cpp/src/llama-hparams.h +12 -1
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  34. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  35. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  36. package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  37. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  40. package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
  41. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  42. package/src/llama.cpp/src/llama-memory.h +3 -0
  43. package/src/llama.cpp/src/llama-model.cpp +3373 -743
  44. package/src/llama.cpp/src/llama-model.h +20 -4
  45. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  46. package/src/llama.cpp/src/llama-vocab.cpp +376 -10
  47. package/src/llama.cpp/src/llama-vocab.h +43 -0
  48. package/src/llama.cpp/src/unicode.cpp +207 -0
  49. package/src/llama.cpp/src/unicode.h +2 -0
  50. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
28
28
  }
29
29
  }
30
30
 
31
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
+ bool res = true;
33
+
34
+ res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
+
37
+ return res;
38
+ }
39
+
31
40
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
32
41
  if (ubatch->pos && pos) {
33
42
  const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
50
59
  }
51
60
  }
52
61
 
62
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
+ bool res = true;
64
+
65
+ res &= pos->ne[0] == params.ubatch.n_tokens;
66
+
67
+ return res;
68
+ }
69
+
53
70
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
54
71
  if (ubatch->pos && attn_scale) {
55
72
  const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
71
88
  const int64_t n_tokens = ubatch->n_tokens;
72
89
 
73
90
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
74
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
91
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
75
92
 
76
93
  int32_t * data = (int32_t *) pos_bucket->data;
77
94
 
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
118
135
  }
119
136
  }
120
137
 
138
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
139
+ bool res = true;
140
+
141
+ res &= n_outputs == params.n_outputs;
142
+
143
+ return res;
144
+ }
145
+
121
146
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
122
147
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
123
148
  const int64_t n_tokens = ubatch->n_tokens;
@@ -281,19 +306,64 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
306
  }
282
307
 
283
308
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
- if (self_kq_mask) {
285
- mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286
- }
309
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
310
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
311
+
312
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
313
+ }
314
+
315
+ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
316
+ const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
317
+
318
+ this->mctx = mctx;
319
+
320
+ bool res = true;
321
+
322
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
323
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
324
+
325
+ res &= self_kq_mask->ne[0] == mctx->get_n_kv();
326
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
327
+
328
+ res &= mctx->get_supports_set_rows(); // TODO: tmp
329
+
330
+ return res;
287
331
  }
288
332
 
289
333
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290
- if (self_kq_mask) {
291
- mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292
- }
334
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
335
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
293
336
 
294
- if (self_kq_mask_swa) {
295
- mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296
- }
337
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
+
339
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
340
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
341
+
342
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
343
+ }
344
+
345
+ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
346
+ const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
347
+
348
+ this->mctx = mctx;
349
+
350
+ bool res = true;
351
+
352
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
353
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
354
+
355
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
356
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
357
+
358
+ res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
359
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
360
+
361
+ res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
362
+ res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
363
+
364
+ res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
365
+
366
+ return res;
297
367
  }
298
368
 
299
369
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -303,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
303
373
  const int64_t n_tokens = ubatch->n_tokens;
304
374
 
305
375
  GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
376
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
307
377
 
308
378
  float * data = (float *) cross_kq_mask->data;
309
379
 
@@ -333,27 +403,93 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333
403
  }
334
404
 
335
405
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
- if (self_kq_mask) {
337
- mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
406
+ inp_attn->set_input(ubatch);
407
+ inp_rs->set_input(ubatch);
408
+ }
409
+
410
+ //
411
+ // llm_graph_result
412
+ //
413
+
414
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
415
+ reset();
416
+
417
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
418
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
419
+ }
420
+
421
+ int64_t llm_graph_result::get_max_nodes() const {
422
+ return max_nodes;
423
+ }
424
+
425
+ void llm_graph_result::reset() {
426
+ t_tokens = nullptr;
427
+ t_logits = nullptr;
428
+ t_embd = nullptr;
429
+ t_embd_pooled = nullptr;
430
+
431
+ params = {};
432
+
433
+ inputs.clear();
434
+
435
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
436
+
437
+ ggml_init_params params = {
438
+ /*.mem_size =*/ buf_compute_meta.size(),
439
+ /*.mem_buffer =*/ buf_compute_meta.data(),
440
+ /*.no_alloc =*/ true,
441
+ };
442
+
443
+ ctx_compute.reset(ggml_init(params));
444
+
445
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
446
+ }
447
+
448
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
449
+ for (auto & input : inputs) {
450
+ input->set_input(ubatch);
451
+ }
452
+ }
453
+
454
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
455
+ if (!this->params.allow_reuse(params)) {
456
+ if (debug > 1) {
457
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
458
+ }
459
+
460
+ return false;
338
461
  }
339
462
 
340
- const int64_t n_rs = mctx->get_recr()->get_n_rs();
463
+ if (debug > 1) {
464
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
465
+ }
341
466
 
342
- if (s_copy) {
343
- GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
- int32_t * data = (int32_t *) s_copy->data;
467
+ bool res = true;
345
468
 
346
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
- for (uint32_t i = 0; i < n_rs; ++i) {
348
- data[i] = mctx->get_recr()->s_copy(i);
469
+ for (auto & input : inputs) {
470
+ const bool cur = input->can_reuse(params);
471
+
472
+ if (debug > 1) {
473
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
349
474
  }
475
+
476
+ res = res && cur;
350
477
  }
478
+
479
+ if (debug > 0) {
480
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
481
+ }
482
+
483
+ return res;
484
+ }
485
+
486
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
487
+ inputs.emplace_back(std::move(input));
488
+ return inputs.back().get();
351
489
  }
352
490
 
353
- void llm_graph_input_one::set_input(const llama_ubatch *) {
354
- GGML_ASSERT(one && ggml_nelements(one) == 1);
355
- float f_one = 1.0f;
356
- ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
491
+ void llm_graph_result::set_params(const llm_graph_params & params) {
492
+ this->params = params;
357
493
  }
358
494
 
359
495
  //
@@ -390,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
390
526
  n_ctx_orig (cparams.n_ctx_orig_yarn),
391
527
  pooling_type (cparams.pooling_type),
392
528
  rope_type (hparams.rope_type),
393
- ctx0 (params.ctx),
394
529
  sched (params.sched),
395
530
  backend_cpu (params.backend_cpu),
396
531
  cvec (params.cvec),
@@ -398,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
398
533
  mctx (params.mctx),
399
534
  cross (params.cross),
400
535
  cb_func (params.cb),
401
- res (std::make_unique<llm_graph_result>()) {
536
+ res (params.res),
537
+ ctx0 (res->get_ctx()),
538
+ gf (res->get_gf()) {
539
+ res->set_params(params);
402
540
  }
403
541
 
404
542
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -769,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
769
907
  cb(cur, "ffn_moe_weighted", il);
770
908
  }
771
909
 
910
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
911
+
912
+ assert(n_expert_used > 0);
913
+
914
+ // order the views before the adds
915
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
916
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
917
+
918
+ ggml_build_forward_expand(gf, cur_experts[i]);
919
+ }
920
+
772
921
  // aggregate experts
773
- ggml_tensor * moe_out = nullptr;
774
- for (int i = 0; i < n_expert_used; ++i) {
775
- ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
776
- experts->nb[2], i*experts->nb[1]);
922
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
923
+ // to avoid potentially a large number of add nodes during warmup
924
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
925
+ ggml_tensor * moe_out = cur_experts[0];
777
926
 
778
- if (i == 0) {
779
- moe_out = cur_expert;
780
- } else {
781
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
782
- }
927
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
928
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
783
929
  }
784
930
 
785
- if (n_expert_used == 1) {
931
+ if (hparams.n_expert_used == 1) {
786
932
  // avoid returning a non-contiguous tensor
787
933
  moe_out = ggml_cont(ctx0, moe_out);
788
934
  }
@@ -987,35 +1133,7 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
987
1133
  return pos_bias;
988
1134
  }
989
1135
 
990
- llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
-
993
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
-
995
- {
996
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
-
998
- const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
-
1000
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1002
- ggml_set_input(inp->self_kq_mask);
1003
-
1004
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1005
- }
1006
-
1007
- {
1008
- const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
-
1010
- inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1011
- ggml_set_input(inp->s_copy);
1012
- }
1013
-
1014
- return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1015
- }
1016
-
1017
1136
  ggml_tensor * llm_graph_context::build_attn_mha(
1018
- ggml_cgraph * gf,
1019
1137
  ggml_tensor * q,
1020
1138
  ggml_tensor * k,
1021
1139
  ggml_tensor * v,
@@ -1025,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1025
1143
  float kq_scale) const {
1026
1144
  const bool v_trans = v->nb[1] > v->nb[2];
1027
1145
 
1146
+ // split the batch into streams if needed
1147
+ const auto n_stream = k->ne[3];
1148
+
1149
+ q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
1150
+
1028
1151
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1029
1152
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1030
1153
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1031
1154
 
1032
- const auto n_tokens = q->ne[1];
1033
- const auto n_head = q->ne[2];
1034
- const auto n_kv = k->ne[1];
1155
+ const auto n_kv = k->ne[1];
1035
1156
 
1036
1157
  ggml_tensor * cur;
1037
1158
 
@@ -1073,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1073
1194
  #endif
1074
1195
  }
1075
1196
 
1076
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1197
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1077
1198
  } else {
1078
1199
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1079
1200
 
@@ -1118,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1118
1239
 
1119
1240
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1120
1241
 
1121
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1242
+ // recombine streams
1243
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1122
1244
 
1123
1245
  if (!cparams.offload_kqv) {
1124
1246
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1135,8 +1257,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1135
1257
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1136
1258
 
1137
1259
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1138
- inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1139
- //cb(inp_kq_mask, "KQ_mask", -1);
1260
+ inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1140
1261
  ggml_set_input(inp->kq_mask);
1141
1262
 
1142
1263
  inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1146,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1146
1267
 
1147
1268
  ggml_tensor * llm_graph_context::build_attn(
1148
1269
  llm_graph_input_attn_no_cache * inp,
1149
- ggml_cgraph * gf,
1150
1270
  ggml_tensor * wo,
1151
1271
  ggml_tensor * wo_b,
1152
1272
  ggml_tensor * q_cur,
@@ -1166,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn(
1166
1286
 
1167
1287
  const auto & kq_mask = inp->get_kq_mask();
1168
1288
 
1289
+ // [TAG_NO_CACHE_PAD]
1290
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1291
+ assert(!ubatch.equal_seqs());
1292
+
1169
1293
  ggml_tensor * q = q_cur;
1170
1294
  ggml_tensor * k = k_cur;
1171
1295
  ggml_tensor * v = v_cur;
1172
1296
 
1173
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1297
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1174
1298
  cb(cur, "kqv_out", il);
1175
1299
 
1176
1300
  if (wo) {
@@ -1188,29 +1312,44 @@ ggml_tensor * llm_graph_context::build_attn(
1188
1312
  return cur;
1189
1313
  }
1190
1314
 
1191
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1192
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1315
+ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
1316
+ ggml_context * ctx0,
1317
+ const llama_ubatch & ubatch,
1318
+ const llama_hparams & hparams,
1319
+ const llama_cparams & cparams,
1320
+ const llama_kv_cache_unified_context * mctx_cur) {
1193
1321
 
1194
1322
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1195
1323
 
1196
1324
  {
1197
1325
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1198
1326
 
1199
- const auto n_kv = mctx_cur->get_n_kv();
1327
+ const auto n_kv = mctx_cur->get_n_kv();
1328
+ const auto n_tokens = ubatch.n_tokens;
1329
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1330
+
1331
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1332
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1200
1333
 
1201
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1202
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1334
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1203
1335
  ggml_set_input(inp->self_kq_mask);
1204
1336
 
1205
1337
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1206
1338
  }
1207
1339
 
1340
+ return inp;
1341
+ }
1342
+
1343
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1344
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1345
+
1346
+ auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1347
+
1208
1348
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1209
1349
  }
1210
1350
 
1211
1351
  ggml_tensor * llm_graph_context::build_attn(
1212
1352
  llm_graph_input_attn_kv_unified * inp,
1213
- ggml_cgraph * gf,
1214
1353
  ggml_tensor * wo,
1215
1354
  ggml_tensor * wo_b,
1216
1355
  ggml_tensor * q_cur,
@@ -1226,12 +1365,15 @@ ggml_tensor * llm_graph_context::build_attn(
1226
1365
  ggml_build_forward_expand(gf, k_cur);
1227
1366
  ggml_build_forward_expand(gf, v_cur);
1228
1367
 
1229
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1368
+ const auto * mctx_cur = inp->mctx;
1230
1369
 
1231
1370
  // store to KV cache
1232
1371
  {
1233
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1372
+ const auto & k_idxs = inp->get_k_idxs();
1373
+ const auto & v_idxs = inp->get_v_idxs();
1374
+
1375
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1376
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1235
1377
  }
1236
1378
 
1237
1379
  const auto & kq_mask = inp->get_kq_mask();
@@ -1240,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
1240
1382
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
1383
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1242
1384
 
1243
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1385
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1244
1386
  cb(cur, "kqv_out", il);
1245
1387
 
1246
1388
  if (wo) {
@@ -1260,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn(
1260
1402
 
1261
1403
  ggml_tensor * llm_graph_context::build_attn(
1262
1404
  llm_graph_input_attn_kv_unified_iswa * inp,
1263
- ggml_cgraph * gf,
1264
1405
  ggml_tensor * wo,
1265
1406
  ggml_tensor * wo_b,
1266
1407
  ggml_tensor * q_cur,
@@ -1282,7 +1423,7 @@ ggml_tensor * llm_graph_context::build_attn(
1282
1423
  ggml_build_forward_expand(gf, v_cur);
1283
1424
  }
1284
1425
 
1285
- const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1426
+ const auto * mctx_iswa = inp->mctx;
1286
1427
 
1287
1428
  const bool is_swa = hparams.is_swa(il);
1288
1429
 
@@ -1290,11 +1431,15 @@ ggml_tensor * llm_graph_context::build_attn(
1290
1431
 
1291
1432
  // optionally store to KV cache
1292
1433
  if (k_cur) {
1293
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1434
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1435
+
1436
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1294
1437
  }
1295
1438
 
1296
1439
  if (v_cur) {
1297
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1440
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1441
+
1442
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1298
1443
  }
1299
1444
 
1300
1445
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1303,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn(
1303
1448
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
1449
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1305
1450
 
1306
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1451
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1307
1452
  cb(cur, "kqv_out", il);
1308
1453
 
1309
1454
  if (wo) {
@@ -1326,7 +1471,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1326
1471
 
1327
1472
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1328
1473
 
1329
- inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1474
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1330
1475
  ggml_set_input(inp->cross_kq_mask);
1331
1476
 
1332
1477
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1336,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1336
1481
 
1337
1482
  ggml_tensor * llm_graph_context::build_attn(
1338
1483
  llm_graph_input_attn_cross * inp,
1339
- ggml_cgraph * gf,
1340
1484
  ggml_tensor * wo,
1341
1485
  ggml_tensor * wo_b,
1342
1486
  ggml_tensor * q_cur,
@@ -1358,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn(
1358
1502
  ggml_tensor * k = k_cur;
1359
1503
  ggml_tensor * v = v_cur;
1360
1504
 
1361
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1505
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1362
1506
  cb(cur, "kqv_out", il);
1363
1507
 
1364
1508
  if (wo) {
@@ -1376,66 +1520,23 @@ ggml_tensor * llm_graph_context::build_attn(
1376
1520
  return cur;
1377
1521
  }
1378
1522
 
1379
- ggml_tensor * llm_graph_context::build_attn(
1380
- llm_graph_input_mem_hybrid * inp,
1381
- ggml_cgraph * gf,
1382
- ggml_tensor * wo,
1383
- ggml_tensor * wo_b,
1384
- ggml_tensor * q_cur,
1385
- ggml_tensor * k_cur,
1386
- ggml_tensor * v_cur,
1387
- ggml_tensor * kq_b,
1388
- ggml_tensor * v_mla,
1389
- float kq_scale,
1390
- int il) const {
1391
- // these nodes are added to the graph together so that they are not reordered
1392
- // by doing so, the number of splits in the graph is reduced
1393
- ggml_build_forward_expand(gf, q_cur);
1394
- ggml_build_forward_expand(gf, k_cur);
1395
- ggml_build_forward_expand(gf, v_cur);
1396
-
1397
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1398
-
1399
- // store to KV cache
1400
- {
1401
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
- }
1404
-
1405
- const auto & kq_mask = inp->get_kq_mask();
1406
-
1407
- ggml_tensor * q = q_cur;
1408
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
- ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1410
-
1411
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
- cb(cur, "kqv_out", il);
1413
-
1414
- if (wo) {
1415
- cur = build_lora_mm(wo, cur);
1416
- if (arch == LLM_ARCH_GLM4) {
1417
- // GLM4 seems to have numerical issues with half-precision accumulators
1418
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1419
- }
1420
- }
1421
-
1422
- if (wo_b) {
1423
- cur = ggml_add(ctx0, cur, wo_b);
1424
- }
1425
-
1426
- return cur;
1427
- }
1428
-
1523
+ // TODO: maybe separate the inner implementation into a separate function
1524
+ // like with the non-sliding window equivalent
1525
+ // once sliding-window hybrid caches are a thing.
1429
1526
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
1527
  const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1431
1528
 
1432
1529
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1433
1530
 
1531
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1532
+
1434
1533
  {
1435
1534
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
1535
 
1437
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1536
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1537
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1538
+
1539
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1439
1540
  ggml_set_input(inp->self_kq_mask);
1440
1541
 
1441
1542
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1446,8 +1547,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1446
1547
 
1447
1548
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
1549
 
1449
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1550
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1551
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1552
+
1553
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1451
1554
  ggml_set_input(inp->self_kq_mask_swa);
1452
1555
 
1453
1556
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1457,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1457
1560
  }
1458
1561
 
1459
1562
  ggml_tensor * llm_graph_context::build_rs(
1460
- ggml_cgraph * gf,
1461
1563
  ggml_tensor * s,
1462
1564
  ggml_tensor * state_copy,
1463
1565
  int32_t state_size,
@@ -1466,7 +1568,7 @@ ggml_tensor * llm_graph_context::build_rs(
1466
1568
  uint32_t kv_head,
1467
1569
  uint32_t kv_size,
1468
1570
  int32_t rs_zero,
1469
- bool avoid_copies) const {
1571
+ const llm_graph_get_rows_fn & get_state_rows) const {
1470
1572
 
1471
1573
  ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1472
1574
 
@@ -1475,19 +1577,11 @@ ggml_tensor * llm_graph_context::build_rs(
1475
1577
  ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
1578
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1477
1579
 
1478
- ggml_tensor * output_states;
1479
-
1480
- if (!avoid_copies) {
1481
- // copy states
1482
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
- // {state_size, kv_size} -> {state_size, n_seqs}
1484
- output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
- ggml_build_forward_expand(gf, output_states);
1486
- } else {
1487
- // FIXME: make the gathering operation happen before the copy below
1488
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
- output_states = states;
1490
- }
1580
+ // copy states
1581
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1582
+ // {state_size, kv_size} -> {state_size, n_seqs}
1583
+ ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1584
+ ggml_build_forward_expand(gf, output_states);
1491
1585
 
1492
1586
  // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
1587
  ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
@@ -1499,8 +1593,9 @@ ggml_tensor * llm_graph_context::build_rs(
1499
1593
  return output_states;
1500
1594
  }
1501
1595
 
1502
- llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1596
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1597
+ ggml_context * ctx0,
1598
+ const llama_memory_recurrent_context * mctx_cur) {
1504
1599
 
1505
1600
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
1601
 
@@ -1509,38 +1604,32 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1509
1604
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1510
1605
  ggml_set_input(inp->s_copy);
1511
1606
 
1512
- return (llm_graph_input_rs *) res->add_input(std::move(inp));
1607
+ return inp;
1513
1608
  }
1514
1609
 
1515
- ggml_tensor * llm_graph_context::build_rs(
1516
- llm_graph_input_rs * inp,
1517
- ggml_cgraph * gf,
1518
- ggml_tensor * s,
1519
- int32_t state_size,
1520
- int32_t n_seqs,
1521
- bool avoid_copies) const {
1610
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1522
1611
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
1612
 
1524
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1613
+ auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1614
+
1615
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1525
1616
  }
1526
1617
 
1527
1618
  ggml_tensor * llm_graph_context::build_rs(
1528
- llm_graph_input_mem_hybrid * inp,
1529
- ggml_cgraph * gf,
1619
+ llm_graph_input_rs * inp,
1530
1620
  ggml_tensor * s,
1531
1621
  int32_t state_size,
1532
1622
  int32_t n_seqs,
1533
- bool avoid_copies) const {
1534
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1623
+ const llm_graph_get_rows_fn & get_state_rows) const {
1624
+ const auto * kv_state = inp->mctx;
1535
1625
 
1536
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1626
+ return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1537
1627
  }
1538
1628
 
1539
1629
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1540
1630
  llm_graph_input_rs * inp,
1541
- ggml_cgraph * gf,
1542
1631
  const llama_ubatch & ubatch,
1543
- int il) const {
1632
+ int il) const {
1544
1633
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1545
1634
 
1546
1635
  const auto token_shift_count = hparams.token_shift_count;
@@ -1550,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1550
1639
  ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1551
1640
 
1552
1641
  ggml_tensor * token_shift = build_rs(
1553
- inp, gf, token_shift_all,
1642
+ inp, token_shift_all,
1554
1643
  hparams.n_embd_r(), n_seqs);
1555
1644
 
1556
1645
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1578,8 +1667,18 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1578
1667
  );
1579
1668
  }
1580
1669
 
1670
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1671
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1672
+
1673
+ auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1674
+ auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1675
+
1676
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1677
+
1678
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1679
+ }
1680
+
1581
1681
  void llm_graph_context::build_pooling(
1582
- ggml_cgraph * gf,
1583
1682
  ggml_tensor * cls,
1584
1683
  ggml_tensor * cls_b,
1585
1684
  ggml_tensor * cls_out,