@fugood/llama.node 1.0.0-beta.5 → 1.0.0-beta.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +3 -1
- package/lib/index.js +2 -0
- package/lib/index.ts +3 -1
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +27 -26
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +28 -7
- package/src/LlamaCompletionWorker.h +4 -0
- package/src/LlamaContext.cpp +14 -17
- package/src/common.hpp +7 -6
- package/src/llama.cpp/CMakeLists.txt +15 -4
- package/src/llama.cpp/common/CMakeLists.txt +15 -24
- package/src/llama.cpp/common/arg.cpp +172 -110
- package/src/llama.cpp/common/chat-parser.cpp +385 -0
- package/src/llama.cpp/common/chat-parser.h +120 -0
- package/src/llama.cpp/common/chat.cpp +726 -596
- package/src/llama.cpp/common/chat.h +74 -8
- package/src/llama.cpp/common/common.cpp +56 -38
- package/src/llama.cpp/common/common.h +9 -3
- package/src/llama.cpp/common/json-partial.cpp +256 -0
- package/src/llama.cpp/common/json-partial.h +38 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/src/llama.cpp/common/sampling.cpp +7 -8
- package/src/llama.cpp/common/speculative.cpp +6 -4
- package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
- package/src/llama.cpp/ggml/include/ggml.h +22 -3
- package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
- package/src/llama.cpp/include/llama.h +145 -40
- package/src/llama.cpp/src/CMakeLists.txt +5 -1
- package/src/llama.cpp/src/llama-arch.cpp +99 -3
- package/src/llama.cpp/src/llama-arch.h +10 -1
- package/src/llama.cpp/src/llama-batch.cpp +728 -272
- package/src/llama.cpp/src/llama-batch.h +112 -54
- package/src/llama.cpp/src/llama-chat.cpp +19 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +525 -339
- package/src/llama.cpp/src/llama-context.h +38 -17
- package/src/llama.cpp/src/llama-cparams.cpp +4 -0
- package/src/llama.cpp/src/llama-cparams.h +2 -0
- package/src/llama.cpp/src/llama-grammar.cpp +12 -2
- package/src/llama.cpp/src/llama-graph.cpp +413 -353
- package/src/llama.cpp/src/llama-graph.h +112 -56
- package/src/llama.cpp/src/llama-hparams.cpp +10 -2
- package/src/llama.cpp/src/llama-hparams.h +13 -2
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
- package/src/llama.cpp/src/llama-kv-cells.h +415 -0
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
- package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
- package/src/llama.cpp/src/llama-memory.cpp +41 -0
- package/src/llama.cpp/src/llama-memory.h +86 -5
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/src/llama.cpp/src/llama-model.cpp +1137 -528
- package/src/llama.cpp/src/llama-model.h +4 -0
- package/src/llama.cpp/src/llama-quant.cpp +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2 -2
- package/src/llama.cpp/src/llama-vocab.cpp +69 -32
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/llama.cpp +11 -7
- package/src/llama.cpp/src/unicode.cpp +5 -0
- package/src/tts_utils.h +1 -1
- package/src/llama.cpp/common/json.hpp +0 -24766
- package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
- package/src/llama.cpp/common/minja/minja.hpp +0 -2974
- package/src/llama.cpp/common/stb_image.h +0 -7988
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
- package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
- package/src/llama.cpp/src/llama-kv-cache.h +0 -515
- /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
#include "llama-context.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
4
5
|
#include "llama-io.h"
|
|
6
|
+
#include "llama-memory.h"
|
|
5
7
|
#include "llama-mmap.h"
|
|
6
8
|
#include "llama-model.h"
|
|
7
|
-
#include "llama-kv-cache.h"
|
|
8
9
|
|
|
10
|
+
#include <cinttypes>
|
|
9
11
|
#include <cstring>
|
|
12
|
+
#include <limits>
|
|
10
13
|
#include <stdexcept>
|
|
11
|
-
#include <cinttypes>
|
|
12
14
|
|
|
13
15
|
//
|
|
14
16
|
// llama_context
|
|
@@ -17,7 +19,8 @@
|
|
|
17
19
|
llama_context::llama_context(
|
|
18
20
|
const llama_model & model,
|
|
19
21
|
llama_context_params params) :
|
|
20
|
-
model(model)
|
|
22
|
+
model(model),
|
|
23
|
+
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
|
21
24
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
|
22
25
|
|
|
23
26
|
t_start_us = model.t_start_us;
|
|
@@ -25,7 +28,11 @@ llama_context::llama_context(
|
|
|
25
28
|
|
|
26
29
|
const auto & hparams = model.hparams;
|
|
27
30
|
|
|
28
|
-
cparams.n_seq_max
|
|
31
|
+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
|
32
|
+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
|
33
|
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
|
34
|
+
}
|
|
35
|
+
|
|
29
36
|
cparams.n_threads = params.n_threads;
|
|
30
37
|
cparams.n_threads_batch = params.n_threads_batch;
|
|
31
38
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
|
@@ -118,6 +125,11 @@ llama_context::llama_context(
|
|
|
118
125
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
119
126
|
}
|
|
120
127
|
|
|
128
|
+
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
|
|
129
|
+
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
|
130
|
+
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
|
131
|
+
}
|
|
132
|
+
|
|
121
133
|
if (!hparams.vocab_only) {
|
|
122
134
|
// GPU backends
|
|
123
135
|
for (auto * dev : model.devices) {
|
|
@@ -255,15 +267,9 @@ llama_context::llama_context(
|
|
|
255
267
|
|
|
256
268
|
// reserve worst-case graph
|
|
257
269
|
if (!hparams.vocab_only && memory) {
|
|
258
|
-
const uint32_t n_seqs =
|
|
270
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
259
271
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
260
272
|
|
|
261
|
-
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
262
|
-
|
|
263
|
-
// restore later
|
|
264
|
-
// TODO: something cleaner
|
|
265
|
-
const auto n_outputs_save = n_outputs;
|
|
266
|
-
|
|
267
273
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
268
274
|
|
|
269
275
|
int n_splits_pp = -1;
|
|
@@ -273,25 +279,18 @@ llama_context::llama_context(
|
|
|
273
279
|
int n_nodes_tg = -1;
|
|
274
280
|
|
|
275
281
|
// simulate full KV cache
|
|
276
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
277
282
|
|
|
278
|
-
|
|
283
|
+
const auto mctx = memory->init_full();
|
|
284
|
+
if (!mctx) {
|
|
285
|
+
throw std::runtime_error("failed to initialize KV cache");
|
|
286
|
+
}
|
|
279
287
|
|
|
280
288
|
cross.v_embd.clear();
|
|
281
289
|
|
|
282
290
|
// reserve pp graph first so that buffers are only allocated once
|
|
283
291
|
{
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
// max number of outputs
|
|
287
|
-
n_outputs = ubatch_pp.n_tokens;
|
|
288
|
-
|
|
289
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
290
|
-
|
|
291
|
-
auto * gf = graph_init();
|
|
292
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
293
|
-
|
|
294
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
292
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
293
|
+
if (!gf) {
|
|
295
294
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
296
295
|
}
|
|
297
296
|
|
|
@@ -301,16 +300,8 @@ llama_context::llama_context(
|
|
|
301
300
|
|
|
302
301
|
// reserve with tg graph to get the number of splits and nodes
|
|
303
302
|
{
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
n_outputs = ubatch_tg.n_tokens;
|
|
307
|
-
|
|
308
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
|
309
|
-
|
|
310
|
-
auto * gf = graph_init();
|
|
311
|
-
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
|
312
|
-
|
|
313
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
303
|
+
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
|
304
|
+
if (!gf) {
|
|
314
305
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
315
306
|
}
|
|
316
307
|
|
|
@@ -320,22 +311,12 @@ llama_context::llama_context(
|
|
|
320
311
|
|
|
321
312
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
322
313
|
{
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
n_outputs = ubatch_pp.n_tokens;
|
|
326
|
-
|
|
327
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
328
|
-
|
|
329
|
-
auto * gf = graph_init();
|
|
330
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
331
|
-
|
|
332
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
314
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
315
|
+
if (!gf) {
|
|
333
316
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
334
317
|
}
|
|
335
318
|
}
|
|
336
319
|
|
|
337
|
-
n_outputs = n_outputs_save;
|
|
338
|
-
|
|
339
320
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
340
321
|
ggml_backend_t backend = backend_ptrs[i];
|
|
341
322
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
@@ -439,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
|
|
|
439
420
|
return cparams.n_threads_batch;
|
|
440
421
|
}
|
|
441
422
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
return kv_self;
|
|
445
|
-
}
|
|
446
|
-
|
|
447
|
-
const llama_kv_cache * llama_context::get_kv_self() const {
|
|
448
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
449
|
-
return kv_self;
|
|
423
|
+
llama_memory_t llama_context::get_memory() const {
|
|
424
|
+
return memory.get();
|
|
450
425
|
}
|
|
451
426
|
|
|
452
|
-
|
|
453
|
-
|
|
427
|
+
// deprecated
|
|
428
|
+
void llama_context::kv_self_defrag_sched() {
|
|
429
|
+
if (!memory) {
|
|
430
|
+
return;
|
|
431
|
+
}
|
|
454
432
|
|
|
455
|
-
|
|
433
|
+
memory_force_optimize = true;
|
|
434
|
+
}
|
|
456
435
|
|
|
457
|
-
|
|
436
|
+
// deprecated
|
|
437
|
+
bool llama_context::kv_self_update(bool optimize) {
|
|
438
|
+
if (!memory) {
|
|
439
|
+
return false;
|
|
440
|
+
}
|
|
458
441
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
442
|
+
{
|
|
443
|
+
// TODO: remove in the future
|
|
444
|
+
optimize |= memory_force_optimize;
|
|
445
|
+
memory_force_optimize = false;
|
|
462
446
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
447
|
+
const auto mctx = memory->init_update(this, optimize);
|
|
448
|
+
switch (mctx->get_status()) {
|
|
449
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
450
|
+
{
|
|
451
|
+
// noop
|
|
452
|
+
} break;
|
|
453
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
|
454
|
+
{
|
|
455
|
+
// no updates need to be performed
|
|
456
|
+
return false;
|
|
457
|
+
}
|
|
458
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
459
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
460
|
+
{
|
|
461
|
+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
|
462
|
+
return false;
|
|
463
|
+
}
|
|
464
|
+
}
|
|
466
465
|
|
|
467
|
-
|
|
468
|
-
|
|
466
|
+
if (!mctx->apply()) {
|
|
467
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
|
468
|
+
}
|
|
469
|
+
}
|
|
469
470
|
|
|
470
|
-
|
|
471
|
-
|
|
471
|
+
// if the memory module did any computation, we have to reserve a new worst-case graph
|
|
472
|
+
{
|
|
473
|
+
const auto mctx = memory->init_full();
|
|
474
|
+
if (!mctx) {
|
|
475
|
+
throw std::runtime_error("failed to initialize memory context");
|
|
476
|
+
}
|
|
472
477
|
|
|
473
|
-
|
|
474
|
-
|
|
478
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
479
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
475
480
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
481
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
482
|
+
if (!gf) {
|
|
483
|
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
|
480
484
|
}
|
|
481
485
|
}
|
|
486
|
+
|
|
487
|
+
return true;
|
|
482
488
|
}
|
|
483
489
|
|
|
484
490
|
enum llama_pooling_type llama_context::pooling_type() const {
|
|
@@ -490,7 +496,7 @@ float * llama_context::get_logits() {
|
|
|
490
496
|
}
|
|
491
497
|
|
|
492
498
|
float * llama_context::get_logits_ith(int32_t i) {
|
|
493
|
-
|
|
499
|
+
int64_t j = -1;
|
|
494
500
|
|
|
495
501
|
try {
|
|
496
502
|
if (logits == nullptr) {
|
|
@@ -513,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
513
519
|
}
|
|
514
520
|
if (j >= n_outputs) {
|
|
515
521
|
// This should not happen
|
|
516
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
|
522
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
517
523
|
}
|
|
518
524
|
|
|
519
525
|
return logits + j*model.vocab.n_tokens();
|
|
@@ -532,7 +538,7 @@ float * llama_context::get_embeddings() {
|
|
|
532
538
|
}
|
|
533
539
|
|
|
534
540
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
535
|
-
|
|
541
|
+
int64_t j = -1;
|
|
536
542
|
|
|
537
543
|
try {
|
|
538
544
|
if (embd == nullptr) {
|
|
@@ -555,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
|
555
561
|
}
|
|
556
562
|
if (j >= n_outputs) {
|
|
557
563
|
// This should not happen
|
|
558
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
|
564
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
559
565
|
}
|
|
560
566
|
|
|
561
567
|
return embd + j*model.hparams.n_embd;
|
|
@@ -672,63 +678,95 @@ bool llama_context::apply_adapter_cvec(
|
|
|
672
678
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
|
673
679
|
}
|
|
674
680
|
|
|
675
|
-
|
|
676
|
-
if (
|
|
677
|
-
LLAMA_LOG_ERROR("%s:
|
|
678
|
-
|
|
681
|
+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
|
682
|
+
if (mctx && !mctx->apply()) {
|
|
683
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
|
684
|
+
ret = GGML_STATUS_FAILED;
|
|
685
|
+
return nullptr;
|
|
679
686
|
}
|
|
680
687
|
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
688
|
+
auto * gf = graph_init();
|
|
689
|
+
if (!gf) {
|
|
690
|
+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
|
691
|
+
ret = GGML_STATUS_FAILED;
|
|
692
|
+
return nullptr;
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
|
696
|
+
if (!res) {
|
|
697
|
+
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
|
698
|
+
ret = GGML_STATUS_FAILED;
|
|
699
|
+
return nullptr;
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
|
703
|
+
|
|
704
|
+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
|
705
|
+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
|
706
|
+
ret = GGML_STATUS_ALLOC_FAILED;
|
|
707
|
+
return nullptr;
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
res->set_inputs(&ubatch);
|
|
711
|
+
|
|
712
|
+
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
|
713
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
714
|
+
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
|
715
|
+
ret = status;
|
|
716
|
+
return nullptr;
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
ret = GGML_STATUS_SUCCESS;
|
|
684
720
|
|
|
685
|
-
|
|
686
|
-
|
|
721
|
+
return res;
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
int llama_context::encode(const llama_batch & batch_inp) {
|
|
725
|
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
|
726
|
+
|
|
727
|
+
if (batch_inp.n_tokens == 0) {
|
|
728
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
729
|
+
return -1;
|
|
730
|
+
}
|
|
687
731
|
|
|
688
732
|
const auto & hparams = model.hparams;
|
|
689
733
|
|
|
690
|
-
|
|
734
|
+
const int64_t n_embd = hparams.n_embd;
|
|
691
735
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
return -1;
|
|
697
|
-
}
|
|
698
|
-
}
|
|
736
|
+
// note: during encode, we always pass the full sequence starting from pos = 0
|
|
737
|
+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
|
738
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
739
|
+
return -1;
|
|
699
740
|
}
|
|
700
741
|
|
|
742
|
+
const uint32_t n_tokens = balloc->get_n_tokens();
|
|
743
|
+
|
|
744
|
+
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
|
745
|
+
|
|
701
746
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
702
|
-
GGML_ASSERT(cparams.n_ubatch >=
|
|
747
|
+
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
703
748
|
|
|
704
749
|
if (t_compute_start_us == 0) {
|
|
705
750
|
t_compute_start_us = ggml_time_us();
|
|
706
751
|
}
|
|
707
752
|
|
|
753
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
708
754
|
embd_seq.clear();
|
|
709
755
|
|
|
710
756
|
n_queued_tokens += n_tokens;
|
|
711
757
|
|
|
712
|
-
const int64_t n_embd = hparams.n_embd;
|
|
713
|
-
|
|
714
|
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
715
|
-
|
|
716
|
-
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
|
717
|
-
|
|
718
758
|
// reserve output buffer
|
|
719
759
|
if (output_reserve(n_tokens) < n_tokens) {
|
|
720
760
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
721
761
|
return -2;
|
|
722
762
|
};
|
|
723
763
|
|
|
724
|
-
for (
|
|
764
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
|
725
765
|
output_ids[i] = i;
|
|
726
766
|
}
|
|
727
767
|
|
|
728
768
|
n_outputs = n_tokens;
|
|
729
769
|
|
|
730
|
-
//batch_manager->prepare(ubatch);
|
|
731
|
-
|
|
732
770
|
ggml_backend_sched_reset(sched.get());
|
|
733
771
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
734
772
|
|
|
@@ -739,26 +777,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
739
777
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
|
740
778
|
cparams.causal_attn = false;
|
|
741
779
|
|
|
742
|
-
|
|
743
|
-
auto res =
|
|
744
|
-
|
|
745
|
-
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
746
|
-
|
|
747
|
-
res->set_inputs(&ubatch);
|
|
780
|
+
ggml_status status;
|
|
781
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
|
748
782
|
|
|
749
783
|
cparams.causal_attn = causal_attn_org;
|
|
750
784
|
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
return -2;
|
|
759
|
-
case GGML_STATUS_FAILED:
|
|
760
|
-
default:
|
|
761
|
-
return -3;
|
|
785
|
+
if (!res) {
|
|
786
|
+
switch (status) {
|
|
787
|
+
case GGML_STATUS_ABORTED: return 2;
|
|
788
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
|
789
|
+
case GGML_STATUS_FAILED: return -3;
|
|
790
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
|
791
|
+
}
|
|
762
792
|
}
|
|
763
793
|
|
|
764
794
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
@@ -783,31 +813,28 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
783
813
|
{
|
|
784
814
|
// extract sequence embeddings
|
|
785
815
|
auto & embd_seq_out = embd_seq;
|
|
786
|
-
embd_seq_out.clear();
|
|
787
816
|
|
|
788
|
-
|
|
817
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
|
818
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
|
819
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
|
789
820
|
|
|
790
|
-
for (int32_t i = 0; i < n_tokens; i++) {
|
|
791
|
-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
792
|
-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
793
|
-
continue;
|
|
794
|
-
}
|
|
795
821
|
embd_seq_out[seq_id].resize(n_embd);
|
|
796
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
|
822
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
|
797
823
|
}
|
|
798
824
|
} break;
|
|
799
825
|
case LLAMA_POOLING_TYPE_RANK:
|
|
800
826
|
{
|
|
801
|
-
// extract the rerank score -
|
|
827
|
+
// extract the rerank score - n_cls_out floats per sequence
|
|
802
828
|
auto & embd_seq_out = embd_seq;
|
|
803
829
|
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
830
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
|
831
|
+
|
|
832
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
|
833
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
|
834
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
|
835
|
+
|
|
836
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
|
837
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
|
811
838
|
}
|
|
812
839
|
} break;
|
|
813
840
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
@@ -832,12 +859,16 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
832
859
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
|
833
860
|
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
|
834
861
|
|
|
862
|
+
const auto & batch = balloc->get_batch();
|
|
863
|
+
|
|
835
864
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
|
836
865
|
cross.seq_ids_enc.resize(n_tokens);
|
|
837
|
-
for (
|
|
866
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
838
867
|
cross.seq_ids_enc[i].clear();
|
|
839
|
-
|
|
840
|
-
|
|
868
|
+
|
|
869
|
+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
|
870
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
|
871
|
+
|
|
841
872
|
cross.seq_ids_enc[i].insert(seq_id);
|
|
842
873
|
}
|
|
843
874
|
}
|
|
@@ -846,49 +877,42 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
846
877
|
return 0;
|
|
847
878
|
}
|
|
848
879
|
|
|
849
|
-
int llama_context::decode(llama_batch &
|
|
880
|
+
int llama_context::decode(const llama_batch & batch_inp) {
|
|
881
|
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
|
882
|
+
|
|
850
883
|
if (!memory) {
|
|
851
|
-
|
|
852
|
-
return encode(
|
|
884
|
+
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
|
885
|
+
return encode(batch_inp);
|
|
853
886
|
}
|
|
854
887
|
|
|
855
|
-
if (
|
|
888
|
+
if (batch_inp.n_tokens == 0) {
|
|
856
889
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
857
890
|
return -1;
|
|
858
891
|
}
|
|
859
892
|
|
|
860
|
-
if (!inp_batch.pos) {
|
|
861
|
-
if (inp_batch.seq_id) {
|
|
862
|
-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
|
863
|
-
return -1;
|
|
864
|
-
}
|
|
865
|
-
}
|
|
866
|
-
|
|
867
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
868
|
-
|
|
869
|
-
// temporary allocate memory for the input batch if needed
|
|
870
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
|
871
|
-
|
|
872
|
-
const llama_batch & batch = batch_allocr.batch;
|
|
873
|
-
|
|
874
893
|
const auto & vocab = model.vocab;
|
|
875
894
|
const auto & hparams = model.hparams;
|
|
876
895
|
|
|
877
896
|
const int32_t n_vocab = vocab.n_tokens();
|
|
897
|
+
const int64_t n_embd = hparams.n_embd;
|
|
878
898
|
|
|
879
|
-
|
|
880
|
-
const
|
|
899
|
+
// when computing embeddings, all tokens are output
|
|
900
|
+
const bool output_all = cparams.embeddings;
|
|
881
901
|
|
|
882
|
-
|
|
902
|
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
|
903
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
904
|
+
return -1;
|
|
905
|
+
}
|
|
883
906
|
|
|
884
|
-
|
|
907
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
|
908
|
+
const uint32_t n_outputs_all = balloc->get_n_outputs();
|
|
885
909
|
|
|
886
|
-
if (
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
910
|
+
if (output_all) {
|
|
911
|
+
// require that all tokens are output
|
|
912
|
+
if (n_outputs_all != n_tokens_all) {
|
|
913
|
+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
|
914
|
+
__func__, n_outputs_all, n_tokens_all);
|
|
915
|
+
return -1;
|
|
892
916
|
}
|
|
893
917
|
}
|
|
894
918
|
|
|
@@ -901,49 +925,77 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
901
925
|
}
|
|
902
926
|
n_queued_tokens += n_tokens_all;
|
|
903
927
|
|
|
904
|
-
// this
|
|
905
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
906
|
-
|
|
928
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
907
929
|
embd_seq.clear();
|
|
908
930
|
|
|
909
|
-
|
|
931
|
+
bool did_optimize = false;
|
|
932
|
+
|
|
933
|
+
// handle any pending defrags/shifts
|
|
934
|
+
kv_self_update(false);
|
|
935
|
+
|
|
936
|
+
llama_memory_context_ptr mctx;
|
|
910
937
|
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
938
|
+
while (true) {
|
|
939
|
+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
|
940
|
+
if (!mctx) {
|
|
941
|
+
return -2;
|
|
915
942
|
}
|
|
916
|
-
} else if (embd_pooled) {
|
|
917
|
-
n_outputs_all = n_tokens_all;
|
|
918
|
-
} else {
|
|
919
|
-
// keep last output only
|
|
920
|
-
n_outputs_all = 1;
|
|
921
|
-
}
|
|
922
943
|
|
|
923
|
-
|
|
944
|
+
switch (mctx->get_status()) {
|
|
945
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
946
|
+
{
|
|
947
|
+
} break;
|
|
948
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
|
949
|
+
{
|
|
950
|
+
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
|
951
|
+
|
|
952
|
+
return -2;
|
|
953
|
+
}
|
|
954
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
955
|
+
{
|
|
956
|
+
if (!did_optimize) {
|
|
957
|
+
did_optimize = true;
|
|
958
|
+
|
|
959
|
+
if (kv_self_update(true)) {
|
|
960
|
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
|
961
|
+
|
|
962
|
+
continue;
|
|
963
|
+
}
|
|
964
|
+
}
|
|
965
|
+
|
|
966
|
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
|
|
967
|
+
|
|
968
|
+
return 1;
|
|
969
|
+
}
|
|
970
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
971
|
+
{
|
|
972
|
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
|
|
973
|
+
|
|
974
|
+
return -2;
|
|
975
|
+
}
|
|
976
|
+
}
|
|
977
|
+
|
|
978
|
+
break;
|
|
979
|
+
}
|
|
924
980
|
|
|
925
981
|
// reserve output buffer
|
|
926
982
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
927
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
|
983
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
928
984
|
return -2;
|
|
929
985
|
};
|
|
930
986
|
|
|
931
|
-
// handle any pending defrags/shifts
|
|
932
|
-
kv_self_update();
|
|
933
|
-
|
|
934
987
|
int64_t n_outputs_prev = 0;
|
|
935
988
|
|
|
936
|
-
|
|
937
|
-
|
|
989
|
+
do {
|
|
990
|
+
const auto & ubatch = mctx->get_ubatch();
|
|
938
991
|
|
|
939
|
-
// count the outputs in this
|
|
992
|
+
// count the outputs in this ubatch
|
|
940
993
|
{
|
|
941
994
|
int32_t n_outputs_new = 0;
|
|
942
995
|
|
|
943
996
|
if (n_outputs_all == n_tokens_all) {
|
|
944
997
|
n_outputs_new = ubatch.n_tokens;
|
|
945
998
|
} else {
|
|
946
|
-
GGML_ASSERT(ubatch.output);
|
|
947
999
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
948
1000
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
|
949
1001
|
}
|
|
@@ -953,33 +1005,41 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
953
1005
|
n_outputs = n_outputs_new;
|
|
954
1006
|
}
|
|
955
1007
|
|
|
956
|
-
// find KV slot
|
|
957
|
-
if (!kv_self->find_slot(ubatch)) {
|
|
958
|
-
return 1;
|
|
959
|
-
}
|
|
960
|
-
|
|
961
1008
|
ggml_backend_sched_reset(sched.get());
|
|
962
1009
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
963
1010
|
|
|
964
|
-
|
|
965
|
-
auto res =
|
|
1011
|
+
ggml_status status;
|
|
1012
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
|
966
1013
|
|
|
967
|
-
|
|
1014
|
+
if (!res) {
|
|
1015
|
+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
|
1016
|
+
llama_pos pos_min[LLAMA_MAX_SEQ];
|
|
1017
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
1018
|
+
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
|
1019
|
+
}
|
|
968
1020
|
|
|
969
|
-
|
|
1021
|
+
// TODO: fix sequence indexing
|
|
1022
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
1023
|
+
const auto & seq_id = ubatch.seq_id[i][0];
|
|
970
1024
|
|
|
971
|
-
|
|
1025
|
+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
|
1026
|
+
}
|
|
972
1027
|
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
1028
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
1029
|
+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
|
1030
|
+
continue;
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
|
1034
|
+
|
|
1035
|
+
memory->seq_rm(s, pos_min[s], -1);
|
|
1036
|
+
}
|
|
1037
|
+
|
|
1038
|
+
switch (status) {
|
|
1039
|
+
case GGML_STATUS_ABORTED: return 2;
|
|
1040
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
|
1041
|
+
case GGML_STATUS_FAILED: return -3;
|
|
1042
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
|
983
1043
|
}
|
|
984
1044
|
}
|
|
985
1045
|
|
|
@@ -988,7 +1048,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
988
1048
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
989
1049
|
//}
|
|
990
1050
|
|
|
991
|
-
auto * t_logits =
|
|
1051
|
+
auto * t_logits = res->get_logits();
|
|
992
1052
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
|
993
1053
|
|
|
994
1054
|
if (t_embd && res->get_embd_pooled()) {
|
|
@@ -1035,27 +1095,27 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1035
1095
|
// extract sequence embeddings (cleared before processing each batch)
|
|
1036
1096
|
auto & embd_seq_out = embd_seq;
|
|
1037
1097
|
|
|
1038
|
-
for (uint32_t s = 0; s < ubatch.
|
|
1039
|
-
const llama_seq_id seq_id
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
}
|
|
1098
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
|
1099
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
|
1100
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
|
1101
|
+
|
|
1043
1102
|
embd_seq_out[seq_id].resize(n_embd);
|
|
1044
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
|
1103
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
|
1045
1104
|
}
|
|
1046
1105
|
} break;
|
|
1047
1106
|
case LLAMA_POOLING_TYPE_RANK:
|
|
1048
1107
|
{
|
|
1049
|
-
// extract the rerank score -
|
|
1108
|
+
// extract the rerank score - n_cls_out floats per sequence
|
|
1050
1109
|
auto & embd_seq_out = embd_seq;
|
|
1051
1110
|
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1111
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
|
1112
|
+
|
|
1113
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
|
1114
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
|
1115
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
|
1116
|
+
|
|
1117
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
|
1118
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
|
1059
1119
|
}
|
|
1060
1120
|
} break;
|
|
1061
1121
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
@@ -1066,23 +1126,20 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1066
1126
|
}
|
|
1067
1127
|
|
|
1068
1128
|
n_outputs_prev += n_outputs;
|
|
1069
|
-
}
|
|
1070
|
-
|
|
1071
|
-
// finalize the batch processing
|
|
1072
|
-
kv_guard.commit();
|
|
1129
|
+
} while (mctx->next());
|
|
1073
1130
|
|
|
1074
1131
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
1075
1132
|
n_outputs = n_outputs_all;
|
|
1076
1133
|
|
|
1077
1134
|
// set output mappings
|
|
1078
|
-
{
|
|
1135
|
+
if (n_outputs > 0) {
|
|
1079
1136
|
bool sorted_output = true;
|
|
1080
1137
|
|
|
1081
|
-
auto & out_ids =
|
|
1138
|
+
auto & out_ids = balloc->get_out_ids();
|
|
1082
1139
|
|
|
1083
|
-
GGML_ASSERT(out_ids.size() == (size_t)
|
|
1140
|
+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
|
1084
1141
|
|
|
1085
|
-
for (int64_t i = 0; i <
|
|
1142
|
+
for (int64_t i = 0; i < n_outputs; ++i) {
|
|
1086
1143
|
int64_t out_id = out_ids[i];
|
|
1087
1144
|
output_ids[out_id] = i;
|
|
1088
1145
|
if (out_id != i) {
|
|
@@ -1094,20 +1151,22 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1094
1151
|
// note: this is mostly relevant for recurrent models atm
|
|
1095
1152
|
if (!sorted_output) {
|
|
1096
1153
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
1097
|
-
const
|
|
1154
|
+
const uint64_t n_embd = model.hparams.n_embd;
|
|
1098
1155
|
|
|
1099
1156
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1100
1157
|
|
|
1101
1158
|
// TODO: is there something more efficient which also minimizes swaps?
|
|
1102
1159
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
1103
|
-
for (
|
|
1104
|
-
|
|
1105
|
-
for (
|
|
1160
|
+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
|
1161
|
+
uint32_t j_min = i;
|
|
1162
|
+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
|
1106
1163
|
if (out_ids[j] < out_ids[j_min]) {
|
|
1107
1164
|
j_min = j;
|
|
1108
1165
|
}
|
|
1109
1166
|
}
|
|
1110
|
-
if (j_min == i) {
|
|
1167
|
+
if (j_min == i) {
|
|
1168
|
+
continue;
|
|
1169
|
+
}
|
|
1111
1170
|
std::swap(out_ids[i], out_ids[j_min]);
|
|
1112
1171
|
if (logits_size > 0) {
|
|
1113
1172
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
@@ -1120,8 +1179,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1120
1179
|
}
|
|
1121
1180
|
}
|
|
1122
1181
|
}
|
|
1182
|
+
|
|
1123
1183
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1124
|
-
|
|
1184
|
+
|
|
1185
|
+
for (uint32_t i = 0; i < n_outputs; ++i) {
|
|
1125
1186
|
output_ids[out_ids[i]] = i;
|
|
1126
1187
|
}
|
|
1127
1188
|
}
|
|
@@ -1130,11 +1191,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1130
1191
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
|
1131
1192
|
//synchronize();
|
|
1132
1193
|
|
|
1133
|
-
// decide if we need to defrag the kv cache
|
|
1134
|
-
if (cparams.defrag_thold > 0.0f) {
|
|
1135
|
-
kv_self->defrag_sched(cparams.defrag_thold);
|
|
1136
|
-
}
|
|
1137
|
-
|
|
1138
1194
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
1139
1195
|
// overlap with device computation.
|
|
1140
1196
|
ggml_backend_sched_reset(sched.get());
|
|
@@ -1146,7 +1202,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1146
1202
|
// output
|
|
1147
1203
|
//
|
|
1148
1204
|
|
|
1149
|
-
|
|
1205
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1150
1206
|
const auto & hparams = model.hparams;
|
|
1151
1207
|
const auto & vocab = model.vocab;
|
|
1152
1208
|
|
|
@@ -1156,9 +1212,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1156
1212
|
const auto n_vocab = vocab.n_tokens();
|
|
1157
1213
|
const auto n_embd = hparams.n_embd;
|
|
1158
1214
|
|
|
1159
|
-
|
|
1160
|
-
bool
|
|
1161
|
-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
1215
|
+
bool has_logits = true;
|
|
1216
|
+
bool has_embd = cparams.embeddings;
|
|
1162
1217
|
|
|
1163
1218
|
// TODO: hacky enc-dec support
|
|
1164
1219
|
if (model.arch == LLM_ARCH_T5) {
|
|
@@ -1212,8 +1267,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1212
1267
|
// set all ids as invalid (negative)
|
|
1213
1268
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1214
1269
|
|
|
1215
|
-
this->n_outputs
|
|
1216
|
-
this->n_outputs_max = n_outputs_max;
|
|
1270
|
+
this->n_outputs = 0;
|
|
1217
1271
|
|
|
1218
1272
|
return n_outputs_max;
|
|
1219
1273
|
}
|
|
@@ -1238,11 +1292,52 @@ ggml_cgraph * llama_context::graph_init() {
|
|
|
1238
1292
|
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
|
1239
1293
|
}
|
|
1240
1294
|
|
|
1295
|
+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
|
1296
|
+
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1297
|
+
|
|
1298
|
+
if (n_tokens % n_seqs != 0) {
|
|
1299
|
+
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
|
1300
|
+
n_outputs = std::min(n_outputs, n_tokens);
|
|
1301
|
+
|
|
1302
|
+
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1303
|
+
}
|
|
1304
|
+
|
|
1305
|
+
// store the n_outputs as it is, and restore it afterwards
|
|
1306
|
+
// TODO: not sure if needed, might simplify in the future by removing this
|
|
1307
|
+
const auto save_n_outputs = this->n_outputs;
|
|
1308
|
+
|
|
1309
|
+
this->n_outputs = n_outputs;
|
|
1310
|
+
|
|
1311
|
+
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
|
1312
|
+
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
|
1313
|
+
|
|
1314
|
+
auto * gf = graph_init();
|
|
1315
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
|
1316
|
+
|
|
1317
|
+
this->n_outputs = save_n_outputs;
|
|
1318
|
+
|
|
1319
|
+
if (!res) {
|
|
1320
|
+
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
|
1321
|
+
return nullptr;
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
ggml_backend_sched_reset(sched.get());
|
|
1325
|
+
|
|
1326
|
+
// initialize scheduler with the specified graph
|
|
1327
|
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
1328
|
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
1329
|
+
return nullptr;
|
|
1330
|
+
}
|
|
1331
|
+
|
|
1332
|
+
return gf;
|
|
1333
|
+
}
|
|
1334
|
+
|
|
1241
1335
|
llm_graph_result_ptr llama_context::graph_build(
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1336
|
+
ggml_context * ctx,
|
|
1337
|
+
ggml_cgraph * gf,
|
|
1338
|
+
const llama_ubatch & ubatch,
|
|
1339
|
+
llm_graph_type gtype,
|
|
1340
|
+
const llama_memory_context_i * mctx) {
|
|
1246
1341
|
return model.build_graph(
|
|
1247
1342
|
{
|
|
1248
1343
|
/*.ctx =*/ ctx,
|
|
@@ -1254,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
|
1254
1349
|
/*.backend_cpu =*/ backend_cpu,
|
|
1255
1350
|
/*.cvec =*/ &cvec,
|
|
1256
1351
|
/*.loras =*/ &loras,
|
|
1257
|
-
/*.
|
|
1352
|
+
/*.mctx =*/ mctx,
|
|
1258
1353
|
/*.cross =*/ &cross,
|
|
1259
1354
|
/*.n_outputs =*/ n_outputs,
|
|
1260
1355
|
/*.cb =*/ graph_get_cb(),
|
|
@@ -1663,14 +1758,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1663
1758
|
|
|
1664
1759
|
std::vector<int32_t> w_output_pos;
|
|
1665
1760
|
|
|
1666
|
-
GGML_ASSERT(n_outputs <= n_outputs_max);
|
|
1667
|
-
|
|
1668
1761
|
w_output_pos.resize(n_outputs);
|
|
1669
1762
|
|
|
1670
1763
|
// build a more compact representation of the output ids
|
|
1671
1764
|
for (size_t i = 0; i < n_batch(); ++i) {
|
|
1672
1765
|
// map an output id to a position in the batch
|
|
1673
|
-
|
|
1766
|
+
int64_t pos = output_ids[i];
|
|
1674
1767
|
if (pos >= 0) {
|
|
1675
1768
|
GGML_ASSERT(pos < n_outputs);
|
|
1676
1769
|
w_output_pos[pos] = i;
|
|
@@ -1710,11 +1803,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1710
1803
|
}
|
|
1711
1804
|
}
|
|
1712
1805
|
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
if (kv_self != nullptr) {
|
|
1806
|
+
if (memory != nullptr) {
|
|
1716
1807
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
|
1717
|
-
|
|
1808
|
+
memory->state_write(io);
|
|
1718
1809
|
}
|
|
1719
1810
|
|
|
1720
1811
|
return io.n_bytes();
|
|
@@ -1801,9 +1892,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
1801
1892
|
if (memory) {
|
|
1802
1893
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
|
1803
1894
|
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
kv_self->state_read(io);
|
|
1895
|
+
memory->state_read(io);
|
|
1807
1896
|
}
|
|
1808
1897
|
|
|
1809
1898
|
return io.n_bytes();
|
|
@@ -1813,9 +1902,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
|
1813
1902
|
GGML_UNUSED(seq_id);
|
|
1814
1903
|
|
|
1815
1904
|
if (memory) {
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
kv_self->state_write(io, seq_id);
|
|
1905
|
+
memory->state_write(io, seq_id);
|
|
1819
1906
|
}
|
|
1820
1907
|
|
|
1821
1908
|
return io.n_bytes();
|
|
@@ -1825,9 +1912,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
|
|
1825
1912
|
GGML_UNUSED(seq_id);
|
|
1826
1913
|
|
|
1827
1914
|
if (memory) {
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
kv_self->state_read(io, seq_id);
|
|
1915
|
+
memory->state_read(io, seq_id);
|
|
1831
1916
|
}
|
|
1832
1917
|
|
|
1833
1918
|
return io.n_bytes();
|
|
@@ -1932,10 +2017,7 @@ void llama_context::opt_epoch_iter(
|
|
|
1932
2017
|
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
|
1933
2018
|
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
|
1934
2019
|
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
kv_self->clear();
|
|
1938
|
-
llama_kv_cache_guard kv_guard(kv_self);
|
|
2020
|
+
memory->clear(true);
|
|
1939
2021
|
|
|
1940
2022
|
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
|
1941
2023
|
batch.n_tokens = n_batch;
|
|
@@ -1947,39 +2029,44 @@ void llama_context::opt_epoch_iter(
|
|
|
1947
2029
|
batch.logits [pos_batch] = true;
|
|
1948
2030
|
}
|
|
1949
2031
|
|
|
1950
|
-
|
|
2032
|
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
|
2033
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
2034
|
+
return;
|
|
2035
|
+
}
|
|
1951
2036
|
|
|
1952
|
-
|
|
2037
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
|
1953
2038
|
|
|
1954
|
-
|
|
1955
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
2039
|
+
n_queued_tokens += n_tokens_all;
|
|
1956
2040
|
|
|
1957
2041
|
embd_seq.clear();
|
|
1958
2042
|
|
|
1959
|
-
|
|
2043
|
+
uint32_t n_outputs_all = n_tokens_all;
|
|
1960
2044
|
|
|
1961
|
-
|
|
2045
|
+
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
|
2046
|
+
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
|
2047
|
+
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
|
2048
|
+
break;
|
|
2049
|
+
}
|
|
1962
2050
|
|
|
1963
2051
|
// reserve output buffer
|
|
1964
2052
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
1965
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
|
2053
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
1966
2054
|
GGML_ABORT("TODO: handle this error");
|
|
1967
2055
|
};
|
|
1968
2056
|
|
|
1969
|
-
|
|
1970
|
-
|
|
2057
|
+
uint32_t pos_batch = 0;
|
|
2058
|
+
do {
|
|
2059
|
+
const auto & ubatch = mctx->get_ubatch();
|
|
1971
2060
|
|
|
1972
2061
|
n_outputs = ubatch.n_tokens;
|
|
1973
2062
|
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
GGML_ABORT("TODO: handle this error");
|
|
2063
|
+
if (!mctx->apply()) {
|
|
2064
|
+
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
|
2065
|
+
break;
|
|
1979
2066
|
}
|
|
1980
2067
|
|
|
1981
2068
|
auto * gf = graph_init();
|
|
1982
|
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
|
2069
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
|
1983
2070
|
|
|
1984
2071
|
struct ggml_context * ctx_compute_opt;
|
|
1985
2072
|
{
|
|
@@ -1994,6 +2081,7 @@ void llama_context::opt_epoch_iter(
|
|
|
1994
2081
|
}
|
|
1995
2082
|
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
|
1996
2083
|
ggml_opt_alloc(opt_ctx, train);
|
|
2084
|
+
|
|
1997
2085
|
res->set_inputs(&ubatch);
|
|
1998
2086
|
{
|
|
1999
2087
|
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
|
@@ -2011,10 +2099,10 @@ void llama_context::opt_epoch_iter(
|
|
|
2011
2099
|
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
|
2012
2100
|
}
|
|
2013
2101
|
ggml_free(ctx_compute_opt);
|
|
2014
|
-
}
|
|
2015
|
-
}
|
|
2016
2102
|
|
|
2017
|
-
|
|
2103
|
+
pos_batch += ubatch.n_tokens;
|
|
2104
|
+
} while (mctx->next());
|
|
2105
|
+
}
|
|
2018
2106
|
}
|
|
2019
2107
|
|
|
2020
2108
|
void llama_context::opt_epoch(
|
|
@@ -2174,12 +2262,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|
|
2174
2262
|
return &ctx->get_model();
|
|
2175
2263
|
}
|
|
2176
2264
|
|
|
2265
|
+
// deprecated
|
|
2177
2266
|
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|
2178
|
-
return ctx->
|
|
2267
|
+
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
|
2179
2268
|
}
|
|
2180
2269
|
|
|
2270
|
+
// deprecated
|
|
2181
2271
|
void llama_kv_self_update(llama_context * ctx) {
|
|
2182
|
-
ctx->kv_self_update();
|
|
2272
|
+
ctx->kv_self_update(false);
|
|
2183
2273
|
}
|
|
2184
2274
|
|
|
2185
2275
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
|
@@ -2294,13 +2384,118 @@ int32_t llama_apply_adapter_cvec(
|
|
|
2294
2384
|
return res ? 0 : -1;
|
|
2295
2385
|
}
|
|
2296
2386
|
|
|
2387
|
+
//
|
|
2388
|
+
// memory
|
|
2389
|
+
//
|
|
2390
|
+
|
|
2391
|
+
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
|
2392
|
+
return ctx->get_memory();
|
|
2393
|
+
}
|
|
2394
|
+
|
|
2395
|
+
void llama_memory_clear(llama_memory_t mem, bool data) {
|
|
2396
|
+
if (!mem) {
|
|
2397
|
+
return;
|
|
2398
|
+
}
|
|
2399
|
+
|
|
2400
|
+
mem->clear(data);
|
|
2401
|
+
}
|
|
2402
|
+
|
|
2403
|
+
bool llama_memory_seq_rm(
|
|
2404
|
+
llama_memory_t mem,
|
|
2405
|
+
llama_seq_id seq_id,
|
|
2406
|
+
llama_pos p0,
|
|
2407
|
+
llama_pos p1) {
|
|
2408
|
+
if (!mem) {
|
|
2409
|
+
return true;
|
|
2410
|
+
}
|
|
2411
|
+
|
|
2412
|
+
return mem->seq_rm(seq_id, p0, p1);
|
|
2413
|
+
}
|
|
2414
|
+
|
|
2415
|
+
void llama_memory_seq_cp(
|
|
2416
|
+
llama_memory_t mem,
|
|
2417
|
+
llama_seq_id seq_id_src,
|
|
2418
|
+
llama_seq_id seq_id_dst,
|
|
2419
|
+
llama_pos p0,
|
|
2420
|
+
llama_pos p1) {
|
|
2421
|
+
if (!mem) {
|
|
2422
|
+
return;
|
|
2423
|
+
}
|
|
2424
|
+
|
|
2425
|
+
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
2426
|
+
}
|
|
2427
|
+
|
|
2428
|
+
void llama_memory_seq_keep(
|
|
2429
|
+
llama_memory_t mem,
|
|
2430
|
+
llama_seq_id seq_id) {
|
|
2431
|
+
if (!mem) {
|
|
2432
|
+
return;
|
|
2433
|
+
}
|
|
2434
|
+
|
|
2435
|
+
mem->seq_keep(seq_id);
|
|
2436
|
+
}
|
|
2437
|
+
|
|
2438
|
+
void llama_memory_seq_add(
|
|
2439
|
+
llama_memory_t mem,
|
|
2440
|
+
llama_seq_id seq_id,
|
|
2441
|
+
llama_pos p0,
|
|
2442
|
+
llama_pos p1,
|
|
2443
|
+
llama_pos delta) {
|
|
2444
|
+
if (!mem) {
|
|
2445
|
+
return;
|
|
2446
|
+
}
|
|
2447
|
+
|
|
2448
|
+
mem->seq_add(seq_id, p0, p1, delta);
|
|
2449
|
+
}
|
|
2450
|
+
|
|
2451
|
+
void llama_memory_seq_div(
|
|
2452
|
+
llama_memory_t mem,
|
|
2453
|
+
llama_seq_id seq_id,
|
|
2454
|
+
llama_pos p0,
|
|
2455
|
+
llama_pos p1,
|
|
2456
|
+
int d) {
|
|
2457
|
+
if (!mem) {
|
|
2458
|
+
return;
|
|
2459
|
+
}
|
|
2460
|
+
|
|
2461
|
+
mem->seq_div(seq_id, p0, p1, d);
|
|
2462
|
+
}
|
|
2463
|
+
|
|
2464
|
+
llama_pos llama_memory_seq_pos_min(
|
|
2465
|
+
llama_memory_t mem,
|
|
2466
|
+
llama_seq_id seq_id) {
|
|
2467
|
+
if (!mem) {
|
|
2468
|
+
return -1;
|
|
2469
|
+
}
|
|
2470
|
+
|
|
2471
|
+
return mem->seq_pos_min(seq_id);
|
|
2472
|
+
}
|
|
2473
|
+
|
|
2474
|
+
llama_pos llama_memory_seq_pos_max(
|
|
2475
|
+
llama_memory_t mem,
|
|
2476
|
+
llama_seq_id seq_id) {
|
|
2477
|
+
if (!mem) {
|
|
2478
|
+
return -1;
|
|
2479
|
+
}
|
|
2480
|
+
|
|
2481
|
+
return mem->seq_pos_max(seq_id);
|
|
2482
|
+
}
|
|
2483
|
+
|
|
2484
|
+
bool llama_memory_can_shift(llama_memory_t mem) {
|
|
2485
|
+
if (!mem) {
|
|
2486
|
+
return false;
|
|
2487
|
+
}
|
|
2488
|
+
|
|
2489
|
+
return mem->get_can_shift();
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2297
2492
|
//
|
|
2298
2493
|
// kv cache
|
|
2299
2494
|
//
|
|
2300
2495
|
|
|
2301
2496
|
// deprecated
|
|
2302
2497
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
2303
|
-
const auto * kv = ctx
|
|
2498
|
+
const auto * kv = llama_get_memory(ctx);
|
|
2304
2499
|
if (!kv) {
|
|
2305
2500
|
return 0;
|
|
2306
2501
|
}
|
|
@@ -2322,7 +2517,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
|
2322
2517
|
// deprecated
|
|
2323
2518
|
// note: this is the same as above - will be removed anyway, so it's ok
|
|
2324
2519
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
2325
|
-
const auto * kv = ctx
|
|
2520
|
+
const auto * kv = llama_get_memory(ctx);
|
|
2326
2521
|
if (!kv) {
|
|
2327
2522
|
return 0;
|
|
2328
2523
|
}
|
|
@@ -2341,114 +2536,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
|
2341
2536
|
return res;
|
|
2342
2537
|
}
|
|
2343
2538
|
|
|
2539
|
+
// deprecated
|
|
2344
2540
|
void llama_kv_self_clear(llama_context * ctx) {
|
|
2345
|
-
auto * kv = ctx
|
|
2541
|
+
auto * kv = llama_get_memory(ctx);
|
|
2346
2542
|
if (!kv) {
|
|
2347
2543
|
return;
|
|
2348
2544
|
}
|
|
2349
2545
|
|
|
2350
|
-
kv
|
|
2546
|
+
llama_memory_clear(kv, true);
|
|
2351
2547
|
}
|
|
2352
2548
|
|
|
2549
|
+
// deprecated
|
|
2353
2550
|
bool llama_kv_self_seq_rm(
|
|
2354
2551
|
llama_context * ctx,
|
|
2355
2552
|
llama_seq_id seq_id,
|
|
2356
2553
|
llama_pos p0,
|
|
2357
2554
|
llama_pos p1) {
|
|
2358
|
-
auto * kv = ctx
|
|
2555
|
+
auto * kv = llama_get_memory(ctx);
|
|
2359
2556
|
if (!kv) {
|
|
2360
2557
|
return true;
|
|
2361
2558
|
}
|
|
2362
2559
|
|
|
2363
|
-
return kv
|
|
2560
|
+
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
|
2364
2561
|
}
|
|
2365
2562
|
|
|
2563
|
+
// deprecated
|
|
2366
2564
|
void llama_kv_self_seq_cp(
|
|
2367
2565
|
llama_context * ctx,
|
|
2368
2566
|
llama_seq_id seq_id_src,
|
|
2369
2567
|
llama_seq_id seq_id_dst,
|
|
2370
2568
|
llama_pos p0,
|
|
2371
2569
|
llama_pos p1) {
|
|
2372
|
-
auto * kv = ctx
|
|
2570
|
+
auto * kv = llama_get_memory(ctx);
|
|
2373
2571
|
if (!kv) {
|
|
2374
2572
|
return;
|
|
2375
2573
|
}
|
|
2376
2574
|
|
|
2377
|
-
kv
|
|
2575
|
+
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
|
2378
2576
|
}
|
|
2379
2577
|
|
|
2578
|
+
// deprecated
|
|
2380
2579
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
2381
|
-
auto * kv = ctx
|
|
2580
|
+
auto * kv = llama_get_memory(ctx);
|
|
2382
2581
|
if (!kv) {
|
|
2383
2582
|
return;
|
|
2384
2583
|
}
|
|
2385
2584
|
|
|
2386
|
-
kv
|
|
2585
|
+
llama_memory_seq_keep(kv, seq_id);
|
|
2387
2586
|
}
|
|
2388
2587
|
|
|
2588
|
+
// deprecated
|
|
2389
2589
|
void llama_kv_self_seq_add(
|
|
2390
2590
|
llama_context * ctx,
|
|
2391
2591
|
llama_seq_id seq_id,
|
|
2392
2592
|
llama_pos p0,
|
|
2393
2593
|
llama_pos p1,
|
|
2394
2594
|
llama_pos delta) {
|
|
2395
|
-
auto * kv = ctx
|
|
2595
|
+
auto * kv = llama_get_memory(ctx);
|
|
2396
2596
|
if (!kv) {
|
|
2397
2597
|
return;
|
|
2398
2598
|
}
|
|
2399
2599
|
|
|
2400
|
-
kv
|
|
2600
|
+
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
|
2401
2601
|
}
|
|
2402
2602
|
|
|
2603
|
+
// deprecated
|
|
2403
2604
|
void llama_kv_self_seq_div(
|
|
2404
2605
|
llama_context * ctx,
|
|
2405
2606
|
llama_seq_id seq_id,
|
|
2406
2607
|
llama_pos p0,
|
|
2407
2608
|
llama_pos p1,
|
|
2408
2609
|
int d) {
|
|
2409
|
-
auto * kv = ctx
|
|
2610
|
+
auto * kv = llama_get_memory(ctx);
|
|
2410
2611
|
if (!kv) {
|
|
2411
2612
|
return;
|
|
2412
2613
|
}
|
|
2413
2614
|
|
|
2414
|
-
kv
|
|
2615
|
+
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
|
2415
2616
|
}
|
|
2416
2617
|
|
|
2618
|
+
// deprecated
|
|
2417
2619
|
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
|
2418
|
-
|
|
2620
|
+
auto * kv = llama_get_memory(ctx);
|
|
2419
2621
|
if (!kv) {
|
|
2420
2622
|
return -1;
|
|
2421
2623
|
}
|
|
2422
2624
|
|
|
2423
|
-
return kv
|
|
2625
|
+
return llama_memory_seq_pos_min(kv, seq_id);
|
|
2424
2626
|
}
|
|
2425
2627
|
|
|
2628
|
+
// deprecated
|
|
2426
2629
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
2427
|
-
|
|
2630
|
+
auto * kv = llama_get_memory(ctx);
|
|
2428
2631
|
if (!kv) {
|
|
2429
2632
|
return -1;
|
|
2430
2633
|
}
|
|
2431
2634
|
|
|
2432
|
-
return kv
|
|
2635
|
+
return llama_memory_seq_pos_max(kv, seq_id);
|
|
2433
2636
|
}
|
|
2434
2637
|
|
|
2638
|
+
// deprecated
|
|
2435
2639
|
void llama_kv_self_defrag(llama_context * ctx) {
|
|
2436
|
-
auto * kv = ctx->get_kv_self();
|
|
2437
|
-
if (!kv) {
|
|
2438
|
-
return;
|
|
2439
|
-
}
|
|
2440
|
-
|
|
2441
2640
|
// force defrag
|
|
2442
|
-
|
|
2641
|
+
ctx->kv_self_defrag_sched();
|
|
2443
2642
|
}
|
|
2444
2643
|
|
|
2644
|
+
// deprecated
|
|
2445
2645
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
2446
|
-
|
|
2646
|
+
auto * kv = llama_get_memory(ctx);
|
|
2447
2647
|
if (!kv) {
|
|
2448
2648
|
return false;
|
|
2449
2649
|
}
|
|
2450
2650
|
|
|
2451
|
-
return kv
|
|
2651
|
+
return llama_memory_can_shift(kv);
|
|
2452
2652
|
}
|
|
2453
2653
|
|
|
2454
2654
|
// llama state API
|
|
@@ -2573,22 +2773,8 @@ int32_t llama_encode(
|
|
|
2573
2773
|
int32_t llama_decode(
|
|
2574
2774
|
llama_context * ctx,
|
|
2575
2775
|
llama_batch batch) {
|
|
2576
|
-
int ret = ctx->decode(batch);
|
|
2577
|
-
|
|
2578
|
-
// defrag and try again
|
|
2579
|
-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
|
2580
|
-
if (ret == 1) {
|
|
2581
|
-
llama_kv_self_defrag(ctx);
|
|
2582
|
-
ret = ctx->decode(batch);
|
|
2583
|
-
|
|
2584
|
-
if (ret == 1) {
|
|
2585
|
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
|
2586
|
-
|
|
2587
|
-
return ret;
|
|
2588
|
-
}
|
|
2589
|
-
}
|
|
2590
|
-
|
|
2591
|
-
if (ret != 0) {
|
|
2776
|
+
const int ret = ctx->decode(batch);
|
|
2777
|
+
if (ret != 0 && ret != 1) {
|
|
2592
2778
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
2593
2779
|
}
|
|
2594
2780
|
|