@fugood/llama.node 0.3.17 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
#include "llama-adapter.h"
|
|
8
8
|
|
|
9
9
|
#include "ggml-cpp.h"
|
|
10
|
+
#include "ggml-opt.h"
|
|
10
11
|
|
|
11
12
|
#include <map>
|
|
12
13
|
#include <vector>
|
|
@@ -27,7 +28,12 @@ struct llama_context {
|
|
|
27
28
|
|
|
28
29
|
void synchronize();
|
|
29
30
|
|
|
30
|
-
const llama_model
|
|
31
|
+
const llama_model & get_model() const;
|
|
32
|
+
const llama_cparams & get_cparams() const;
|
|
33
|
+
|
|
34
|
+
ggml_backend_sched_t get_sched() const;
|
|
35
|
+
|
|
36
|
+
ggml_context * get_ctx_compute() const;
|
|
31
37
|
|
|
32
38
|
uint32_t n_ctx() const;
|
|
33
39
|
uint32_t n_ctx_per_seq() const;
|
|
@@ -128,6 +134,32 @@ struct llama_context {
|
|
|
128
134
|
llama_perf_context_data perf_get_data() const;
|
|
129
135
|
void perf_reset();
|
|
130
136
|
|
|
137
|
+
//
|
|
138
|
+
// training
|
|
139
|
+
//
|
|
140
|
+
|
|
141
|
+
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
|
142
|
+
|
|
143
|
+
void opt_epoch(
|
|
144
|
+
ggml_opt_dataset_t dataset,
|
|
145
|
+
ggml_opt_result_t result_train,
|
|
146
|
+
ggml_opt_result_t result_eval,
|
|
147
|
+
int64_t idata_split,
|
|
148
|
+
ggml_opt_epoch_callback callback_train,
|
|
149
|
+
ggml_opt_epoch_callback callback_eval);
|
|
150
|
+
|
|
151
|
+
void opt_epoch_iter(
|
|
152
|
+
ggml_opt_dataset_t dataset,
|
|
153
|
+
ggml_opt_result_t result,
|
|
154
|
+
const std::vector<llama_token> & tokens,
|
|
155
|
+
const std::vector<llama_token> & labels_sparse,
|
|
156
|
+
llama_batch & batch,
|
|
157
|
+
ggml_opt_epoch_callback callback,
|
|
158
|
+
bool train,
|
|
159
|
+
int64_t idata_in_loop,
|
|
160
|
+
int64_t ndata_in_loop,
|
|
161
|
+
int64_t t_loop_start);
|
|
162
|
+
|
|
131
163
|
private:
|
|
132
164
|
//
|
|
133
165
|
// output
|
|
@@ -137,49 +169,30 @@ private:
|
|
|
137
169
|
// Returns max number of outputs for which space was reserved.
|
|
138
170
|
int32_t output_reserve(int32_t n_outputs);
|
|
139
171
|
|
|
140
|
-
// make the outputs have the same order they had in the user-provided batch
|
|
141
|
-
// TODO: maybe remove this
|
|
142
|
-
void output_reorder();
|
|
143
|
-
|
|
144
172
|
//
|
|
145
173
|
// graph
|
|
146
174
|
//
|
|
147
175
|
|
|
176
|
+
public:
|
|
148
177
|
int32_t graph_max_nodes() const;
|
|
149
178
|
|
|
150
179
|
// zero-out inputs and create the ctx_compute for the compute graph
|
|
151
180
|
ggml_cgraph * graph_init();
|
|
152
181
|
|
|
182
|
+
// returns the result of ggml_backend_sched_graph_compute_async execution
|
|
183
|
+
ggml_status graph_compute(
|
|
184
|
+
ggml_cgraph * gf,
|
|
185
|
+
bool batched);
|
|
186
|
+
|
|
187
|
+
private:
|
|
153
188
|
llm_graph_result_ptr graph_build(
|
|
154
189
|
ggml_context * ctx,
|
|
155
190
|
ggml_cgraph * gf,
|
|
156
191
|
const llama_ubatch & ubatch,
|
|
157
192
|
llm_graph_type gtype);
|
|
158
193
|
|
|
159
|
-
// returns the result of ggml_backend_sched_graph_compute_async execution
|
|
160
|
-
ggml_status graph_compute(
|
|
161
|
-
ggml_cgraph * gf,
|
|
162
|
-
bool batched);
|
|
163
|
-
|
|
164
194
|
llm_graph_cb graph_get_cb() const;
|
|
165
195
|
|
|
166
|
-
// used by kv_self_update()
|
|
167
|
-
ggml_tensor * build_rope_shift(
|
|
168
|
-
ggml_context * ctx0,
|
|
169
|
-
ggml_tensor * cur,
|
|
170
|
-
ggml_tensor * shift,
|
|
171
|
-
ggml_tensor * factors,
|
|
172
|
-
float freq_base,
|
|
173
|
-
float freq_scale) const;
|
|
174
|
-
|
|
175
|
-
llm_graph_result_ptr build_kv_self_shift(
|
|
176
|
-
ggml_context * ctx0,
|
|
177
|
-
ggml_cgraph * gf) const;
|
|
178
|
-
|
|
179
|
-
llm_graph_result_ptr build_kv_self_defrag(
|
|
180
|
-
ggml_context * ctx0,
|
|
181
|
-
ggml_cgraph * gf) const;
|
|
182
|
-
|
|
183
196
|
// TODO: read/write lora adapters and cvec
|
|
184
197
|
size_t state_write_data(llama_io_write_i & io);
|
|
185
198
|
size_t state_read_data (llama_io_read_i & io);
|
|
@@ -196,14 +209,10 @@ private:
|
|
|
196
209
|
llama_cparams cparams;
|
|
197
210
|
llama_adapter_cvec cvec;
|
|
198
211
|
llama_adapter_loras loras;
|
|
199
|
-
llama_sbatch sbatch;
|
|
200
212
|
|
|
201
213
|
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
|
202
214
|
|
|
203
|
-
std::unique_ptr<
|
|
204
|
-
|
|
205
|
-
// TODO: remove
|
|
206
|
-
bool logits_all = false;
|
|
215
|
+
std::unique_ptr<llama_memory_i> memory;
|
|
207
216
|
|
|
208
217
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
|
209
218
|
size_t logits_size = 0; // capacity (of floats) for logits
|
|
@@ -230,6 +239,9 @@ private:
|
|
|
230
239
|
|
|
231
240
|
ggml_context_ptr ctx_compute;
|
|
232
241
|
|
|
242
|
+
// training
|
|
243
|
+
ggml_opt_context_t opt_ctx = nullptr;
|
|
244
|
+
|
|
233
245
|
ggml_threadpool_t threadpool = nullptr;
|
|
234
246
|
ggml_threadpool_t threadpool_batch = nullptr;
|
|
235
247
|
|
|
@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
|
284
284
|
|
|
285
285
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
286
286
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
//////////////////////////////////////////////
|
|
290
|
-
// TODO: this should not mutate the KV cache !
|
|
291
|
-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
|
292
|
-
|
|
293
|
-
// prevent out-of-bound sources
|
|
294
|
-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
|
295
|
-
kv_cell.src = cell_id;
|
|
296
|
-
}
|
|
297
|
-
|
|
298
|
-
data[i] = kv_cell.src;
|
|
299
|
-
|
|
300
|
-
// TODO: do not mutate the KV cache
|
|
301
|
-
// ensure copy only happens once
|
|
302
|
-
if (kv_cell.src != (int32_t) cell_id) {
|
|
303
|
-
kv_cell.src = cell_id;
|
|
304
|
-
}
|
|
287
|
+
data[i] = kv_self->s_copy(i);
|
|
305
288
|
}
|
|
306
289
|
}
|
|
307
290
|
}
|
|
@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
|
317
300
|
|
|
318
301
|
// clear unused states
|
|
319
302
|
for (int i = 0; i < n_kv; ++i) {
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
//////////////////////////////////////////////
|
|
323
|
-
// TODO: this should not mutate the KV cache !
|
|
324
|
-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
|
325
|
-
|
|
326
|
-
data[i] = (float) (kv_cell.src >= 0);
|
|
327
|
-
|
|
328
|
-
// only clear once
|
|
329
|
-
if (kv_cell.src < 0) {
|
|
330
|
-
kv_cell.src = cell_id;
|
|
331
|
-
}
|
|
303
|
+
data[i] = kv_self->s_mask(i);
|
|
332
304
|
}
|
|
333
305
|
}
|
|
334
306
|
}
|
|
@@ -810,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
810
782
|
} break;
|
|
811
783
|
}
|
|
812
784
|
|
|
813
|
-
if (type_gate == LLM_FFN_PAR) {
|
|
785
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
814
786
|
cur = ggml_mul(ctx0, cur, tmp);
|
|
815
787
|
cb(cur, "ffn_gate_par", il);
|
|
816
788
|
}
|
|
@@ -999,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
999
971
|
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
|
1000
972
|
//cb(inp->tokens, "inp_tokens", -1);
|
|
1001
973
|
ggml_set_input(inp->tokens);
|
|
974
|
+
res->t_tokens = inp->tokens;
|
|
1002
975
|
|
|
1003
976
|
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
|
1004
977
|
|
|
@@ -1105,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
|
1105
1078
|
}
|
|
1106
1079
|
|
|
1107
1080
|
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
1108
|
-
const
|
|
1081
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1109
1082
|
|
|
1110
1083
|
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
|
1111
1084
|
|
|
@@ -1122,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
|
1122
1095
|
}
|
|
1123
1096
|
|
|
1124
1097
|
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
|
1125
|
-
const
|
|
1098
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1126
1099
|
|
|
1127
1100
|
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
|
1128
1101
|
|
|
@@ -1255,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1255
1228
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
1256
1229
|
|
|
1257
1230
|
if (v_mla) {
|
|
1231
|
+
#if 0
|
|
1232
|
+
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
|
1233
|
+
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
|
|
1258
1234
|
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
|
1259
1235
|
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
|
1236
|
+
#else
|
|
1237
|
+
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
|
|
1238
|
+
// The permutations are noops and only change how the tensor data is interpreted.
|
|
1239
|
+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
1240
|
+
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
|
1241
|
+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
1242
|
+
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
|
1243
|
+
#endif
|
|
1260
1244
|
}
|
|
1261
1245
|
|
|
1262
1246
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
|
@@ -1436,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1436
1420
|
|
|
1437
1421
|
// store to KV cache
|
|
1438
1422
|
{
|
|
1439
|
-
GGML_ASSERT(!kv_self->recurrent);
|
|
1440
|
-
|
|
1441
1423
|
const auto kv_head = kv_self->head;
|
|
1442
1424
|
|
|
1443
1425
|
GGML_ASSERT(kv_self->size == n_ctx);
|
|
@@ -1587,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
|
|
1587
1569
|
ggml_tensor * state_mask,
|
|
1588
1570
|
int32_t n_state,
|
|
1589
1571
|
int32_t n_seqs) const {
|
|
1590
|
-
const
|
|
1572
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1591
1573
|
|
|
1592
1574
|
const auto n_kv = kv_self->n;
|
|
1593
1575
|
const auto kv_head = kv_self->head;
|
|
@@ -1619,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
1619
1601
|
ggml_tensor * state_mask,
|
|
1620
1602
|
const llama_ubatch & ubatch,
|
|
1621
1603
|
int il) const {
|
|
1622
|
-
const
|
|
1604
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1623
1605
|
|
|
1624
1606
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1625
1607
|
|
|
@@ -1640,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1640
1622
|
ggml_tensor * token_shift,
|
|
1641
1623
|
const llama_ubatch & ubatch,
|
|
1642
1624
|
int il) const {
|
|
1643
|
-
const
|
|
1625
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1644
1626
|
|
|
1645
1627
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1646
1628
|
const auto n_embd = hparams.n_embd;
|
|
@@ -19,6 +19,7 @@ struct llama_cparams;
|
|
|
19
19
|
|
|
20
20
|
class llama_memory_i;
|
|
21
21
|
class llama_kv_cache_unified;
|
|
22
|
+
class llama_kv_cache_recurrent;
|
|
22
23
|
|
|
23
24
|
// certain models (typically multi-modal) can produce different types of graphs
|
|
24
25
|
enum llm_graph_type {
|
|
@@ -186,26 +187,26 @@ public:
|
|
|
186
187
|
|
|
187
188
|
class llm_graph_input_s_copy : public llm_graph_input_i {
|
|
188
189
|
public:
|
|
189
|
-
llm_graph_input_s_copy(const
|
|
190
|
+
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
|
190
191
|
virtual ~llm_graph_input_s_copy() = default;
|
|
191
192
|
|
|
192
193
|
void set_input(const llama_ubatch * ubatch) override;
|
|
193
194
|
|
|
194
195
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
195
196
|
|
|
196
|
-
const
|
|
197
|
+
const llama_kv_cache_recurrent * kv_self;
|
|
197
198
|
};
|
|
198
199
|
|
|
199
200
|
class llm_graph_input_s_mask : public llm_graph_input_i {
|
|
200
201
|
public:
|
|
201
|
-
llm_graph_input_s_mask(const
|
|
202
|
+
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
|
202
203
|
virtual ~llm_graph_input_s_mask() = default;
|
|
203
204
|
|
|
204
205
|
void set_input(const llama_ubatch * ubatch) override;
|
|
205
206
|
|
|
206
207
|
ggml_tensor * s_mask; // F32 [1, n_kv]
|
|
207
208
|
|
|
208
|
-
const
|
|
209
|
+
const llama_kv_cache_recurrent * kv_self;
|
|
209
210
|
};
|
|
210
211
|
|
|
211
212
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -297,6 +298,7 @@ class llm_graph_result_i {
|
|
|
297
298
|
public:
|
|
298
299
|
virtual ~llm_graph_result_i() = default;
|
|
299
300
|
|
|
301
|
+
virtual ggml_tensor * get_tokens() = 0;
|
|
300
302
|
virtual ggml_tensor * get_logits() = 0;
|
|
301
303
|
virtual ggml_tensor * get_embd() = 0;
|
|
302
304
|
virtual ggml_tensor * get_embd_pooled() = 0;
|
|
@@ -311,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
|
|
|
311
313
|
public:
|
|
312
314
|
virtual ~llm_graph_result() = default;
|
|
313
315
|
|
|
316
|
+
ggml_tensor * get_tokens() override { return t_tokens; }
|
|
314
317
|
ggml_tensor * get_logits() override { return t_logits; }
|
|
315
318
|
ggml_tensor * get_embd() override { return t_embd; }
|
|
316
319
|
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
|
@@ -327,6 +330,7 @@ public:
|
|
|
327
330
|
}
|
|
328
331
|
|
|
329
332
|
// important graph nodes
|
|
333
|
+
ggml_tensor * t_tokens = nullptr;
|
|
330
334
|
ggml_tensor * t_logits = nullptr;
|
|
331
335
|
ggml_tensor * t_embd = nullptr;
|
|
332
336
|
ggml_tensor * t_embd_pooled = nullptr;
|
|
@@ -350,8 +354,8 @@ struct llm_graph_params {
|
|
|
350
354
|
const llama_cparams & cparams;
|
|
351
355
|
const llama_ubatch & ubatch;
|
|
352
356
|
|
|
353
|
-
|
|
354
|
-
|
|
357
|
+
ggml_backend_sched_t sched;
|
|
358
|
+
ggml_backend_t backend_cpu;
|
|
355
359
|
|
|
356
360
|
const llama_adapter_cvec * cvec;
|
|
357
361
|
const llama_adapter_loras * loras;
|
|
@@ -402,9 +406,9 @@ struct llm_graph_context {
|
|
|
402
406
|
|
|
403
407
|
ggml_context * ctx0 = nullptr;
|
|
404
408
|
|
|
405
|
-
|
|
409
|
+
ggml_backend_sched_t sched;
|
|
406
410
|
|
|
407
|
-
|
|
411
|
+
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
408
412
|
|
|
409
413
|
const llama_adapter_cvec * cvec;
|
|
410
414
|
const llama_adapter_loras * loras;
|