@fugood/llama.node 1.0.0-beta.5 → 1.0.0-beta.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/lib/binding.ts +3 -1
  2. package/lib/index.js +2 -0
  3. package/lib/index.ts +3 -1
  4. package/package.json +14 -14
  5. package/scripts/llama.cpp.patch +27 -26
  6. package/src/EmbeddingWorker.cpp +1 -1
  7. package/src/LlamaCompletionWorker.cpp +28 -7
  8. package/src/LlamaCompletionWorker.h +4 -0
  9. package/src/LlamaContext.cpp +14 -17
  10. package/src/common.hpp +7 -6
  11. package/src/llama.cpp/CMakeLists.txt +15 -4
  12. package/src/llama.cpp/common/CMakeLists.txt +15 -24
  13. package/src/llama.cpp/common/arg.cpp +172 -110
  14. package/src/llama.cpp/common/chat-parser.cpp +385 -0
  15. package/src/llama.cpp/common/chat-parser.h +120 -0
  16. package/src/llama.cpp/common/chat.cpp +726 -596
  17. package/src/llama.cpp/common/chat.h +74 -8
  18. package/src/llama.cpp/common/common.cpp +56 -38
  19. package/src/llama.cpp/common/common.h +9 -3
  20. package/src/llama.cpp/common/json-partial.cpp +256 -0
  21. package/src/llama.cpp/common/json-partial.h +38 -0
  22. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  23. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
  24. package/src/llama.cpp/common/sampling.cpp +7 -8
  25. package/src/llama.cpp/common/speculative.cpp +6 -4
  26. package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
  27. package/src/llama.cpp/ggml/include/ggml.h +22 -3
  28. package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
  29. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
  30. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  31. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  32. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  36. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  37. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  38. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  39. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  40. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  42. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  43. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  44. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  45. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
  47. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
  48. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  49. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  50. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  51. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  52. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  53. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
  54. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  55. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  56. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  57. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
  58. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  59. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
  60. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  61. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  62. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
  63. package/src/llama.cpp/include/llama.h +145 -40
  64. package/src/llama.cpp/src/CMakeLists.txt +5 -1
  65. package/src/llama.cpp/src/llama-arch.cpp +99 -3
  66. package/src/llama.cpp/src/llama-arch.h +10 -1
  67. package/src/llama.cpp/src/llama-batch.cpp +728 -272
  68. package/src/llama.cpp/src/llama-batch.h +112 -54
  69. package/src/llama.cpp/src/llama-chat.cpp +19 -2
  70. package/src/llama.cpp/src/llama-chat.h +1 -0
  71. package/src/llama.cpp/src/llama-context.cpp +525 -339
  72. package/src/llama.cpp/src/llama-context.h +38 -17
  73. package/src/llama.cpp/src/llama-cparams.cpp +4 -0
  74. package/src/llama.cpp/src/llama-cparams.h +2 -0
  75. package/src/llama.cpp/src/llama-grammar.cpp +12 -2
  76. package/src/llama.cpp/src/llama-graph.cpp +413 -353
  77. package/src/llama.cpp/src/llama-graph.h +112 -56
  78. package/src/llama.cpp/src/llama-hparams.cpp +10 -2
  79. package/src/llama.cpp/src/llama-hparams.h +13 -2
  80. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
  81. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
  82. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
  83. package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
  84. package/src/llama.cpp/src/llama-kv-cells.h +415 -0
  85. package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  86. package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
  87. package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
  88. package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
  89. package/src/llama.cpp/src/llama-memory.cpp +41 -0
  90. package/src/llama.cpp/src/llama-memory.h +86 -5
  91. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  92. package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
  93. package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
  94. package/src/llama.cpp/src/llama-model.cpp +1137 -528
  95. package/src/llama.cpp/src/llama-model.h +4 -0
  96. package/src/llama.cpp/src/llama-quant.cpp +2 -1
  97. package/src/llama.cpp/src/llama-sampling.cpp +2 -2
  98. package/src/llama.cpp/src/llama-vocab.cpp +69 -32
  99. package/src/llama.cpp/src/llama-vocab.h +1 -0
  100. package/src/llama.cpp/src/llama.cpp +11 -7
  101. package/src/llama.cpp/src/unicode.cpp +5 -0
  102. package/src/tts_utils.h +1 -1
  103. package/src/llama.cpp/common/json.hpp +0 -24766
  104. package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
  105. package/src/llama.cpp/common/minja/minja.hpp +0 -2974
  106. package/src/llama.cpp/common/stb_image.h +0 -7988
  107. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  108. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
  109. package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
  110. package/src/llama.cpp/src/llama-kv-cache.h +0 -515
  111. /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  112. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  113. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -1,14 +1,16 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-io.h"
6
+ #include "llama-memory.h"
5
7
  #include "llama-mmap.h"
6
8
  #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
9
 
10
+ #include <cinttypes>
9
11
  #include <cstring>
12
+ #include <limits>
10
13
  #include <stdexcept>
11
- #include <cinttypes>
12
14
 
13
15
  //
14
16
  // llama_context
@@ -17,7 +19,8 @@
17
19
  llama_context::llama_context(
18
20
  const llama_model & model,
19
21
  llama_context_params params) :
