@fugood/llama.node 0.3.14 → 0.3.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/llama.cpp/.github/workflows/build.yml +30 -1
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/arg.cpp +20 -2
- package/src/llama.cpp/common/common.cpp +6 -3
- package/src/llama.cpp/common/speculative.cpp +4 -4
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +6 -6
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/run.cpp +91 -46
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +37 -15
- package/src/llama.cpp/examples/server/utils.hpp +3 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/tts/tts.cpp +20 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +24 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
- package/src/llama.cpp/ggml/src/ggml.c +85 -2
- package/src/llama.cpp/include/llama.h +86 -22
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +103 -16
- package/src/llama.cpp/src/llama-arch.h +18 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -110
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-model.cpp +8244 -173
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama.cpp +51 -9984
- package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
|
@@ -1,732 +1,846 @@
|
|
|
1
1
|
#include "llama-context.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
|
+
#include "llama-io.h"
|
|
4
5
|
#include "llama-mmap.h"
|
|
6
|
+
#include "llama-model.h"
|
|
7
|
+
#include "llama-kv-cache.h"
|
|
5
8
|
|
|
6
9
|
#include <cassert>
|
|
7
|
-
#include <cmath>
|
|
8
10
|
#include <cstring>
|
|
9
11
|
#include <stdexcept>
|
|
12
|
+
#include <cinttypes>
|
|
10
13
|
|
|
11
|
-
|
|
12
|
-
|
|
14
|
+
//
|
|
15
|
+
// llama_context
|
|
16
|
+
//
|
|
13
17
|
|
|
14
|
-
|
|
18
|
+
llama_context::llama_context(
|
|
19
|
+
const llama_model & model,
|
|
20
|
+
llama_context_params params) :
|
|
21
|
+
model(model) {
|
|
22
|
+
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
|
15
23
|
|
|
16
|
-
|
|
24
|
+
t_start_us = model.t_start_us;
|
|
25
|
+
t_load_us = model.t_load_us;
|
|
17
26
|
|
|
18
|
-
|
|
19
|
-
data[i] = lctx.kv_self.cells[i].delta;
|
|
20
|
-
}
|
|
21
|
-
}
|
|
27
|
+
const auto & hparams = model.hparams;
|
|
22
28
|
|
|
23
|
-
|
|
24
|
-
|
|
29
|
+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
|
30
|
+
cparams.n_threads = params.n_threads;
|
|
31
|
+
cparams.n_threads_batch = params.n_threads_batch;
|
|
32
|
+
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
|
33
|
+
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
|
34
|
+
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
|
35
|
+
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
|
36
|
+
cparams.defrag_thold = params.defrag_thold;
|
|
37
|
+
cparams.embeddings = params.embeddings;
|
|
38
|
+
cparams.offload_kqv = params.offload_kqv;
|
|
39
|
+
cparams.flash_attn = params.flash_attn;
|
|
40
|
+
cparams.no_perf = params.no_perf;
|
|
41
|
+
cparams.pooling_type = params.pooling_type;
|
|
42
|
+
cparams.warmup = false;
|
|
25
43
|
|
|
26
|
-
|
|
44
|
+
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
|
45
|
+
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
|
46
|
+
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
|
27
47
|
|
|
28
|
-
|
|
48
|
+
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
|
49
|
+
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
|
|
50
|
+
hparams.n_ctx_train;
|
|
29
51
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
}
|
|
33
|
-
}
|
|
52
|
+
cparams.cb_eval = params.cb_eval;
|
|
53
|
+
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
|
34
54
|
|
|
35
|
-
|
|
55
|
+
auto rope_scaling_type = params.rope_scaling_type;
|
|
56
|
+
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
|
57
|
+
rope_scaling_type = hparams.rope_scaling_type_train;
|
|
58
|
+
}
|
|
36
59
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
60
|
+
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
|
|
61
|
+
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
|
|
62
|
+
}
|
|
40
63
|
|
|
41
|
-
if (
|
|
42
|
-
|
|
64
|
+
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
|
|
65
|
+
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
|
43
66
|
}
|
|
44
67
|
|
|
45
|
-
|
|
68
|
+
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
|
46
69
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
70
|
+
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
71
|
+
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
72
|
+
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
73
|
+
} else {
|
|
74
|
+
cparams.pooling_type = hparams.pooling_type;
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
|
79
|
+
cparams.causal_attn = hparams.causal_attn;
|
|
52
80
|
} else {
|
|
53
|
-
|
|
81
|
+
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
|
54
82
|
}
|
|
55
|
-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
|
56
|
-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
|
57
|
-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
|
58
|
-
return relative_bucket;
|
|
59
|
-
}
|
|
60
83
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
// set input data
|
|
64
|
-
//
|
|
84
|
+
// with causal attention, the batch size is limited by the context size
|
|
85
|
+
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
|
65
86
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
87
|
+
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
|
88
|
+
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
|
89
|
+
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
|
90
|
+
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
|
|
91
|
+
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
|
92
|
+
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
|
93
|
+
cparams.n_batch = GGML_KQ_MASK_PAD;
|
|
94
|
+
}
|
|
69
95
|
|
|
70
|
-
|
|
71
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
|
96
|
+
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
|
72
97
|
|
|
73
|
-
|
|
74
|
-
}
|
|
98
|
+
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
75
99
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
100
|
+
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
|
101
|
+
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
|
102
|
+
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
|
103
|
+
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
|
104
|
+
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
|
105
|
+
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
|
106
|
+
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
|
107
|
+
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
|
108
|
+
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
|
79
109
|
|
|
80
|
-
|
|
110
|
+
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
|
111
|
+
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
|
112
|
+
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
81
113
|
}
|
|
82
114
|
|
|
83
|
-
if (
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
|
|
115
|
+
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
|
116
|
+
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
|
117
|
+
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
87
118
|
}
|
|
88
119
|
|
|
89
|
-
|
|
90
|
-
//GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
|
91
|
-
|
|
92
|
-
if (!lctx.inp_out_ids) {
|
|
93
|
-
LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
|
|
94
|
-
} else {
|
|
95
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
|
120
|
+
logits_all = params.logits_all;
|
|
96
121
|
|
|
97
|
-
|
|
98
|
-
|
|
122
|
+
if (!hparams.vocab_only) {
|
|
123
|
+
// GPU backends
|
|
124
|
+
for (auto * dev : model.devices) {
|
|
125
|
+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
|
126
|
+
if (backend == nullptr) {
|
|
127
|
+
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
|
|
128
|
+
}
|
|
129
|
+
backends.emplace_back(backend);
|
|
130
|
+
}
|
|
99
131
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
132
|
+
// add ACCEL backends (such as BLAS)
|
|
133
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
134
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
135
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
|
136
|
+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
|
137
|
+
if (backend == nullptr) {
|
|
138
|
+
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
|
|
103
139
|
}
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
140
|
+
backends.emplace_back(backend);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// add CPU backend
|
|
145
|
+
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
|
146
|
+
if (backend_cpu == nullptr) {
|
|
147
|
+
throw std::runtime_error("failed to initialize CPU backend");
|
|
148
|
+
}
|
|
149
|
+
backends.emplace_back(backend_cpu);
|
|
150
|
+
|
|
151
|
+
// create a list of the set_n_threads functions in the backends
|
|
152
|
+
for (auto & backend : backends) {
|
|
153
|
+
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
|
|
154
|
+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
|
155
|
+
if (reg) {
|
|
156
|
+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
157
|
+
if (ggml_backend_set_n_threads_fn) {
|
|
158
|
+
set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
|
|
110
159
|
}
|
|
111
|
-
// the graph needs to have been passed the correct number of outputs
|
|
112
|
-
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
|
113
|
-
} else if (lctx.n_outputs == 1) {
|
|
114
|
-
// only keep last output
|
|
115
|
-
data[0] = n_tokens - 1;
|
|
116
|
-
} else {
|
|
117
|
-
GGML_ASSERT(lctx.n_outputs == 0);
|
|
118
160
|
}
|
|
119
161
|
}
|
|
120
|
-
}
|
|
121
162
|
|
|
122
|
-
|
|
123
|
-
// (!a || b) is a logical implication (a -> b)
|
|
124
|
-
// !hparams.causal_attn -> !cparams.causal_attn
|
|
125
|
-
(hparams.causal_attn || !cparams.causal_attn) &&
|
|
126
|
-
"causal attention is not supported by this model"
|
|
127
|
-
);
|
|
163
|
+
llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
|
|
128
164
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
|
165
|
+
// graph outputs buffer
|
|
166
|
+
{
|
|
167
|
+
// resized during inference when a batch uses more outputs
|
|
168
|
+
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
|
|
169
|
+
throw std::runtime_error("failed to reserve initial output buffer");
|
|
170
|
+
}
|
|
136
171
|
|
|
172
|
+
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
|
|
173
|
+
ggml_backend_buffer_name (buf_output.get()),
|
|
174
|
+
ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
137
177
|
|
|
138
|
-
|
|
139
|
-
|
|
178
|
+
// init the memory module
|
|
179
|
+
// TODO: for now, always create a unified KV cache
|
|
180
|
+
if (!hparams.vocab_only) {
|
|
181
|
+
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
|
|
140
182
|
|
|
141
|
-
|
|
142
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
|
143
|
-
data = (float *) lctx.inp_KQ_mask->data;
|
|
144
|
-
}
|
|
183
|
+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
|
145
184
|
|
|
146
|
-
|
|
147
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
|
|
148
|
-
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
|
149
|
-
}
|
|
185
|
+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
|
|
150
186
|
|
|
151
|
-
|
|
152
|
-
// of the correct sequence for each token of the ubatch.
|
|
153
|
-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
154
|
-
for (int h = 0; h < 1; ++h) {
|
|
155
|
-
for (int s = 0; s < n_seqs; ++s) {
|
|
156
|
-
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
157
|
-
|
|
158
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
159
|
-
const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
|
|
160
|
-
|
|
161
|
-
for (int i = 0; i < n_kv; ++i) {
|
|
162
|
-
float f;
|
|
163
|
-
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
|
164
|
-
f = -INFINITY;
|
|
165
|
-
} else {
|
|
166
|
-
if (hparams.use_alibi) {
|
|
167
|
-
f = -std::abs(kv_self.cells[i].pos - pos);
|
|
168
|
-
} else {
|
|
169
|
-
f = 0.0f;
|
|
170
|
-
}
|
|
171
|
-
}
|
|
187
|
+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
|
172
188
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
189
|
+
uint32_t kv_size = cparams.n_ctx;
|
|
190
|
+
ggml_type type_k = params.type_k;
|
|
191
|
+
ggml_type type_v = params.type_v;
|
|
176
192
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
}
|
|
185
|
-
}
|
|
186
|
-
}
|
|
193
|
+
if (llama_model_is_recurrent(&model)) {
|
|
194
|
+
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
|
195
|
+
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
|
196
|
+
// it's probably best to keep as much precision as possible for the states
|
|
197
|
+
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
|
198
|
+
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
|
199
|
+
}
|
|
187
200
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
for (int j = 0; j < n_kv; ++j) {
|
|
191
|
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
192
|
-
}
|
|
193
|
-
}
|
|
194
|
-
}
|
|
201
|
+
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
|
202
|
+
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
|
195
203
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
200
|
-
}
|
|
201
|
-
}
|
|
202
|
-
}
|
|
203
|
-
}
|
|
204
|
-
} else {
|
|
205
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
|
206
|
-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
207
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
|
208
|
-
// when using kv cache, the mask needs to match the kv cache size
|
|
209
|
-
const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
|
|
210
|
-
|
|
211
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
|
212
|
-
|
|
213
|
-
float * data = (float *) lctx.inp_KQ_mask->data;
|
|
214
|
-
|
|
215
|
-
for (int h = 0; h < 1; ++h) {
|
|
216
|
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
|
217
|
-
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
|
|
218
|
-
|
|
219
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
220
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
|
221
|
-
|
|
222
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
|
223
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
|
224
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
|
225
|
-
float f = -INFINITY;
|
|
226
|
-
|
|
227
|
-
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
|
|
228
|
-
if (ubatch.seq_id[s0][s] == seq_id) {
|
|
229
|
-
if (hparams.use_alibi) {
|
|
230
|
-
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
|
|
231
|
-
} else {
|
|
232
|
-
f = 0.0f;
|
|
233
|
-
}
|
|
234
|
-
break;
|
|
235
|
-
}
|
|
236
|
-
}
|
|
237
|
-
|
|
238
|
-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
|
239
|
-
}
|
|
240
|
-
}
|
|
204
|
+
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
|
205
|
+
throw std::runtime_error("failed to initialize self-attention cache");
|
|
206
|
+
}
|
|
241
207
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
208
|
+
{
|
|
209
|
+
const size_t memory_size_k = kv_self->size_k_bytes();
|
|
210
|
+
const size_t memory_size_v = kv_self->size_v_bytes();
|
|
211
|
+
|
|
212
|
+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
|
213
|
+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
|
214
|
+
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
|
215
|
+
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
|
248
216
|
}
|
|
249
217
|
}
|
|
250
218
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
|
219
|
+
// init backends
|
|
220
|
+
if (!hparams.vocab_only) {
|
|
221
|
+
LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
|
|
255
222
|
|
|
256
|
-
|
|
257
|
-
|
|
223
|
+
backend_buft.clear();
|
|
224
|
+
backend_ptrs.clear();
|
|
258
225
|
|
|
259
|
-
|
|
260
|
-
|
|
226
|
+
for (auto & backend : backends) {
|
|
227
|
+
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
|
228
|
+
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
261
229
|
|
|
262
|
-
|
|
230
|
+
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
|
|
231
|
+
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
|
232
|
+
auto * dev = model.devices[0];
|
|
233
|
+
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
|
234
|
+
if (host_buft) {
|
|
235
|
+
buft = host_buft;
|
|
236
|
+
}
|
|
237
|
+
}
|
|
263
238
|
|
|
264
|
-
|
|
265
|
-
|
|
239
|
+
backend_buft.push_back(buft);
|
|
240
|
+
backend_ptrs.push_back(backend.get());
|
|
241
|
+
}
|
|
266
242
|
|
|
267
|
-
|
|
268
|
-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
|
243
|
+
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
|
269
244
|
|
|
270
|
-
|
|
271
|
-
}
|
|
245
|
+
const size_t max_nodes = this->graph_max_nodes();
|
|
272
246
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
247
|
+
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
248
|
+
|
|
249
|
+
// buffer used to store the computation graph and the tensor meta data
|
|
250
|
+
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
|
251
|
+
|
|
252
|
+
// TODO: move these checks to ggml_backend_sched
|
|
253
|
+
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
254
|
+
bool pipeline_parallel =
|
|
255
|
+
model.n_devices() > 1 &&
|
|
256
|
+
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
|
|
257
|
+
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
|
258
|
+
cparams.offload_kqv;
|
|
259
|
+
|
|
260
|
+
// pipeline parallelism requires support for async compute and events in all devices
|
|
261
|
+
if (pipeline_parallel) {
|
|
262
|
+
for (auto & backend : backends) {
|
|
263
|
+
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
264
|
+
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
|
265
|
+
// ignore CPU backend
|
|
266
|
+
continue;
|
|
267
|
+
}
|
|
268
|
+
auto * dev = ggml_backend_get_device(backend.get());
|
|
269
|
+
ggml_backend_dev_props props;
|
|
270
|
+
ggml_backend_dev_get_props(dev, &props);
|
|
271
|
+
if (!props.caps.async || !props.caps.events) {
|
|
272
|
+
// device does not support async compute or events
|
|
273
|
+
pipeline_parallel = false;
|
|
274
|
+
break;
|
|
275
|
+
}
|
|
278
276
|
}
|
|
279
277
|
}
|
|
280
278
|
|
|
281
|
-
|
|
282
|
-
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
279
|
+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
|
283
280
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
}
|
|
281
|
+
if (pipeline_parallel) {
|
|
282
|
+
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
|
287
283
|
}
|
|
288
284
|
}
|
|
289
285
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
const
|
|
294
|
-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
295
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
|
286
|
+
// reserve worst-case graph
|
|
287
|
+
if (!hparams.vocab_only) {
|
|
288
|
+
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
289
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
296
290
|
|
|
297
|
-
|
|
298
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
|
291
|
+
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
299
292
|
|
|
300
|
-
|
|
301
|
-
|
|
293
|
+
// restore later
|
|
294
|
+
// TODO: something cleaner
|
|
295
|
+
const auto n_outputs_save = n_outputs;
|
|
302
296
|
|
|
303
|
-
|
|
304
|
-
|
|
297
|
+
// max number of outputs
|
|
298
|
+
n_outputs = n_tokens;
|
|
305
299
|
|
|
306
|
-
|
|
307
|
-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
|
300
|
+
LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
308
301
|
|
|
309
|
-
|
|
310
|
-
|
|
302
|
+
int n_splits_pp = -1;
|
|
303
|
+
int n_nodes_pp = -1;
|
|
311
304
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
}
|
|
315
|
-
}
|
|
316
|
-
}
|
|
317
|
-
}
|
|
305
|
+
int n_splits_tg = -1;
|
|
306
|
+
int n_nodes_tg = -1;
|
|
318
307
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
322
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
|
308
|
+
// simulate full KV cache
|
|
309
|
+
kv_self->n = kv_self->size;
|
|
323
310
|
|
|
324
|
-
|
|
325
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
|
311
|
+
cross.v_embd.clear();
|
|
326
312
|
|
|
327
|
-
|
|
328
|
-
|
|
313
|
+
// reserve pp graph first so that buffers are only allocated once
|
|
314
|
+
{
|
|
315
|
+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
316
|
+
auto * gf = graph_init();
|
|
317
|
+
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
318
|
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
319
|
+
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
320
|
+
}
|
|
329
321
|
|
|
330
|
-
|
|
331
|
-
|
|
322
|
+
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
|
323
|
+
n_nodes_pp = ggml_graph_n_nodes(gf);
|
|
324
|
+
}
|
|
332
325
|
|
|
333
|
-
|
|
334
|
-
|
|
326
|
+
// reserve with tg graph to get the number of splits and nodes
|
|
327
|
+
{
|
|
328
|
+
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
329
|
+
auto * gf = graph_init();
|
|
330
|
+
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
|
331
|
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
332
|
+
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
333
|
+
}
|
|
334
|
+
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
|
335
|
+
n_nodes_tg = ggml_graph_n_nodes(gf);
|
|
336
|
+
}
|
|
335
337
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
+
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
339
|
+
{
|
|
340
|
+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
341
|
+
auto * gf = graph_init();
|
|
342
|
+
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
343
|
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
344
|
+
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
345
|
+
}
|
|
346
|
+
}
|
|
338
347
|
|
|
339
|
-
|
|
340
|
-
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
|
|
348
|
+
n_outputs = n_outputs_save;
|
|
341
349
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
350
|
+
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
351
|
+
ggml_backend_t backend = backend_ptrs[i];
|
|
352
|
+
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
353
|
+
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
354
|
+
if (size > 1) {
|
|
355
|
+
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
|
356
|
+
ggml_backend_buft_name(buft),
|
|
357
|
+
size / 1024.0 / 1024.0);
|
|
346
358
|
}
|
|
347
359
|
}
|
|
348
360
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
361
|
+
if (n_nodes_pp == n_nodes_tg) {
|
|
362
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
|
363
|
+
} else {
|
|
364
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
|
353
365
|
}
|
|
354
|
-
}
|
|
355
366
|
|
|
356
|
-
|
|
357
|
-
|
|
367
|
+
if (n_splits_pp == n_splits_tg) {
|
|
368
|
+
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
|
369
|
+
} else {
|
|
370
|
+
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
}
|
|
358
374
|
|
|
359
|
-
|
|
360
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
|
361
|
-
float * data = (float *) lctx.inp_s_mask->data;
|
|
375
|
+
llama_context::~llama_context() = default;
|
|
362
376
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
const uint32_t cell_id = i + kv_self.head;
|
|
366
|
-
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
|
377
|
+
void llama_context::synchronize() {
|
|
378
|
+
ggml_backend_sched_synchronize(sched.get());
|
|
367
379
|
|
|
368
|
-
|
|
380
|
+
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
|
381
|
+
// the stats will be added to the prompt evaluation stats
|
|
382
|
+
// this should only happen when using batch size 1 to evaluate a batch
|
|
369
383
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
}
|
|
384
|
+
// add the evaluation to the stats
|
|
385
|
+
if (n_queued_tokens == 1) {
|
|
386
|
+
if (!cparams.no_perf) {
|
|
387
|
+
t_eval_us += ggml_time_us() - t_compute_start_us;
|
|
375
388
|
}
|
|
389
|
+
n_eval++;
|
|
390
|
+
} else if (n_queued_tokens > 1) {
|
|
391
|
+
if (!cparams.no_perf) {
|
|
392
|
+
t_p_eval_us += ggml_time_us() - t_compute_start_us;
|
|
393
|
+
}
|
|
394
|
+
n_p_eval += n_queued_tokens;
|
|
395
|
+
}
|
|
376
396
|
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
397
|
+
// get a more accurate load time, upon first eval
|
|
398
|
+
if (n_queued_tokens > 0 && !has_evaluated_once) {
|
|
399
|
+
t_load_us = ggml_time_us() - t_start_us;
|
|
400
|
+
has_evaluated_once = true;
|
|
401
|
+
}
|
|
380
402
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
|
403
|
+
n_queued_tokens = 0;
|
|
404
|
+
t_compute_start_us = 0;
|
|
405
|
+
}
|
|
385
406
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
}
|
|
407
|
+
const llama_model & llama_context::get_model() const {
|
|
408
|
+
return model;
|
|
409
|
+
}
|
|
390
410
|
|
|
391
|
-
|
|
411
|
+
uint32_t llama_context::n_ctx() const {
|
|
412
|
+
return cparams.n_ctx;
|
|
413
|
+
}
|
|
392
414
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
}
|
|
397
|
-
}
|
|
398
|
-
}
|
|
399
|
-
}
|
|
415
|
+
uint32_t llama_context::n_ctx_per_seq() const {
|
|
416
|
+
return cparams.n_ctx / cparams.n_seq_max;
|
|
417
|
+
}
|
|
400
418
|
|
|
401
|
-
|
|
402
|
-
|
|
419
|
+
uint32_t llama_context::n_batch() const {
|
|
420
|
+
return cparams.n_batch;
|
|
421
|
+
}
|
|
403
422
|
|
|
404
|
-
|
|
405
|
-
|
|
423
|
+
uint32_t llama_context::n_ubatch() const {
|
|
424
|
+
return cparams.n_ubatch;
|
|
425
|
+
}
|
|
406
426
|
|
|
407
|
-
|
|
427
|
+
uint32_t llama_context::n_seq_max() const {
|
|
428
|
+
return cparams.n_seq_max;
|
|
429
|
+
}
|
|
408
430
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
431
|
+
uint32_t llama_context::n_threads() const {
|
|
432
|
+
return cparams.n_threads;
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
uint32_t llama_context::n_threads_batch() const {
|
|
436
|
+
return cparams.n_threads_batch;
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
llama_kv_cache * llama_context::get_kv_self() {
|
|
440
|
+
return kv_self.get();
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
const llama_kv_cache * llama_context::get_kv_self() const {
|
|
444
|
+
return kv_self.get();
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
ggml_tensor * llama_context::build_rope_shift(
|
|
448
|
+
ggml_context * ctx0,
|
|
449
|
+
ggml_tensor * cur,
|
|
450
|
+
ggml_tensor * shift,
|
|
451
|
+
ggml_tensor * factors,
|
|
452
|
+
float freq_base,
|
|
453
|
+
float freq_scale,
|
|
454
|
+
ggml_backend_buffer * bbuf) const {
|
|
455
|
+
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
|
456
|
+
|
|
457
|
+
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
|
458
|
+
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
|
459
|
+
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
|
460
|
+
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
|
461
|
+
|
|
462
|
+
const auto & hparams = model.hparams;
|
|
463
|
+
|
|
464
|
+
const auto & n_rot = hparams.n_rot;
|
|
465
|
+
const auto & rope_type = hparams.rope_type;
|
|
466
|
+
|
|
467
|
+
ggml_tensor * tmp;
|
|
468
|
+
|
|
469
|
+
if (ggml_is_quantized(cur->type)) {
|
|
470
|
+
// dequantize to f32 -> RoPE -> quantize back
|
|
471
|
+
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
|
472
|
+
|
|
473
|
+
if (bbuf) {
|
|
474
|
+
for (const auto & backend : backends) {
|
|
475
|
+
// Figure out which backend KV cache belongs to
|
|
476
|
+
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
|
|
477
|
+
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
|
|
478
|
+
break;
|
|
424
479
|
}
|
|
425
480
|
}
|
|
426
481
|
}
|
|
427
|
-
}
|
|
428
482
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
483
|
+
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
|
484
|
+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
485
|
+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
|
432
486
|
|
|
433
|
-
|
|
487
|
+
tmp = ggml_cpy(ctx0, tmp, cur);
|
|
488
|
+
} else {
|
|
489
|
+
// we rotate only the first n_rot dimensions
|
|
490
|
+
tmp = ggml_rope_ext_inplace(ctx0, cur,
|
|
491
|
+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
492
|
+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
|
434
493
|
}
|
|
435
494
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
|
495
|
+
return tmp;
|
|
496
|
+
}
|
|
439
497
|
|
|
440
|
-
|
|
441
|
-
|
|
498
|
+
class llm_graph_input_k_shift : public llm_graph_input_i {
|
|
499
|
+
public:
|
|
500
|
+
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
|
501
|
+
virtual ~llm_graph_input_k_shift() = default;
|
|
442
502
|
|
|
443
|
-
|
|
503
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
444
504
|
|
|
445
|
-
|
|
446
|
-
for (int j = 0; j < n_tokens; ++j) {
|
|
447
|
-
for (int i = 0; i < n_output_enc; ++i) {
|
|
448
|
-
float f = -INFINITY;
|
|
449
|
-
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
|
|
450
|
-
const llama_seq_id seq_id = ubatch.seq_id[j][s];
|
|
451
|
-
if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
|
|
452
|
-
f = 0.0f;
|
|
453
|
-
}
|
|
454
|
-
}
|
|
455
|
-
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
|
|
456
|
-
}
|
|
457
|
-
}
|
|
505
|
+
ggml_tensor * k_shift; // I32 [kv_size]
|
|
458
506
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
507
|
+
const llama_kv_cache_unified * kv_self;
|
|
508
|
+
};
|
|
509
|
+
|
|
510
|
+
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|
511
|
+
GGML_UNUSED(ubatch);
|
|
512
|
+
|
|
513
|
+
if (k_shift) {
|
|
514
|
+
assert(ggml_backend_buffer_is_host(k_shift->buffer));
|
|
515
|
+
|
|
516
|
+
int32_t * data = (int32_t *) k_shift->data;
|
|
517
|
+
|
|
518
|
+
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
|
519
|
+
data[i] = kv_self->cells[i].delta;
|
|
464
520
|
}
|
|
465
521
|
}
|
|
466
522
|
}
|
|
467
523
|
|
|
468
|
-
|
|
524
|
+
llm_graph_result_ptr llama_context::build_kv_self_shift(
|
|
525
|
+
ggml_context * ctx0,
|
|
526
|
+
ggml_cgraph * gf) const {
|
|
527
|
+
auto res = std::make_unique<llm_graph_result>();
|
|
469
528
|
|
|
470
|
-
|
|
471
|
-
const auto & cparams = lctx.cparams;
|
|
472
|
-
const auto & hparams = lctx.model.hparams;
|
|
473
|
-
const auto & vocab = lctx.model.vocab;
|
|
529
|
+
const auto & hparams = model.hparams;
|
|
474
530
|
|
|
475
|
-
const
|
|
531
|
+
const auto & n_layer = hparams.n_layer;
|
|
476
532
|
|
|
477
|
-
const auto
|
|
478
|
-
|
|
479
|
-
const auto n_embd = hparams.n_embd;
|
|
533
|
+
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
534
|
+
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
480
535
|
|
|
481
|
-
//
|
|
482
|
-
const bool has_logits = !cparams.embeddings;
|
|
483
|
-
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
536
|
+
//GGML_ASSERT(kv_self->size == n_ctx);
|
|
484
537
|
|
|
485
|
-
|
|
486
|
-
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
|
538
|
+
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
|
|
487
539
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
lctx.output_ids.resize(n_batch);
|
|
491
|
-
}
|
|
540
|
+
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
|
|
541
|
+
ggml_set_input(inp->k_shift);
|
|
492
542
|
|
|
493
|
-
|
|
494
|
-
|
|
543
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
544
|
+
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
545
|
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
495
546
|
|
|
496
|
-
|
|
497
|
-
// TODO: also consider shrinking the buffer
|
|
498
|
-
if (!lctx.buf_output || prev_size < new_size) {
|
|
499
|
-
if (lctx.buf_output) {
|
|
500
|
-
#ifndef NDEBUG
|
|
501
|
-
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
|
502
|
-
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
|
503
|
-
#endif
|
|
504
|
-
lctx.buf_output = nullptr;
|
|
505
|
-
lctx.logits = nullptr;
|
|
506
|
-
lctx.embd = nullptr;
|
|
507
|
-
}
|
|
547
|
+
const bool is_swa = hparams.is_swa(il);
|
|
508
548
|
|
|
509
|
-
|
|
510
|
-
//
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
549
|
+
// note: the swa rope params could become part of the cparams in the future
|
|
550
|
+
// if we decide to make them configurable, like the non-sliding ones
|
|
551
|
+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
|
552
|
+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
|
553
|
+
|
|
554
|
+
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
|
555
|
+
|
|
556
|
+
ggml_tensor * k =
|
|
557
|
+
ggml_view_3d(ctx0, kv_self->k_l[il],
|
|
558
|
+
n_embd_head_k, n_head_kv, kv_self->size,
|
|
559
|
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
|
560
|
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
561
|
+
0);
|
|
562
|
+
|
|
563
|
+
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
|
564
|
+
|
|
565
|
+
ggml_build_forward_expand(gf, cur);
|
|
521
566
|
}
|
|
522
567
|
|
|
523
|
-
|
|
568
|
+
res->add_input(std::move(inp));
|
|
524
569
|
|
|
525
|
-
|
|
526
|
-
|
|
570
|
+
return res;
|
|
571
|
+
}
|
|
527
572
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
573
|
+
llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
|
574
|
+
ggml_context * ctx0,
|
|
575
|
+
ggml_cgraph * gf) const {
|
|
576
|
+
auto res = std::make_unique<llm_graph_result>();
|
|
531
577
|
|
|
532
|
-
|
|
533
|
-
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
|
|
578
|
+
const auto & hparams = model.hparams;
|
|
534
579
|
|
|
535
|
-
|
|
580
|
+
const auto & ids = kv_self->defrag_info.ids;
|
|
536
581
|
|
|
537
|
-
|
|
582
|
+
#if 0
|
|
583
|
+
// CPU defrag
|
|
584
|
+
//
|
|
585
|
+
// TODO: optimizations are possible:
|
|
586
|
+
// - multiple threads
|
|
587
|
+
// - avoid copying to the host memory when already there
|
|
588
|
+
//
|
|
589
|
+
// likely not worth the effort, as we have ggml_graph based defrag
|
|
590
|
+
//
|
|
538
591
|
|
|
539
|
-
|
|
540
|
-
|
|
592
|
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
593
|
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
541
594
|
|
|
542
|
-
|
|
543
|
-
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
|
|
544
|
-
if (!out_ids.empty()) {
|
|
545
|
-
const uint32_t n_vocab = ctx.model.vocab.n_tokens();
|
|
546
|
-
const uint32_t n_embd = ctx.model.hparams.n_embd;
|
|
595
|
+
const uint32_t kv_size = size;
|
|
547
596
|
|
|
548
|
-
|
|
549
|
-
|
|
597
|
+
std::vector<uint8_t> buf_k;
|
|
598
|
+
std::vector<uint8_t> buf_v;
|
|
550
599
|
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
600
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
601
|
+
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
|
602
|
+
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
|
603
|
+
|
|
604
|
+
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
|
605
|
+
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
|
606
|
+
|
|
607
|
+
buf_k.resize(k_size);
|
|
608
|
+
buf_v.resize(v_size);
|
|
609
|
+
|
|
610
|
+
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
|
611
|
+
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
|
612
|
+
|
|
613
|
+
// batch move [i, i+nm) to [id, id+nm)
|
|
614
|
+
// note: cells can move only to a lower index
|
|
615
|
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
|
616
|
+
const uint32_t id = ids[i];
|
|
617
|
+
|
|
618
|
+
if (i == id || id == n_kv) {
|
|
619
|
+
continue;
|
|
559
620
|
}
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
}
|
|
621
|
+
|
|
622
|
+
uint32_t nm = 1;
|
|
623
|
+
|
|
624
|
+
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
|
625
|
+
nm++;
|
|
566
626
|
}
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
627
|
+
|
|
628
|
+
// move keys
|
|
629
|
+
{
|
|
630
|
+
const int64_t os = i*k_size_row;
|
|
631
|
+
const int64_t od = id*k_size_row;
|
|
632
|
+
|
|
633
|
+
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
|
634
|
+
}
|
|
635
|
+
|
|
636
|
+
// move values (note: they are transposed)
|
|
637
|
+
{
|
|
638
|
+
const int64_t os = i;
|
|
639
|
+
const int64_t od = id;
|
|
640
|
+
|
|
641
|
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
642
|
+
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
|
570
643
|
}
|
|
571
644
|
}
|
|
645
|
+
|
|
646
|
+
i += nm - 1;
|
|
572
647
|
}
|
|
573
|
-
std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
|
|
574
|
-
for (int32_t i = 0; i < n_outputs; ++i) {
|
|
575
|
-
ctx.output_ids[out_ids[i]] = i;
|
|
576
|
-
}
|
|
577
|
-
out_ids.clear();
|
|
578
|
-
}
|
|
579
|
-
}
|
|
580
648
|
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
649
|
+
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
|
650
|
+
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
|
651
|
+
}
|
|
652
|
+
#else
|
|
653
|
+
for (uint32_t i = 0; i < ids.size(); ++i) {
|
|
654
|
+
const uint32_t id = ids[i];
|
|
584
655
|
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
}
|
|
656
|
+
if (i == id || id == ids.size()) {
|
|
657
|
+
continue;
|
|
658
|
+
}
|
|
588
659
|
|
|
589
|
-
uint32_t
|
|
590
|
-
return ctx->cparams.n_ctx;
|
|
591
|
-
}
|
|
660
|
+
uint32_t nm = 1;
|
|
592
661
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
}
|
|
662
|
+
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
|
663
|
+
nm++;
|
|
664
|
+
}
|
|
596
665
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
666
|
+
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
|
|
667
|
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
668
|
+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
669
|
+
|
|
670
|
+
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
|
|
671
|
+
n_embd_k_gqa, nm,
|
|
672
|
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
673
|
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
|
|
674
|
+
|
|
675
|
+
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
|
|
676
|
+
n_embd_k_gqa, nm,
|
|
677
|
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
678
|
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
|
|
679
|
+
|
|
680
|
+
ggml_tensor * view_v_src;
|
|
681
|
+
ggml_tensor * view_v_dst;
|
|
682
|
+
|
|
683
|
+
if (cparams.flash_attn) {
|
|
684
|
+
// NOTE: the V cache is not transposed when using flash attention
|
|
685
|
+
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
686
|
+
n_embd_v_gqa, nm,
|
|
687
|
+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
|
688
|
+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
|
|
689
|
+
|
|
690
|
+
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
691
|
+
n_embd_v_gqa, nm,
|
|
692
|
+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
|
693
|
+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
|
|
694
|
+
} else {
|
|
695
|
+
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
696
|
+
nm, n_embd_v_gqa,
|
|
697
|
+
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
|
698
|
+
ggml_row_size(kv_self->v_l[il]->type, i));
|
|
699
|
+
|
|
700
|
+
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
701
|
+
nm, n_embd_v_gqa,
|
|
702
|
+
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
|
703
|
+
ggml_row_size(kv_self->v_l[il]->type, id));
|
|
704
|
+
}
|
|
600
705
|
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
}
|
|
706
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
|
707
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
|
708
|
+
}
|
|
604
709
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
}
|
|
710
|
+
i += nm - 1;
|
|
711
|
+
}
|
|
608
712
|
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
}
|
|
713
|
+
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
|
714
|
+
#endif
|
|
612
715
|
|
|
613
|
-
|
|
614
|
-
struct llama_context * ctx,
|
|
615
|
-
ggml_threadpool_t threadpool,
|
|
616
|
-
ggml_threadpool_t threadpool_batch) {
|
|
617
|
-
ctx->threadpool = threadpool;
|
|
618
|
-
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
|
716
|
+
return res;
|
|
619
717
|
}
|
|
620
718
|
|
|
621
|
-
void
|
|
622
|
-
|
|
623
|
-
ctx->threadpool_batch = nullptr;
|
|
624
|
-
}
|
|
719
|
+
void llama_context::kv_self_update() {
|
|
720
|
+
auto & kv = kv_self;
|
|
625
721
|
|
|
626
|
-
|
|
627
|
-
ctx->cparams.n_threads = n_threads;
|
|
628
|
-
ctx->cparams.n_threads_batch = n_threads_batch;
|
|
629
|
-
}
|
|
722
|
+
bool need_reserve = false;
|
|
630
723
|
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
724
|
+
if (kv->has_shift) {
|
|
725
|
+
if (!kv->get_can_shift()) {
|
|
726
|
+
GGML_ABORT("The current context does not support K-shift");
|
|
727
|
+
}
|
|
634
728
|
|
|
635
|
-
|
|
636
|
-
return ctx->cparams.n_threads_batch;
|
|
637
|
-
}
|
|
729
|
+
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
|
638
730
|
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
731
|
+
// apply K-shift if needed
|
|
732
|
+
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
|
733
|
+
ggml_backend_sched_reset(sched.get());
|
|
642
734
|
|
|
643
|
-
|
|
644
|
-
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
|
645
|
-
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
|
646
|
-
if (set_abort_callback_fn) {
|
|
647
|
-
set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
|
|
648
|
-
}
|
|
649
|
-
}
|
|
650
|
-
}
|
|
735
|
+
auto * gf = graph_init();
|
|
651
736
|
|
|
652
|
-
|
|
653
|
-
ctx->cparams.embeddings = embeddings;
|
|
654
|
-
}
|
|
737
|
+
auto res = build_kv_self_shift(ctx_compute.get(), gf);
|
|
655
738
|
|
|
656
|
-
|
|
657
|
-
ctx->cparams.causal_attn = causal_attn;
|
|
658
|
-
}
|
|
739
|
+
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
659
740
|
|
|
660
|
-
|
|
661
|
-
ggml_backend_sched_synchronize(ctx->sched.get());
|
|
741
|
+
res->set_inputs(nullptr);
|
|
662
742
|
|
|
663
|
-
|
|
664
|
-
// the stats will be added to the prompt evaluation stats
|
|
665
|
-
// this should only happen when using batch size 1 to evaluate a batch
|
|
743
|
+
graph_compute(gf, false);
|
|
666
744
|
|
|
667
|
-
|
|
668
|
-
if (ctx->n_queued_tokens == 1) {
|
|
669
|
-
if (!ctx->cparams.no_perf) {
|
|
670
|
-
ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
|
|
745
|
+
need_reserve = true;
|
|
671
746
|
}
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
747
|
+
|
|
748
|
+
{
|
|
749
|
+
kv->has_shift = false;
|
|
750
|
+
|
|
751
|
+
for (uint32_t i = 0; i < kv->size; ++i) {
|
|
752
|
+
kv->cells[i].delta = 0;
|
|
753
|
+
}
|
|
676
754
|
}
|
|
677
|
-
ctx->n_p_eval += ctx->n_queued_tokens;
|
|
678
755
|
}
|
|
679
756
|
|
|
680
|
-
//
|
|
681
|
-
if (
|
|
682
|
-
|
|
683
|
-
ctx->has_evaluated_once = true;
|
|
684
|
-
}
|
|
757
|
+
// defragment the KV cache if needed
|
|
758
|
+
if (kv->do_defrag) {
|
|
759
|
+
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
|
685
760
|
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
}
|
|
761
|
+
if (kv->defrag_prepare(graph_max_nodes())) {
|
|
762
|
+
ggml_backend_sched_reset(sched.get());
|
|
689
763
|
|
|
690
|
-
|
|
691
|
-
llama_synchronize(ctx);
|
|
764
|
+
auto * gf = graph_init();
|
|
692
765
|
|
|
693
|
-
|
|
694
|
-
// TODO: maybe deprecate this
|
|
695
|
-
llama_output_reorder(*ctx);
|
|
766
|
+
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
|
|
696
767
|
|
|
697
|
-
|
|
698
|
-
}
|
|
768
|
+
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
699
769
|
|
|
700
|
-
|
|
701
|
-
int32_t j = -1;
|
|
770
|
+
res->set_inputs(nullptr);
|
|
702
771
|
|
|
703
|
-
|
|
772
|
+
graph_compute(gf, false);
|
|
704
773
|
|
|
705
|
-
|
|
706
|
-
if (ctx->logits == nullptr) {
|
|
707
|
-
throw std::runtime_error("no logits");
|
|
774
|
+
need_reserve = true;
|
|
708
775
|
}
|
|
709
776
|
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
777
|
+
kv->do_defrag = false;
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
// reserve a worst case graph if needed
|
|
781
|
+
if (need_reserve) {
|
|
782
|
+
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
|
783
|
+
|
|
784
|
+
// build worst-case graph
|
|
785
|
+
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
786
|
+
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
787
|
+
|
|
788
|
+
// simulate full KV cache
|
|
789
|
+
kv_self->n = kv_self->size;
|
|
790
|
+
|
|
791
|
+
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
792
|
+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
793
|
+
|
|
794
|
+
auto * gf = graph_init();
|
|
795
|
+
graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
|
796
|
+
|
|
797
|
+
// initialize scheduler with the worst-case graph
|
|
798
|
+
ggml_backend_sched_reset(sched.get());
|
|
799
|
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
800
|
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
801
|
+
}
|
|
802
|
+
}
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
enum llama_pooling_type llama_context::pooling_type() const {
|
|
806
|
+
return cparams.pooling_type;
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
float * llama_context::get_logits() {
|
|
810
|
+
// reorder logits for backward compatibility
|
|
811
|
+
output_reorder();
|
|
812
|
+
|
|
813
|
+
return logits;
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
float * llama_context::get_logits_ith(int32_t i) {
|
|
817
|
+
int32_t j = -1;
|
|
818
|
+
|
|
819
|
+
try {
|
|
820
|
+
if (logits == nullptr) {
|
|
821
|
+
throw std::runtime_error("no logits");
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
if (i < 0) {
|
|
825
|
+
j = n_outputs + i;
|
|
826
|
+
if (j < 0) {
|
|
827
|
+
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
714
828
|
}
|
|
715
|
-
} else if ((size_t) i >=
|
|
716
|
-
throw std::runtime_error(format("out of range [0, %zu)",
|
|
829
|
+
} else if ((size_t) i >= output_ids.size()) {
|
|
830
|
+
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
717
831
|
} else {
|
|
718
|
-
j =
|
|
832
|
+
j = output_ids[i];
|
|
719
833
|
}
|
|
720
834
|
|
|
721
835
|
if (j < 0) {
|
|
722
836
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
723
837
|
}
|
|
724
|
-
if (j >=
|
|
838
|
+
if (j >= n_outputs) {
|
|
725
839
|
// This should not happen
|
|
726
|
-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j,
|
|
840
|
+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
|
|
727
841
|
}
|
|
728
842
|
|
|
729
|
-
return
|
|
843
|
+
return logits + j*model.vocab.n_tokens();
|
|
730
844
|
} catch (const std::exception & err) {
|
|
731
845
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
|
732
846
|
#ifndef NDEBUG
|
|
@@ -737,46 +851,41 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
|
|
737
851
|
}
|
|
738
852
|
}
|
|
739
853
|
|
|
740
|
-
float *
|
|
741
|
-
llama_synchronize(ctx);
|
|
742
|
-
|
|
854
|
+
float * llama_context::get_embeddings() {
|
|
743
855
|
// reorder embeddings for backward compatibility
|
|
744
|
-
|
|
745
|
-
llama_output_reorder(*ctx);
|
|
856
|
+
output_reorder();
|
|
746
857
|
|
|
747
|
-
return
|
|
858
|
+
return embd;
|
|
748
859
|
}
|
|
749
860
|
|
|
750
|
-
float *
|
|
861
|
+
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
751
862
|
int32_t j = -1;
|
|
752
863
|
|
|
753
|
-
llama_synchronize(ctx);
|
|
754
|
-
|
|
755
864
|
try {
|
|
756
|
-
if (
|
|
865
|
+
if (embd == nullptr) {
|
|
757
866
|
throw std::runtime_error("no embeddings");
|
|
758
867
|
}
|
|
759
868
|
|
|
760
869
|
if (i < 0) {
|
|
761
|
-
j =
|
|
870
|
+
j = n_outputs + i;
|
|
762
871
|
if (j < 0) {
|
|
763
|
-
throw std::runtime_error(format("negative index out of range [0, %d)",
|
|
872
|
+
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
764
873
|
}
|
|
765
|
-
} else if ((size_t) i >=
|
|
766
|
-
throw std::runtime_error(format("out of range [0, %zu)",
|
|
874
|
+
} else if ((size_t) i >= output_ids.size()) {
|
|
875
|
+
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
767
876
|
} else {
|
|
768
|
-
j =
|
|
877
|
+
j = output_ids[i];
|
|
769
878
|
}
|
|
770
879
|
|
|
771
880
|
if (j < 0) {
|
|
772
881
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
773
882
|
}
|
|
774
|
-
if (j >=
|
|
883
|
+
if (j >= n_outputs) {
|
|
775
884
|
// This should not happen
|
|
776
|
-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j,
|
|
885
|
+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
|
|
777
886
|
}
|
|
778
887
|
|
|
779
|
-
return
|
|
888
|
+
return embd + j*model.hparams.n_embd;
|
|
780
889
|
} catch (const std::exception & err) {
|
|
781
890
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
782
891
|
#ifndef NDEBUG
|
|
@@ -787,696 +896,943 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
|
|
787
896
|
}
|
|
788
897
|
}
|
|
789
898
|
|
|
790
|
-
float *
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
auto it = ctx->embd_seq.find(seq_id);
|
|
794
|
-
if (it == ctx->embd_seq.end()) {
|
|
899
|
+
float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|
900
|
+
auto it = embd_seq.find(seq_id);
|
|
901
|
+
if (it == embd_seq.end()) {
|
|
795
902
|
return nullptr;
|
|
796
903
|
}
|
|
797
904
|
|
|
798
905
|
return it->second.data();
|
|
799
906
|
}
|
|
800
907
|
|
|
801
|
-
|
|
908
|
+
void llama_context::attach_threadpool(
|
|
909
|
+
ggml_threadpool_t threadpool,
|
|
910
|
+
ggml_threadpool_t threadpool_batch) {
|
|
911
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
802
912
|
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
return llama_state_get_size(ctx);
|
|
913
|
+
this->threadpool = threadpool;
|
|
914
|
+
this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
|
806
915
|
}
|
|
807
916
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
return llama_state_get_data(ctx, dst, -1);
|
|
811
|
-
}
|
|
917
|
+
void llama_context::detach_threadpool() {
|
|
918
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
812
919
|
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
return llama_state_set_data(ctx, src, -1);
|
|
920
|
+
this->threadpool = nullptr;
|
|
921
|
+
this->threadpool_batch = nullptr;
|
|
816
922
|
}
|
|
817
923
|
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
|
821
|
-
}
|
|
924
|
+
void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
|
|
925
|
+
LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
|
|
822
926
|
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
|
927
|
+
cparams.n_threads = n_threads;
|
|
928
|
+
cparams.n_threads_batch = n_threads_batch;
|
|
826
929
|
}
|
|
827
930
|
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
virtual void write(const void * src, size_t size) = 0;
|
|
831
|
-
virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
|
832
|
-
virtual size_t get_size_written() = 0;
|
|
833
|
-
virtual ~llama_data_write() = default;
|
|
931
|
+
void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
|
|
932
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
834
933
|
|
|
835
|
-
|
|
836
|
-
|
|
934
|
+
this->abort_callback = abort_callback;
|
|
935
|
+
this->abort_callback_data = abort_callback_data;
|
|
837
936
|
|
|
838
|
-
|
|
839
|
-
|
|
937
|
+
for (auto & backend : backends) {
|
|
938
|
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
|
939
|
+
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
|
940
|
+
if (set_abort_callback_fn) {
|
|
941
|
+
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
|
|
942
|
+
}
|
|
840
943
|
}
|
|
944
|
+
}
|
|
841
945
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
write_string(arch_str);
|
|
845
|
-
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
|
846
|
-
}
|
|
946
|
+
void llama_context::set_embeddings(bool value) {
|
|
947
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
847
948
|
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
// rng_ss << rng;
|
|
949
|
+
cparams.embeddings = value;
|
|
950
|
+
}
|
|
851
951
|
|
|
852
|
-
|
|
952
|
+
void llama_context::set_causal_attn(bool value) {
|
|
953
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
853
954
|
|
|
854
|
-
|
|
855
|
-
|
|
955
|
+
cparams.causal_attn = value;
|
|
956
|
+
}
|
|
856
957
|
|
|
857
|
-
|
|
858
|
-
|
|
958
|
+
void llama_context::set_warmup(bool value) {
|
|
959
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
859
960
|
|
|
860
|
-
|
|
961
|
+
cparams.warmup = value;
|
|
962
|
+
}
|
|
861
963
|
|
|
862
|
-
|
|
964
|
+
void llama_context::set_adapter_lora(
|
|
965
|
+
llama_adapter_lora * adapter,
|
|
966
|
+
float scale) {
|
|
967
|
+
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
|
|
863
968
|
|
|
864
|
-
|
|
865
|
-
|
|
969
|
+
loras[adapter] = scale;
|
|
970
|
+
}
|
|
866
971
|
|
|
867
|
-
|
|
972
|
+
bool llama_context::rm_adapter_lora(
|
|
973
|
+
llama_adapter_lora * adapter) {
|
|
974
|
+
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
|
|
868
975
|
|
|
869
|
-
|
|
976
|
+
auto pos = loras.find(adapter);
|
|
977
|
+
if (pos != loras.end()) {
|
|
978
|
+
loras.erase(pos);
|
|
979
|
+
return true;
|
|
980
|
+
}
|
|
870
981
|
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
// map an output id to a position in the batch
|
|
874
|
-
int32_t pos = output_ids[i];
|
|
875
|
-
if (pos >= 0) {
|
|
876
|
-
GGML_ASSERT((uint32_t) pos < n_outputs);
|
|
877
|
-
output_pos[pos] = i;
|
|
878
|
-
}
|
|
879
|
-
}
|
|
982
|
+
return false;
|
|
983
|
+
}
|
|
880
984
|
|
|
881
|
-
|
|
985
|
+
void llama_context::clear_adapter_lora() {
|
|
986
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
882
987
|
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
}
|
|
886
|
-
}
|
|
988
|
+
loras.clear();
|
|
989
|
+
}
|
|
887
990
|
|
|
888
|
-
|
|
889
|
-
|
|
991
|
+
bool llama_context::apply_adapter_cvec(
|
|
992
|
+
const float * data,
|
|
993
|
+
size_t len,
|
|
994
|
+
int32_t n_embd,
|
|
995
|
+
int32_t il_start,
|
|
996
|
+
int32_t il_end) {
|
|
997
|
+
LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
|
|
890
998
|
|
|
891
|
-
|
|
999
|
+
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
|
1000
|
+
}
|
|
892
1001
|
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
1002
|
+
int llama_context::encode(llama_batch & inp_batch) {
|
|
1003
|
+
if (inp_batch.n_tokens == 0) {
|
|
1004
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
1005
|
+
return -1;
|
|
896
1006
|
}
|
|
897
1007
|
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
write(&embeddings_size, sizeof(embeddings_size));
|
|
1008
|
+
// temporary allocate memory for the input batch if needed
|
|
1009
|
+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
|
1010
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
|
902
1011
|
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
}
|
|
906
|
-
}
|
|
1012
|
+
const llama_batch & batch = batch_allocr.batch;
|
|
1013
|
+
const int32_t n_tokens = batch.n_tokens;
|
|
907
1014
|
|
|
908
|
-
|
|
909
|
-
for (const auto & range : cell_ranges) {
|
|
910
|
-
for (uint32_t i = range.first; i < range.second; ++i) {
|
|
911
|
-
const auto & cell = kv_self.cells[i];
|
|
912
|
-
const llama_pos pos = cell.pos;
|
|
913
|
-
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
|
1015
|
+
const auto & hparams = model.hparams;
|
|
914
1016
|
|
|
915
|
-
|
|
916
|
-
write(&n_seq_id, sizeof(n_seq_id));
|
|
1017
|
+
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
917
1018
|
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
1019
|
+
if (batch.token) {
|
|
1020
|
+
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
1021
|
+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
1022
|
+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
1023
|
+
return -1;
|
|
923
1024
|
}
|
|
924
1025
|
}
|
|
925
1026
|
}
|
|
926
1027
|
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
const struct llama_hparams & hparams = ctx->model.hparams;
|
|
1028
|
+
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
1029
|
+
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
930
1030
|
|
|
931
|
-
|
|
932
|
-
|
|
1031
|
+
if (t_compute_start_us == 0) {
|
|
1032
|
+
t_compute_start_us = ggml_time_us();
|
|
1033
|
+
}
|
|
933
1034
|
|
|
934
|
-
|
|
935
|
-
write(&n_layer, sizeof(n_layer));
|
|
1035
|
+
n_queued_tokens += n_tokens;
|
|
936
1036
|
|
|
937
|
-
|
|
1037
|
+
const int64_t n_embd = hparams.n_embd;
|
|
938
1038
|
|
|
939
|
-
|
|
940
|
-
// Get whole range at a time
|
|
941
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
942
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
1039
|
+
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
943
1040
|
|
|
944
|
-
|
|
945
|
-
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
|
946
|
-
write(&k_type_i, sizeof(k_type_i));
|
|
1041
|
+
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
|
947
1042
|
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
1043
|
+
// reserve output buffer
|
|
1044
|
+
if (output_reserve(n_tokens) < n_tokens) {
|
|
1045
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
1046
|
+
return -2;
|
|
1047
|
+
};
|
|
951
1048
|
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
const size_t buf_size = range_size * k_size_row;
|
|
956
|
-
write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
|
|
957
|
-
}
|
|
958
|
-
}
|
|
1049
|
+
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
1050
|
+
output_ids[i] = i;
|
|
1051
|
+
}
|
|
959
1052
|
|
|
960
|
-
|
|
961
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
962
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
1053
|
+
n_outputs = n_tokens;
|
|
963
1054
|
|
|
964
|
-
|
|
965
|
-
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
|
966
|
-
write(&v_type_i, sizeof(v_type_i));
|
|
1055
|
+
//batch_manager->prepare(ubatch);
|
|
967
1056
|
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
write(&v_size_row, sizeof(v_size_row));
|
|
1057
|
+
ggml_backend_sched_reset(sched.get());
|
|
1058
|
+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
971
1059
|
|
|
972
|
-
|
|
973
|
-
for (const auto & range : cell_ranges) {
|
|
974
|
-
const size_t range_size = range.second - range.first;
|
|
975
|
-
const size_t buf_size = range_size * v_size_row;
|
|
976
|
-
write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
|
|
977
|
-
}
|
|
978
|
-
}
|
|
979
|
-
} else {
|
|
980
|
-
// When v is transposed, we also need the element size and get the element ranges from each row
|
|
981
|
-
const uint32_t kv_size = kv_self.size;
|
|
982
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
983
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
1060
|
+
const auto causal_attn_org = cparams.causal_attn;
|
|
984
1061
|
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
1062
|
+
// always use non-causal attention for encoder graphs
|
|
1063
|
+
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
|
1064
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
|
1065
|
+
cparams.causal_attn = false;
|
|
988
1066
|
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
write(&v_size_el, sizeof(v_size_el));
|
|
1067
|
+
auto * gf = graph_init();
|
|
1068
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
|
992
1069
|
|
|
993
|
-
|
|
994
|
-
write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
|
1070
|
+
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
995
1071
|
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1072
|
+
res->set_inputs(&ubatch);
|
|
1073
|
+
|
|
1074
|
+
cparams.causal_attn = causal_attn_org;
|
|
1075
|
+
|
|
1076
|
+
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
|
1077
|
+
switch (compute_status) {
|
|
1078
|
+
case GGML_STATUS_SUCCESS:
|
|
1079
|
+
break;
|
|
1080
|
+
case GGML_STATUS_ABORTED:
|
|
1081
|
+
return 2;
|
|
1082
|
+
case GGML_STATUS_ALLOC_FAILED:
|
|
1083
|
+
return -2;
|
|
1084
|
+
case GGML_STATUS_FAILED:
|
|
1085
|
+
default:
|
|
1086
|
+
return -3;
|
|
1008
1087
|
}
|
|
1009
1088
|
|
|
1010
|
-
|
|
1011
|
-
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
|
1012
|
-
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
|
1013
|
-
uint32_t cell_count = 0;
|
|
1089
|
+
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
1014
1090
|
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1091
|
+
// extract embeddings
|
|
1092
|
+
if (t_embd) {
|
|
1093
|
+
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1094
|
+
GGML_ASSERT(backend_embd != nullptr);
|
|
1095
|
+
|
|
1096
|
+
GGML_ASSERT(embd != nullptr);
|
|
1097
|
+
|
|
1098
|
+
switch (cparams.pooling_type) {
|
|
1099
|
+
case LLAMA_POOLING_TYPE_NONE:
|
|
1100
|
+
{
|
|
1101
|
+
// extract token embeddings
|
|
1102
|
+
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
|
1103
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
|
|
1104
|
+
} break;
|
|
1105
|
+
case LLAMA_POOLING_TYPE_MEAN:
|
|
1106
|
+
case LLAMA_POOLING_TYPE_CLS:
|
|
1107
|
+
case LLAMA_POOLING_TYPE_LAST:
|
|
1108
|
+
{
|
|
1109
|
+
// extract sequence embeddings
|
|
1110
|
+
auto & embd_seq_out = embd_seq;
|
|
1111
|
+
embd_seq_out.clear();
|
|
1112
|
+
|
|
1113
|
+
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
1114
|
+
|
|
1115
|
+
for (int32_t i = 0; i < n_tokens; i++) {
|
|
1116
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
1117
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
1118
|
+
continue;
|
|
1119
|
+
}
|
|
1120
|
+
embd_seq_out[seq_id].resize(n_embd);
|
|
1121
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
|
1122
|
+
}
|
|
1123
|
+
} break;
|
|
1124
|
+
case LLAMA_POOLING_TYPE_RANK:
|
|
1125
|
+
{
|
|
1126
|
+
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
|
1127
|
+
// wait for an encoder model that requires this pooling type in order to test it
|
|
1128
|
+
// https://github.com/ggerganov/llama.cpp/pull/9510
|
|
1129
|
+
GGML_ABORT("RANK pooling not implemented yet");
|
|
1024
1130
|
}
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
cell_range_begin = kv_self.size;
|
|
1131
|
+
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
1132
|
+
{
|
|
1133
|
+
GGML_ABORT("unknown pooling type");
|
|
1029
1134
|
}
|
|
1030
|
-
}
|
|
1031
1135
|
}
|
|
1032
|
-
|
|
1033
|
-
cell_ranges.emplace_back(cell_range_begin, kv_self.size);
|
|
1034
|
-
}
|
|
1035
|
-
|
|
1036
|
-
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
|
1037
|
-
uint32_t cell_count_check = 0;
|
|
1038
|
-
for (const auto & range : cell_ranges) {
|
|
1039
|
-
cell_count_check += range.second - range.first;
|
|
1040
|
-
}
|
|
1041
|
-
GGML_ASSERT(cell_count == cell_count_check);
|
|
1136
|
+
}
|
|
1042
1137
|
|
|
1043
|
-
|
|
1138
|
+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
1139
|
+
// overlap with device computation.
|
|
1140
|
+
ggml_backend_sched_reset(sched.get());
|
|
1044
1141
|
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
};
|
|
1142
|
+
// TODO: hacky solution
|
|
1143
|
+
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
|
1144
|
+
//cross.t_embd = t_embd;
|
|
1049
1145
|
|
|
1050
|
-
|
|
1051
|
-
virtual const uint8_t * read(size_t size) = 0;
|
|
1052
|
-
virtual void read_to(void * dst, size_t size) = 0;
|
|
1053
|
-
virtual size_t get_size_read() = 0;
|
|
1054
|
-
virtual ~llama_data_read() = default;
|
|
1146
|
+
synchronize();
|
|
1055
1147
|
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1148
|
+
cross.n_embd = t_embd->ne[0];
|
|
1149
|
+
cross.n_enc = t_embd->ne[1];
|
|
1150
|
+
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
|
1151
|
+
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
|
1059
1152
|
|
|
1060
|
-
|
|
1153
|
+
// remember the sequence ids used during the encoding - needed for cross attention later
|
|
1154
|
+
cross.seq_ids_enc.resize(n_tokens);
|
|
1155
|
+
for (int32_t i = 0; i < n_tokens; i++) {
|
|
1156
|
+
cross.seq_ids_enc[i].clear();
|
|
1157
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
1158
|
+
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
|
1159
|
+
cross.seq_ids_enc[i].insert(seq_id);
|
|
1160
|
+
}
|
|
1161
|
+
}
|
|
1061
1162
|
}
|
|
1062
1163
|
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
const std::string cur_arch_str = llm_arch_name(ctx->model.arch);
|
|
1164
|
+
return 0;
|
|
1165
|
+
}
|
|
1066
1166
|
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
}
|
|
1072
|
-
// TODO: add more info which needs to be identical but which is not verified otherwise
|
|
1167
|
+
int llama_context::decode(llama_batch & inp_batch) {
|
|
1168
|
+
if (inp_batch.n_tokens == 0) {
|
|
1169
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
1170
|
+
return -1;
|
|
1073
1171
|
}
|
|
1074
1172
|
|
|
1075
|
-
//
|
|
1076
|
-
//
|
|
1077
|
-
|
|
1173
|
+
// temporary allocate memory for the input batch if needed
|
|
1174
|
+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
|
1175
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
|
1078
1176
|
|
|
1079
|
-
|
|
1080
|
-
// rng_ss >> rng;
|
|
1177
|
+
const llama_batch & batch = batch_allocr.batch;
|
|
1081
1178
|
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
// }
|
|
1085
|
-
//}
|
|
1179
|
+
const auto & vocab = model.vocab;
|
|
1180
|
+
const auto & hparams = model.hparams;
|
|
1086
1181
|
|
|
1087
|
-
|
|
1088
|
-
std::vector<int32_t> output_pos;
|
|
1182
|
+
const int32_t n_vocab = vocab.n_tokens();
|
|
1089
1183
|
|
|
1090
|
-
|
|
1091
|
-
|
|
1184
|
+
const int64_t n_tokens_all = batch.n_tokens;
|
|
1185
|
+
const int64_t n_embd = hparams.n_embd;
|
|
1092
1186
|
|
|
1093
|
-
|
|
1094
|
-
|
|
1187
|
+
// TODO: remove this stuff
|
|
1188
|
+
class batch_guard {
|
|
1189
|
+
public:
|
|
1190
|
+
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
|
|
1095
1191
|
}
|
|
1096
1192
|
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
|
1102
|
-
int32_t id = output_pos[i];
|
|
1103
|
-
if ((uint32_t) id >= ctx->cparams.n_batch) {
|
|
1104
|
-
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
|
|
1105
|
-
}
|
|
1106
|
-
ctx->output_ids[id] = i;
|
|
1193
|
+
~batch_guard() {
|
|
1194
|
+
if (!is_done) {
|
|
1195
|
+
kv_slot_restorer.restore();
|
|
1107
1196
|
}
|
|
1108
|
-
|
|
1109
|
-
ctx->n_outputs = n_outputs;
|
|
1110
1197
|
}
|
|
1111
|
-
}
|
|
1112
1198
|
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
read_to(&logits_size, sizeof(logits_size));
|
|
1116
|
-
|
|
1117
|
-
if (ctx->logits_size < logits_size) {
|
|
1118
|
-
throw std::runtime_error("logits buffer too small");
|
|
1199
|
+
void done() {
|
|
1200
|
+
is_done = true;
|
|
1119
1201
|
}
|
|
1120
1202
|
|
|
1121
|
-
|
|
1122
|
-
|
|
1203
|
+
void save(const llama_kv_cache_slot_info & slot_info) {
|
|
1204
|
+
kv_slot_restorer.save(slot_info);
|
|
1123
1205
|
}
|
|
1124
|
-
}
|
|
1125
1206
|
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
read_to(&embeddings_size, sizeof(embeddings_size));
|
|
1207
|
+
private:
|
|
1208
|
+
bool is_done = false;
|
|
1129
1209
|
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1210
|
+
llama_kv_slot_restorer kv_slot_restorer;
|
|
1211
|
+
};
|
|
1212
|
+
|
|
1213
|
+
batch_guard bg(*kv_self);
|
|
1133
1214
|
|
|
1134
|
-
|
|
1135
|
-
|
|
1215
|
+
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
1216
|
+
|
|
1217
|
+
if (batch.token) {
|
|
1218
|
+
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
|
1219
|
+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
1220
|
+
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
|
1221
|
+
throw std::runtime_error("invalid token");
|
|
1222
|
+
}
|
|
1136
1223
|
}
|
|
1137
1224
|
}
|
|
1138
1225
|
|
|
1139
|
-
|
|
1140
|
-
struct llama_kv_cache & kv_self = ctx->kv_self;
|
|
1226
|
+
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
|
1141
1227
|
|
|
1142
|
-
|
|
1143
|
-
// single sequence
|
|
1228
|
+
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
|
1144
1229
|
|
|
1145
|
-
|
|
1230
|
+
if (t_compute_start_us == 0) {
|
|
1231
|
+
t_compute_start_us = ggml_time_us();
|
|
1232
|
+
}
|
|
1233
|
+
n_queued_tokens += n_tokens_all;
|
|
1146
1234
|
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
batch.n_seq_tokens = cell_count;
|
|
1150
|
-
batch.n_seqs = 1;
|
|
1235
|
+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
1236
|
+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
1151
1237
|
|
|
1152
|
-
|
|
1153
|
-
llama_pos pos;
|
|
1154
|
-
uint32_t n_seq_id;
|
|
1238
|
+
embd_seq.clear();
|
|
1155
1239
|
|
|
1156
|
-
|
|
1157
|
-
read_to(&n_seq_id, sizeof(n_seq_id));
|
|
1240
|
+
int64_t n_outputs_all = 0;
|
|
1158
1241
|
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1242
|
+
// count outputs
|
|
1243
|
+
if (batch.logits && !embd_pooled) {
|
|
1244
|
+
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
1245
|
+
n_outputs_all += batch.logits[i] != 0;
|
|
1246
|
+
}
|
|
1247
|
+
} else if (logits_all || embd_pooled) {
|
|
1248
|
+
n_outputs_all = n_tokens_all;
|
|
1249
|
+
} else {
|
|
1250
|
+
// keep last output only
|
|
1251
|
+
n_outputs_all = 1;
|
|
1252
|
+
}
|
|
1163
1253
|
|
|
1164
|
-
|
|
1165
|
-
}
|
|
1166
|
-
batch.n_seq_id[0] = 1;
|
|
1167
|
-
batch.seq_id[0] = &dest_seq_id;
|
|
1168
|
-
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
|
1169
|
-
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
1170
|
-
return false;
|
|
1171
|
-
}
|
|
1254
|
+
const bool logits_all = n_outputs_all == n_tokens_all;
|
|
1172
1255
|
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
|
|
1177
|
-
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
|
1178
|
-
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
|
1179
|
-
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
|
1180
|
-
} else {
|
|
1181
|
-
// whole KV cache restore
|
|
1256
|
+
sbatch.from_batch(batch, n_embd,
|
|
1257
|
+
/* simple_split */ !kv_self->recurrent,
|
|
1258
|
+
/* logits_all */ logits_all);
|
|
1182
1259
|
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1260
|
+
// reserve output buffer
|
|
1261
|
+
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
1262
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
|
1263
|
+
return -2;
|
|
1264
|
+
};
|
|
1265
|
+
|
|
1266
|
+
int64_t n_outputs_prev = 0;
|
|
1187
1267
|
|
|
1188
|
-
|
|
1268
|
+
while (sbatch.n_tokens > 0) {
|
|
1269
|
+
llama_ubatch ubatch = llama_ubatch();
|
|
1189
1270
|
|
|
1190
|
-
|
|
1191
|
-
llama_kv_cell & cell = kv_self.cells[i];
|
|
1271
|
+
const auto & n_ubatch = cparams.n_ubatch;
|
|
1192
1272
|
|
|
1193
|
-
|
|
1194
|
-
|
|
1273
|
+
if (kv_self->recurrent) {
|
|
1274
|
+
if (embd_pooled) {
|
|
1275
|
+
// Pooled embeddings cannot be split across ubatches (yet)
|
|
1276
|
+
ubatch = sbatch.split_seq(cparams.n_ubatch);
|
|
1277
|
+
} else {
|
|
1278
|
+
// recurrent model architectures are easier to implement
|
|
1279
|
+
// with equal-length sequences
|
|
1280
|
+
ubatch = sbatch.split_equal(cparams.n_ubatch);
|
|
1281
|
+
}
|
|
1282
|
+
} else {
|
|
1283
|
+
ubatch = sbatch.split_simple(n_ubatch);
|
|
1284
|
+
}
|
|
1195
1285
|
|
|
1196
|
-
|
|
1197
|
-
|
|
1286
|
+
// count the outputs in this u_batch
|
|
1287
|
+
{
|
|
1288
|
+
int32_t n_outputs_new = 0;
|
|
1198
1289
|
|
|
1199
|
-
|
|
1290
|
+
if (n_outputs_all == n_tokens_all) {
|
|
1291
|
+
n_outputs_new = ubatch.n_tokens;
|
|
1292
|
+
} else {
|
|
1293
|
+
GGML_ASSERT(ubatch.output);
|
|
1294
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
1295
|
+
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
|
1296
|
+
}
|
|
1297
|
+
}
|
|
1200
1298
|
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1299
|
+
// needs to happen before the graph is built
|
|
1300
|
+
n_outputs = n_outputs_new;
|
|
1301
|
+
}
|
|
1204
1302
|
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
}
|
|
1303
|
+
// non-causal masks do not use the KV cache
|
|
1304
|
+
if (hparams.causal_attn) {
|
|
1305
|
+
kv_self_update();
|
|
1209
1306
|
|
|
1210
|
-
|
|
1307
|
+
// if we have enough unused cells before the current head ->
|
|
1308
|
+
// better to start searching from the beginning of the cache, hoping to fill it
|
|
1309
|
+
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
|
|
1310
|
+
kv_self->head = 0;
|
|
1311
|
+
}
|
|
1211
1312
|
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
return false;
|
|
1217
|
-
}
|
|
1218
|
-
tail = i;
|
|
1219
|
-
}
|
|
1220
|
-
}
|
|
1313
|
+
const auto slot_info = kv_self->find_slot(ubatch);
|
|
1314
|
+
if (!slot_info) {
|
|
1315
|
+
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
|
1316
|
+
return -3;
|
|
1221
1317
|
}
|
|
1222
1318
|
|
|
1223
|
-
|
|
1224
|
-
kv_self.used = cell_count;
|
|
1225
|
-
}
|
|
1319
|
+
bg.save(slot_info);
|
|
1226
1320
|
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
//
|
|
1231
|
-
|
|
1321
|
+
if (!kv_self->recurrent) {
|
|
1322
|
+
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
1323
|
+
// after enough generations, the benefit from this heuristic disappears
|
|
1324
|
+
// if we start defragmenting the cache, the benefit from this will be more important
|
|
1325
|
+
const uint32_t pad = kv_self->get_padding(cparams);
|
|
1326
|
+
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
|
|
1232
1327
|
}
|
|
1233
1328
|
}
|
|
1234
1329
|
|
|
1235
|
-
|
|
1236
|
-
}
|
|
1330
|
+
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
|
1237
1331
|
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
struct llama_kv_cache & kv_self = ctx->kv_self;
|
|
1241
|
-
uint32_t v_trans;
|
|
1242
|
-
uint32_t n_layer;
|
|
1243
|
-
read_to(&v_trans, sizeof(v_trans));
|
|
1244
|
-
read_to(&n_layer, sizeof(n_layer));
|
|
1332
|
+
ggml_backend_sched_reset(sched.get());
|
|
1333
|
+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
1245
1334
|
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
return false;
|
|
1249
|
-
}
|
|
1250
|
-
if (cell_count > kv_self.size) {
|
|
1251
|
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
|
|
1252
|
-
return false;
|
|
1253
|
-
}
|
|
1254
|
-
if (kv_self.v_trans != (bool) v_trans) {
|
|
1255
|
-
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
|
1256
|
-
return false;
|
|
1257
|
-
}
|
|
1335
|
+
auto * gf = graph_init();
|
|
1336
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
|
|
1258
1337
|
|
|
1259
|
-
//
|
|
1260
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
1261
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
1338
|
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
|
1262
1339
|
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
|
1267
|
-
if (k_type_i != k_type_i_ref) {
|
|
1268
|
-
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
|
1269
|
-
return false;
|
|
1270
|
-
}
|
|
1340
|
+
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
1341
|
+
|
|
1342
|
+
res->set_inputs(&ubatch);
|
|
1271
1343
|
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1344
|
+
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
|
1345
|
+
if (compute_status != GGML_STATUS_SUCCESS) {
|
|
1346
|
+
switch (compute_status) {
|
|
1347
|
+
case GGML_STATUS_ABORTED:
|
|
1348
|
+
return 2;
|
|
1349
|
+
case GGML_STATUS_ALLOC_FAILED:
|
|
1350
|
+
return -2;
|
|
1351
|
+
case GGML_STATUS_FAILED:
|
|
1352
|
+
default:
|
|
1353
|
+
return -3;
|
|
1279
1354
|
}
|
|
1355
|
+
}
|
|
1356
|
+
|
|
1357
|
+
// update the kv ring buffer
|
|
1358
|
+
{
|
|
1359
|
+
kv_self->head += ubatch.n_tokens;
|
|
1280
1360
|
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1361
|
+
// Ensure kv cache head points to a valid index.
|
|
1362
|
+
if (kv_self->head >= kv_self->size) {
|
|
1363
|
+
kv_self->head = 0;
|
|
1284
1364
|
}
|
|
1285
1365
|
}
|
|
1286
1366
|
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1367
|
+
// plot the computation graph in dot format (for debugging purposes)
|
|
1368
|
+
//if (n_past%100 == 0) {
|
|
1369
|
+
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
1370
|
+
//}
|
|
1290
1371
|
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
|
1294
|
-
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
|
1295
|
-
if (v_type_i != v_type_i_ref) {
|
|
1296
|
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
|
1297
|
-
return false;
|
|
1298
|
-
}
|
|
1372
|
+
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
|
1373
|
+
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
|
1299
1374
|
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
|
1304
|
-
if (v_size_row != v_size_row_ref) {
|
|
1305
|
-
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
|
1306
|
-
return false;
|
|
1307
|
-
}
|
|
1375
|
+
if (t_embd && res->get_embd_pooled()) {
|
|
1376
|
+
t_embd = res->get_embd_pooled();
|
|
1377
|
+
}
|
|
1308
1378
|
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
} else {
|
|
1315
|
-
// For each layer, read the values for each cell (transposed)
|
|
1316
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
1317
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
1318
|
-
|
|
1319
|
-
// Read type of value
|
|
1320
|
-
int32_t v_type_i_ref;
|
|
1321
|
-
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
|
1322
|
-
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
|
1323
|
-
if (v_type_i != v_type_i_ref) {
|
|
1324
|
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
|
1325
|
-
return false;
|
|
1326
|
-
}
|
|
1379
|
+
// extract logits
|
|
1380
|
+
if (t_logits && n_outputs > 0) {
|
|
1381
|
+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
1382
|
+
GGML_ASSERT(backend_res != nullptr);
|
|
1383
|
+
GGML_ASSERT(logits != nullptr);
|
|
1327
1384
|
|
|
1328
|
-
|
|
1329
|
-
uint32_t v_size_el_ref;
|
|
1330
|
-
read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
|
1331
|
-
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
|
1332
|
-
if (v_size_el != v_size_el_ref) {
|
|
1333
|
-
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
|
1334
|
-
return false;
|
|
1335
|
-
}
|
|
1385
|
+
float * logits_out = logits + n_outputs_prev*n_vocab;
|
|
1336
1386
|
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
}
|
|
1387
|
+
if (n_outputs) {
|
|
1388
|
+
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1389
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
|
1390
|
+
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
|
1391
|
+
}
|
|
1392
|
+
}
|
|
1344
1393
|
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1394
|
+
// extract embeddings
|
|
1395
|
+
if (t_embd && n_outputs > 0) {
|
|
1396
|
+
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1397
|
+
GGML_ASSERT(backend_embd != nullptr);
|
|
1398
|
+
|
|
1399
|
+
switch (cparams.pooling_type) {
|
|
1400
|
+
case LLAMA_POOLING_TYPE_NONE:
|
|
1401
|
+
{
|
|
1402
|
+
// extract token embeddings
|
|
1403
|
+
GGML_ASSERT(embd != nullptr);
|
|
1404
|
+
float * embd_out = embd + n_outputs_prev*n_embd;
|
|
1405
|
+
|
|
1406
|
+
if (n_outputs) {
|
|
1407
|
+
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1408
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
|
1409
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
|
1410
|
+
}
|
|
1411
|
+
} break;
|
|
1412
|
+
case LLAMA_POOLING_TYPE_MEAN:
|
|
1413
|
+
case LLAMA_POOLING_TYPE_CLS:
|
|
1414
|
+
case LLAMA_POOLING_TYPE_LAST:
|
|
1415
|
+
{
|
|
1416
|
+
// extract sequence embeddings (cleared before processing each batch)
|
|
1417
|
+
auto & embd_seq_out = embd_seq;
|
|
1418
|
+
|
|
1419
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
1420
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
1421
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
1422
|
+
continue;
|
|
1423
|
+
}
|
|
1424
|
+
embd_seq_out[seq_id].resize(n_embd);
|
|
1425
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
|
1426
|
+
}
|
|
1427
|
+
} break;
|
|
1428
|
+
case LLAMA_POOLING_TYPE_RANK:
|
|
1429
|
+
{
|
|
1430
|
+
// extract the rerank score - a single float per sequence
|
|
1431
|
+
auto & embd_seq_out = embd_seq;
|
|
1432
|
+
|
|
1433
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
1434
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
1435
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
1436
|
+
continue;
|
|
1437
|
+
}
|
|
1438
|
+
embd_seq_out[seq_id].resize(1);
|
|
1439
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
|
1440
|
+
}
|
|
1441
|
+
} break;
|
|
1442
|
+
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
1443
|
+
{
|
|
1444
|
+
GGML_ABORT("unknown pooling type");
|
|
1350
1445
|
}
|
|
1351
|
-
}
|
|
1352
1446
|
}
|
|
1353
1447
|
}
|
|
1354
|
-
|
|
1448
|
+
|
|
1449
|
+
n_outputs_prev += n_outputs;
|
|
1355
1450
|
}
|
|
1356
1451
|
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1452
|
+
// finalize the batch processing
|
|
1453
|
+
bg.done();
|
|
1454
|
+
|
|
1455
|
+
// set output mappings
|
|
1456
|
+
{
|
|
1457
|
+
bool sorted_output = true;
|
|
1360
1458
|
|
|
1361
|
-
|
|
1459
|
+
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
|
1362
1460
|
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1461
|
+
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
|
1462
|
+
int64_t out_id = sbatch.out_ids[i];
|
|
1463
|
+
output_ids[out_id] = i;
|
|
1464
|
+
if (out_id != i) {
|
|
1465
|
+
sorted_output = false;
|
|
1368
1466
|
}
|
|
1369
|
-
throw std::runtime_error("failed to restore kv cache");
|
|
1370
1467
|
}
|
|
1371
|
-
}
|
|
1372
|
-
};
|
|
1373
|
-
|
|
1374
|
-
struct llama_data_write_dummy : llama_data_write {
|
|
1375
|
-
size_t size_written = 0;
|
|
1376
1468
|
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
size_written += size;
|
|
1469
|
+
if (sorted_output) {
|
|
1470
|
+
sbatch.out_ids.clear();
|
|
1471
|
+
}
|
|
1381
1472
|
}
|
|
1382
1473
|
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
}
|
|
1474
|
+
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
1475
|
+
n_outputs = n_outputs_all;
|
|
1386
1476
|
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
}
|
|
1390
|
-
};
|
|
1477
|
+
// wait for the computation to finish (automatically done when obtaining the model output)
|
|
1478
|
+
//synchronize();
|
|
1391
1479
|
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1480
|
+
// decide if we need to defrag the kv cache
|
|
1481
|
+
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
|
1482
|
+
// - do not defrag small contexts (i.e. < 2048 tokens)
|
|
1483
|
+
// - count the padding towards the number of used tokens
|
|
1484
|
+
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
|
|
1396
1485
|
|
|
1397
|
-
|
|
1486
|
+
// queue defragmentation for next llama_kv_cache_update
|
|
1487
|
+
if (fragmentation > cparams.defrag_thold) {
|
|
1488
|
+
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
1398
1489
|
|
|
1399
|
-
|
|
1400
|
-
if (size > buf_size) {
|
|
1401
|
-
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
1490
|
+
kv_self->defrag();
|
|
1402
1491
|
}
|
|
1403
|
-
memcpy(ptr, src, size);
|
|
1404
|
-
ptr += size;
|
|
1405
|
-
size_written += size;
|
|
1406
|
-
buf_size -= size;
|
|
1407
1492
|
}
|
|
1408
1493
|
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
}
|
|
1413
|
-
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
|
1414
|
-
ptr += size;
|
|
1415
|
-
size_written += size;
|
|
1416
|
-
buf_size -= size;
|
|
1417
|
-
}
|
|
1494
|
+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
1495
|
+
// overlap with device computation.
|
|
1496
|
+
ggml_backend_sched_reset(sched.get());
|
|
1418
1497
|
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
}
|
|
1422
|
-
};
|
|
1498
|
+
return 0;
|
|
1499
|
+
}
|
|
1423
1500
|
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
size_t size_read = 0;
|
|
1501
|
+
//
|
|
1502
|
+
// output
|
|
1503
|
+
//
|
|
1428
1504
|
|
|
1429
|
-
|
|
1505
|
+
int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1506
|
+
const auto & hparams = model.hparams;
|
|
1507
|
+
const auto & vocab = model.vocab;
|
|
1430
1508
|
|
|
1431
|
-
const
|
|
1432
|
-
const uint8_t * base_ptr = ptr;
|
|
1433
|
-
if (size > buf_size) {
|
|
1434
|
-
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
1435
|
-
}
|
|
1436
|
-
ptr += size;
|
|
1437
|
-
size_read += size;
|
|
1438
|
-
buf_size -= size;
|
|
1439
|
-
return base_ptr;
|
|
1440
|
-
}
|
|
1509
|
+
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
|
|
1441
1510
|
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1511
|
+
const auto n_batch = cparams.n_batch;
|
|
1512
|
+
const auto n_vocab = vocab.n_tokens();
|
|
1513
|
+
const auto n_embd = hparams.n_embd;
|
|
1445
1514
|
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
};
|
|
1515
|
+
// TODO: use a per-batch flag for logits presence instead
|
|
1516
|
+
bool has_logits = !cparams.embeddings;
|
|
1517
|
+
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
1450
1518
|
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1519
|
+
// TODO: hacky enc-dec support
|
|
1520
|
+
if (model.arch == LLM_ARCH_T5) {
|
|
1521
|
+
has_logits = true;
|
|
1522
|
+
has_embd = true;
|
|
1523
|
+
}
|
|
1455
1524
|
|
|
1456
|
-
|
|
1525
|
+
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
|
1526
|
+
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
|
1457
1527
|
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1528
|
+
if (output_ids.empty()) {
|
|
1529
|
+
// init, never resized afterwards
|
|
1530
|
+
output_ids.resize(n_batch);
|
|
1461
1531
|
}
|
|
1462
1532
|
|
|
1463
|
-
|
|
1533
|
+
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
|
1534
|
+
const size_t new_size = (logits_size + embd_size) * sizeof(float);
|
|
1535
|
+
|
|
1536
|
+
// alloc only when more than the current capacity is required
|
|
1537
|
+
// TODO: also consider shrinking the buffer
|
|
1538
|
+
if (!buf_output || prev_size < new_size) {
|
|
1539
|
+
if (buf_output) {
|
|
1540
|
+
#ifndef NDEBUG
|
|
1541
|
+
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
|
1542
|
+
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
|
1543
|
+
#endif
|
|
1544
|
+
buf_output = nullptr;
|
|
1545
|
+
logits = nullptr;
|
|
1546
|
+
embd = nullptr;
|
|
1547
|
+
}
|
|
1548
|
+
|
|
1549
|
+
auto * buft = ggml_backend_cpu_buffer_type();
|
|
1550
|
+
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
|
|
1551
|
+
auto * output_dev = model.dev_output();
|
|
1552
|
+
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
|
|
1553
|
+
if (output_dev_host_buft) {
|
|
1554
|
+
buft = output_dev_host_buft;
|
|
1555
|
+
}
|
|
1556
|
+
buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
|
|
1557
|
+
if (buf_output == nullptr) {
|
|
1558
|
+
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
|
|
1559
|
+
return 0;
|
|
1560
|
+
}
|
|
1561
|
+
}
|
|
1562
|
+
|
|
1563
|
+
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
|
1564
|
+
|
|
1565
|
+
logits = has_logits ? output_base : nullptr;
|
|
1566
|
+
embd = has_embd ? output_base + logits_size : nullptr;
|
|
1567
|
+
|
|
1568
|
+
// set all ids as invalid (negative)
|
|
1569
|
+
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1570
|
+
|
|
1571
|
+
ggml_backend_buffer_clear(buf_output.get(), 0);
|
|
1572
|
+
|
|
1573
|
+
this->n_outputs = 0;
|
|
1574
|
+
this->n_outputs_max = n_outputs_max;
|
|
1575
|
+
|
|
1576
|
+
return n_outputs_max;
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
void llama_context::output_reorder() {
|
|
1580
|
+
auto & out_ids = sbatch.out_ids;
|
|
1581
|
+
if (!out_ids.empty()) {
|
|
1582
|
+
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
1583
|
+
const uint32_t n_embd = model.hparams.n_embd;
|
|
1584
|
+
|
|
1585
|
+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1586
|
+
|
|
1587
|
+
// TODO: is there something more efficient which also minimizes swaps?
|
|
1588
|
+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
1589
|
+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
|
1590
|
+
int32_t j_min = i;
|
|
1591
|
+
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
|
1592
|
+
if (out_ids[j] < out_ids[j_min]) {
|
|
1593
|
+
j_min = j;
|
|
1594
|
+
}
|
|
1595
|
+
}
|
|
1596
|
+
if (j_min == i) { continue; }
|
|
1597
|
+
std::swap(out_ids[i], out_ids[j_min]);
|
|
1598
|
+
if (logits_size > 0) {
|
|
1599
|
+
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
1600
|
+
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
|
1601
|
+
}
|
|
1602
|
+
}
|
|
1603
|
+
if (embd_size > 0) {
|
|
1604
|
+
for (uint32_t k = 0; k < n_embd; k++) {
|
|
1605
|
+
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
|
1606
|
+
}
|
|
1607
|
+
}
|
|
1608
|
+
}
|
|
1609
|
+
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1610
|
+
for (int32_t i = 0; i < n_outputs; ++i) {
|
|
1611
|
+
output_ids[out_ids[i]] = i;
|
|
1612
|
+
}
|
|
1613
|
+
out_ids.clear();
|
|
1614
|
+
}
|
|
1615
|
+
}
|
|
1616
|
+
|
|
1617
|
+
//
|
|
1618
|
+
// graph
|
|
1619
|
+
//
|
|
1620
|
+
|
|
1621
|
+
int32_t llama_context::graph_max_nodes() const {
|
|
1622
|
+
return std::max<int32_t>(65536, 5*model.n_tensors());
|
|
1623
|
+
}
|
|
1624
|
+
|
|
1625
|
+
ggml_cgraph * llama_context::graph_init() {
|
|
1626
|
+
ggml_init_params params = {
|
|
1627
|
+
/*.mem_size =*/ buf_compute_meta.size(),
|
|
1628
|
+
/*.mem_buffer =*/ buf_compute_meta.data(),
|
|
1629
|
+
/*.no_alloc =*/ true,
|
|
1630
|
+
};
|
|
1631
|
+
|
|
1632
|
+
ctx_compute.reset(ggml_init(params));
|
|
1633
|
+
|
|
1634
|
+
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
|
1635
|
+
}
|
|
1636
|
+
|
|
1637
|
+
llm_graph_result_ptr llama_context::graph_build(
|
|
1638
|
+
ggml_context * ctx,
|
|
1639
|
+
ggml_cgraph * gf,
|
|
1640
|
+
const llama_ubatch & ubatch,
|
|
1641
|
+
llm_graph_type gtype) {
|
|
1642
|
+
return model.build_graph(
|
|
1643
|
+
{
|
|
1644
|
+
/*.ctx =*/ ctx,
|
|
1645
|
+
/*.arch =*/ model.arch,
|
|
1646
|
+
/*.hparams =*/ model.hparams,
|
|
1647
|
+
/*.cparams =*/ cparams,
|
|
1648
|
+
/*.ubatch =*/ ubatch,
|
|
1649
|
+
/*.sched =*/ sched.get(),
|
|
1650
|
+
/*.backend_cpu =*/ backend_cpu,
|
|
1651
|
+
/*.cvec =*/ &cvec,
|
|
1652
|
+
/*.loras =*/ &loras,
|
|
1653
|
+
/*.memory =*/ kv_self.get(),
|
|
1654
|
+
/*.cross =*/ &cross,
|
|
1655
|
+
/*.n_outputs =*/ n_outputs,
|
|
1656
|
+
/*.cb =*/ graph_get_cb(),
|
|
1657
|
+
}, gf, gtype);
|
|
1658
|
+
}
|
|
1659
|
+
|
|
1660
|
+
ggml_status llama_context::graph_compute(
|
|
1661
|
+
ggml_cgraph * gf,
|
|
1662
|
+
bool batched) {
|
|
1663
|
+
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
|
|
1664
|
+
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
|
|
1665
|
+
|
|
1666
|
+
if (backend_cpu != nullptr) {
|
|
1667
|
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
|
|
1668
|
+
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
|
|
1669
|
+
set_threadpool_fn(backend_cpu, tp);
|
|
1670
|
+
}
|
|
1671
|
+
|
|
1672
|
+
// set the number of threads for all the backends
|
|
1673
|
+
for (const auto & set_n_threads_fn : set_n_threads_fns) {
|
|
1674
|
+
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
|
1675
|
+
}
|
|
1676
|
+
|
|
1677
|
+
auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf);
|
|
1678
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
1679
|
+
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
|
1680
|
+
}
|
|
1681
|
+
|
|
1682
|
+
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
|
|
1683
|
+
|
|
1684
|
+
return status;
|
|
1685
|
+
}
|
|
1686
|
+
|
|
1687
|
+
llm_graph_cb llama_context::graph_get_cb() const {
|
|
1688
|
+
return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
|
|
1689
|
+
if (il >= 0) {
|
|
1690
|
+
ggml_format_name(cur, "%s-%d", name, il);
|
|
1691
|
+
} else {
|
|
1692
|
+
ggml_set_name(cur, name);
|
|
1693
|
+
}
|
|
1694
|
+
|
|
1695
|
+
if (!cparams.offload_kqv) {
|
|
1696
|
+
if (strcmp(name, "kqv_merged_cont") == 0) {
|
|
1697
|
+
// all nodes between the KV store and the attention output are run on the CPU
|
|
1698
|
+
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
|
|
1699
|
+
}
|
|
1700
|
+
}
|
|
1701
|
+
|
|
1702
|
+
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
|
|
1703
|
+
// FIXME: fix in ggml_backend_sched
|
|
1704
|
+
const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
|
|
1705
|
+
if (ubatch.n_tokens < 32 || full_offload) {
|
|
1706
|
+
if (il != -1 && strcmp(name, "norm") == 0) {
|
|
1707
|
+
const auto & dev_layer = model.dev_layer(il);
|
|
1708
|
+
for (const auto & backend : backends) {
|
|
1709
|
+
if (ggml_backend_get_device(backend.get()) == dev_layer) {
|
|
1710
|
+
if (ggml_backend_supports_op(backend.get(), cur)) {
|
|
1711
|
+
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
|
|
1712
|
+
}
|
|
1713
|
+
}
|
|
1714
|
+
}
|
|
1715
|
+
}
|
|
1716
|
+
}
|
|
1717
|
+
};
|
|
1718
|
+
}
|
|
1719
|
+
|
|
1720
|
+
//
|
|
1721
|
+
// state save/load
|
|
1722
|
+
//
|
|
1723
|
+
|
|
1724
|
+
class llama_io_write_dummy : public llama_io_write_i {
|
|
1725
|
+
public:
|
|
1726
|
+
llama_io_write_dummy() = default;
|
|
1727
|
+
|
|
1728
|
+
void write(const void * /* src */, size_t size) override {
|
|
1729
|
+
size_written += size;
|
|
1730
|
+
}
|
|
1731
|
+
|
|
1732
|
+
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
|
1733
|
+
size_written += size;
|
|
1734
|
+
}
|
|
1735
|
+
|
|
1736
|
+
size_t n_bytes() override {
|
|
1737
|
+
return size_written;
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
private:
|
|
1741
|
+
size_t size_written = 0;
|
|
1742
|
+
};
|
|
1743
|
+
|
|
1744
|
+
class llama_io_write_buffer : public llama_io_write_i {
|
|
1745
|
+
public:
|
|
1746
|
+
llama_io_write_buffer(
|
|
1747
|
+
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
|
1748
|
+
|
|
1749
|
+
void write(const void * src, size_t size) override {
|
|
1750
|
+
if (size > buf_size) {
|
|
1751
|
+
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
1752
|
+
}
|
|
1753
|
+
memcpy(ptr, src, size);
|
|
1754
|
+
ptr += size;
|
|
1755
|
+
size_written += size;
|
|
1756
|
+
buf_size -= size;
|
|
1757
|
+
}
|
|
1758
|
+
|
|
1759
|
+
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
|
1760
|
+
if (size > buf_size) {
|
|
1761
|
+
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
1762
|
+
}
|
|
1763
|
+
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
|
1764
|
+
ptr += size;
|
|
1765
|
+
size_written += size;
|
|
1766
|
+
buf_size -= size;
|
|
1767
|
+
}
|
|
1768
|
+
|
|
1769
|
+
size_t n_bytes() override {
|
|
1770
|
+
return size_written;
|
|
1771
|
+
}
|
|
1772
|
+
|
|
1773
|
+
private:
|
|
1774
|
+
uint8_t * ptr;
|
|
1775
|
+
size_t buf_size = 0;
|
|
1776
|
+
size_t size_written = 0;
|
|
1777
|
+
};
|
|
1778
|
+
|
|
1779
|
+
class llama_io_read_buffer : public llama_io_read_i {
|
|
1780
|
+
public:
|
|
1781
|
+
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
|
1782
|
+
|
|
1783
|
+
const uint8_t * read(size_t size) override {
|
|
1784
|
+
const uint8_t * base_ptr = ptr;
|
|
1785
|
+
if (size > buf_size) {
|
|
1786
|
+
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
1787
|
+
}
|
|
1788
|
+
ptr += size;
|
|
1789
|
+
size_read += size;
|
|
1790
|
+
buf_size -= size;
|
|
1791
|
+
return base_ptr;
|
|
1792
|
+
}
|
|
1793
|
+
|
|
1794
|
+
void read_to(void * dst, size_t size) override {
|
|
1795
|
+
memcpy(dst, read(size), size);
|
|
1796
|
+
}
|
|
1797
|
+
|
|
1798
|
+
size_t n_bytes() override {
|
|
1799
|
+
return size_read;
|
|
1800
|
+
}
|
|
1801
|
+
|
|
1802
|
+
private:
|
|
1803
|
+
const uint8_t * ptr;
|
|
1804
|
+
size_t buf_size = 0;
|
|
1805
|
+
size_t size_read = 0;
|
|
1806
|
+
};
|
|
1807
|
+
|
|
1808
|
+
class llama_io_write_file : public llama_io_write_i {
|
|
1809
|
+
public:
|
|
1810
|
+
llama_io_write_file(llama_file * f) : file(f) {}
|
|
1811
|
+
|
|
1812
|
+
void write(const void * src, size_t size) override {
|
|
1813
|
+
file->write_raw(src, size);
|
|
1814
|
+
size_written += size;
|
|
1815
|
+
}
|
|
1816
|
+
|
|
1817
|
+
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
|
1464
1818
|
temp_buffer.resize(size);
|
|
1465
1819
|
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
|
1466
1820
|
write(temp_buffer.data(), temp_buffer.size());
|
|
1467
1821
|
}
|
|
1468
1822
|
|
|
1469
|
-
size_t
|
|
1823
|
+
size_t n_bytes() override {
|
|
1470
1824
|
return size_written;
|
|
1471
1825
|
}
|
|
1472
|
-
};
|
|
1473
1826
|
|
|
1474
|
-
|
|
1827
|
+
private:
|
|
1475
1828
|
llama_file * file;
|
|
1476
|
-
size_t
|
|
1829
|
+
size_t size_written = 0;
|
|
1477
1830
|
std::vector<uint8_t> temp_buffer;
|
|
1831
|
+
};
|
|
1478
1832
|
|
|
1479
|
-
|
|
1833
|
+
class llama_io_read_file : public llama_io_read_i {
|
|
1834
|
+
public:
|
|
1835
|
+
llama_io_read_file(llama_file * f) : file(f) {}
|
|
1480
1836
|
|
|
1481
1837
|
void read_to(void * dst, size_t size) override {
|
|
1482
1838
|
file->read_raw(dst, size);
|
|
@@ -1489,89 +1845,78 @@ struct llama_data_read_file : llama_data_read {
|
|
|
1489
1845
|
return temp_buffer.data();
|
|
1490
1846
|
}
|
|
1491
1847
|
|
|
1492
|
-
size_t
|
|
1848
|
+
size_t n_bytes() override {
|
|
1493
1849
|
return size_read;
|
|
1494
1850
|
}
|
|
1495
|
-
};
|
|
1496
1851
|
|
|
1497
|
-
|
|
1498
|
-
*
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
* llama_state_get_data_internal(ctx, data_ctx);
|
|
1503
|
-
*
|
|
1504
|
-
* buffer context:
|
|
1505
|
-
* std::vector<uint8_t> buf(max_size, 0);
|
|
1506
|
-
* llama_data_write_buffer data_ctx(buf.data(), max_size);
|
|
1507
|
-
* llama_state_get_data_internal(ctx, data_ctx);
|
|
1508
|
-
*
|
|
1509
|
-
*/
|
|
1510
|
-
static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
|
|
1511
|
-
llama_synchronize(ctx);
|
|
1512
|
-
|
|
1513
|
-
data_ctx.write_model_info(ctx);
|
|
1514
|
-
|
|
1515
|
-
// copy outputs
|
|
1516
|
-
data_ctx.write_output_ids(ctx);
|
|
1517
|
-
data_ctx.write_logits(ctx);
|
|
1518
|
-
data_ctx.write_embeddings(ctx);
|
|
1519
|
-
|
|
1520
|
-
data_ctx.write_kv_cache(ctx);
|
|
1852
|
+
private:
|
|
1853
|
+
llama_file * file;
|
|
1854
|
+
size_t size_read = 0;
|
|
1855
|
+
std::vector<uint8_t> temp_buffer;
|
|
1856
|
+
};
|
|
1521
1857
|
|
|
1522
|
-
|
|
1858
|
+
size_t llama_context::state_get_size() {
|
|
1859
|
+
llama_io_write_dummy io;
|
|
1860
|
+
try {
|
|
1861
|
+
return state_write_data(io);
|
|
1862
|
+
} catch (const std::exception & err) {
|
|
1863
|
+
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
|
1864
|
+
return 0;
|
|
1865
|
+
}
|
|
1523
1866
|
}
|
|
1524
1867
|
|
|
1525
|
-
size_t
|
|
1526
|
-
|
|
1868
|
+
size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
|
|
1869
|
+
llama_io_write_buffer io(dst, size);
|
|
1527
1870
|
try {
|
|
1528
|
-
return
|
|
1871
|
+
return state_write_data(io);
|
|
1529
1872
|
} catch (const std::exception & err) {
|
|
1530
1873
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
|
1531
1874
|
return 0;
|
|
1532
1875
|
}
|
|
1533
1876
|
}
|
|
1534
1877
|
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
size_t llama_state_get_size(struct llama_context * ctx) {
|
|
1538
|
-
llama_data_write_dummy data_ctx;
|
|
1878
|
+
size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
|
1879
|
+
llama_io_read_buffer io(src, size);
|
|
1539
1880
|
try {
|
|
1540
|
-
return
|
|
1881
|
+
return state_read_data(io);
|
|
1541
1882
|
} catch (const std::exception & err) {
|
|
1542
|
-
LLAMA_LOG_ERROR("%s: error
|
|
1883
|
+
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
|
1543
1884
|
return 0;
|
|
1544
1885
|
}
|
|
1545
1886
|
}
|
|
1546
1887
|
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
data_ctx.read_kv_cache(ctx);
|
|
1888
|
+
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
|
1889
|
+
llama_io_write_dummy io;
|
|
1890
|
+
try {
|
|
1891
|
+
return state_seq_write_data(io, seq_id);
|
|
1892
|
+
} catch (const std::exception & err) {
|
|
1893
|
+
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
|
1894
|
+
return 0;
|
|
1895
|
+
}
|
|
1896
|
+
}
|
|
1558
1897
|
|
|
1559
|
-
|
|
1898
|
+
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
|
1899
|
+
llama_io_write_buffer io(dst, size);
|
|
1900
|
+
try {
|
|
1901
|
+
return state_seq_write_data(io, seq_id);
|
|
1902
|
+
} catch (const std::exception & err) {
|
|
1903
|
+
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
|
1904
|
+
return 0;
|
|
1905
|
+
}
|
|
1560
1906
|
}
|
|
1561
1907
|
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
llama_data_read_buffer data_ctx(src, size);
|
|
1908
|
+
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
|
1909
|
+
llama_io_read_buffer io(src, size);
|
|
1565
1910
|
try {
|
|
1566
|
-
return
|
|
1911
|
+
return state_seq_read_data(io, seq_id);
|
|
1567
1912
|
} catch (const std::exception & err) {
|
|
1568
1913
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
|
1569
1914
|
return 0;
|
|
1570
1915
|
}
|
|
1571
1916
|
}
|
|
1572
1917
|
|
|
1573
|
-
|
|
1574
|
-
llama_file file(
|
|
1918
|
+
bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
1919
|
+
llama_file file(filepath, "rb");
|
|
1575
1920
|
|
|
1576
1921
|
// sanity checks
|
|
1577
1922
|
{
|
|
@@ -1601,28 +1946,20 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
|
|
|
1601
1946
|
{
|
|
1602
1947
|
const size_t n_state_size_cur = file.size() - file.tell();
|
|
1603
1948
|
|
|
1604
|
-
|
|
1605
|
-
const size_t n_read =
|
|
1949
|
+
llama_io_read_file io( &file);
|
|
1950
|
+
const size_t n_read = state_read_data(io);
|
|
1606
1951
|
|
|
1607
1952
|
if (n_read != n_state_size_cur) {
|
|
1608
1953
|
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
|
|
1609
1954
|
return false;
|
|
1610
1955
|
}
|
|
1611
1956
|
}
|
|
1612
|
-
return true;
|
|
1613
|
-
}
|
|
1614
1957
|
|
|
1615
|
-
|
|
1616
|
-
try {
|
|
1617
|
-
return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
|
1618
|
-
} catch (const std::exception & err) {
|
|
1619
|
-
LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
|
|
1620
|
-
return false;
|
|
1621
|
-
}
|
|
1958
|
+
return true;
|
|
1622
1959
|
}
|
|
1623
1960
|
|
|
1624
|
-
|
|
1625
|
-
llama_file file(
|
|
1961
|
+
bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
1962
|
+
llama_file file(filepath, "wb");
|
|
1626
1963
|
|
|
1627
1964
|
file.write_u32(LLAMA_SESSION_MAGIC);
|
|
1628
1965
|
file.write_u32(LLAMA_SESSION_VERSION);
|
|
@@ -1632,63 +1969,56 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
|
|
|
1632
1969
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
|
1633
1970
|
|
|
1634
1971
|
// save the context state using stream saving
|
|
1635
|
-
|
|
1636
|
-
|
|
1972
|
+
llama_io_write_file io(&file);
|
|
1973
|
+
state_write_data(io);
|
|
1637
1974
|
|
|
1638
1975
|
return true;
|
|
1639
1976
|
}
|
|
1640
1977
|
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
|
|
1644
|
-
} catch (const std::exception & err) {
|
|
1645
|
-
LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
|
|
1646
|
-
return false;
|
|
1647
|
-
}
|
|
1648
|
-
}
|
|
1978
|
+
size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
1979
|
+
llama_file file(filepath, "rb");
|
|
1649
1980
|
|
|
1650
|
-
|
|
1651
|
-
|
|
1981
|
+
// version checks
|
|
1982
|
+
{
|
|
1983
|
+
const uint32_t magic = file.read_u32();
|
|
1984
|
+
const uint32_t version = file.read_u32();
|
|
1652
1985
|
|
|
1653
|
-
|
|
1986
|
+
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
|
1987
|
+
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
|
1988
|
+
return 0;
|
|
1989
|
+
}
|
|
1990
|
+
}
|
|
1654
1991
|
|
|
1655
|
-
|
|
1656
|
-
|
|
1992
|
+
// load the prompt
|
|
1993
|
+
{
|
|
1994
|
+
const uint32_t n_token_count = file.read_u32();
|
|
1657
1995
|
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
}
|
|
1996
|
+
if (n_token_count > n_token_capacity) {
|
|
1997
|
+
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
|
1998
|
+
return 0;
|
|
1999
|
+
}
|
|
1662
2000
|
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
try {
|
|
1666
|
-
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
|
1667
|
-
} catch (const std::exception & err) {
|
|
1668
|
-
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
|
|
1669
|
-
return 0;
|
|
2001
|
+
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
|
2002
|
+
*n_token_count_out = n_token_count;
|
|
1670
2003
|
}
|
|
1671
|
-
}
|
|
1672
2004
|
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
2005
|
+
// restore the context state
|
|
2006
|
+
{
|
|
2007
|
+
const size_t state_size = file.size() - file.tell();
|
|
2008
|
+
llama_io_read_file io(&file);
|
|
2009
|
+
const size_t nread = state_seq_read_data(io, seq_id);
|
|
2010
|
+
if (!nread) {
|
|
2011
|
+
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
|
2012
|
+
return 0;
|
|
2013
|
+
}
|
|
2014
|
+
GGML_ASSERT(nread <= state_size);
|
|
2015
|
+
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
|
2016
|
+
}
|
|
1677
2017
|
|
|
1678
|
-
return
|
|
2018
|
+
return file.tell();
|
|
1679
2019
|
}
|
|
1680
2020
|
|
|
1681
|
-
size_t
|
|
1682
|
-
llama_data_read_buffer data_ctx(src, size);
|
|
1683
|
-
try {
|
|
1684
|
-
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
|
1685
|
-
} catch (const std::exception & err) {
|
|
1686
|
-
LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
|
|
1687
|
-
return 0;
|
|
1688
|
-
}
|
|
1689
|
-
}
|
|
1690
|
-
|
|
1691
|
-
static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
|
2021
|
+
size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
1692
2022
|
llama_file file(filepath, "wb");
|
|
1693
2023
|
|
|
1694
2024
|
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
|
@@ -1699,77 +2029,778 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
|
|
|
1699
2029
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
|
1700
2030
|
|
|
1701
2031
|
// save the context state using stream saving
|
|
1702
|
-
|
|
1703
|
-
|
|
2032
|
+
llama_io_write_file io(&file);
|
|
2033
|
+
state_seq_write_data(io, seq_id);
|
|
1704
2034
|
|
|
1705
2035
|
const size_t res = file.tell();
|
|
1706
|
-
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count +
|
|
2036
|
+
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
|
2037
|
+
|
|
1707
2038
|
return res;
|
|
1708
2039
|
}
|
|
1709
2040
|
|
|
1710
|
-
|
|
1711
|
-
|
|
2041
|
+
size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
2042
|
+
LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
|
|
1712
2043
|
|
|
1713
|
-
//
|
|
2044
|
+
// write model info
|
|
1714
2045
|
{
|
|
1715
|
-
|
|
1716
|
-
const uint32_t version = file.read_u32();
|
|
2046
|
+
LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
|
|
1717
2047
|
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
2048
|
+
const std::string arch_str = llm_arch_name(model.arch);
|
|
2049
|
+
io.write_string(arch_str);
|
|
2050
|
+
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
|
2051
|
+
}
|
|
2052
|
+
|
|
2053
|
+
// write output ids
|
|
2054
|
+
{
|
|
2055
|
+
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
|
2056
|
+
|
|
2057
|
+
output_reorder();
|
|
2058
|
+
|
|
2059
|
+
const auto n_outputs = this->n_outputs;
|
|
2060
|
+
const auto & output_ids = this->output_ids;
|
|
2061
|
+
|
|
2062
|
+
std::vector<int32_t> w_output_pos;
|
|
2063
|
+
|
|
2064
|
+
GGML_ASSERT(n_outputs <= n_outputs_max);
|
|
2065
|
+
|
|
2066
|
+
w_output_pos.resize(n_outputs);
|
|
2067
|
+
|
|
2068
|
+
// build a more compact representation of the output ids
|
|
2069
|
+
for (size_t i = 0; i < n_batch(); ++i) {
|
|
2070
|
+
// map an output id to a position in the batch
|
|
2071
|
+
int32_t pos = output_ids[i];
|
|
2072
|
+
if (pos >= 0) {
|
|
2073
|
+
GGML_ASSERT(pos < n_outputs);
|
|
2074
|
+
w_output_pos[pos] = i;
|
|
2075
|
+
}
|
|
2076
|
+
}
|
|
2077
|
+
|
|
2078
|
+
io.write(&n_outputs, sizeof(n_outputs));
|
|
2079
|
+
|
|
2080
|
+
if (n_outputs) {
|
|
2081
|
+
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
|
1721
2082
|
}
|
|
1722
2083
|
}
|
|
1723
2084
|
|
|
1724
|
-
//
|
|
2085
|
+
// write logits
|
|
1725
2086
|
{
|
|
1726
|
-
|
|
2087
|
+
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
|
1727
2088
|
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
2089
|
+
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
|
2090
|
+
|
|
2091
|
+
io.write(&logits_size, sizeof(logits_size));
|
|
2092
|
+
|
|
2093
|
+
if (logits_size) {
|
|
2094
|
+
io.write(logits, logits_size * sizeof(float));
|
|
1731
2095
|
}
|
|
2096
|
+
}
|
|
1732
2097
|
|
|
1733
|
-
|
|
1734
|
-
|
|
2098
|
+
// write embeddings
|
|
2099
|
+
{
|
|
2100
|
+
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
|
2101
|
+
|
|
2102
|
+
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
|
2103
|
+
|
|
2104
|
+
io.write(&embd_size, sizeof(embd_size));
|
|
2105
|
+
|
|
2106
|
+
if (embd_size) {
|
|
2107
|
+
io.write(embd, embd_size * sizeof(float));
|
|
2108
|
+
}
|
|
1735
2109
|
}
|
|
1736
2110
|
|
|
1737
|
-
|
|
2111
|
+
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
|
2112
|
+
kv_self->state_write(io);
|
|
2113
|
+
|
|
2114
|
+
return io.n_bytes();
|
|
2115
|
+
}
|
|
2116
|
+
|
|
2117
|
+
size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
2118
|
+
LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
|
|
2119
|
+
|
|
2120
|
+
// read model info
|
|
1738
2121
|
{
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
const
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
2122
|
+
LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
|
|
2123
|
+
|
|
2124
|
+
const std::string cur_arch_str = llm_arch_name(model.arch);
|
|
2125
|
+
|
|
2126
|
+
std::string arch_str;
|
|
2127
|
+
io.read_string(arch_str);
|
|
2128
|
+
if (cur_arch_str != arch_str) {
|
|
2129
|
+
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
|
1745
2130
|
}
|
|
1746
|
-
|
|
1747
|
-
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
|
2131
|
+
// TODO: add more info which needs to be identical but which is not verified otherwise
|
|
1748
2132
|
}
|
|
1749
2133
|
|
|
1750
|
-
|
|
2134
|
+
// read output ids
|
|
2135
|
+
{
|
|
2136
|
+
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
|
|
2137
|
+
|
|
2138
|
+
auto n_outputs = this->n_outputs;
|
|
2139
|
+
io.read_to(&n_outputs, sizeof(n_outputs));
|
|
2140
|
+
|
|
2141
|
+
if (n_outputs > output_reserve(n_outputs)) {
|
|
2142
|
+
throw std::runtime_error("could not reserve outputs");
|
|
2143
|
+
}
|
|
2144
|
+
|
|
2145
|
+
std::vector<int32_t> output_pos;
|
|
2146
|
+
|
|
2147
|
+
if (n_outputs) {
|
|
2148
|
+
output_pos.resize(n_outputs);
|
|
2149
|
+
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
|
2150
|
+
|
|
2151
|
+
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
|
2152
|
+
int32_t id = output_pos[i];
|
|
2153
|
+
if ((uint32_t) id >= n_batch()) {
|
|
2154
|
+
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
|
2155
|
+
}
|
|
2156
|
+
this->output_ids[id] = i;
|
|
2157
|
+
}
|
|
2158
|
+
|
|
2159
|
+
this->n_outputs = n_outputs;
|
|
2160
|
+
}
|
|
2161
|
+
}
|
|
2162
|
+
|
|
2163
|
+
// read logits
|
|
2164
|
+
{
|
|
2165
|
+
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
|
|
2166
|
+
|
|
2167
|
+
uint64_t logits_size;
|
|
2168
|
+
io.read_to(&logits_size, sizeof(logits_size));
|
|
2169
|
+
|
|
2170
|
+
if (this->logits_size < logits_size) {
|
|
2171
|
+
throw std::runtime_error("logits buffer too small");
|
|
2172
|
+
}
|
|
2173
|
+
|
|
2174
|
+
if (logits_size) {
|
|
2175
|
+
io.read_to(this->logits, logits_size * sizeof(float));
|
|
2176
|
+
}
|
|
2177
|
+
}
|
|
2178
|
+
|
|
2179
|
+
// read embeddings
|
|
2180
|
+
{
|
|
2181
|
+
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
|
|
2182
|
+
|
|
2183
|
+
uint64_t embd_size;
|
|
2184
|
+
io.read_to(&embd_size, sizeof(embd_size));
|
|
2185
|
+
|
|
2186
|
+
if (this->embd_size < embd_size) {
|
|
2187
|
+
throw std::runtime_error("embeddings buffer too small");
|
|
2188
|
+
}
|
|
2189
|
+
|
|
2190
|
+
if (embd_size) {
|
|
2191
|
+
io.read_to(this->embd, embd_size * sizeof(float));
|
|
2192
|
+
}
|
|
2193
|
+
}
|
|
2194
|
+
|
|
2195
|
+
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
|
2196
|
+
kv_self->state_read(io);
|
|
2197
|
+
|
|
2198
|
+
return io.n_bytes();
|
|
1751
2199
|
}
|
|
1752
2200
|
|
|
1753
|
-
size_t
|
|
2201
|
+
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
|
2202
|
+
GGML_UNUSED(seq_id);
|
|
2203
|
+
|
|
2204
|
+
kv_self->state_write(io, seq_id);
|
|
2205
|
+
|
|
2206
|
+
return io.n_bytes();
|
|
2207
|
+
}
|
|
2208
|
+
|
|
2209
|
+
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
2210
|
+
GGML_UNUSED(seq_id);
|
|
2211
|
+
|
|
2212
|
+
kv_self->state_read(io, seq_id);
|
|
2213
|
+
|
|
2214
|
+
return io.n_bytes();
|
|
2215
|
+
}
|
|
2216
|
+
|
|
2217
|
+
//
|
|
2218
|
+
// perf
|
|
2219
|
+
//
|
|
2220
|
+
|
|
2221
|
+
llama_perf_context_data llama_context::perf_get_data() const {
|
|
2222
|
+
llama_perf_context_data data = {};
|
|
2223
|
+
|
|
2224
|
+
data.t_start_ms = 1e-3 * t_start_us;
|
|
2225
|
+
data.t_load_ms = 1e-3 * t_load_us;
|
|
2226
|
+
data.t_p_eval_ms = 1e-3 * t_p_eval_us;
|
|
2227
|
+
data.t_eval_ms = 1e-3 * t_eval_us;
|
|
2228
|
+
data.n_p_eval = std::max(1, n_p_eval);
|
|
2229
|
+
data.n_eval = std::max(1, n_eval);
|
|
2230
|
+
|
|
2231
|
+
return data;
|
|
2232
|
+
}
|
|
2233
|
+
|
|
2234
|
+
void llama_context::perf_reset() {
|
|
2235
|
+
t_start_us = ggml_time_us();
|
|
2236
|
+
t_eval_us = n_eval = 0;
|
|
2237
|
+
t_p_eval_us = n_p_eval = 0;
|
|
2238
|
+
}
|
|
2239
|
+
|
|
2240
|
+
//
|
|
2241
|
+
// interface implementation
|
|
2242
|
+
//
|
|
2243
|
+
|
|
2244
|
+
llama_context_params llama_context_default_params() {
|
|
2245
|
+
llama_context_params result = {
|
|
2246
|
+
/*.n_ctx =*/ 512,
|
|
2247
|
+
/*.n_batch =*/ 2048,
|
|
2248
|
+
/*.n_ubatch =*/ 512,
|
|
2249
|
+
/*.n_seq_max =*/ 1,
|
|
2250
|
+
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
|
2251
|
+
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
|
2252
|
+
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
|
2253
|
+
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
|
2254
|
+
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
|
2255
|
+
/*.rope_freq_base =*/ 0.0f,
|
|
2256
|
+
/*.rope_freq_scale =*/ 0.0f,
|
|
2257
|
+
/*.yarn_ext_factor =*/ -1.0f,
|
|
2258
|
+
/*.yarn_attn_factor =*/ 1.0f,
|
|
2259
|
+
/*.yarn_beta_fast =*/ 32.0f,
|
|
2260
|
+
/*.yarn_beta_slow =*/ 1.0f,
|
|
2261
|
+
/*.yarn_orig_ctx =*/ 0,
|
|
2262
|
+
/*.defrag_thold =*/ -1.0f,
|
|
2263
|
+
/*.cb_eval =*/ nullptr,
|
|
2264
|
+
/*.cb_eval_user_data =*/ nullptr,
|
|
2265
|
+
/*.type_k =*/ GGML_TYPE_F16,
|
|
2266
|
+
/*.type_v =*/ GGML_TYPE_F16,
|
|
2267
|
+
/*.logits_all =*/ false,
|
|
2268
|
+
/*.embeddings =*/ false,
|
|
2269
|
+
/*.offload_kqv =*/ true,
|
|
2270
|
+
/*.flash_attn =*/ false,
|
|
2271
|
+
/*.no_perf =*/ true,
|
|
2272
|
+
/*.abort_callback =*/ nullptr,
|
|
2273
|
+
/*.abort_callback_data =*/ nullptr,
|
|
2274
|
+
};
|
|
2275
|
+
|
|
2276
|
+
return result;
|
|
2277
|
+
}
|
|
2278
|
+
|
|
2279
|
+
llama_context * llama_init_from_model(
|
|
2280
|
+
llama_model * model,
|
|
2281
|
+
llama_context_params params) {
|
|
2282
|
+
if (!model) {
|
|
2283
|
+
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
|
|
2284
|
+
return nullptr;
|
|
2285
|
+
}
|
|
2286
|
+
|
|
2287
|
+
if (params.n_batch == 0 && params.n_ubatch == 0) {
|
|
2288
|
+
LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
|
|
2289
|
+
return nullptr;
|
|
2290
|
+
}
|
|
2291
|
+
|
|
2292
|
+
if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
|
|
2293
|
+
LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
|
|
2294
|
+
return nullptr;
|
|
2295
|
+
}
|
|
2296
|
+
|
|
2297
|
+
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
|
|
2298
|
+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
|
2299
|
+
params.flash_attn = false;
|
|
2300
|
+
}
|
|
2301
|
+
|
|
2302
|
+
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
|
2303
|
+
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
|
2304
|
+
params.flash_attn = false;
|
|
2305
|
+
}
|
|
2306
|
+
|
|
2307
|
+
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
|
2308
|
+
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
|
2309
|
+
return nullptr;
|
|
2310
|
+
}
|
|
2311
|
+
|
|
1754
2312
|
try {
|
|
1755
|
-
|
|
2313
|
+
auto * ctx = new llama_context(*model, params);
|
|
2314
|
+
return ctx;
|
|
2315
|
+
} catch (const std::exception & err) {
|
|
2316
|
+
LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
|
|
2317
|
+
}
|
|
2318
|
+
|
|
2319
|
+
return nullptr;
|
|
2320
|
+
}
|
|
2321
|
+
|
|
2322
|
+
// deprecated
|
|
2323
|
+
llama_context * llama_new_context_with_model(
|
|
2324
|
+
llama_model * model,
|
|
2325
|
+
llama_context_params params) {
|
|
2326
|
+
return llama_init_from_model(model, params);
|
|
2327
|
+
}
|
|
2328
|
+
|
|
2329
|
+
void llama_free(llama_context * ctx) {
|
|
2330
|
+
delete ctx;
|
|
2331
|
+
}
|
|
2332
|
+
|
|
2333
|
+
uint32_t llama_n_ctx(const llama_context * ctx) {
|
|
2334
|
+
return ctx->n_ctx();
|
|
2335
|
+
}
|
|
2336
|
+
|
|
2337
|
+
uint32_t llama_n_batch(const llama_context * ctx) {
|
|
2338
|
+
return ctx->n_batch();
|
|
2339
|
+
}
|
|
2340
|
+
|
|
2341
|
+
uint32_t llama_n_ubatch(const llama_context * ctx) {
|
|
2342
|
+
return ctx->n_ubatch();
|
|
2343
|
+
}
|
|
2344
|
+
|
|
2345
|
+
uint32_t llama_n_seq_max(const llama_context * ctx) {
|
|
2346
|
+
return ctx->n_seq_max();
|
|
2347
|
+
}
|
|
2348
|
+
|
|
2349
|
+
const llama_model * llama_get_model(const llama_context * ctx) {
|
|
2350
|
+
return &ctx->get_model();
|
|
2351
|
+
}
|
|
2352
|
+
|
|
2353
|
+
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|
2354
|
+
return ctx->get_kv_self();
|
|
2355
|
+
}
|
|
2356
|
+
|
|
2357
|
+
void llama_kv_self_update(llama_context * ctx) {
|
|
2358
|
+
ctx->kv_self_update();
|
|
2359
|
+
}
|
|
2360
|
+
|
|
2361
|
+
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
|
2362
|
+
return ctx->pooling_type();
|
|
2363
|
+
}
|
|
2364
|
+
|
|
2365
|
+
void llama_attach_threadpool(
|
|
2366
|
+
llama_context * ctx,
|
|
2367
|
+
ggml_threadpool_t threadpool,
|
|
2368
|
+
ggml_threadpool_t threadpool_batch) {
|
|
2369
|
+
ctx->attach_threadpool(threadpool, threadpool_batch);
|
|
2370
|
+
}
|
|
2371
|
+
|
|
2372
|
+
void llama_detach_threadpool(llama_context * ctx) {
|
|
2373
|
+
ctx->detach_threadpool();
|
|
2374
|
+
}
|
|
2375
|
+
|
|
2376
|
+
void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
|
|
2377
|
+
ctx->set_n_threads(n_threads, n_threads_batch);
|
|
2378
|
+
}
|
|
2379
|
+
|
|
2380
|
+
int32_t llama_n_threads(llama_context * ctx) {
|
|
2381
|
+
return ctx->n_threads();
|
|
2382
|
+
}
|
|
2383
|
+
|
|
2384
|
+
int32_t llama_n_threads_batch(llama_context * ctx) {
|
|
2385
|
+
return ctx->n_threads_batch();
|
|
2386
|
+
}
|
|
2387
|
+
|
|
2388
|
+
void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
|
|
2389
|
+
ctx->set_abort_callback(abort_callback, abort_callback_data);
|
|
2390
|
+
}
|
|
2391
|
+
|
|
2392
|
+
void llama_set_embeddings(llama_context * ctx, bool embeddings) {
|
|
2393
|
+
ctx->set_embeddings(embeddings);
|
|
2394
|
+
}
|
|
2395
|
+
|
|
2396
|
+
void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
|
|
2397
|
+
ctx->set_causal_attn(causal_attn);
|
|
2398
|
+
}
|
|
2399
|
+
|
|
2400
|
+
void llama_set_warmup(llama_context * ctx, bool warmup) {
|
|
2401
|
+
ctx->set_warmup(warmup);
|
|
2402
|
+
}
|
|
2403
|
+
|
|
2404
|
+
void llama_synchronize(llama_context * ctx) {
|
|
2405
|
+
ctx->synchronize();
|
|
2406
|
+
}
|
|
2407
|
+
|
|
2408
|
+
float * llama_get_logits(llama_context * ctx) {
|
|
2409
|
+
ctx->synchronize();
|
|
2410
|
+
|
|
2411
|
+
return ctx->get_logits();
|
|
2412
|
+
}
|
|
2413
|
+
|
|
2414
|
+
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|
2415
|
+
ctx->synchronize();
|
|
2416
|
+
|
|
2417
|
+
return ctx->get_logits_ith(i);
|
|
2418
|
+
}
|
|
2419
|
+
|
|
2420
|
+
float * llama_get_embeddings(llama_context * ctx) {
|
|
2421
|
+
ctx->synchronize();
|
|
2422
|
+
|
|
2423
|
+
return ctx->get_embeddings();
|
|
2424
|
+
}
|
|
2425
|
+
|
|
2426
|
+
float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
|
|
2427
|
+
ctx->synchronize();
|
|
2428
|
+
|
|
2429
|
+
return ctx->get_embeddings_ith(i);
|
|
2430
|
+
}
|
|
2431
|
+
|
|
2432
|
+
float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|
2433
|
+
ctx->synchronize();
|
|
2434
|
+
|
|
2435
|
+
return ctx->get_embeddings_seq(seq_id);
|
|
2436
|
+
}
|
|
2437
|
+
|
|
2438
|
+
// llama adapter API
|
|
2439
|
+
|
|
2440
|
+
int32_t llama_set_adapter_lora(
|
|
2441
|
+
llama_context * ctx,
|
|
2442
|
+
llama_adapter_lora * adapter,
|
|
2443
|
+
float scale) {
|
|
2444
|
+
ctx->set_adapter_lora(adapter, scale);
|
|
2445
|
+
|
|
2446
|
+
return 0;
|
|
2447
|
+
}
|
|
2448
|
+
|
|
2449
|
+
int32_t llama_rm_adapter_lora(
|
|
2450
|
+
llama_context * ctx,
|
|
2451
|
+
llama_adapter_lora * adapter) {
|
|
2452
|
+
bool res = ctx->rm_adapter_lora(adapter);
|
|
2453
|
+
|
|
2454
|
+
return res ? 0 : -1;
|
|
2455
|
+
}
|
|
2456
|
+
|
|
2457
|
+
void llama_clear_adapter_lora(llama_context * ctx) {
|
|
2458
|
+
ctx->clear_adapter_lora();
|
|
2459
|
+
}
|
|
2460
|
+
|
|
2461
|
+
int32_t llama_apply_adapter_cvec(
|
|
2462
|
+
llama_context * ctx,
|
|
2463
|
+
const float * data,
|
|
2464
|
+
size_t len,
|
|
2465
|
+
int32_t n_embd,
|
|
2466
|
+
int32_t il_start,
|
|
2467
|
+
int32_t il_end) {
|
|
2468
|
+
bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
|
|
2469
|
+
|
|
2470
|
+
return res ? 0 : -1;
|
|
2471
|
+
}
|
|
2472
|
+
|
|
2473
|
+
//
|
|
2474
|
+
// kv cache view
|
|
2475
|
+
//
|
|
2476
|
+
|
|
2477
|
+
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
|
2478
|
+
const auto * kv = ctx->get_kv_self();
|
|
2479
|
+
if (kv == nullptr) {
|
|
2480
|
+
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
|
2481
|
+
return {};
|
|
2482
|
+
}
|
|
2483
|
+
|
|
2484
|
+
return llama_kv_cache_view_init(*kv, n_seq_max);
|
|
2485
|
+
}
|
|
2486
|
+
|
|
2487
|
+
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
|
2488
|
+
const auto * kv = ctx->get_kv_self();
|
|
2489
|
+
if (kv == nullptr) {
|
|
2490
|
+
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
|
2491
|
+
return;
|
|
2492
|
+
}
|
|
2493
|
+
|
|
2494
|
+
llama_kv_cache_view_update(view, kv);
|
|
2495
|
+
}
|
|
2496
|
+
|
|
2497
|
+
//
|
|
2498
|
+
// kv cache
|
|
2499
|
+
//
|
|
2500
|
+
|
|
2501
|
+
// deprecated
|
|
2502
|
+
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
|
2503
|
+
return llama_kv_self_n_tokens(ctx);
|
|
2504
|
+
}
|
|
2505
|
+
|
|
2506
|
+
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
2507
|
+
return llama_kv_cache_n_tokens(ctx->get_kv_self());
|
|
2508
|
+
}
|
|
2509
|
+
|
|
2510
|
+
// deprecated
|
|
2511
|
+
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
|
2512
|
+
return llama_kv_self_used_cells(ctx);
|
|
2513
|
+
}
|
|
2514
|
+
|
|
2515
|
+
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
2516
|
+
return llama_kv_cache_used_cells(ctx->get_kv_self());
|
|
2517
|
+
}
|
|
2518
|
+
|
|
2519
|
+
// deprecated
|
|
2520
|
+
void llama_kv_cache_clear(llama_context * ctx) {
|
|
2521
|
+
llama_kv_self_clear(ctx);
|
|
2522
|
+
}
|
|
2523
|
+
|
|
2524
|
+
void llama_kv_self_clear(llama_context * ctx) {
|
|
2525
|
+
llama_kv_cache_clear(ctx->get_kv_self());
|
|
2526
|
+
}
|
|
2527
|
+
|
|
2528
|
+
// deprecated
|
|
2529
|
+
bool llama_kv_cache_seq_rm(
|
|
2530
|
+
llama_context * ctx,
|
|
2531
|
+
llama_seq_id seq_id,
|
|
2532
|
+
llama_pos p0,
|
|
2533
|
+
llama_pos p1) {
|
|
2534
|
+
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
|
2535
|
+
}
|
|
2536
|
+
|
|
2537
|
+
bool llama_kv_self_seq_rm(
|
|
2538
|
+
llama_context * ctx,
|
|
2539
|
+
llama_seq_id seq_id,
|
|
2540
|
+
llama_pos p0,
|
|
2541
|
+
llama_pos p1) {
|
|
2542
|
+
return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
|
|
2543
|
+
}
|
|
2544
|
+
|
|
2545
|
+
// deprecated
|
|
2546
|
+
void llama_kv_cache_seq_cp(
|
|
2547
|
+
llama_context * ctx,
|
|
2548
|
+
llama_seq_id seq_id_src,
|
|
2549
|
+
llama_seq_id seq_id_dst,
|
|
2550
|
+
llama_pos p0,
|
|
2551
|
+
llama_pos p1) {
|
|
2552
|
+
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
|
2553
|
+
}
|
|
2554
|
+
|
|
2555
|
+
void llama_kv_self_seq_cp(
|
|
2556
|
+
llama_context * ctx,
|
|
2557
|
+
llama_seq_id seq_id_src,
|
|
2558
|
+
llama_seq_id seq_id_dst,
|
|
2559
|
+
llama_pos p0,
|
|
2560
|
+
llama_pos p1) {
|
|
2561
|
+
return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
|
|
2562
|
+
}
|
|
2563
|
+
|
|
2564
|
+
// deprecated
|
|
2565
|
+
void llama_kv_cache_seq_keep(
|
|
2566
|
+
llama_context * ctx,
|
|
2567
|
+
llama_seq_id seq_id) {
|
|
2568
|
+
return llama_kv_self_seq_keep(ctx, seq_id);
|
|
2569
|
+
}
|
|
2570
|
+
|
|
2571
|
+
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
2572
|
+
return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
|
|
2573
|
+
}
|
|
2574
|
+
|
|
2575
|
+
// deprecated
|
|
2576
|
+
void llama_kv_cache_seq_add(
|
|
2577
|
+
llama_context * ctx,
|
|
2578
|
+
llama_seq_id seq_id,
|
|
2579
|
+
llama_pos p0,
|
|
2580
|
+
llama_pos p1,
|
|
2581
|
+
llama_pos delta) {
|
|
2582
|
+
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
|
2583
|
+
}
|
|
2584
|
+
|
|
2585
|
+
void llama_kv_self_seq_add(
|
|
2586
|
+
llama_context * ctx,
|
|
2587
|
+
llama_seq_id seq_id,
|
|
2588
|
+
llama_pos p0,
|
|
2589
|
+
llama_pos p1,
|
|
2590
|
+
llama_pos delta) {
|
|
2591
|
+
return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
|
|
2592
|
+
}
|
|
2593
|
+
|
|
2594
|
+
// deprecated
|
|
2595
|
+
void llama_kv_cache_seq_div(
|
|
2596
|
+
llama_context * ctx,
|
|
2597
|
+
llama_seq_id seq_id,
|
|
2598
|
+
llama_pos p0,
|
|
2599
|
+
llama_pos p1,
|
|
2600
|
+
int d) {
|
|
2601
|
+
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
|
2602
|
+
}
|
|
2603
|
+
|
|
2604
|
+
void llama_kv_self_seq_div(
|
|
2605
|
+
llama_context * ctx,
|
|
2606
|
+
llama_seq_id seq_id,
|
|
2607
|
+
llama_pos p0,
|
|
2608
|
+
llama_pos p1,
|
|
2609
|
+
int d) {
|
|
2610
|
+
return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
|
|
2611
|
+
}
|
|
2612
|
+
|
|
2613
|
+
// deprecated
|
|
2614
|
+
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
2615
|
+
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
|
2616
|
+
}
|
|
2617
|
+
|
|
2618
|
+
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
2619
|
+
return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
|
|
2620
|
+
}
|
|
2621
|
+
|
|
2622
|
+
// deprecated
|
|
2623
|
+
void llama_kv_cache_defrag(llama_context * ctx) {
|
|
2624
|
+
return llama_kv_self_defrag(ctx);
|
|
2625
|
+
}
|
|
2626
|
+
|
|
2627
|
+
void llama_kv_self_defrag(llama_context * ctx) {
|
|
2628
|
+
llama_kv_cache_defrag(ctx->get_kv_self());
|
|
2629
|
+
}
|
|
2630
|
+
|
|
2631
|
+
// deprecated
|
|
2632
|
+
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
|
2633
|
+
return llama_kv_self_can_shift(ctx);
|
|
2634
|
+
}
|
|
2635
|
+
|
|
2636
|
+
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
2637
|
+
return llama_kv_cache_can_shift(ctx->get_kv_self());
|
|
2638
|
+
}
|
|
2639
|
+
|
|
2640
|
+
// deprecated
|
|
2641
|
+
void llama_kv_cache_update(llama_context * ctx) {
|
|
2642
|
+
llama_kv_self_update(ctx);
|
|
2643
|
+
}
|
|
2644
|
+
|
|
2645
|
+
// llama state API
|
|
2646
|
+
|
|
2647
|
+
// deprecated
|
|
2648
|
+
size_t llama_get_state_size(llama_context * ctx) {
|
|
2649
|
+
return llama_state_get_size(ctx);
|
|
2650
|
+
}
|
|
2651
|
+
|
|
2652
|
+
// deprecated
|
|
2653
|
+
size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
|
|
2654
|
+
return llama_state_get_data(ctx, dst, -1);
|
|
2655
|
+
}
|
|
2656
|
+
|
|
2657
|
+
// deprecated
|
|
2658
|
+
size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
|
|
2659
|
+
return llama_state_set_data(ctx, src, -1);
|
|
2660
|
+
}
|
|
2661
|
+
|
|
2662
|
+
// deprecated
|
|
2663
|
+
bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
2664
|
+
return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
|
2665
|
+
}
|
|
2666
|
+
|
|
2667
|
+
// deprecated
|
|
2668
|
+
bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
|
2669
|
+
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
|
2670
|
+
}
|
|
2671
|
+
|
|
2672
|
+
// Returns the *actual* size of the state.
|
|
2673
|
+
// Intended to be used when saving to state to a buffer.
|
|
2674
|
+
size_t llama_state_get_size(llama_context * ctx) {
|
|
2675
|
+
return ctx->state_get_size();
|
|
2676
|
+
}
|
|
2677
|
+
|
|
2678
|
+
size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
|
|
2679
|
+
ctx->synchronize();
|
|
2680
|
+
|
|
2681
|
+
return ctx->state_get_data(dst, size);
|
|
2682
|
+
}
|
|
2683
|
+
|
|
2684
|
+
// Sets the state reading from the specified source address
|
|
2685
|
+
size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
|
|
2686
|
+
ctx->synchronize();
|
|
2687
|
+
|
|
2688
|
+
return ctx->state_set_data(src, size);
|
|
2689
|
+
}
|
|
2690
|
+
|
|
2691
|
+
bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
2692
|
+
ctx->synchronize();
|
|
2693
|
+
|
|
2694
|
+
try {
|
|
2695
|
+
return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
|
|
2696
|
+
} catch (const std::exception & err) {
|
|
2697
|
+
LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
|
|
2698
|
+
return false;
|
|
2699
|
+
}
|
|
2700
|
+
}
|
|
2701
|
+
|
|
2702
|
+
bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
|
2703
|
+
ctx->synchronize();
|
|
2704
|
+
|
|
2705
|
+
try {
|
|
2706
|
+
return ctx->state_save_file(path_session, tokens, n_token_count);
|
|
2707
|
+
} catch (const std::exception & err) {
|
|
2708
|
+
LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
|
|
2709
|
+
return false;
|
|
2710
|
+
}
|
|
2711
|
+
}
|
|
2712
|
+
|
|
2713
|
+
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
|
2714
|
+
return ctx->state_seq_get_size(seq_id);
|
|
2715
|
+
}
|
|
2716
|
+
|
|
2717
|
+
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
|
2718
|
+
ctx->synchronize();
|
|
2719
|
+
|
|
2720
|
+
return ctx->state_seq_get_data(seq_id, dst, size);
|
|
2721
|
+
}
|
|
2722
|
+
|
|
2723
|
+
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
|
2724
|
+
ctx->synchronize();
|
|
2725
|
+
|
|
2726
|
+
return ctx->state_seq_set_data(seq_id, src, size);
|
|
2727
|
+
}
|
|
2728
|
+
|
|
2729
|
+
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
|
2730
|
+
ctx->synchronize();
|
|
2731
|
+
|
|
2732
|
+
try {
|
|
2733
|
+
return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
|
|
1756
2734
|
} catch (const std::exception & err) {
|
|
1757
2735
|
LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
|
|
1758
2736
|
return 0;
|
|
1759
2737
|
}
|
|
1760
2738
|
}
|
|
1761
2739
|
|
|
1762
|
-
size_t llama_state_seq_load_file(
|
|
2740
|
+
size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
2741
|
+
ctx->synchronize();
|
|
2742
|
+
|
|
1763
2743
|
try {
|
|
1764
|
-
return
|
|
2744
|
+
return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
|
|
1765
2745
|
} catch (const std::exception & err) {
|
|
1766
2746
|
LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
|
|
1767
2747
|
return 0;
|
|
1768
2748
|
}
|
|
1769
2749
|
}
|
|
1770
2750
|
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
|
|
2751
|
+
///
|
|
2752
|
+
|
|
2753
|
+
int32_t llama_encode(
|
|
2754
|
+
llama_context * ctx,
|
|
2755
|
+
llama_batch batch) {
|
|
2756
|
+
const int ret = ctx->encode(batch);
|
|
2757
|
+
if (ret != 0) {
|
|
2758
|
+
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
|
2759
|
+
}
|
|
2760
|
+
|
|
2761
|
+
return ret;
|
|
2762
|
+
}
|
|
2763
|
+
|
|
2764
|
+
int32_t llama_decode(
|
|
2765
|
+
llama_context * ctx,
|
|
2766
|
+
llama_batch batch) {
|
|
2767
|
+
const int ret = ctx->decode(batch);
|
|
2768
|
+
if (ret != 0) {
|
|
2769
|
+
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
2770
|
+
}
|
|
2771
|
+
|
|
2772
|
+
return ret;
|
|
2773
|
+
}
|
|
2774
|
+
|
|
2775
|
+
//
|
|
2776
|
+
// perf
|
|
2777
|
+
//
|
|
2778
|
+
|
|
2779
|
+
llama_perf_context_data llama_perf_context(const llama_context * ctx) {
|
|
2780
|
+
llama_perf_context_data data = {};
|
|
2781
|
+
|
|
2782
|
+
if (ctx == nullptr) {
|
|
2783
|
+
return data;
|
|
2784
|
+
}
|
|
2785
|
+
|
|
2786
|
+
data = ctx->perf_get_data();
|
|
2787
|
+
|
|
2788
|
+
return data;
|
|
2789
|
+
}
|
|
2790
|
+
|
|
2791
|
+
void llama_perf_context_print(const llama_context * ctx) {
|
|
2792
|
+
const auto data = llama_perf_context(ctx);
|
|
2793
|
+
|
|
2794
|
+
const double t_end_ms = 1e-3 * ggml_time_us();
|
|
2795
|
+
|
|
2796
|
+
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
|
2797
|
+
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
2798
|
+
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
|
2799
|
+
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
2800
|
+
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
|
2801
|
+
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
|
2802
|
+
}
|
|
2803
|
+
|
|
2804
|
+
void llama_perf_context_reset(llama_context * ctx) {
|
|
2805
|
+
ctx->perf_reset();
|
|
1775
2806
|
}
|