@fugood/llama.node 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  4. package/src/llama.cpp/common/arg.cpp +44 -0
  5. package/src/llama.cpp/common/common.cpp +22 -6
  6. package/src/llama.cpp/common/common.h +15 -1
  7. package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
  8. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  9. package/src/llama.cpp/ggml/include/ggml.h +104 -10
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  12. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
  19. package/src/llama.cpp/include/llama.h +13 -47
  20. package/src/llama.cpp/src/llama-arch.cpp +298 -3
  21. package/src/llama.cpp/src/llama-arch.h +22 -1
  22. package/src/llama.cpp/src/llama-batch.cpp +103 -71
  23. package/src/llama.cpp/src/llama-batch.h +31 -18
  24. package/src/llama.cpp/src/llama-chat.cpp +59 -1
  25. package/src/llama.cpp/src/llama-chat.h +3 -0
  26. package/src/llama.cpp/src/llama-context.cpp +134 -95
  27. package/src/llama.cpp/src/llama-context.h +13 -16
  28. package/src/llama.cpp/src/llama-cparams.h +3 -2
  29. package/src/llama.cpp/src/llama-graph.cpp +279 -180
  30. package/src/llama.cpp/src/llama-graph.h +183 -122
  31. package/src/llama.cpp/src/llama-hparams.cpp +47 -1
  32. package/src/llama.cpp/src/llama-hparams.h +12 -1
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  34. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  35. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  36. package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  37. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  40. package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
  41. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  42. package/src/llama.cpp/src/llama-memory.h +3 -0
  43. package/src/llama.cpp/src/llama-model.cpp +3373 -743
  44. package/src/llama.cpp/src/llama-model.h +20 -4
  45. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  46. package/src/llama.cpp/src/llama-vocab.cpp +376 -10
  47. package/src/llama.cpp/src/llama-vocab.h +43 -0
  48. package/src/llama.cpp/src/unicode.cpp +207 -0
  49. package/src/llama.cpp/src/unicode.h +2 -0
  50. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -98,10 +98,20 @@ llama_context::llama_context(
98
98
  LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
99
  cparams.n_batch = GGML_KQ_MASK_PAD;
100
100
  }
101
-
102
101
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
103
102
 
104
103
  cparams.op_offload = params.op_offload;
104
+ cparams.kv_unified = params.kv_unified;
105
+
106
+ {
107
+ const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
108
+ const bool supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
109
+
110
+ if (!supports_set_rows && !cparams.kv_unified) {
111
+ LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
112
+ cparams.kv_unified = true;
113
+ }
114
+ }
105
115
 
106
116
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
107
117
 