20
- model(model) {
22
+ model(model),
23
+ balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
21
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
22
25
 
23
26
  t_start_us = model.t_start_us;
@@ -25,7 +28,11 @@ llama_context::llama_context(
25
28
 
26
29
  const auto & hparams = model.hparams;
27
30
 
28
- cparams.n_seq_max = std::max(1u, params.n_seq_max);
31
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
32
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
34
+ }
35
+
29
36
  cparams.n_threads = params.n_threads;
30
37
  cparams.n_threads_batch = params.n_threads_batch;
31
38
  cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -118,6 +125,11 @@ llama_context::llama_context(
118
125
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
119
126
  }
120
127
 
128
+ if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
129
+ LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
130
+ __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
131
+ }
132
+
121
133
  if (!hparams.vocab_only) {
122
134
  // GPU backends
123
135
  for (auto * dev : model.devices) {
@@ -255,15 +267,9 @@ llama_context::llama_context(
255
267
 
256
268
  // reserve worst-case graph
257
269
  if (!hparams.vocab_only && memory) {
258
- const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
270
+ const uint32_t n_seqs = cparams.n_seq_max;
259
271
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
260
272
 
261
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
262
-
263
- // restore later
264
- // TODO: something cleaner
265
- const auto n_outputs_save = n_outputs;
266
-
267
273
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
268
274
 
269
275
  int n_splits_pp = -1;
@@ -273,25 +279,18 @@ llama_context::llama_context(
273
279
  int n_nodes_tg = -1;
274
280
 
275
281
  // simulate full KV cache
276
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
277
282
 
278
- kv_self->set_full();
283
+ const auto mctx = memory->init_full();
284
+ if (!mctx) {
285
+ throw std::runtime_error("failed to initialize KV cache");
286
+ }
279
287
 
280
288
  cross.v_embd.clear();
281
289
 
282
290
  // reserve pp graph first so that buffers are only allocated once
283
291
  {
284
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
285
-
286
- // max number of outputs
287
- n_outputs = ubatch_pp.n_tokens;
288
-
289
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
290
-
291
- auto * gf = graph_init();
292
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
293
-
294
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
292
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293
+ if (!gf) {
295
294
  throw std::runtime_error("failed to allocate compute pp buffers");
296
295
  }
297
296
 
@@ -301,16 +300,8 @@ llama_context::llama_context(
301
300
 
302
301
  // reserve with tg graph to get the number of splits and nodes
303
302
  {
304
- llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
305
-
306
- n_outputs = ubatch_tg.n_tokens;
307
-
308
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
309
-
310
- auto * gf = graph_init();
311
- graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
312
-
313
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
303
+ auto * gf = graph_reserve(1, 1, 1, mctx.get());
304
+ if (!gf) {
314
305
  throw std::runtime_error("failed to allocate compute tg buffers");
315
306
  }
316
307
 
@@ -320,22 +311,12 @@ llama_context::llama_context(
320
311
 
321
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
322
313
  {
323
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
324
-
325
- n_outputs = ubatch_pp.n_tokens;
326
-
327
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
328
-
329
- auto * gf = graph_init();
330
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
331
-
332
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
314
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
+ if (!gf) {
333
316
  throw std::runtime_error("failed to allocate compute pp buffers");
334
317
  }
335
318
  }
336
319
 
337
- n_outputs = n_outputs_save;
338
-
339
320
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
340
321
  ggml_backend_t backend = backend_ptrs[i];
341
322
  ggml_backend_buffer_type_t buft = backend_buft[i];
@@ -439,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
439
420
  return cparams.n_threads_batch;
440
421
  }
441
422
 
442
- llama_kv_cache * llama_context::get_kv_self() {
443
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
444
- return kv_self;
445
- }
446
-
447
- const llama_kv_cache * llama_context::get_kv_self() const {
448
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
449
- return kv_self;
423
+ llama_memory_t llama_context::get_memory() const {
424
+ return memory.get();
450
425
  }
451
426
 
452
- void llama_context::kv_self_update() {
453
- bool need_reserve = false;
427
+ // deprecated
428
+ void llama_context::kv_self_defrag_sched() {
429
+ if (!memory) {
430
+ return;
431
+ }
454
432
 
455
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
433
+ memory_force_optimize = true;
434
+ }
456
435
 
457
- need_reserve = kv_self->update(*this);
436
+ // deprecated
437
+ bool llama_context::kv_self_update(bool optimize) {
438
+ if (!memory) {
439
+ return false;
440
+ }
458
441
 
459
- // reserve a worst case graph if needed
460
- if (need_reserve) {
461
- LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
442
+ {
443
+ // TODO: remove in the future
444
+ optimize |= memory_force_optimize;
445
+ memory_force_optimize = false;
462
446
 
463
- // build worst-case graph
464
- uint32_t n_seqs = 1; // TODO: worst-case number of sequences
465
- uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
447
+ const auto mctx = memory->init_update(this, optimize);
448
+ switch (mctx->get_status()) {
449
+ case LLAMA_MEMORY_STATUS_SUCCESS:
450
+ {
451
+ // noop
452
+ } break;
453
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
454
+ {
455
+ // no updates need to be performed
456
+ return false;
457
+ }
458
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
459
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
460
+ {
461
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
462
+ return false;
463
+ }
464
+ }
466
465
 
467
- // simulate full KV cache
468
- kv_self->set_full();
466
+ if (!mctx->apply()) {
467
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
+ }
469
+ }
469
470
 
470
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
471
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
471
+ // if the memory module did any computation, we have to reserve a new worst-case graph
472
+ {
473
+ const auto mctx = memory->init_full();
474
+ if (!mctx) {
475
+ throw std::runtime_error("failed to initialize memory context");
476
+ }
472
477
 
473
- auto * gf = graph_init();
474
- graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
478
+ const uint32_t n_seqs = cparams.n_seq_max;
479
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
475
480
 
476
- // initialize scheduler with the worst-case graph
477
- ggml_backend_sched_reset(sched.get());
478
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
479
- LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
481
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
482
+ if (!gf) {
483
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
480
484
  }
481
485
  }
486
+
487
+ return true;
482
488
  }
483
489
 
484
490
  enum llama_pooling_type llama_context::pooling_type() const {
@@ -490,7 +496,7 @@ float * llama_context::get_logits() {
490
496
  }
491
497
 
492
498
  float * llama_context::get_logits_ith(int32_t i) {
493
- int32_t j = -1;
499
+ int64_t j = -1;
494
500
 
495
501
  try {
496
502
  if (logits == nullptr) {
@@ -513,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
513
519
  }
514
520
  if (j >= n_outputs) {
515
521
  // This should not happen
516
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
522
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
517
523
  }
518
524
 
519
525
  return logits + j*model.vocab.n_tokens();
@@ -532,7 +538,7 @@ float * llama_context::get_embeddings() {
532
538
  }
533
539
 
534
540
  float * llama_context::get_embeddings_ith(int32_t i) {
535
- int32_t j = -1;
541
+ int64_t j = -1;
536
542
 
537
543
  try {
538
544
  if (embd == nullptr) {
@@ -555,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
555
561
  }
556
562
  if (j >= n_outputs) {
557
563
  // This should not happen
558
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
564
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
559
565
  }
560
566
 
561
567
  return embd + j*model.hparams.n_embd;
@@ -672,63 +678,95 @@ bool llama_context::apply_adapter_cvec(
672
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
673
679
  }
674
680
 
675
- int llama_context::encode(llama_batch & inp_batch) {
676
- if (inp_batch.n_tokens == 0) {
677
- LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
678
- return -1;
681
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
+ if (mctx && !mctx->apply()) {
683
+ LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
+ ret = GGML_STATUS_FAILED;
685
+ return nullptr;
679
686
  }
680
687
 
681
- // temporary allocate memory for the input batch if needed
682
- // note: during encode, we always pass the full sequence starting from pos = 0
683
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
688
+ auto * gf = graph_init();
689
+ if (!gf) {
690
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
+ ret = GGML_STATUS_FAILED;
692
+ return nullptr;
693
+ }
694
+
695
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
+ if (!res) {
697
+ LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
+ ret = GGML_STATUS_FAILED;
699
+ return nullptr;
700
+ }
701
+
702
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
703
+
704
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
+ ret = GGML_STATUS_ALLOC_FAILED;
707
+ return nullptr;
708
+ }
709
+
710
+ res->set_inputs(&ubatch);
711
+
712
+ const auto status = graph_compute(gf, ubatch.n_tokens > 1);
713
+ if (status != GGML_STATUS_SUCCESS) {
714
+ LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
+ ret = status;
716
+ return nullptr;
717
+ }
718
+
719
+ ret = GGML_STATUS_SUCCESS;
684
720
 
685
- const llama_batch & batch = batch_allocr.batch;
686
- const int32_t n_tokens = batch.n_tokens;
721
+ return res;
722
+ }
723
+
724
+ int llama_context::encode(const llama_batch & batch_inp) {
725
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
726
+
727
+ if (batch_inp.n_tokens == 0) {
728
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
729
+ return -1;
730
+ }
687
731
 
688
732
  const auto & hparams = model.hparams;
689
733
 
690
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
734
+ const int64_t n_embd = hparams.n_embd;
691
735
 
692
- if (batch.token) {
693
- for (int32_t i = 0; i < n_tokens; ++i) {
694
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
695
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
696
- return -1;
697
- }
698
- }
736
+ // note: during encode, we always pass the full sequence starting from pos = 0
737
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
738
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
+ return -1;
699
740
  }
700
741
 
742
+ const uint32_t n_tokens = balloc->get_n_tokens();
743
+
744
+ const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745
+
701
746
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
702
- GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
747
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
703
748
 
704
749
  if (t_compute_start_us == 0) {
705
750
  t_compute_start_us = ggml_time_us();
706
751
  }
707
752
 
753
+ // TODO: this clear of the buffer can easily be forgotten - need something better
708
754
  embd_seq.clear();
709
755
 
710
756
  n_queued_tokens += n_tokens;
711
757
 
712
- const int64_t n_embd = hparams.n_embd;
713
-
714
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
715
-
716
- const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
717
-
718
758
  // reserve output buffer
719
759
  if (output_reserve(n_tokens) < n_tokens) {
720
760
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
721
761
  return -2;
722
762
  };
723
763
 
724
- for (int32_t i = 0; i < n_tokens; ++i) {
764
+ for (uint32_t i = 0; i < n_tokens; ++i) {
725
765
  output_ids[i] = i;
726
766
  }
727
767
 
728
768
  n_outputs = n_tokens;
729
769
 
730
- //batch_manager->prepare(ubatch);
731
-
732
770
  ggml_backend_sched_reset(sched.get());
733
771
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
734
772
 
@@ -739,26 +777,18 @@ int llama_context::encode(llama_batch & inp_batch) {
739
777
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
740
778
  cparams.causal_attn = false;
741
779
 
742
- auto * gf = graph_init();
743
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
744
-
745
- ggml_backend_sched_alloc_graph(sched.get(), gf);
746
-
747
- res->set_inputs(&ubatch);
780
+ ggml_status status;
781
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
748
782
 
749
783
  cparams.causal_attn = causal_attn_org;
750
784
 
751
- const auto compute_status = graph_compute(gf, n_tokens > 1);
752
- switch (compute_status) {
753
- case GGML_STATUS_SUCCESS:
754
- break;
755
- case GGML_STATUS_ABORTED:
756
- return 2;
757
- case GGML_STATUS_ALLOC_FAILED:
758
- return -2;
759
- case GGML_STATUS_FAILED:
760
- default:
761
- return -3;
785
+ if (!res) {
786
+ switch (status) {
787
+ case GGML_STATUS_ABORTED: return 2;
788
+ case GGML_STATUS_ALLOC_FAILED: return -2;
789
+ case GGML_STATUS_FAILED: return -3;
790
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
791
+ }
762
792
  }
763
793
 
764
794
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -783,31 +813,28 @@ int llama_context::encode(llama_batch & inp_batch) {
783
813
  {
784
814
  // extract sequence embeddings
785
815
  auto & embd_seq_out = embd_seq;
786
- embd_seq_out.clear();
787
816
 
788
- GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
817
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
818
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
819
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
789
820
 
790
- for (int32_t i = 0; i < n_tokens; i++) {
791
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
792
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
793
- continue;
794
- }
795
821
  embd_seq_out[seq_id].resize(n_embd);
796
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
822
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
797
823
  }
798
824
  } break;
799
825
  case LLAMA_POOLING_TYPE_RANK:
800
826
  {
801
- // extract the rerank score - a single float per sequence
827
+ // extract the rerank score - n_cls_out floats per sequence
802
828
  auto & embd_seq_out = embd_seq;
803
829
 
804
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
805
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
806
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
807
- continue;
808
- }
809
- embd_seq_out[seq_id].resize(1);
810
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
830
+ const uint32_t n_cls_out = hparams.n_cls_out;
831
+
832
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
833
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
834
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
835
+
836
+ embd_seq_out[seq_id].resize(n_cls_out);
837
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
811
838
  }
812
839
  } break;
813
840
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -832,12 +859,16 @@ int llama_context::encode(llama_batch & inp_batch) {
832
859
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
833
860
  memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
834
861
 
862
+ const auto & batch = balloc->get_batch();
863
+
835
864
  // remember the sequence ids used during the encoding - needed for cross attention later
836
865
  cross.seq_ids_enc.resize(n_tokens);
837
- for (int32_t i = 0; i < n_tokens; i++) {
866
+ for (uint32_t i = 0; i < n_tokens; i++) {
838
867
  cross.seq_ids_enc[i].clear();
839
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
840
- llama_seq_id seq_id = ubatch.seq_id[i][s];
868
+
869
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
870
+ const llama_seq_id seq_id = batch.seq_id[i][s];
871
+
841
872
  cross.seq_ids_enc[i].insert(seq_id);
842
873
  }
843
874
  }
@@ -846,49 +877,42 @@ int llama_context::encode(llama_batch & inp_batch) {
846
877
  return 0;
847
878
  }
848
879
 
849
- int llama_context::decode(llama_batch & inp_batch) {
880
+ int llama_context::decode(const llama_batch & batch_inp) {
881
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
882
+
850
883
  if (!memory) {
851
- LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
852
- return encode(inp_batch);
884
+ LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
885
+ return encode(batch_inp);
853
886
  }
854
887
 
855
- if (inp_batch.n_tokens == 0) {
888
+ if (batch_inp.n_tokens == 0) {
856
889
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
857
890
  return -1;
858
891
  }
859
892
 
860
- if (!inp_batch.pos) {
861
- if (inp_batch.seq_id) {
862
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
863
- return -1;
864
- }
865
- }
866
-
867
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
868
-
869
- // temporary allocate memory for the input batch if needed
870
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
871
-
872
- const llama_batch & batch = batch_allocr.batch;
873
-
874
893
  const auto & vocab = model.vocab;
875
894
  const auto & hparams = model.hparams;
876
895
 
877
896
  const int32_t n_vocab = vocab.n_tokens();
897
+ const int64_t n_embd = hparams.n_embd;
878
898
 
879
- const int64_t n_tokens_all = batch.n_tokens;
880
- const int64_t n_embd = hparams.n_embd;
899
+ // when computing embeddings, all tokens are output
900
+ const bool output_all = cparams.embeddings;
881
901
 
882
- llama_kv_cache_guard kv_guard(kv_self);
902
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
903
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
+ return -1;
905
+ }
883
906
 
884
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
907
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
908
+ const uint32_t n_outputs_all = balloc->get_n_outputs();
885
909
 
886
- if (batch.token) {
887
- for (int64_t i = 0; i < n_tokens_all; ++i) {
888
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
889
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
890
- throw std::runtime_error("invalid token");
891
- }
910
+ if (output_all) {
911
+ // require that all tokens are output
912
+ if (n_outputs_all != n_tokens_all) {
913
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
914
+ __func__, n_outputs_all, n_tokens_all);
915
+ return -1;
892
916
  }
893
917
  }
894
918
 
@@ -901,49 +925,77 @@ int llama_context::decode(llama_batch & inp_batch) {
901
925
  }
902
926
  n_queued_tokens += n_tokens_all;
903
927
 
904
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
905
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
906
-
928
+ // TODO: this clear of the buffer can easily be forgotten - need something better
907
929
  embd_seq.clear();
908
930
 
909
- int64_t n_outputs_all = 0;
931
+ bool did_optimize = false;
932
+
933
+ // handle any pending defrags/shifts
934
+ kv_self_update(false);
935
+
936
+ llama_memory_context_ptr mctx;
910
937
 
911
- // count outputs
912
- if (batch.logits && !embd_pooled) {
913
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
914
- n_outputs_all += batch.logits[i] != 0;
938
+ while (true) {
939
+ mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940
+ if (!mctx) {
941
+ return -2;
915
942
  }
916
- } else if (embd_pooled) {
917
- n_outputs_all = n_tokens_all;
918
- } else {
919
- // keep last output only
920
- n_outputs_all = 1;
921
- }
922
943
 
923
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
944
+ switch (mctx->get_status()) {
945
+ case LLAMA_MEMORY_STATUS_SUCCESS:
946
+ {
947
+ } break;
948
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
949
+ {
950
+ LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
951
+
952
+ return -2;
953
+ }
954
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
955
+ {
956
+ if (!did_optimize) {
957
+ did_optimize = true;
958
+
959
+ if (kv_self_update(true)) {
960
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
961
+
962
+ continue;
963
+ }
964
+ }
965
+
966
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
967
+
968
+ return 1;
969
+ }
970
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
971
+ {
972
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
973
+
974
+ return -2;
975
+ }
976
+ }
977
+
978
+ break;
979
+ }
924
980
 
925
981
  // reserve output buffer
926
982
  if (output_reserve(n_outputs_all) < n_outputs_all) {
927
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
983
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
928
984
  return -2;
929
985
  };
930
986
 
931
- // handle any pending defrags/shifts
932
- kv_self_update();
933
-
934
987
  int64_t n_outputs_prev = 0;
935
988
 
936
- while (sbatch.n_tokens > 0) {
937
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
989
+ do {
990
+ const auto & ubatch = mctx->get_ubatch();
938
991
 
939
- // count the outputs in this u_batch
992
+ // count the outputs in this ubatch
940
993
  {
941
994
  int32_t n_outputs_new = 0;
942
995
 
943
996
  if (n_outputs_all == n_tokens_all) {
944
997
  n_outputs_new = ubatch.n_tokens;
945
998
  } else {
946
- GGML_ASSERT(ubatch.output);
947
999
  for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
948
1000
  n_outputs_new += (int32_t) (ubatch.output[i] != 0);
949
1001
  }
@@ -953,33 +1005,41 @@ int llama_context::decode(llama_batch & inp_batch) {
953
1005
  n_outputs = n_outputs_new;
954
1006
  }
955
1007
 
956
- // find KV slot
957
- if (!kv_self->find_slot(ubatch)) {
958
- return 1;
959
- }
960
-
961
1008
  ggml_backend_sched_reset(sched.get());
962
1009
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
963
1010
 
964
- auto * gf = graph_init();
965
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1011
+ ggml_status status;
1012
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
966
1013
 
967
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1014
+ if (!res) {
1015
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1016
+ llama_pos pos_min[LLAMA_MAX_SEQ];
1017
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1018
+ pos_min[s] = std::numeric_limits<llama_pos>::max();
1019
+ }
968
1020
 
969
- ggml_backend_sched_alloc_graph(sched.get(), gf);
1021
+ // TODO: fix sequence indexing
1022
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1023
+ const auto & seq_id = ubatch.seq_id[i][0];
970
1024
 
971
- res->set_inputs(&ubatch);
1025
+ pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1026
+ }
972
1027
 
973
- const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
974
- if (compute_status != GGML_STATUS_SUCCESS) {
975
- switch (compute_status) {
976
- case GGML_STATUS_ABORTED:
977
- return 2;
978
- case GGML_STATUS_ALLOC_FAILED:
979
- return -2;
980
- case GGML_STATUS_FAILED:
981
- default:
982
- return -3;
1028
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1029
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1030
+ continue;
1031
+ }
1032
+
1033
+ LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1034
+
1035
+ memory->seq_rm(s, pos_min[s], -1);
1036
+ }
1037
+
1038
+ switch (status) {
1039
+ case GGML_STATUS_ABORTED: return 2;
1040
+ case GGML_STATUS_ALLOC_FAILED: return -2;
1041
+ case GGML_STATUS_FAILED: return -3;
1042
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
983
1043
  }
984
1044
  }
985
1045
 
@@ -988,7 +1048,7 @@ int llama_context::decode(llama_batch & inp_batch) {
988
1048
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
989
1049
  //}
990
1050
 
991
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1051
+ auto * t_logits = res->get_logits();
992
1052
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
993
1053
 
994
1054
  if (t_embd && res->get_embd_pooled()) {
@@ -1035,27 +1095,27 @@ int llama_context::decode(llama_batch & inp_batch) {
1035
1095
  // extract sequence embeddings (cleared before processing each batch)
1036
1096
  auto & embd_seq_out = embd_seq;
1037
1097
 
1038
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1039
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1040
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1041
- continue;
1042
- }
1098
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1099
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1100
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1101
+
1043
1102
  embd_seq_out[seq_id].resize(n_embd);
1044
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1103
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
1045
1104
  }
1046
1105
  } break;
1047
1106
  case LLAMA_POOLING_TYPE_RANK:
1048
1107
  {
1049
- // extract the rerank score - a single float per sequence
1108
+ // extract the rerank score - n_cls_out floats per sequence
1050
1109
  auto & embd_seq_out = embd_seq;
1051
1110
 
1052
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1053
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1054
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1055
- continue;
1056
- }
1057
- embd_seq_out[seq_id].resize(1);
1058
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
1111
+ const uint32_t n_cls_out = hparams.n_cls_out;
1112
+
1113
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1114
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1115
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1116
+
1117
+ embd_seq_out[seq_id].resize(n_cls_out);
1118
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
1059
1119
  }
1060
1120
  } break;
1061
1121
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1066,23 +1126,20 @@ int llama_context::decode(llama_batch & inp_batch) {
1066
1126
  }
1067
1127
 
1068
1128
  n_outputs_prev += n_outputs;
1069
- }
1070
-
1071
- // finalize the batch processing
1072
- kv_guard.commit();
1129
+ } while (mctx->next());
1073
1130
 
1074
1131
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1075
1132
  n_outputs = n_outputs_all;
1076
1133
 
1077
1134
  // set output mappings
1078
- {
1135
+ if (n_outputs > 0) {
1079
1136
  bool sorted_output = true;
1080
1137
 
1081
- auto & out_ids = sbatch.out_ids;
1138
+ auto & out_ids = balloc->get_out_ids();
1082
1139
 
1083
- GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1140
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1084
1141
 
1085
- for (int64_t i = 0; i < n_outputs_all; ++i) {
1142
+ for (int64_t i = 0; i < n_outputs; ++i) {
1086
1143
  int64_t out_id = out_ids[i];
1087
1144
  output_ids[out_id] = i;
1088
1145
  if (out_id != i) {
@@ -1094,20 +1151,22 @@ int llama_context::decode(llama_batch & inp_batch) {
1094
1151
  // note: this is mostly relevant for recurrent models atm
1095
1152
  if (!sorted_output) {
1096
1153
  const uint32_t n_vocab = model.vocab.n_tokens();
1097
- const uint32_t n_embd = model.hparams.n_embd;
1154
+ const uint64_t n_embd = model.hparams.n_embd;
1098
1155
 
1099
1156
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1100
1157
 
1101
1158
  // TODO: is there something more efficient which also minimizes swaps?
1102
1159
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1103
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1104
- int32_t j_min = i;
1105
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1160
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1161
+ uint32_t j_min = i;
1162
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
1106
1163
  if (out_ids[j] < out_ids[j_min]) {
1107
1164
  j_min = j;
1108
1165
  }
1109
1166
  }
1110
- if (j_min == i) { continue; }
1167
+ if (j_min == i) {
1168
+ continue;
1169
+ }
1111
1170
  std::swap(out_ids[i], out_ids[j_min]);
1112
1171
  if (logits_size > 0) {
1113
1172
  for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1120,8 +1179,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1120
1179
  }
1121
1180
  }
1122
1181
  }
1182
+
1123
1183
  std::fill(output_ids.begin(), output_ids.end(), -1);
1124
- for (int32_t i = 0; i < n_outputs; ++i) {
1184
+
1185
+ for (uint32_t i = 0; i < n_outputs; ++i) {
1125
1186
  output_ids[out_ids[i]] = i;
1126
1187
  }
1127
1188
  }
@@ -1130,11 +1191,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1130
1191
  // wait for the computation to finish (automatically done when obtaining the model output)
1131
1192
  //synchronize();
1132
1193
 
1133
- // decide if we need to defrag the kv cache
1134
- if (cparams.defrag_thold > 0.0f) {
1135
- kv_self->defrag_sched(cparams.defrag_thold);
1136
- }
1137
-
1138
1194
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1139
1195
  // overlap with device computation.
1140
1196
  ggml_backend_sched_reset(sched.get());
@@ -1146,7 +1202,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1146
1202
  // output
1147
1203
  //
1148
1204
 
1149
- int32_t llama_context::output_reserve(int32_t n_outputs) {
1205
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1150
1206
  const auto & hparams = model.hparams;
1151
1207
  const auto & vocab = model.vocab;
1152
1208
 
@@ -1156,9 +1212,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1156
1212
  const auto n_vocab = vocab.n_tokens();
1157
1213
  const auto n_embd = hparams.n_embd;
1158
1214
 
1159
- // TODO: use a per-batch flag for logits presence instead
1160
- bool has_logits = !cparams.embeddings;
1161
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1215
+ bool has_logits = true;
1216
+ bool has_embd = cparams.embeddings;
1162
1217
 
1163
1218
  // TODO: hacky enc-dec support
1164
1219
  if (model.arch == LLM_ARCH_T5) {
@@ -1212,8 +1267,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1212
1267
  // set all ids as invalid (negative)
1213
1268
  std::fill(output_ids.begin(), output_ids.end(), -1);
1214
1269
 
1215
- this->n_outputs = 0;
1216
- this->n_outputs_max = n_outputs_max;
1270
+ this->n_outputs = 0;
1217
1271
 
1218
1272
  return n_outputs_max;
1219
1273
  }
@@ -1238,11 +1292,52 @@ ggml_cgraph * llama_context::graph_init() {
1238
1292
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1239
1293
  }
1240
1294
 
1295
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1296
+ LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1297
+
1298
+ if (n_tokens % n_seqs != 0) {
1299
+ n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1300
+ n_outputs = std::min(n_outputs, n_tokens);
1301
+
1302
+ LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1303
+ }
1304
+
1305
+ // store the n_outputs as it is, and restore it afterwards
1306
+ // TODO: not sure if needed, might simplify in the future by removing this
1307
+ const auto save_n_outputs = this->n_outputs;
1308
+
1309
+ this->n_outputs = n_outputs;
1310
+
1311
+ llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1312
+ llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1313
+
1314
+ auto * gf = graph_init();
1315
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1316
+
1317
+ this->n_outputs = save_n_outputs;
1318
+
1319
+ if (!res) {
1320
+ LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1321
+ return nullptr;
1322
+ }
1323
+
1324
+ ggml_backend_sched_reset(sched.get());
1325
+
1326
+ // initialize scheduler with the specified graph
1327
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1328
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1329
+ return nullptr;
1330
+ }
1331
+
1332
+ return gf;
1333
+ }
1334
+
1241
1335
  llm_graph_result_ptr llama_context::graph_build(
1242
- ggml_context * ctx,
1243
- ggml_cgraph * gf,
1244
- const llama_ubatch & ubatch,
1245
- llm_graph_type gtype) {
1336
+ ggml_context * ctx,
1337
+ ggml_cgraph * gf,
1338
+ const llama_ubatch & ubatch,
1339
+ llm_graph_type gtype,
1340
+ const llama_memory_context_i * mctx) {
1246
1341
  return model.build_graph(
1247
1342
  {
1248
1343
  /*.ctx =*/ ctx,
@@ -1254,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
1254
1349
  /*.backend_cpu =*/ backend_cpu,
1255
1350
  /*.cvec =*/ &cvec,
1256
1351
  /*.loras =*/ &loras,
1257
- /*.memory =*/ memory.get(),
1352
+ /*.mctx =*/ mctx,
1258
1353
  /*.cross =*/ &cross,
1259
1354
  /*.n_outputs =*/ n_outputs,
1260
1355
  /*.cb =*/ graph_get_cb(),
@@ -1663,14 +1758,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1663
1758
 
1664
1759
  std::vector<int32_t> w_output_pos;
1665
1760
 
1666
- GGML_ASSERT(n_outputs <= n_outputs_max);
1667
-
1668
1761
  w_output_pos.resize(n_outputs);
1669
1762
 
1670
1763
  // build a more compact representation of the output ids
1671
1764
  for (size_t i = 0; i < n_batch(); ++i) {
1672
1765
  // map an output id to a position in the batch
1673
- int32_t pos = output_ids[i];
1766
+ int64_t pos = output_ids[i];
1674
1767
  if (pos >= 0) {
1675
1768
  GGML_ASSERT(pos < n_outputs);
1676
1769
  w_output_pos[pos] = i;
@@ -1710,11 +1803,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1710
1803
  }
1711
1804
  }
1712
1805
 
1713
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1714
-
1715
- if (kv_self != nullptr) {
1806
+ if (memory != nullptr) {
1716
1807
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1717
- kv_self->state_write(io);
1808
+ memory->state_write(io);
1718
1809
  }
1719
1810
 
1720
1811
  return io.n_bytes();
@@ -1801,9 +1892,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1801
1892
  if (memory) {
1802
1893
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1803
1894
 
1804
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1805
-
1806
- kv_self->state_read(io);
1895
+ memory->state_read(io);
1807
1896
  }
1808
1897
 
1809
1898
  return io.n_bytes();
@@ -1813,9 +1902,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1813
1902
  GGML_UNUSED(seq_id);
1814
1903
 
1815
1904
  if (memory) {
1816
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1817
-
1818
- kv_self->state_write(io, seq_id);
1905
+ memory->state_write(io, seq_id);
1819
1906
  }
1820
1907
 
1821
1908
  return io.n_bytes();
@@ -1825,9 +1912,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
1825
1912
  GGML_UNUSED(seq_id);
1826
1913
 
1827
1914
  if (memory) {
1828
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1829
-
1830
- kv_self->state_read(io, seq_id);
1915
+ memory->state_read(io, seq_id);
1831
1916
  }
1832
1917
 
1833
1918
  return io.n_bytes();
@@ -1932,10 +2017,7 @@ void llama_context::opt_epoch_iter(
1932
2017
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1933
2018
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1934
2019
 
1935
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1936
-
1937
- kv_self->clear();
1938
- llama_kv_cache_guard kv_guard(kv_self);
2020
+ memory->clear(true);
1939
2021
 
1940
2022
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1941
2023
  batch.n_tokens = n_batch;
@@ -1947,39 +2029,44 @@ void llama_context::opt_epoch_iter(
1947
2029
  batch.logits [pos_batch] = true;
1948
2030
  }
1949
2031
 
1950
- const auto n_tokens_all = batch.n_tokens;
2032
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2033
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2034
+ return;
2035
+ }
1951
2036
 
1952
- n_queued_tokens += n_tokens_all;
2037
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
1953
2038
 
1954
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1955
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2039
+ n_queued_tokens += n_tokens_all;
1956
2040
 
1957
2041
  embd_seq.clear();
1958
2042
 
1959
- int64_t n_outputs_all = n_tokens_all;
2043
+ uint32_t n_outputs_all = n_tokens_all;
1960
2044
 
1961
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
2045
+ auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2046
+ if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2047
+ LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2048
+ break;
2049
+ }
1962
2050
 
1963
2051
  // reserve output buffer
1964
2052
  if (output_reserve(n_outputs_all) < n_outputs_all) {
1965
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2053
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1966
2054
  GGML_ABORT("TODO: handle this error");
1967
2055
  };
1968
2056
 
1969
- for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1970
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
2057
+ uint32_t pos_batch = 0;
2058
+ do {
2059
+ const auto & ubatch = mctx->get_ubatch();
1971
2060
 
1972
2061
  n_outputs = ubatch.n_tokens;
1973
2062
 
1974
- // TODO: not sure if this is needed
1975
- if (!kv_self->find_slot(ubatch)) {
1976
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1977
-
1978
- GGML_ABORT("TODO: handle this error");
2063
+ if (!mctx->apply()) {
2064
+ LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
2065
+ break;
1979
2066
  }
1980
2067
 
1981
2068
  auto * gf = graph_init();
1982
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2069
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
1983
2070
 
1984
2071
  struct ggml_context * ctx_compute_opt;
1985
2072
  {
@@ -1994,6 +2081,7 @@ void llama_context::opt_epoch_iter(
1994
2081
  }
1995
2082
  ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
1996
2083
  ggml_opt_alloc(opt_ctx, train);
2084
+
1997
2085
  res->set_inputs(&ubatch);
1998
2086
  {
1999
2087
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@@ -2011,10 +2099,10 @@ void llama_context::opt_epoch_iter(
2011
2099
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2012
2100
  }
2013
2101
  ggml_free(ctx_compute_opt);
2014
- }
2015
- }
2016
2102
 
2017
- kv_guard.commit();
2103
+ pos_batch += ubatch.n_tokens;
2104
+ } while (mctx->next());
2105
+ }
2018
2106
  }
2019
2107
 
2020
2108
  void llama_context::opt_epoch(
@@ -2174,12 +2262,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2174
2262
  return &ctx->get_model();
2175
2263
  }
2176
2264
 
2265
+ // deprecated
2177
2266
  llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2178
- return ctx->get_kv_self();
2267
+ return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2179
2268
  }
2180
2269
 
2270
+ // deprecated
2181
2271
  void llama_kv_self_update(llama_context * ctx) {
2182
- ctx->kv_self_update();
2272
+ ctx->kv_self_update(false);
2183
2273
  }
2184
2274
 
2185
2275
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2294,13 +2384,118 @@ int32_t llama_apply_adapter_cvec(
2294
2384
  return res ? 0 : -1;
2295
2385
  }
2296
2386
 
2387
+ //
2388
+ // memory
2389
+ //
2390
+
2391
+ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2392
+ return ctx->get_memory();
2393
+ }
2394
+
2395
+ void llama_memory_clear(llama_memory_t mem, bool data) {
2396
+ if (!mem) {
2397
+ return;
2398
+ }
2399
+
2400
+ mem->clear(data);
2401
+ }
2402
+
2403
+ bool llama_memory_seq_rm(
2404
+ llama_memory_t mem,
2405
+ llama_seq_id seq_id,
2406
+ llama_pos p0,
2407
+ llama_pos p1) {
2408
+ if (!mem) {
2409
+ return true;
2410
+ }
2411
+
2412
+ return mem->seq_rm(seq_id, p0, p1);
2413
+ }
2414
+
2415
+ void llama_memory_seq_cp(
2416
+ llama_memory_t mem,
2417
+ llama_seq_id seq_id_src,
2418
+ llama_seq_id seq_id_dst,
2419
+ llama_pos p0,
2420
+ llama_pos p1) {
2421
+ if (!mem) {
2422
+ return;
2423
+ }
2424
+
2425
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2426
+ }
2427
+
2428
+ void llama_memory_seq_keep(
2429
+ llama_memory_t mem,
2430
+ llama_seq_id seq_id) {
2431
+ if (!mem) {
2432
+ return;
2433
+ }
2434
+
2435
+ mem->seq_keep(seq_id);
2436
+ }
2437
+
2438
+ void llama_memory_seq_add(
2439
+ llama_memory_t mem,
2440
+ llama_seq_id seq_id,
2441
+ llama_pos p0,
2442
+ llama_pos p1,
2443
+ llama_pos delta) {
2444
+ if (!mem) {
2445
+ return;
2446
+ }
2447
+
2448
+ mem->seq_add(seq_id, p0, p1, delta);
2449
+ }
2450
+
2451
+ void llama_memory_seq_div(
2452
+ llama_memory_t mem,
2453
+ llama_seq_id seq_id,
2454
+ llama_pos p0,
2455
+ llama_pos p1,
2456
+ int d) {
2457
+ if (!mem) {
2458
+ return;
2459
+ }
2460
+
2461
+ mem->seq_div(seq_id, p0, p1, d);
2462
+ }
2463
+
2464
+ llama_pos llama_memory_seq_pos_min(
2465
+ llama_memory_t mem,
2466
+ llama_seq_id seq_id) {
2467
+ if (!mem) {
2468
+ return -1;
2469
+ }
2470
+
2471
+ return mem->seq_pos_min(seq_id);
2472
+ }
2473
+
2474
+ llama_pos llama_memory_seq_pos_max(
2475
+ llama_memory_t mem,
2476
+ llama_seq_id seq_id) {
2477
+ if (!mem) {
2478
+ return -1;
2479
+ }
2480
+
2481
+ return mem->seq_pos_max(seq_id);
2482
+ }
2483
+
2484
+ bool llama_memory_can_shift(llama_memory_t mem) {
2485
+ if (!mem) {
2486
+ return false;
2487
+ }
2488
+
2489
+ return mem->get_can_shift();
2490
+ }
2491
+
2297
2492
  //
2298
2493
  // kv cache
2299
2494
  //
2300
2495
 
2301
2496
  // deprecated
2302
2497
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2303
- const auto * kv = ctx->get_kv_self();
2498
+ const auto * kv = llama_get_memory(ctx);
2304
2499
  if (!kv) {
2305
2500
  return 0;
2306
2501
  }
@@ -2322,7 +2517,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2322
2517
  // deprecated
2323
2518
  // note: this is the same as above - will be removed anyway, so it's ok
2324
2519
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2325
- const auto * kv = ctx->get_kv_self();
2520
+ const auto * kv = llama_get_memory(ctx);
2326
2521
  if (!kv) {
2327
2522
  return 0;
2328
2523
  }
@@ -2341,114 +2536,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2341
2536
  return res;
2342
2537
  }
2343
2538
 
2539
+ // deprecated
2344
2540
  void llama_kv_self_clear(llama_context * ctx) {
2345
- auto * kv = ctx->get_kv_self();
2541
+ auto * kv = llama_get_memory(ctx);
2346
2542
  if (!kv) {
2347
2543
  return;
2348
2544
  }
2349
2545
 
2350
- kv->clear();
2546
+ llama_memory_clear(kv, true);
2351
2547
  }
2352
2548
 
2549
+ // deprecated
2353
2550
  bool llama_kv_self_seq_rm(
2354
2551
  llama_context * ctx,
2355
2552
  llama_seq_id seq_id,
2356
2553
  llama_pos p0,
2357
2554
  llama_pos p1) {
2358
- auto * kv = ctx->get_kv_self();
2555
+ auto * kv = llama_get_memory(ctx);
2359
2556
  if (!kv) {
2360
2557
  return true;
2361
2558
  }
2362
2559
 
2363
- return kv->seq_rm(seq_id, p0, p1);
2560
+ return llama_memory_seq_rm(kv, seq_id, p0, p1);
2364
2561
  }
2365
2562
 
2563
+ // deprecated
2366
2564
  void llama_kv_self_seq_cp(
2367
2565
  llama_context * ctx,
2368
2566
  llama_seq_id seq_id_src,
2369
2567
  llama_seq_id seq_id_dst,
2370
2568
  llama_pos p0,
2371
2569
  llama_pos p1) {
2372
- auto * kv = ctx->get_kv_self();
2570
+ auto * kv = llama_get_memory(ctx);
2373
2571
  if (!kv) {
2374
2572
  return;
2375
2573
  }
2376
2574
 
2377
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2575
+ llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2378
2576
  }
2379
2577
 
2578
+ // deprecated
2380
2579
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2381
- auto * kv = ctx->get_kv_self();
2580
+ auto * kv = llama_get_memory(ctx);
2382
2581
  if (!kv) {
2383
2582
  return;
2384
2583
  }
2385
2584
 
2386
- kv->seq_keep(seq_id);
2585
+ llama_memory_seq_keep(kv, seq_id);
2387
2586
  }
2388
2587
 
2588
+ // deprecated
2389
2589
  void llama_kv_self_seq_add(
2390
2590
  llama_context * ctx,
2391
2591
  llama_seq_id seq_id,
2392
2592
  llama_pos p0,
2393
2593
  llama_pos p1,
2394
2594
  llama_pos delta) {
2395
- auto * kv = ctx->get_kv_self();
2595
+ auto * kv = llama_get_memory(ctx);
2396
2596
  if (!kv) {
2397
2597
  return;
2398
2598
  }
2399
2599
 
2400
- kv->seq_add(seq_id, p0, p1, delta);
2600
+ llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2401
2601
  }
2402
2602
 
2603
+ // deprecated
2403
2604
  void llama_kv_self_seq_div(
2404
2605
  llama_context * ctx,
2405
2606
  llama_seq_id seq_id,
2406
2607
  llama_pos p0,
2407
2608
  llama_pos p1,
2408
2609
  int d) {
2409
- auto * kv = ctx->get_kv_self();
2610
+ auto * kv = llama_get_memory(ctx);
2410
2611
  if (!kv) {
2411
2612
  return;
2412
2613
  }
2413
2614
 
2414
- kv->seq_div(seq_id, p0, p1, d);
2615
+ llama_memory_seq_div(kv, seq_id, p0, p1, d);
2415
2616
  }
2416
2617
 
2618
+ // deprecated
2417
2619
  llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2418
- const auto * kv = ctx->get_kv_self();
2620
+ auto * kv = llama_get_memory(ctx);
2419
2621
  if (!kv) {
2420
2622
  return -1;
2421
2623
  }
2422
2624
 
2423
- return kv->seq_pos_min(seq_id);
2625
+ return llama_memory_seq_pos_min(kv, seq_id);
2424
2626
  }
2425
2627
 
2628
+ // deprecated
2426
2629
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2427
- const auto * kv = ctx->get_kv_self();
2630
+ auto * kv = llama_get_memory(ctx);
2428
2631
  if (!kv) {
2429
2632
  return -1;
2430
2633
  }
2431
2634
 
2432
- return kv->seq_pos_max(seq_id);
2635
+ return llama_memory_seq_pos_max(kv, seq_id);
2433
2636
  }
2434
2637
 
2638
+ // deprecated
2435
2639
  void llama_kv_self_defrag(llama_context * ctx) {
2436
- auto * kv = ctx->get_kv_self();
2437
- if (!kv) {
2438
- return;
2439
- }
2440
-
2441
2640
  // force defrag
2442
- kv->defrag_sched(-1.0f);
2641
+ ctx->kv_self_defrag_sched();
2443
2642
  }
2444
2643
 
2644
+ // deprecated
2445
2645
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2446
- const auto * kv = ctx->get_kv_self();
2646
+ auto * kv = llama_get_memory(ctx);
2447
2647
  if (!kv) {
2448
2648
  return false;
2449
2649
  }
2450
2650
 
2451
- return kv->get_can_shift();
2651
+ return llama_memory_can_shift(kv);
2452
2652
  }
2453
2653
 
2454
2654
  // llama state API
@@ -2573,22 +2773,8 @@ int32_t llama_encode(
2573
2773
  int32_t llama_decode(
2574
2774
  llama_context * ctx,
2575
2775
  llama_batch batch) {
2576
- int ret = ctx->decode(batch);
2577
-
2578
- // defrag and try again
2579
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2580
- if (ret == 1) {
2581
- llama_kv_self_defrag(ctx);
2582
- ret = ctx->decode(batch);
2583
-
2584
- if (ret == 1) {
2585
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2586
-
2587
- return ret;
2588
- }
2589
- }
2590
-
2591
- if (ret != 0) {
2776
+ const int ret = ctx->decode(batch);
2777
+ if (ret != 0 && ret != 1) {
2592
2778
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2593
2779
  }
2594
2780