@novastera-oss/llamarn 0.2.5 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +3 -2
- package/android/CMakeLists.txt +6 -3
- package/android/src/main/cpp/include/llama.h +12 -8
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +46 -65
- package/cpp/LlamaCppModel.h +5 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/README.md +1 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +5 -8
- package/cpp/llama.cpp/common/arg.cpp +8 -6
- package/cpp/llama.cpp/common/chat-parser.cpp +4 -3
- package/cpp/llama.cpp/common/chat-parser.h +2 -1
- package/cpp/llama.cpp/common/chat.cpp +4 -4
- package/cpp/llama.cpp/common/common.cpp +2 -0
- package/cpp/llama.cpp/common/json-partial.cpp +5 -4
- package/cpp/llama.cpp/common/json-partial.h +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +31 -28
- package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +23 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +19 -8
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml.c +9 -2
- package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
- package/cpp/llama.cpp/include/llama.h +12 -8
- package/cpp/llama.cpp/src/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +19 -12
- package/cpp/llama.cpp/src/llama-batch.h +15 -10
- package/cpp/llama.cpp/src/llama-context.cpp +226 -151
- package/cpp/llama.cpp/src/llama-context.h +25 -8
- package/cpp/llama.cpp/src/llama-graph.cpp +50 -47
- package/cpp/llama.cpp/src/llama-graph.h +25 -24
- package/cpp/llama.cpp/src/llama-kv-cache-recurrent.cpp +1132 -0
- package/cpp/llama.cpp/src/llama-kv-cache-recurrent.h +191 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +249 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +136 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1717 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +278 -0
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2746
- package/cpp/llama.cpp/src/llama-kv-cache.h +14 -472
- package/cpp/llama.cpp/src/llama-kv-cells.h +37 -6
- package/cpp/llama.cpp/src/llama-memory.h +44 -0
- package/cpp/llama.cpp/src/llama-model.cpp +23 -16
- package/cpp/llama.cpp/src/llama-vocab.cpp +7 -2
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
- package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
- package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
- package/cpp/rn-completion.cpp +63 -8
- package/cpp/rn-utils.hpp +8 -1
- package/ios/include/common/minja/chat-template.hpp +1 -1
- package/ios/include/common/minja/minja.hpp +1 -1
- package/ios/include/json-schema-to-grammar.h +4 -4
- package/ios/include/llama.h +12 -8
- package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
- package/ios/libs/llama.xcframework/Info.plist +22 -22
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4617
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3557
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3624 -3559
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4616
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4637
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3556
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4725 -4653
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4746 -4674
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3652 -3587
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
|
@@ -18,6 +18,9 @@ struct llama_kv_cache;
|
|
|
18
18
|
class llama_io_read_i;
|
|
19
19
|
class llama_io_write_i;
|
|
20
20
|
|
|
21
|
+
class llama_memory_i;
|
|
22
|
+
class llama_memory_state_i;
|
|
23
|
+
|
|
21
24
|
struct llama_context {
|
|
22
25
|
// init scheduler and compute buffers, reserve worst-case graphs
|
|
23
26
|
llama_context(
|
|
@@ -47,7 +50,9 @@ struct llama_context {
|
|
|
47
50
|
llama_kv_cache * get_kv_self();
|
|
48
51
|
const llama_kv_cache * get_kv_self() const;
|
|
49
52
|
|
|
50
|
-
|
|
53
|
+
// return true of the KV cache was updated
|
|
54
|
+
// TODO: remove
|
|
55
|
+
bool kv_self_update();
|
|
51
56
|
|
|
52
57
|
enum llama_pooling_type pooling_type() const;
|
|
53
58
|
|
|
@@ -88,6 +93,16 @@ struct llama_context {
|
|
|
88
93
|
int32_t il_start,
|
|
89
94
|
int32_t il_end);
|
|
90
95
|
|
|
96
|
+
// process a single ubatch with a specific graph type
|
|
97
|
+
// if memory_state is provided, it will be applied first to the context's memory
|
|
98
|
+
// ret contains the status of the graph computation
|
|
99
|
+
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
|
100
|
+
llm_graph_result_ptr process_ubatch(
|
|
101
|
+
const llama_ubatch & ubatch,
|
|
102
|
+
llm_graph_type gtype,
|
|
103
|
+
llama_memory_state_i * mstate,
|
|
104
|
+
ggml_status & ret);
|
|
105
|
+
|
|
91
106
|
int encode(llama_batch & inp_batch);
|
|
92
107
|
int decode(llama_batch & inp_batch);
|
|
93
108
|
|
|
@@ -180,16 +195,18 @@ public:
|
|
|
180
195
|
ggml_cgraph * graph_init();
|
|
181
196
|
|
|
182
197
|
// returns the result of ggml_backend_sched_graph_compute_async execution
|
|
183
|
-
ggml_status graph_compute(
|
|
184
|
-
|
|
185
|
-
|
|
198
|
+
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
|
199
|
+
|
|
200
|
+
// reserve a graph with a dummy ubatch of the specified size
|
|
201
|
+
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
|
|
186
202
|
|
|
187
203
|
private:
|
|
188
204
|
llm_graph_result_ptr graph_build(
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
205
|
+
ggml_context * ctx,
|
|
206
|
+
ggml_cgraph * gf,
|
|
207
|
+
const llama_ubatch & ubatch,
|
|
208
|
+
llm_graph_type gtype,
|
|
209
|
+
const llama_memory_state_i * mstate);
|
|
193
210
|
|
|
194
211
|
llm_graph_cb graph_get_cb() const;
|
|
195
212
|
|
|
@@ -3,7 +3,10 @@
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
4
|
#include "llama-batch.h"
|
|
5
5
|
#include "llama-cparams.h"
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
#include "llama-kv-cache-unified.h"
|
|
8
|
+
#include "llama-kv-cache-unified-iswa.h"
|
|
9
|
+
#include "llama-kv-cache-recurrent.h"
|
|
7
10
|
|
|
8
11
|
#include <cassert>
|
|
9
12
|
#include <cmath>
|
|
@@ -83,7 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
83
86
|
|
|
84
87
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
|
85
88
|
if (pos_bucket) {
|
|
86
|
-
|
|
89
|
+
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
|
87
90
|
}
|
|
88
91
|
}
|
|
89
92
|
|
|
@@ -234,7 +237,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
234
237
|
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
235
238
|
GGML_UNUSED(ubatch);
|
|
236
239
|
|
|
237
|
-
const int64_t n_kv =
|
|
240
|
+
const int64_t n_kv = kv_state->get_n_kv();
|
|
238
241
|
|
|
239
242
|
if (s_copy) {
|
|
240
243
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
@@ -242,7 +245,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
|
242
245
|
|
|
243
246
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
244
247
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
|
245
|
-
data[i] =
|
|
248
|
+
data[i] = kv_state->s_copy(i);
|
|
246
249
|
}
|
|
247
250
|
}
|
|
248
251
|
}
|
|
@@ -250,7 +253,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
|
250
253
|
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
251
254
|
GGML_UNUSED(ubatch);
|
|
252
255
|
|
|
253
|
-
const int64_t n_kv =
|
|
256
|
+
const int64_t n_kv = kv_state->get_n_kv();
|
|
254
257
|
|
|
255
258
|
if (s_mask) {
|
|
256
259
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
|
@@ -258,7 +261,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
|
258
261
|
|
|
259
262
|
// clear unused states
|
|
260
263
|
for (int i = 0; i < n_kv; ++i) {
|
|
261
|
-
data[i] =
|
|
264
|
+
data[i] = kv_state->s_mask(i);
|
|
262
265
|
}
|
|
263
266
|
}
|
|
264
267
|
}
|
|
@@ -362,17 +365,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
362
365
|
|
|
363
366
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
364
367
|
if (self_kq_mask) {
|
|
365
|
-
|
|
368
|
+
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
366
369
|
}
|
|
367
370
|
}
|
|
368
371
|
|
|
369
372
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
|
370
373
|
if (self_kq_mask) {
|
|
371
|
-
|
|
374
|
+
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
372
375
|
}
|
|
373
376
|
|
|
374
377
|
if (self_kq_mask_swa) {
|
|
375
|
-
|
|
378
|
+
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
376
379
|
}
|
|
377
380
|
}
|
|
378
381
|
|
|
@@ -448,7 +451,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
448
451
|
backend_cpu (params.backend_cpu),
|
|
449
452
|
cvec (params.cvec),
|
|
450
453
|
loras (params.loras),
|
|
451
|
-
|
|
454
|
+
mstate (params.mstate),
|
|
452
455
|
cross (params.cross),
|
|
453
456
|
cb_func (params.cb),
|
|
454
457
|
res (std::make_unique<llm_graph_result>()) {
|
|
@@ -954,11 +957,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
|
954
957
|
}
|
|
955
958
|
|
|
956
959
|
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
957
|
-
const
|
|
960
|
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
958
961
|
|
|
959
|
-
auto inp = std::make_unique<llm_graph_input_s_copy>(
|
|
962
|
+
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
|
960
963
|
|
|
961
|
-
const auto n_kv =
|
|
964
|
+
const auto n_kv = kv_state->get_n_kv();
|
|
962
965
|
|
|
963
966
|
auto & cur = inp->s_copy;
|
|
964
967
|
|
|
@@ -971,11 +974,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
|
971
974
|
}
|
|
972
975
|
|
|
973
976
|
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
|
974
|
-
const
|
|
977
|
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
975
978
|
|
|
976
|
-
auto inp = std::make_unique<llm_graph_input_s_mask>(
|
|
979
|
+
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
|
977
980
|
|
|
978
|
-
const auto n_kv =
|
|
981
|
+
const auto n_kv = kv_state->get_n_kv();
|
|
979
982
|
|
|
980
983
|
auto & cur = inp->s_mask;
|
|
981
984
|
|
|
@@ -1025,11 +1028,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
|
1025
1028
|
}
|
|
1026
1029
|
|
|
1027
1030
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
1028
|
-
const
|
|
1031
|
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
|
1029
1032
|
|
|
1030
|
-
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams,
|
|
1033
|
+
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
|
1031
1034
|
|
|
1032
|
-
const auto n_kv =
|
|
1035
|
+
const auto n_kv = kv_state->get_n_kv();
|
|
1033
1036
|
|
|
1034
1037
|
auto & cur = inp->pos_bucket;
|
|
1035
1038
|
|
|
@@ -1231,14 +1234,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1231
1234
|
}
|
|
1232
1235
|
|
|
1233
1236
|
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
|
1234
|
-
const
|
|
1237
|
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
|
1235
1238
|
|
|
1236
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams,
|
|
1239
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
|
1237
1240
|
|
|
1238
1241
|
{
|
|
1239
1242
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
|
1240
1243
|
|
|
1241
|
-
const auto n_kv =
|
|
1244
|
+
const auto n_kv = kv_state->get_n_kv();
|
|
1242
1245
|
|
|
1243
1246
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1244
1247
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -1268,19 +1271,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1268
1271
|
ggml_build_forward_expand(gf, k_cur);
|
|
1269
1272
|
ggml_build_forward_expand(gf, v_cur);
|
|
1270
1273
|
|
|
1271
|
-
const
|
|
1274
|
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
|
1272
1275
|
|
|
1273
1276
|
// store to KV cache
|
|
1274
1277
|
{
|
|
1275
|
-
ggml_build_forward_expand(gf,
|
|
1276
|
-
ggml_build_forward_expand(gf,
|
|
1278
|
+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
|
1279
|
+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
|
1277
1280
|
}
|
|
1278
1281
|
|
|
1279
1282
|
const auto & kq_mask = inp->get_kq_mask();
|
|
1280
1283
|
|
|
1281
1284
|
ggml_tensor * q = q_cur;
|
|
1282
|
-
ggml_tensor * k =
|
|
1283
|
-
ggml_tensor * v =
|
|
1285
|
+
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
|
1286
|
+
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
|
1284
1287
|
|
|
1285
1288
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1286
1289
|
cb(cur, "kqv_out", il);
|
|
@@ -1301,12 +1304,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1301
1304
|
}
|
|
1302
1305
|
|
|
1303
1306
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
1304
|
-
const
|
|
1307
|
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
|
1305
1308
|
|
|
1306
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams,
|
|
1309
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
|
1307
1310
|
|
|
1308
1311
|
{
|
|
1309
|
-
const auto n_kv =
|
|
1312
|
+
const auto n_kv = kv_state->get_base()->get_n_kv();
|
|
1310
1313
|
|
|
1311
1314
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1312
1315
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -1318,7 +1321,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1318
1321
|
{
|
|
1319
1322
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
1320
1323
|
|
|
1321
|
-
const auto n_kv =
|
|
1324
|
+
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
|
1322
1325
|
|
|
1323
1326
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1324
1327
|
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
@@ -1348,23 +1351,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1348
1351
|
ggml_build_forward_expand(gf, k_cur);
|
|
1349
1352
|
ggml_build_forward_expand(gf, v_cur);
|
|
1350
1353
|
|
|
1351
|
-
const
|
|
1354
|
+
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
|
1352
1355
|
|
|
1353
|
-
const
|
|
1356
|
+
const bool is_swa = hparams.is_swa(il);
|
|
1354
1357
|
|
|
1355
|
-
const auto *
|
|
1358
|
+
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
|
1356
1359
|
|
|
1357
1360
|
// store to KV cache
|
|
1358
1361
|
{
|
|
1359
|
-
ggml_build_forward_expand(gf,
|
|
1360
|
-
ggml_build_forward_expand(gf,
|
|
1362
|
+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
|
1363
|
+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
|
1361
1364
|
}
|
|
1362
1365
|
|
|
1363
1366
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
1364
1367
|
|
|
1365
1368
|
ggml_tensor * q = q_cur;
|
|
1366
|
-
ggml_tensor * k =
|
|
1367
|
-
ggml_tensor * v =
|
|
1369
|
+
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
|
1370
|
+
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
|
1368
1371
|
|
|
1369
1372
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1370
1373
|
cb(cur, "kqv_out", il);
|
|
@@ -1446,12 +1449,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
|
|
1446
1449
|
ggml_tensor * state_mask,
|
|
1447
1450
|
int32_t n_state,
|
|
1448
1451
|
int32_t n_seqs) const {
|
|
1449
|
-
const
|
|
1452
|
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
1450
1453
|
|
|
1451
|
-
const auto n_kv =
|
|
1452
|
-
const auto kv_head =
|
|
1454
|
+
const auto n_kv = kv_state->get_n_kv();
|
|
1455
|
+
const auto kv_head = kv_state->get_head();
|
|
1453
1456
|
|
|
1454
|
-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state,
|
|
1457
|
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
|
|
1455
1458
|
|
|
1456
1459
|
// copy states
|
|
1457
1460
|
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
@@ -1478,13 +1481,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
1478
1481
|
ggml_tensor * state_mask,
|
|
1479
1482
|
const llama_ubatch & ubatch,
|
|
1480
1483
|
int il) const {
|
|
1481
|
-
const
|
|
1484
|
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
1482
1485
|
|
|
1483
1486
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1484
1487
|
|
|
1485
1488
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
1486
1489
|
|
|
1487
|
-
ggml_tensor * token_shift_all =
|
|
1490
|
+
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
|
1488
1491
|
|
|
1489
1492
|
ggml_tensor * token_shift = build_copy_mask_state(
|
|
1490
1493
|
gf, token_shift_all, state_copy, state_mask,
|
|
@@ -1499,19 +1502,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1499
1502
|
ggml_tensor * token_shift,
|
|
1500
1503
|
const llama_ubatch & ubatch,
|
|
1501
1504
|
int il) const {
|
|
1502
|
-
const
|
|
1505
|
+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
1503
1506
|
|
|
1504
1507
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1505
1508
|
const auto n_embd = hparams.n_embd;
|
|
1506
1509
|
|
|
1507
1510
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
1508
1511
|
|
|
1509
|
-
const auto kv_head =
|
|
1512
|
+
const auto kv_head = kv_state->get_head();
|
|
1510
1513
|
|
|
1511
1514
|
return ggml_cpy(
|
|
1512
1515
|
ctx0,
|
|
1513
1516
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
|
1514
|
-
ggml_view_1d(ctx0,
|
|
1517
|
+
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
|
1515
1518
|
);
|
|
1516
1519
|
}
|
|
1517
1520
|
|
|
@@ -17,10 +17,11 @@ struct ggml_tensor;
|
|
|
17
17
|
struct llama_ubatch;
|
|
18
18
|
struct llama_cparams;
|
|
19
19
|
|
|
20
|
-
class
|
|
21
|
-
|
|
22
|
-
class
|
|
23
|
-
class
|
|
20
|
+
class llama_memory_state_i;
|
|
21
|
+
|
|
22
|
+
class llama_kv_cache_unified_state;
|
|
23
|
+
class llama_kv_cache_unified_iswa_state;
|
|
24
|
+
class llama_kv_cache_recurrent_state;
|
|
24
25
|
|
|
25
26
|
// certain models (typically multi-modal) can produce different types of graphs
|
|
26
27
|
enum llm_graph_type {
|
|
@@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|
|
133
134
|
public:
|
|
134
135
|
llm_graph_input_pos_bucket_kv(
|
|
135
136
|
const llama_hparams & hparams,
|
|
136
|
-
const
|
|
137
|
+
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
|
137
138
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
|
138
139
|
|
|
139
140
|
void set_input(const llama_ubatch * ubatch) override;
|
|
@@ -141,7 +142,7 @@ public:
|
|
|
141
142
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
|
142
143
|
|
|
143
144
|
const llama_hparams & hparams;
|
|
144
|
-
const
|
|
145
|
+
const llama_kv_cache_unified_state * kv_state;
|
|
145
146
|
};
|
|
146
147
|
|
|
147
148
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
@@ -188,26 +189,26 @@ public:
|
|
|
188
189
|
|
|
189
190
|
class llm_graph_input_s_copy : public llm_graph_input_i {
|
|
190
191
|
public:
|
|
191
|
-
llm_graph_input_s_copy(const
|
|
192
|
+
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
|
192
193
|
virtual ~llm_graph_input_s_copy() = default;
|
|
193
194
|
|
|
194
195
|
void set_input(const llama_ubatch * ubatch) override;
|
|
195
196
|
|
|
196
197
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
197
198
|
|
|
198
|
-
const
|
|
199
|
+
const llama_kv_cache_recurrent_state * kv_state;
|
|
199
200
|
};
|
|
200
201
|
|
|
201
202
|
class llm_graph_input_s_mask : public llm_graph_input_i {
|
|
202
203
|
public:
|
|
203
|
-
llm_graph_input_s_mask(const
|
|
204
|
+
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
|
204
205
|
virtual ~llm_graph_input_s_mask() = default;
|
|
205
206
|
|
|
206
207
|
void set_input(const llama_ubatch * ubatch) override;
|
|
207
208
|
|
|
208
209
|
ggml_tensor * s_mask; // F32 [1, n_kv]
|
|
209
210
|
|
|
210
|
-
const
|
|
211
|
+
const llama_kv_cache_recurrent_state * kv_state;
|
|
211
212
|
};
|
|
212
213
|
|
|
213
214
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -247,10 +248,10 @@ public:
|
|
|
247
248
|
llm_graph_input_attn_kv_unified(
|
|
248
249
|
const llama_hparams & hparams,
|
|
249
250
|
const llama_cparams & cparams,
|
|
250
|
-
const
|
|
251
|
+
const llama_kv_cache_unified_state * kv_state) :
|
|
251
252
|
hparams(hparams),
|
|
252
253
|
cparams(cparams),
|
|
253
|
-
|
|
254
|
+
kv_state(kv_state) {
|
|
254
255
|
}
|
|
255
256
|
~llm_graph_input_attn_kv_unified() = default;
|
|
256
257
|
|
|
@@ -264,7 +265,7 @@ public:
|
|
|
264
265
|
const llama_hparams & hparams;
|
|
265
266
|
const llama_cparams & cparams;
|
|
266
267
|
|
|
267
|
-
const
|
|
268
|
+
const llama_kv_cache_unified_state * kv_state;
|
|
268
269
|
};
|
|
269
270
|
|
|
270
271
|
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
@@ -272,10 +273,10 @@ public:
|
|
|
272
273
|
llm_graph_input_attn_kv_unified_iswa(
|
|
273
274
|
const llama_hparams & hparams,
|
|
274
275
|
const llama_cparams & cparams,
|
|
275
|
-
const
|
|
276
|
+
const llama_kv_cache_unified_iswa_state * kv_state) :
|
|
276
277
|
hparams(hparams),
|
|
277
278
|
cparams(cparams),
|
|
278
|
-
|
|
279
|
+
kv_state(kv_state) {
|
|
279
280
|
}
|
|
280
281
|
~llm_graph_input_attn_kv_unified_iswa() = default;
|
|
281
282
|
|
|
@@ -292,7 +293,7 @@ public:
|
|
|
292
293
|
const llama_hparams & hparams;
|
|
293
294
|
const llama_cparams & cparams;
|
|
294
295
|
|
|
295
|
-
const
|
|
296
|
+
const llama_kv_cache_unified_iswa_state * kv_state;
|
|
296
297
|
};
|
|
297
298
|
|
|
298
299
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -383,10 +384,10 @@ struct llm_graph_params {
|
|
|
383
384
|
ggml_backend_sched_t sched;
|
|
384
385
|
ggml_backend_t backend_cpu;
|
|
385
386
|
|
|
386
|
-
const llama_adapter_cvec
|
|
387
|
-
const llama_adapter_loras
|
|
388
|
-
const
|
|
389
|
-
const llama_cross
|
|
387
|
+
const llama_adapter_cvec * cvec;
|
|
388
|
+
const llama_adapter_loras * loras;
|
|
389
|
+
const llama_memory_state_i * mstate;
|
|
390
|
+
const llama_cross * cross;
|
|
390
391
|
|
|
391
392
|
int32_t n_outputs;
|
|
392
393
|
|
|
@@ -435,10 +436,10 @@ struct llm_graph_context {
|
|
|
435
436
|
|
|
436
437
|
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
437
438
|
|
|
438
|
-
const llama_adapter_cvec
|
|
439
|
-
const llama_adapter_loras
|
|
440
|
-
const
|
|
441
|
-
const llama_cross
|
|
439
|
+
const llama_adapter_cvec * cvec;
|
|
440
|
+
const llama_adapter_loras * loras;
|
|
441
|
+
const llama_memory_state_i * mstate;
|
|
442
|
+
const llama_cross * cross;
|
|
442
443
|
|
|
443
444
|
const llm_graph_cb & cb_func;
|
|
444
445
|
|