@@ -112,6 +122,7 @@ llama_context::llama_context(
112
122
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
113
123
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
114
124
  LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
125
+ LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
115
126
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
116
127
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
117
128
 
@@ -227,8 +238,8 @@ llama_context::llama_context(
227
238
 
228
239
  LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
229
240
 
230
- // buffer used to store the computation graph and the tensor meta data
231
- buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
241
+ gf_res_prev.reset(new llm_graph_result(max_nodes));
242
+ gf_res_reserve.reset(new llm_graph_result(max_nodes));
232
243
 
233
244
  // TODO: move these checks to ggml_backend_sched
234
245
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -267,7 +278,7 @@ llama_context::llama_context(
267
278
 
268
279
  // reserve worst-case graph
269
280
  if (!hparams.vocab_only && memory) {
270
- const uint32_t n_seqs = cparams.n_seq_max;
281
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
271
282
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
272
283
 
273
284
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -300,7 +311,7 @@ llama_context::llama_context(
300
311
 
301
312
  // reserve with tg graph to get the number of splits and nodes
302
313
  {
303
- auto * gf = graph_reserve(1, 1, 1, mctx.get());
314
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
304
315
  if (!gf) {
305
316
  throw std::runtime_error("failed to allocate compute tg buffers");
306
317
  }
@@ -311,6 +322,10 @@ llama_context::llama_context(
311
322
 
312
323
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
313
324
  {
325
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
326
+ //
327
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
328
+ //
314
329
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
330
  if (!gf) {
316
331
  throw std::runtime_error("failed to allocate compute pp buffers");
@@ -388,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388
403
  return sched.get();
389
404
  }
390
405
 
391
- ggml_context * llama_context::get_ctx_compute() const {
392
- return ctx_compute.get();
393
- }
394
-
395
406
  uint32_t llama_context::n_ctx() const {
396
407
  return cparams.n_ctx;
397
408
  }
@@ -463,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) {
463
474
  }
464
475
  }
465
476
 
477
+ // reset the previous graph result to make sure that it won't be reused
478
+ // TODO: change the mctx->apply() to return information if a graph reserve is needed
479
+ // reset the graph result only if the memory module did reset the scheduler
480
+ gf_res_prev->reset();
481
+
466
482
  if (!mctx->apply()) {
467
483
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
484
  }
@@ -475,7 +491,7 @@ bool llama_context::kv_self_update(bool optimize) {
475
491
  throw std::runtime_error("failed to initialize memory context");
476
492
  }
477
493
 
478
- const uint32_t n_seqs = cparams.n_seq_max;
494
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
479
495
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
496
 
481
497
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -678,38 +694,59 @@ bool llama_context::apply_adapter_cvec(
678
694
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
679
695
  }
680
696
 
681
- llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
697
+ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
698
  if (mctx && !mctx->apply()) {
683
699
  LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
700
  ret = GGML_STATUS_FAILED;
685
701
  return nullptr;
686
702
  }
687
703
 
688
- auto * gf = graph_init();
689
- if (!gf) {
690
- LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
- ret = GGML_STATUS_FAILED;
692
- return nullptr;
693
- }
704
+ auto * res = gf_res_prev.get();
705
+ auto * gf = res->get_gf();
694
706
 
695
- auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
- if (!res) {
697
- LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
- ret = GGML_STATUS_FAILED;
699
- return nullptr;
700
- }
707
+ // the new graph parameters
708
+ // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
709
+ const auto gparams = graph_params(res, ubatch, mctx, gtype);
701
710
 
702
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
711
+ if (res->can_reuse(gparams)) {
712
+ //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
703
713
 
704
- if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
- LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
- ret = GGML_STATUS_ALLOC_FAILED;
707
- return nullptr;
714
+ n_reused++;
715
+ } else {
716
+ res->reset();
717
+
718
+ ggml_backend_sched_reset(sched.get());
719
+ ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
720
+
721
+ //const auto t_start_us = ggml_time_us();
722
+
723
+ gf = model.build_graph(gparams);
724
+
725
+ //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
726
+
727
+ if (!gf) {
728
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
729
+ ret = GGML_STATUS_FAILED;
730
+ return nullptr;
731
+ }
732
+
733
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
734
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
735
+ ret = GGML_STATUS_ALLOC_FAILED;
736
+ return nullptr;
737
+ }
708
738
  }
709
739
 
710
- res->set_inputs(&ubatch);
740
+ // set the input data for the input tensors
741
+ {
742
+ //const auto t_start_us = ggml_time_us();
743
+
744
+ res->set_inputs(&ubatch);
745
+
746
+ //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
747
+ }
711
748
 
712
- const auto status = graph_compute(gf, ubatch.n_tokens > 1);
749
+ const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
713
750
  if (status != GGML_STATUS_SUCCESS) {
714
751
  LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
752
  ret = status;
@@ -731,16 +768,19 @@ int llama_context::encode(const llama_batch & batch_inp) {
731
768
 
732
769
  const auto & hparams = model.hparams;
733
770
 
734
- const int64_t n_embd = hparams.n_embd;
771
+ const int64_t n_embd = hparams.n_embd;
772
+ const int32_t n_vocab = model.vocab.n_tokens();
735
773
 
736
774
  // note: during encode, we always pass the full sequence starting from pos = 0
737
- if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
775
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
738
776
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
777
  return -1;
740
778
  }
741
779
 
742
780
  const uint32_t n_tokens = balloc->get_n_tokens();
743
781
 
782
+ // [TAG_NO_CACHE_PAD]
783
+ // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
744
784
  const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745
785
 
746
786
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@@ -767,9 +807,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
767
807
 
768
808
  n_outputs = n_tokens;
769
809
 
770
- ggml_backend_sched_reset(sched.get());
771
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
772
-
773
810
  const auto causal_attn_org = cparams.causal_attn;
774
811
 
775
812
  // always use non-causal attention for encoder graphs
@@ -778,7 +815,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778
815
  cparams.causal_attn = false;
779
816
 
780
817
  ggml_status status;
781
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
818
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
782
819
 
783
820
  cparams.causal_attn = causal_attn_org;
784
821
 
@@ -791,10 +828,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
791
828
  }
792
829
  }
793
830
 
831
+ auto * t_logits = res->get_logits();
794
832
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
795
833
 
834
+ // extract logits
835
+ if (logits && t_logits) {
836
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
837
+ GGML_ASSERT(backend_res != nullptr);
838
+ GGML_ASSERT(logits != nullptr);
839
+
840
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
841
+ }
842
+
796
843
  // extract embeddings
797
- if (t_embd) {
844
+ if (embd && t_embd) {
798
845
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
799
846
  GGML_ASSERT(backend_embd != nullptr);
800
847
 
@@ -844,10 +891,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
844
891
  }
845
892
  }
846
893
 
847
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848
- // overlap with device computation.
849
- ggml_backend_sched_reset(sched.get());
850
-
851
894
  // TODO: hacky solution
852
895
  if (model.arch == LLM_ARCH_T5 && t_embd) {
853
896
  //cross.t_embd = t_embd;
@@ -899,7 +942,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
899
942
  // when computing embeddings, all tokens are output
900
943
  const bool output_all = cparams.embeddings;
901
944
 
902
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
945
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
903
946
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
947
  return -1;
905
948
  }
@@ -1005,11 +1048,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
1005
1048
  n_outputs = n_outputs_new;
1006
1049
  }
1007
1050
 
1008
- ggml_backend_sched_reset(sched.get());
1009
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010
-
1011
1051
  ggml_status status;
1012
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1052
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1013
1053
 
1014
1054
  if (!res) {
1015
1055
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1190,10 +1230,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1190
1230
  // wait for the computation to finish (automatically done when obtaining the model output)
1191
1231
  //synchronize();
1192
1232
 
1193
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1194
- // overlap with device computation.
1195
- ggml_backend_sched_reset(sched.get());
1196
-
1197
1233
  return 0;
1198
1234
  }
1199
1235
 
@@ -1275,20 +1311,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1275
1311
  // graph
1276
1312
  //
1277
1313
 
1278
- int32_t llama_context::graph_max_nodes() const {
1279
- return std::max<int32_t>(65536, 5*model.n_tensors());
1314
+ uint32_t llama_context::graph_max_nodes() const {
1315
+ return std::max<uint32_t>(1024u, 8u*model.n_tensors());
1280
1316
  }
1281
1317
 
1282
- ggml_cgraph * llama_context::graph_init() {
1283
- ggml_init_params params = {
1284
- /*.mem_size =*/ buf_compute_meta.size(),
1285
- /*.mem_buffer =*/ buf_compute_meta.data(),
1286
- /*.no_alloc =*/ true,
1287
- };
1288
-
1289
- ctx_compute.reset(ggml_init(params));
1290
-
1291
- return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1318
+ llm_graph_result * llama_context::get_gf_res_reserve() const {
1319
+ return static_cast<llm_graph_result *>(gf_res_reserve.get());
1292
1320
  }
1293
1321
 
1294
1322
  ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1329,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1301
1329
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
1330
  }
1303
1331
 
1332
+ ggml_backend_sched_reset(sched.get());
1333
+
1334
+ // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
1335
+ gf_res_prev->reset();
1336
+
1304
1337
  // store the n_outputs as it is, and restore it afterwards
1305
1338
  // TODO: not sure if needed, might simplify in the future by removing this
1306
1339
  const auto save_n_outputs = this->n_outputs;
@@ -1310,17 +1343,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1310
1343
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311
1344
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1312
1345
 
1313
- auto * gf = graph_init();
1314
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1346
+ auto * res = gf_res_reserve.get();
1315
1347
 
1316
- this->n_outputs = save_n_outputs;
1348
+ const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
1317
1349
 
1318
- if (!res) {
1319
- LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320
- return nullptr;
1321
- }
1350
+ res->reset();
1322
1351
 
1323
- ggml_backend_sched_reset(sched.get());
1352
+ auto * gf = model.build_graph(gparams);
1353
+
1354
+ this->n_outputs = save_n_outputs;
1324
1355
 
1325
1356
  // initialize scheduler with the specified graph
1326
1357
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
@@ -1331,28 +1362,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1331
1362
  return gf;
1332
1363
  }
1333
1364
 
1334
- llm_graph_result_ptr llama_context::graph_build(
1335
- ggml_context * ctx,
1336
- ggml_cgraph * gf,
1337
- const llama_ubatch & ubatch,
1338
- llm_graph_type gtype,
1339
- const llama_memory_context_i * mctx) {
1340
- return model.build_graph(
1341
- {
1342
- /*.ctx =*/ ctx,
1343
- /*.arch =*/ model.arch,
1344
- /*.hparams =*/ model.hparams,
1345
- /*.cparams =*/ cparams,
1346
- /*.ubatch =*/ ubatch,
1347
- /*.sched =*/ sched.get(),
1348
- /*.backend_cpu =*/ backend_cpu,
1349
- /*.cvec =*/ &cvec,
1350
- /*.loras =*/ &loras,
1351
- /*.mctx =*/ mctx,
1352
- /*.cross =*/ &cross,
1353
- /*.n_outputs =*/ n_outputs,
1354
- /*.cb =*/ graph_get_cb(),
1355
- }, gf, gtype);
1365
+ llm_graph_params llama_context::graph_params(
1366
+ llm_graph_result * res,
1367
+ const llama_ubatch & ubatch,
1368
+ const llama_memory_context_i * mctx,
1369
+ llm_graph_type gtype) const {
1370
+ return {
1371
+ /*.arch =*/ model.arch,
1372
+ /*.hparams =*/ model.hparams,
1373
+ /*.cparams =*/ cparams,
1374
+ /*.ubatch =*/ ubatch,
1375
+ /*.gtype =*/ gtype,
1376
+ /*.sched =*/ sched.get(),
1377
+ /*.backend_cpu =*/ backend_cpu,
1378
+ /*.cvec =*/ &cvec,
1379
+ /*.loras =*/ &loras,
1380
+ /*.mctx =*/ mctx,
1381
+ /*.cross =*/ &cross,
1382
+ /*.n_outputs =*/ n_outputs,
1383
+ /*.cb =*/ graph_get_cb(),
1384
+ /*.res =*/ res,
1385
+ };
1356
1386
  }
1357
1387
 
1358
1388
  ggml_status llama_context::graph_compute(
@@ -1930,6 +1960,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
1930
1960
  data.t_eval_ms = 1e-3 * t_eval_us;
1931
1961
  data.n_p_eval = std::max(1, n_p_eval);
1932
1962
  data.n_eval = std::max(1, n_eval);
1963
+ data.n_reused = std::max(0, n_reused);
1933
1964
 
1934
1965
  return data;
1935
1966
  }
@@ -1938,6 +1969,7 @@ void llama_context::perf_reset() {
1938
1969
  t_start_us = ggml_time_us();
1939
1970
  t_eval_us = n_eval = 0;
1940
1971
  t_p_eval_us = n_p_eval = 0;
1972
+ n_reused = 0;
1941
1973
  }
1942
1974
 
1943
1975
  //
@@ -2028,7 +2060,7 @@ void llama_context::opt_epoch_iter(
2028
2060
  batch.logits [pos_batch] = true;
2029
2061
  }
2030
2062
 
2031
- if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2063
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2032
2064
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2033
2065
  return;
2034
2066
  }
@@ -2064,8 +2096,13 @@ void llama_context::opt_epoch_iter(
2064
2096
  break;
2065
2097
  }
2066
2098
 
2067
- auto * gf = graph_init();
2068
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2099
+ auto * res = gf_res_prev.get();
2100
+
2101
+ const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2102
+
2103
+ res->reset();
2104
+
2105
+ auto * gf = model.build_graph(gparams);
2069
2106
 
2070
2107
  struct ggml_context * ctx_compute_opt;
2071
2108
  {
@@ -2187,6 +2224,7 @@ llama_context_params llama_context_default_params() {
2187
2224
  /*.no_perf =*/ true,
2188
2225
  /*.op_offload =*/ true,
2189
2226
  /*.swa_full =*/ true,
2227
+ /*.kv_unified =*/ false,
2190
2228
  };
2191
2229
 
2192
2230
  return result;
@@ -2807,6 +2845,7 @@ void llama_perf_context_print(const llama_context * ctx) {
2807
2845
  LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2808
2846
  __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2809
2847
  LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2848
+ LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
2810
2849
  }
2811
2850
 
2812
2851
  void llama_perf_context_reset(llama_context * ctx) {
@@ -35,8 +35,6 @@ struct llama_context {
35
35
 
36
36
  ggml_backend_sched_t get_sched() const;
37
37
 
38
- ggml_context * get_ctx_compute() const;
39
-
40
38
  uint32_t n_ctx() const;
41
39
  uint32_t n_ctx_per_seq() const;
42
40
  uint32_t n_batch() const;
@@ -96,7 +94,7 @@ struct llama_context {
96
94
  // if memory_context is provided, it will be applied first to the context's memory
97
95
  // ret contains the status of the graph computation
98
96
  // returns nullptr only if ret != GGML_STATUS_SUCCESS
99
- llm_graph_result_ptr process_ubatch(
97
+ llm_graph_result * process_ubatch(
100
98
  const llama_ubatch & ubatch,
101
99
  llm_graph_type gtype,
102
100
  llama_memory_context_i * mctx,
@@ -188,10 +186,10 @@ private:
188
186
  //
189
187
 
190
188
  public:
191
- int32_t graph_max_nodes() const;
189
+ uint32_t graph_max_nodes() const;
192
190
 
193
- // zero-out inputs and create the ctx_compute for the compute graph
194
- ggml_cgraph * graph_init();
191
+ // can reuse the llm_graph_result instance of the context (for example to update a memory module)
192
+ llm_graph_result * get_gf_res_reserve() const;
195
193
 
196
194
  // returns the result of ggml_backend_sched_graph_compute_async execution
197
195
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
@@ -200,12 +198,11 @@ public:
200
198
  ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201
199
 
202
200
  private:
203
- llm_graph_result_ptr graph_build(
204
- ggml_context * ctx,
205
- ggml_cgraph * gf,
206
- const llama_ubatch & ubatch,
207
- llm_graph_type gtype,
208
- const llama_memory_context_i * mctx);
201
+ llm_graph_params graph_params(
202
+ llm_graph_result * res,
203
+ const llama_ubatch & ubatch,
204
+ const llama_memory_context_i * mctx,
205
+ llm_graph_type gtype) const;
209
206
 
210
207
  llm_graph_cb graph_get_cb() const;
211
208
 
@@ -258,8 +255,6 @@ private:
258
255
  ggml_backend_t backend_cpu = nullptr;
259
256
  std::vector<ggml_backend_ptr> backends;
260
257
 
261
- ggml_context_ptr ctx_compute;
262
-
263
258
  // training
264
259
  ggml_opt_context_t opt_ctx = nullptr;
265
260
 
@@ -275,8 +270,8 @@ private:
275
270
  std::vector<ggml_backend_t> backend_ptrs;
276
271
  std::vector<ggml_backend_buffer_type_t> backend_buft;
277
272
 
278
- // memory buffers used to evaluate the model
279
- std::vector<uint8_t> buf_compute_meta;
273
+ llm_graph_result_ptr gf_res_prev;
274
+ llm_graph_result_ptr gf_res_reserve;
280
275
 
281
276
  // host buffer for the model output (logits and embeddings)
282
277
  ggml_backend_buffer_ptr buf_output;
@@ -294,4 +289,6 @@ private:
294
289
 
295
290
  mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
296
291
  mutable int32_t n_eval = 0; // number of eval calls
292
+
293
+ mutable int32_t n_reused = 0; // number of times the previous graph was reused
297
294
  };
@@ -11,8 +11,8 @@ struct llama_cparams {
11
11
  uint32_t n_batch;
12
12
  uint32_t n_ubatch;
13
13
  uint32_t n_seq_max;
14
- int n_threads; // number of threads to use for generation
15
- int n_threads_batch; // number of threads to use for batch processing
14
+ int32_t n_threads; // number of threads to use for generation
15
+ int32_t n_threads_batch; // number of threads to use for batch processing
16
16
 
17
17
  float rope_freq_base;
18
18
  float rope_freq_scale;
@@ -33,6 +33,7 @@ struct llama_cparams {
33
33
  bool no_perf;
34
34
  bool warmup;
35
35
  bool op_offload;
36
+ bool kv_unified;
36
37
 
37
38
  enum llama_pooling_type pooling_type;
38
39