cui-llama.rn 1.6.1 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +38 -5
- package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +41 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/llama-kv-cache.cpp
CHANGED
@@ -23,32 +23,21 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
|
|
23
23
|
}
|
24
24
|
|
25
25
|
llama_kv_cache_unified::llama_kv_cache_unified(
|
26
|
-
const llama_model &
|
27
|
-
|
28
|
-
lm_ggml_type
|
29
|
-
|
30
|
-
bool
|
31
|
-
|
32
|
-
uint32_t
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
LM_GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
|
42
|
-
|
43
|
-
head = 0;
|
44
|
-
size = kv_size;
|
45
|
-
used = 0;
|
46
|
-
|
47
|
-
this->type_k = type_k;
|
48
|
-
this->type_v = type_v;
|
49
|
-
|
50
|
-
cells.clear();
|
51
|
-
cells.resize(kv_size);
|
26
|
+
const llama_model & model,
|
27
|
+
layer_filter_cb && filter,
|
28
|
+
lm_ggml_type type_k,
|
29
|
+
lm_ggml_type type_v,
|
30
|
+
bool v_trans,
|
31
|
+
bool offload,
|
32
|
+
uint32_t kv_size,
|
33
|
+
uint32_t n_seq_max,
|
34
|
+
uint32_t n_pad,
|
35
|
+
uint32_t n_swa,
|
36
|
+
llama_swa_type swa_type) :
|
37
|
+
model(model), hparams(model.hparams), v_trans(v_trans),
|
38
|
+
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
39
|
+
|
40
|
+
LM_GGML_ASSERT(kv_size % n_pad == 0);
|
52
41
|
|
53
42
|
// create a context for each buffer type
|
54
43
|
std::map<lm_ggml_backend_buffer_type_t, lm_ggml_context *> ctx_map;
|
@@ -56,7 +45,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
56
45
|
auto it = ctx_map.find(buft);
|
57
46
|
if (it == ctx_map.end()) {
|
58
47
|
lm_ggml_init_params params = {
|
59
|
-
/*.mem_size =*/ size_t(2u*n_layer*lm_ggml_tensor_overhead()),
|
48
|
+
/*.mem_size =*/ size_t(2u*hparams.n_layer*lm_ggml_tensor_overhead()),
|
60
49
|
/*.mem_buffer =*/ NULL,
|
61
50
|
/*.no_alloc =*/ true,
|
62
51
|
};
|
@@ -75,37 +64,50 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
75
64
|
return it->second;
|
76
65
|
};
|
77
66
|
|
78
|
-
|
79
|
-
|
67
|
+
head = 0;
|
68
|
+
size = kv_size;
|
69
|
+
used = 0;
|
80
70
|
|
81
|
-
|
82
|
-
|
83
|
-
|
71
|
+
cells.resize(kv_size);
|
72
|
+
|
73
|
+
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
74
|
+
if (filter && !filter(il)) {
|
75
|
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
76
|
+
continue;
|
77
|
+
}
|
78
|
+
|
79
|
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
80
|
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
84
81
|
|
85
82
|
const char * dev_name = "CPU";
|
86
83
|
|
87
84
|
lm_ggml_backend_buffer_type_t buft = lm_ggml_backend_cpu_buffer_type();
|
88
85
|
|
89
86
|
if (offload) {
|
90
|
-
auto * dev = model.dev_layer(
|
87
|
+
auto * dev = model.dev_layer(il);
|
91
88
|
buft = lm_ggml_backend_dev_buffer_type(dev);
|
92
89
|
|
93
90
|
dev_name = lm_ggml_backend_dev_name(dev);
|
94
91
|
}
|
95
92
|
|
96
|
-
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__,
|
93
|
+
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
|
97
94
|
|
98
95
|
lm_ggml_context * ctx = ctx_for_buft(buft);
|
99
96
|
if (!ctx) {
|
100
97
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
101
98
|
}
|
102
99
|
|
103
|
-
lm_ggml_tensor * k
|
104
|
-
lm_ggml_tensor * v
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
100
|
+
lm_ggml_tensor * k;
|
101
|
+
lm_ggml_tensor * v;
|
102
|
+
|
103
|
+
k = lm_ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
|
104
|
+
v = lm_ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
|
105
|
+
|
106
|
+
lm_ggml_format_name(k, "cache_k_l%d", il);
|
107
|
+
lm_ggml_format_name(v, "cache_v_l%d", il);
|
108
|
+
|
109
|
+
map_layer_ids[il] = layers.size();
|
110
|
+
layers.push_back({ il, k, v });
|
109
111
|
}
|
110
112
|
|
111
113
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
@@ -117,8 +119,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
117
119
|
if (!buf) {
|
118
120
|
throw std::runtime_error("failed to allocate buffer for kv cache");
|
119
121
|
}
|
120
|
-
|
122
|
+
|
121
123
|
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf), lm_ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
124
|
+
|
125
|
+
lm_ggml_backend_buffer_clear(buf, 0);
|
122
126
|
bufs.emplace_back(buf);
|
123
127
|
}
|
124
128
|
|
@@ -126,18 +130,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
126
130
|
const size_t memory_size_k = size_k_bytes();
|
127
131
|
const size_t memory_size_v = size_v_bytes();
|
128
132
|
|
129
|
-
LLAMA_LOG_INFO("%s:
|
130
|
-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
133
|
+
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
134
|
+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
|
131
135
|
lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
132
136
|
lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
133
137
|
}
|
134
138
|
}
|
135
139
|
|
136
140
|
void llama_kv_cache_unified::clear() {
|
137
|
-
for (
|
141
|
+
for (uint32_t i = 0; i < size; ++i) {
|
138
142
|
cells[i].pos = -1;
|
139
143
|
cells[i].seq_id.clear();
|
140
144
|
}
|
145
|
+
|
141
146
|
head = 0;
|
142
147
|
used = 0;
|
143
148
|
|
@@ -166,6 +171,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
166
171
|
} else {
|
167
172
|
continue;
|
168
173
|
}
|
174
|
+
|
169
175
|
if (cells[i].is_empty()) {
|
170
176
|
// keep count of the number of used cells
|
171
177
|
if (cells[i].pos >= 0) {
|
@@ -262,6 +268,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|
262
268
|
for (uint32_t i = 0; i < size; ++i) {
|
263
269
|
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
264
270
|
has_shift = true;
|
271
|
+
|
265
272
|
cells[i].pos += delta;
|
266
273
|
cells[i].delta += delta;
|
267
274
|
|
@@ -314,53 +321,60 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|
314
321
|
}
|
315
322
|
}
|
316
323
|
|
317
|
-
llama_pos llama_kv_cache_unified::
|
318
|
-
llama_pos result =
|
324
|
+
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
325
|
+
llama_pos result = std::numeric_limits<llama_pos>::max();
|
319
326
|
|
320
327
|
for (uint32_t i = 0; i < size; ++i) {
|
321
328
|
if (cells[i].has_seq_id(seq_id)) {
|
322
|
-
result = std::
|
329
|
+
result = std::min(result, cells[i].pos);
|
323
330
|
}
|
324
331
|
}
|
325
332
|
|
333
|
+
if (result == std::numeric_limits<llama_pos>::max()) {
|
334
|
+
result = -1;
|
335
|
+
}
|
336
|
+
|
326
337
|
return result;
|
327
338
|
}
|
328
339
|
|
329
|
-
|
330
|
-
|
331
|
-
return;
|
332
|
-
}
|
340
|
+
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
341
|
+
llama_pos result = -1;
|
333
342
|
|
334
|
-
uint32_t
|
343
|
+
for (uint32_t i = 0; i < size; ++i) {
|
344
|
+
if (cells[i].has_seq_id(seq_id)) {
|
345
|
+
result = std::max(result, cells[i].pos);
|
346
|
+
}
|
347
|
+
}
|
335
348
|
|
336
|
-
|
337
|
-
|
338
|
-
cells[i].seq_id.clear();
|
349
|
+
return result;
|
350
|
+
}
|
339
351
|
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
352
|
+
void llama_kv_cache_unified::restore() {
|
353
|
+
for (const auto & [id, cell] : recovery.cells) {
|
354
|
+
// TODO: move to new `struct kv_cells`
|
355
|
+
const bool is_empty0 = cells[id].is_empty();
|
356
|
+
const bool is_empty1 = cell.is_empty();
|
344
357
|
|
345
|
-
|
358
|
+
if (!is_empty0 && is_empty1) {
|
359
|
+
used--;
|
360
|
+
} else if (is_empty0 && !is_empty1) {
|
361
|
+
used++;
|
346
362
|
}
|
347
363
|
|
348
|
-
|
364
|
+
cells[id] = cell;
|
349
365
|
}
|
350
366
|
|
351
|
-
|
352
|
-
head = new_head;
|
353
|
-
}
|
367
|
+
recovery.clear();
|
354
368
|
}
|
355
369
|
|
356
370
|
void llama_kv_cache_unified::commit() {
|
357
|
-
if (
|
358
|
-
LLAMA_LOG_WARN("%s:
|
359
|
-
__func__, "https://github.com/ggml-org/llama.cpp/pull/
|
371
|
+
if (recovery.cells.empty()) {
|
372
|
+
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
|
373
|
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
|
360
374
|
return;
|
361
375
|
}
|
362
376
|
|
363
|
-
|
377
|
+
recovery.clear();
|
364
378
|
}
|
365
379
|
|
366
380
|
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
@@ -429,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
429
443
|
void llama_kv_cache_unified::defrag_sched(float thold) {
|
430
444
|
// - do not defrag small contexts (i.e. < 2048 tokens)
|
431
445
|
// - count the padding towards the number of used tokens
|
432
|
-
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used +
|
446
|
+
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
|
433
447
|
|
434
448
|
// queue defragmentation for next llama_kv_cache_update
|
435
449
|
if (fragmentation > thold) {
|
@@ -441,27 +455,26 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
|
441
455
|
|
442
456
|
void llama_kv_cache_unified::set_full() {
|
443
457
|
n = size;
|
458
|
+
|
459
|
+
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
460
|
+
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
461
|
+
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
|
462
|
+
// setting it to 0 is the simplest way to achieve that
|
463
|
+
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
|
464
|
+
head = 0;
|
444
465
|
}
|
445
466
|
|
446
|
-
llama_sbatch llama_kv_cache_unified::sbatch_init(
|
447
|
-
const llama_batch & batch,
|
448
|
-
bool logits_all) {
|
467
|
+
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
|
449
468
|
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
450
469
|
}
|
451
470
|
|
452
|
-
llama_ubatch llama_kv_cache_unified::ubatch_next(
|
453
|
-
llama_sbatch & sbatch,
|
454
|
-
uint32_t n_ubatch,
|
455
|
-
bool embd_pooled) const {
|
471
|
+
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
|
456
472
|
LM_GGML_UNUSED(embd_pooled);
|
457
473
|
return sbatch.split_simple(n_ubatch);
|
458
474
|
}
|
459
475
|
|
460
|
-
bool llama_kv_cache_unified::find_slot(
|
461
|
-
const llama_ubatch & ubatch) {
|
476
|
+
bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
462
477
|
const uint32_t n_tokens = ubatch.n_tokens;
|
463
|
-
const uint32_t n_seqs = ubatch.n_seqs;
|
464
|
-
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
465
478
|
|
466
479
|
// if we have enough unused cells before the current head ->
|
467
480
|
// better to start searching from the beginning of the cache, hoping to fill it
|
@@ -476,6 +489,29 @@ bool llama_kv_cache_unified::find_slot(
|
|
476
489
|
return false;
|
477
490
|
}
|
478
491
|
|
492
|
+
//#define FIND_SLOT_DEBUG 1
|
493
|
+
#if FIND_SLOT_DEBUG
|
494
|
+
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
495
|
+
|
496
|
+
// for debugging
|
497
|
+
{
|
498
|
+
std::string ss;
|
499
|
+
if (n_swa > 0) {
|
500
|
+
for (uint32_t i = 0; i < size; ++i) {
|
501
|
+
if (cells[i].pos == -1) {
|
502
|
+
ss += '.';
|
503
|
+
} else {
|
504
|
+
ss += std::to_string(*cells[i].seq_id.begin());
|
505
|
+
}
|
506
|
+
if (i%256 == 255) {
|
507
|
+
ss += '\n';
|
508
|
+
}
|
509
|
+
}
|
510
|
+
}
|
511
|
+
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
|
512
|
+
}
|
513
|
+
#endif
|
514
|
+
|
479
515
|
uint32_t n_tested = 0;
|
480
516
|
|
481
517
|
while (true) {
|
@@ -505,60 +541,257 @@ bool llama_kv_cache_unified::find_slot(
|
|
505
541
|
}
|
506
542
|
}
|
507
543
|
|
508
|
-
for (uint32_t
|
509
|
-
|
510
|
-
|
511
|
-
cells[head +
|
544
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
545
|
+
// remember the original state
|
546
|
+
if (recovery.cells.find(head + i) == recovery.cells.end()) {
|
547
|
+
recovery.cells[head + i] = cells[head + i];
|
548
|
+
}
|
512
549
|
|
513
|
-
|
514
|
-
|
515
|
-
|
550
|
+
cells[head + i].pos = ubatch.pos[i];
|
551
|
+
|
552
|
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
553
|
+
cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
|
516
554
|
}
|
517
555
|
}
|
518
556
|
|
519
557
|
used += n_tokens;
|
520
558
|
|
521
|
-
pending.ranges.push_back({head, head + n_tokens});
|
522
|
-
|
523
559
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
524
560
|
// after enough generations, the benefit from this heuristic disappears
|
525
561
|
// if we start defragmenting the cache, the benefit from this will be more important
|
526
|
-
n = std::min(size, std::max(
|
562
|
+
n = std::min(size, std::max(n_pad, LM_GGML_PAD(cell_max(), n_pad)));
|
563
|
+
|
564
|
+
#ifdef FIND_SLOT_DEBUG
|
565
|
+
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
566
|
+
#endif
|
527
567
|
|
528
|
-
|
568
|
+
return true;
|
569
|
+
}
|
529
570
|
|
571
|
+
bool llama_kv_cache_unified::get_can_shift() const {
|
530
572
|
return true;
|
531
573
|
}
|
532
574
|
|
533
|
-
|
534
|
-
|
575
|
+
uint32_t llama_kv_cache_unified::get_n() const {
|
576
|
+
return n;
|
577
|
+
}
|
578
|
+
|
579
|
+
uint32_t llama_kv_cache_unified::get_size() const {
|
580
|
+
return size;
|
581
|
+
}
|
582
|
+
|
583
|
+
lm_ggml_tensor * llama_kv_cache_unified::get_k(lm_ggml_context * ctx, int32_t il) const {
|
584
|
+
const int32_t ikv = map_layer_ids.at(il);
|
585
|
+
|
586
|
+
auto * k = layers[ikv].k;
|
587
|
+
|
588
|
+
return lm_ggml_view_3d(ctx, k,
|
589
|
+
hparams.n_embd_head_k, hparams.n_head_kv(il), n,
|
590
|
+
lm_ggml_row_size(k->type, hparams.n_embd_head_k),
|
591
|
+
lm_ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
|
592
|
+
0);
|
593
|
+
}
|
535
594
|
|
536
|
-
|
537
|
-
|
595
|
+
lm_ggml_tensor * llama_kv_cache_unified::get_v(lm_ggml_context * ctx, int32_t il) const {
|
596
|
+
const int32_t ikv = map_layer_ids.at(il);
|
597
|
+
|
598
|
+
auto * v = layers[ikv].v;
|
599
|
+
|
600
|
+
if (!v_trans) {
|
601
|
+
// note: v->nb[1] <= v->nb[2]
|
602
|
+
return lm_ggml_view_3d(ctx, v,
|
603
|
+
hparams.n_embd_head_v, hparams.n_head_kv(il), n,
|
604
|
+
lm_ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
605
|
+
lm_ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
606
|
+
0);
|
538
607
|
}
|
539
608
|
|
540
|
-
|
609
|
+
// note: v->nb[1] > v->nb[2]
|
610
|
+
return lm_ggml_view_3d(ctx, v,
|
611
|
+
n, hparams.n_head_kv(il), hparams.n_embd_head_v,
|
612
|
+
lm_ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
|
613
|
+
lm_ggml_row_size(v->type, v->ne[1]), // v->nb[2]
|
614
|
+
0);
|
541
615
|
}
|
542
616
|
|
543
|
-
|
544
|
-
|
617
|
+
lm_ggml_tensor * llama_kv_cache_unified::cpy_k(lm_ggml_context * ctx, lm_ggml_tensor * k_cur, int32_t il) const {
|
618
|
+
const int32_t ikv = map_layer_ids.at(il);
|
619
|
+
|
620
|
+
auto * k = layers[ikv].k;
|
621
|
+
|
622
|
+
const int64_t n_tokens = k_cur->ne[2];
|
623
|
+
|
624
|
+
lm_ggml_tensor * k_view = lm_ggml_view_1d(ctx, k,
|
625
|
+
n_tokens*hparams.n_embd_k_gqa(il),
|
626
|
+
lm_ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
|
627
|
+
|
628
|
+
return lm_ggml_cpy(ctx, k_cur, k_view);
|
545
629
|
}
|
546
630
|
|
547
|
-
|
548
|
-
|
631
|
+
lm_ggml_tensor * llama_kv_cache_unified::cpy_v(lm_ggml_context * ctx, lm_ggml_tensor * v_cur, int32_t il) const {
|
632
|
+
const int32_t ikv = map_layer_ids.at(il);
|
633
|
+
|
634
|
+
auto * v = layers[ikv].v;
|
635
|
+
|
636
|
+
const int64_t n_tokens = v_cur->ne[2];
|
637
|
+
|
638
|
+
v_cur = lm_ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
|
639
|
+
|
640
|
+
lm_ggml_tensor * v_view = nullptr;
|
641
|
+
|
642
|
+
if (!v_trans) {
|
643
|
+
v_view = lm_ggml_view_1d(ctx, v,
|
644
|
+
n_tokens*hparams.n_embd_v_gqa(il),
|
645
|
+
lm_ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
|
646
|
+
} else {
|
647
|
+
// note: the V cache is transposed when not using flash attention
|
648
|
+
v_view = lm_ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
649
|
+
(v->ne[1])*lm_ggml_element_size(v),
|
650
|
+
( head)*lm_ggml_element_size(v));
|
651
|
+
|
652
|
+
v_cur = lm_ggml_transpose(ctx, v_cur);
|
653
|
+
}
|
654
|
+
|
655
|
+
return lm_ggml_cpy(ctx, v_cur, v_view);
|
549
656
|
}
|
550
657
|
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
658
|
+
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
|
659
|
+
// no pruning is needed when the cache does not use SWA
|
660
|
+
LM_GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
|
661
|
+
|
662
|
+
int n_attended = 0;
|
663
|
+
|
664
|
+
for (uint32_t i = 0; i < size; ++i) {
|
665
|
+
const llama_pos p0 = cells[i].pos;
|
666
|
+
|
667
|
+
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
|
668
|
+
n_attended++;
|
669
|
+
}
|
670
|
+
|
671
|
+
if (is_masked_swa(p0, pmax)) {
|
672
|
+
if (seq_id < 0) {
|
673
|
+
cells[i].seq_id.clear();
|
674
|
+
} else if (cells[i].has_seq_id(seq_id)) {
|
675
|
+
cells[i].seq_id.erase(seq_id);
|
676
|
+
} else {
|
677
|
+
continue;
|
678
|
+
}
|
679
|
+
|
680
|
+
if (cells[i].is_empty()) {
|
681
|
+
// keep count of the number of used cells
|
682
|
+
if (cells[i].pos >= 0) {
|
683
|
+
used--;
|
684
|
+
}
|
685
|
+
|
686
|
+
cells[i].pos = -1;
|
687
|
+
}
|
688
|
+
}
|
689
|
+
}
|
690
|
+
|
691
|
+
if (n_attended < std::min<int>(n_swa, pmin)) {
|
692
|
+
LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
|
693
|
+
}
|
694
|
+
}
|
695
|
+
|
696
|
+
void llama_kv_cache_unified::set_input_kq_mask(lm_ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
697
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
698
|
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
699
|
+
const int64_t n_seqs = ubatch->n_seqs;
|
700
|
+
|
701
|
+
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(dst->buffer));
|
702
|
+
float * data = (float *) dst->data;
|
703
|
+
|
704
|
+
const int64_t n_kv = n;
|
705
|
+
|
706
|
+
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
707
|
+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
708
|
+
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
709
|
+
// Causal mask:
|
710
|
+
// xxx-------
|
711
|
+
// xxxx------
|
712
|
+
// xxxxx-----
|
713
|
+
// Non-causal mask:
|
714
|
+
// xxxxx-----
|
715
|
+
// xxxxx-----
|
716
|
+
// xxxxx-----
|
717
|
+
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
718
|
+
for (int h = 0; h < 1; ++h) {
|
719
|
+
for (int s = 0; s < n_seqs; ++s) {
|
720
|
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
721
|
+
|
722
|
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
723
|
+
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
724
|
+
|
725
|
+
for (int i = 0; i < n_kv; ++i) {
|
726
|
+
const llama_pos p0 = cells[i].pos;
|
727
|
+
|
728
|
+
bool masked = false;
|
729
|
+
|
730
|
+
// mask the token if not the same sequence
|
731
|
+
masked = masked || (!cells[i].has_seq_id(seq_id));
|
732
|
+
|
733
|
+
// mask future tokens
|
734
|
+
masked = masked || (causal_attn && p0 > p1);
|
735
|
+
|
736
|
+
// apply SWA if any
|
737
|
+
masked = masked || (is_masked_swa(p0, p1));
|
738
|
+
|
739
|
+
float f = 0.0f;
|
740
|
+
|
741
|
+
if (masked) {
|
742
|
+
f = -INFINITY;
|
743
|
+
} else if (hparams.use_alibi) {
|
744
|
+
f = -std::abs(p0 - p1);
|
745
|
+
}
|
746
|
+
|
747
|
+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
748
|
+
}
|
749
|
+
}
|
750
|
+
}
|
751
|
+
|
752
|
+
// mask padded tokens
|
753
|
+
if (data) {
|
754
|
+
for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
|
755
|
+
for (int j = 0; j < n_kv; ++j) {
|
756
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
757
|
+
}
|
758
|
+
}
|
759
|
+
}
|
555
760
|
}
|
761
|
+
}
|
762
|
+
|
763
|
+
void llama_kv_cache_unified::set_input_k_shift(lm_ggml_tensor * dst) const {
|
764
|
+
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(dst->buffer));
|
765
|
+
|
766
|
+
int32_t * data = (int32_t *) dst->data;
|
767
|
+
|
768
|
+
for (uint32_t i = 0; i < size; ++i) {
|
769
|
+
data[i] = cells[i].delta;
|
770
|
+
}
|
771
|
+
}
|
556
772
|
|
557
|
-
|
773
|
+
void llama_kv_cache_unified::set_input_pos_bucket(lm_ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
774
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
775
|
+
|
776
|
+
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(dst->buffer));
|
777
|
+
LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
778
|
+
|
779
|
+
int32_t * data = (int32_t *) dst->data;
|
780
|
+
|
781
|
+
const int64_t n_kv = n;
|
782
|
+
|
783
|
+
for (int h = 0; h < 1; ++h) {
|
784
|
+
for (int j = 0; j < n_tokens; ++j) {
|
785
|
+
for (int i = 0; i < n_kv; ++i) {
|
786
|
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
787
|
+
}
|
788
|
+
}
|
789
|
+
}
|
558
790
|
}
|
559
791
|
|
560
792
|
size_t llama_kv_cache_unified::total_size() const {
|
561
793
|
size_t size = 0;
|
794
|
+
|
562
795
|
for (const auto & buf : bufs) {
|
563
796
|
size += lm_ggml_backend_buffer_get_size(buf.get());
|
564
797
|
}
|
@@ -569,8 +802,8 @@ size_t llama_kv_cache_unified::total_size() const {
|
|
569
802
|
size_t llama_kv_cache_unified::size_k_bytes() const {
|
570
803
|
size_t size_k_bytes = 0;
|
571
804
|
|
572
|
-
for (const auto &
|
573
|
-
size_k_bytes += lm_ggml_nbytes(k);
|
805
|
+
for (const auto & layer : layers) {
|
806
|
+
size_k_bytes += lm_ggml_nbytes(layer.k);
|
574
807
|
}
|
575
808
|
|
576
809
|
return size_k_bytes;
|
@@ -579,8 +812,8 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
|
|
579
812
|
size_t llama_kv_cache_unified::size_v_bytes() const {
|
580
813
|
size_t size_v_bytes = 0;
|
581
814
|
|
582
|
-
for (const auto &
|
583
|
-
size_v_bytes += lm_ggml_nbytes(v);
|
815
|
+
for (const auto & layer : layers) {
|
816
|
+
size_v_bytes += lm_ggml_nbytes(layer.v);
|
584
817
|
}
|
585
818
|
|
586
819
|
return size_v_bytes;
|
@@ -644,13 +877,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|
644
877
|
LM_GGML_UNUSED(ubatch);
|
645
878
|
|
646
879
|
if (k_shift) {
|
647
|
-
|
648
|
-
|
649
|
-
int32_t * data = (int32_t *) k_shift->data;
|
650
|
-
|
651
|
-
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
652
|
-
data[i] = kv_self->cells[i].delta;
|
653
|
-
}
|
880
|
+
kv_self->set_input_k_shift(k_shift);
|
654
881
|
}
|
655
882
|
}
|
656
883
|
|
@@ -660,13 +887,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
660
887
|
lm_ggml_cgraph * gf) const {
|
661
888
|
auto res = std::make_unique<llm_graph_result>();
|
662
889
|
|
663
|
-
const auto & n_layer = hparams.n_layer;
|
664
|
-
|
665
890
|
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
666
891
|
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
667
892
|
|
668
|
-
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
669
|
-
|
670
893
|
//LM_GGML_ASSERT(kv_self->size == n_ctx);
|
671
894
|
|
672
895
|
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
@@ -674,24 +897,22 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
674
897
|
inp->k_shift = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, cparams.n_ctx);
|
675
898
|
lm_ggml_set_input(inp->k_shift);
|
676
899
|
|
677
|
-
for (
|
900
|
+
for (const auto & layer : layers) {
|
901
|
+
const uint32_t il = layer.il;
|
902
|
+
|
678
903
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
679
904
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
680
905
|
|
681
|
-
const
|
906
|
+
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
907
|
+
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
682
908
|
|
683
|
-
|
684
|
-
// if we decide to make them configurable, like the non-sliding ones
|
685
|
-
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
686
|
-
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
687
|
-
|
688
|
-
lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
909
|
+
lm_ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
689
910
|
|
690
911
|
lm_ggml_tensor * k =
|
691
|
-
lm_ggml_view_3d(ctx,
|
912
|
+
lm_ggml_view_3d(ctx, layer.k,
|
692
913
|
n_embd_head_k, n_head_kv, size,
|
693
|
-
lm_ggml_row_size(
|
694
|
-
lm_ggml_row_size(
|
914
|
+
lm_ggml_row_size(layer.k->type, n_embd_head_k),
|
915
|
+
lm_ggml_row_size(layer.k->type, n_embd_k_gqa),
|
695
916
|
0);
|
696
917
|
|
697
918
|
lm_ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
@@ -796,44 +1017,46 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
796
1017
|
nm++;
|
797
1018
|
}
|
798
1019
|
|
799
|
-
for (
|
1020
|
+
for (const auto & layer : layers) {
|
1021
|
+
const uint32_t il = layer.il;
|
1022
|
+
|
800
1023
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
801
1024
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
802
1025
|
|
803
|
-
lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx,
|
1026
|
+
lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx, layer.k,
|
804
1027
|
n_embd_k_gqa, nm,
|
805
|
-
lm_ggml_row_size(
|
806
|
-
lm_ggml_row_size(
|
1028
|
+
lm_ggml_row_size(layer.k->type, n_embd_k_gqa),
|
1029
|
+
lm_ggml_row_size(layer.k->type, n_embd_k_gqa*i));
|
807
1030
|
|
808
|
-
lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx,
|
1031
|
+
lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx, layer.k,
|
809
1032
|
n_embd_k_gqa, nm,
|
810
|
-
lm_ggml_row_size(
|
811
|
-
lm_ggml_row_size(
|
1033
|
+
lm_ggml_row_size(layer.k->type, n_embd_k_gqa),
|
1034
|
+
lm_ggml_row_size(layer.k->type, n_embd_k_gqa*id));
|
812
1035
|
|
813
1036
|
lm_ggml_tensor * view_v_src;
|
814
1037
|
lm_ggml_tensor * view_v_dst;
|
815
1038
|
|
816
1039
|
if (cparams.flash_attn) {
|
817
1040
|
// NOTE: the V cache is not transposed when using flash attention
|
818
|
-
view_v_src = lm_ggml_view_2d(ctx,
|
1041
|
+
view_v_src = lm_ggml_view_2d(ctx, layer.v,
|
819
1042
|
n_embd_v_gqa, nm,
|
820
|
-
lm_ggml_row_size(
|
821
|
-
lm_ggml_row_size(
|
1043
|
+
lm_ggml_row_size(layer.v->type, n_embd_v_gqa),
|
1044
|
+
lm_ggml_row_size(layer.v->type, n_embd_v_gqa*i));
|
822
1045
|
|
823
|
-
view_v_dst = lm_ggml_view_2d(ctx,
|
1046
|
+
view_v_dst = lm_ggml_view_2d(ctx, layer.v,
|
824
1047
|
n_embd_v_gqa, nm,
|
825
|
-
lm_ggml_row_size(
|
826
|
-
lm_ggml_row_size(
|
1048
|
+
lm_ggml_row_size(layer.v->type, n_embd_v_gqa),
|
1049
|
+
lm_ggml_row_size(layer.v->type, n_embd_v_gqa*id));
|
827
1050
|
} else {
|
828
|
-
view_v_src = lm_ggml_view_2d(ctx,
|
1051
|
+
view_v_src = lm_ggml_view_2d(ctx, layer.v,
|
829
1052
|
nm, n_embd_v_gqa,
|
830
|
-
lm_ggml_row_size(
|
831
|
-
lm_ggml_row_size(
|
1053
|
+
lm_ggml_row_size(layer.v->type, size),
|
1054
|
+
lm_ggml_row_size(layer.v->type, i));
|
832
1055
|
|
833
|
-
view_v_dst = lm_ggml_view_2d(ctx,
|
1056
|
+
view_v_dst = lm_ggml_view_2d(ctx, layer.v,
|
834
1057
|
nm, n_embd_v_gqa,
|
835
|
-
lm_ggml_row_size(
|
836
|
-
lm_ggml_row_size(
|
1058
|
+
lm_ggml_row_size(layer.v->type, size),
|
1059
|
+
lm_ggml_row_size(layer.v->type, id));
|
837
1060
|
}
|
838
1061
|
|
839
1062
|
lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx, view_k_src, view_k_dst));
|
@@ -850,7 +1073,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
850
1073
|
}
|
851
1074
|
|
852
1075
|
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
853
|
-
const uint32_t n_layer =
|
1076
|
+
const uint32_t n_layer = layers.size();
|
854
1077
|
|
855
1078
|
const uint32_t n_kv = cell_max();
|
856
1079
|
const uint32_t n_used = used;
|
@@ -998,6 +1221,34 @@ uint32_t llama_kv_cache_unified::cell_max() const {
|
|
998
1221
|
return 0;
|
999
1222
|
}
|
1000
1223
|
|
1224
|
+
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
1225
|
+
if (p0 < 0) {
|
1226
|
+
return true;
|
1227
|
+
}
|
1228
|
+
|
1229
|
+
switch (swa_type) {
|
1230
|
+
case LLAMA_SWA_TYPE_NONE:
|
1231
|
+
{
|
1232
|
+
} break;
|
1233
|
+
case LLAMA_SWA_TYPE_STANDARD:
|
1234
|
+
{
|
1235
|
+
if (p1 - p0 >= (int32_t) n_swa) {
|
1236
|
+
return true;
|
1237
|
+
}
|
1238
|
+
} break;
|
1239
|
+
case LLAMA_SWA_TYPE_CHUNKED:
|
1240
|
+
{
|
1241
|
+
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
1242
|
+
|
1243
|
+
if (p0 < pos_chunk_start) {
|
1244
|
+
return true;
|
1245
|
+
}
|
1246
|
+
} break;
|
1247
|
+
}
|
1248
|
+
|
1249
|
+
return false;
|
1250
|
+
}
|
1251
|
+
|
1001
1252
|
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
1002
1253
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
1003
1254
|
uint32_t cell_count = 0;
|
@@ -1075,7 +1326,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
|
|
1075
1326
|
|
1076
1327
|
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
1077
1328
|
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
1078
|
-
const uint32_t n_layer =
|
1329
|
+
const uint32_t n_layer = layers.size();
|
1079
1330
|
|
1080
1331
|
io.write(&v_trans, sizeof(v_trans));
|
1081
1332
|
io.write(&n_layer, sizeof(n_layer));
|
@@ -1084,56 +1335,63 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
1084
1335
|
|
1085
1336
|
// Iterate and write all the keys first, each row is a cell
|
1086
1337
|
// Get whole range at a time
|
1087
|
-
for (
|
1338
|
+
for (const auto & layer : layers) {
|
1339
|
+
const uint32_t il = layer.il;
|
1340
|
+
|
1088
1341
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
1089
1342
|
|
1090
1343
|
// Write key type
|
1091
|
-
const int32_t k_type_i = (int32_t)
|
1344
|
+
const int32_t k_type_i = (int32_t)layer.k->type;
|
1092
1345
|
io.write(&k_type_i, sizeof(k_type_i));
|
1093
1346
|
|
1094
1347
|
// Write row size of key
|
1095
|
-
const uint64_t k_size_row = lm_ggml_row_size(
|
1348
|
+
const uint64_t k_size_row = lm_ggml_row_size(layer.k->type, n_embd_k_gqa);
|
1096
1349
|
io.write(&k_size_row, sizeof(k_size_row));
|
1097
1350
|
|
1098
1351
|
// Read each range of cells of k_size length each into tmp_buf and write out
|
1099
1352
|
for (const auto & range : cell_ranges) {
|
1100
1353
|
const size_t range_size = range.second - range.first;
|
1101
1354
|
const size_t buf_size = range_size * k_size_row;
|
1102
|
-
io.write_tensor(
|
1355
|
+
io.write_tensor(layer.k, range.first * k_size_row, buf_size);
|
1103
1356
|
}
|
1104
1357
|
}
|
1105
1358
|
|
1106
1359
|
if (!v_trans) {
|
1107
|
-
for (
|
1360
|
+
for (const auto & layer : layers) {
|
1361
|
+
const uint32_t il = layer.il;
|
1362
|
+
|
1108
1363
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
1109
1364
|
|
1110
1365
|
// Write value type
|
1111
|
-
const int32_t v_type_i = (int32_t)
|
1366
|
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
1112
1367
|
io.write(&v_type_i, sizeof(v_type_i));
|
1113
1368
|
|
1114
1369
|
// Write row size of value
|
1115
|
-
const uint64_t v_size_row = lm_ggml_row_size(
|
1370
|
+
const uint64_t v_size_row = lm_ggml_row_size(layer.v->type, n_embd_v_gqa);
|
1116
1371
|
io.write(&v_size_row, sizeof(v_size_row));
|
1117
1372
|
|
1118
1373
|
// Read each range of cells of v_size length each into tmp_buf and write out
|
1119
1374
|
for (const auto & range : cell_ranges) {
|
1120
1375
|
const size_t range_size = range.second - range.first;
|
1121
1376
|
const size_t buf_size = range_size * v_size_row;
|
1122
|
-
io.write_tensor(
|
1377
|
+
io.write_tensor(layer.v, range.first * v_size_row, buf_size);
|
1123
1378
|
}
|
1124
1379
|
}
|
1125
1380
|
} else {
|
1126
1381
|
// When v is transposed, we also need the element size and get the element ranges from each row
|
1127
1382
|
const uint32_t kv_size = size;
|
1128
|
-
|
1383
|
+
|
1384
|
+
for (const auto & layer : layers) {
|
1385
|
+
const uint32_t il = layer.il;
|
1386
|
+
|
1129
1387
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
1130
1388
|
|
1131
1389
|
// Write value type
|
1132
|
-
const int32_t v_type_i = (int32_t)
|
1390
|
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
1133
1391
|
io.write(&v_type_i, sizeof(v_type_i));
|
1134
1392
|
|
1135
1393
|
// Write element size
|
1136
|
-
const uint32_t v_size_el = lm_ggml_type_size(
|
1394
|
+
const uint32_t v_size_el = lm_ggml_type_size(layer.v->type);
|
1137
1395
|
io.write(&v_size_el, sizeof(v_size_el));
|
1138
1396
|
|
1139
1397
|
// Write GQA embedding size
|
@@ -1146,7 +1404,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
1146
1404
|
const size_t range_size = range.second - range.first;
|
1147
1405
|
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
1148
1406
|
const size_t buf_size = range_size * v_size_el;
|
1149
|
-
io.write_tensor(
|
1407
|
+
io.write_tensor(layer.v, src_offset, buf_size);
|
1150
1408
|
}
|
1151
1409
|
}
|
1152
1410
|
}
|
@@ -1163,8 +1421,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
1163
1421
|
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
1164
1422
|
|
1165
1423
|
batch.n_tokens = cell_count;
|
1166
|
-
batch.n_seq_tokens = cell_count;
|
1167
|
-
batch.n_seqs = 1;
|
1168
1424
|
|
1169
1425
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
1170
1426
|
llama_pos pos;
|
@@ -1179,13 +1435,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
1179
1435
|
}
|
1180
1436
|
|
1181
1437
|
batch.pos[i] = pos;
|
1438
|
+
batch.n_seq_id[i] = 1;
|
1439
|
+
batch.seq_id[i] = &dest_seq_id;
|
1182
1440
|
}
|
1183
|
-
|
1184
|
-
batch.seq_id[0] = &dest_seq_id;
|
1441
|
+
|
1185
1442
|
if (!find_slot(batch)) {
|
1186
1443
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
1187
1444
|
return false;
|
1188
1445
|
}
|
1446
|
+
|
1189
1447
|
commit();
|
1190
1448
|
|
1191
1449
|
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
@@ -1220,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
1220
1478
|
llama_seq_id seq_id;
|
1221
1479
|
io.read_to(&seq_id, sizeof(seq_id));
|
1222
1480
|
|
1223
|
-
|
1224
|
-
|
1225
|
-
if (seq_id < 0) {
|
1226
|
-
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
1227
|
-
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
|
1481
|
+
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
|
1482
|
+
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
|
1228
1483
|
return false;
|
1229
1484
|
}
|
1230
1485
|
|
@@ -1242,11 +1497,12 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
1242
1497
|
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
1243
1498
|
uint32_t v_trans;
|
1244
1499
|
uint32_t n_layer;
|
1500
|
+
|
1245
1501
|
io.read_to(&v_trans, sizeof(v_trans));
|
1246
1502
|
io.read_to(&n_layer, sizeof(n_layer));
|
1247
1503
|
|
1248
|
-
if (n_layer !=
|
1249
|
-
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer,
|
1504
|
+
if (n_layer != layers.size()) {
|
1505
|
+
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
1250
1506
|
return false;
|
1251
1507
|
}
|
1252
1508
|
if (cell_count > size) {
|
@@ -1259,13 +1515,15 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1259
1515
|
}
|
1260
1516
|
|
1261
1517
|
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
1262
|
-
for (
|
1518
|
+
for (const auto & layer : layers) {
|
1519
|
+
const uint32_t il = layer.il;
|
1520
|
+
|
1263
1521
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
1264
1522
|
|
1265
1523
|
// Read type of key
|
1266
1524
|
int32_t k_type_i_ref;
|
1267
1525
|
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
1268
|
-
const int32_t k_type_i = (int32_t)
|
1526
|
+
const int32_t k_type_i = (int32_t) layer.k->type;
|
1269
1527
|
if (k_type_i != k_type_i_ref) {
|
1270
1528
|
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
1271
1529
|
return false;
|
@@ -1274,7 +1532,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1274
1532
|
// Read row size of key
|
1275
1533
|
uint64_t k_size_row_ref;
|
1276
1534
|
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
1277
|
-
const size_t k_size_row = lm_ggml_row_size(
|
1535
|
+
const size_t k_size_row = lm_ggml_row_size(layer.k->type, n_embd_k_gqa);
|
1278
1536
|
if (k_size_row != k_size_row_ref) {
|
1279
1537
|
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
1280
1538
|
return false;
|
@@ -1282,18 +1540,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1282
1540
|
|
1283
1541
|
if (cell_count) {
|
1284
1542
|
// Read and set the keys for the whole cell range
|
1285
|
-
lm_ggml_backend_tensor_set(
|
1543
|
+
lm_ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
1286
1544
|
}
|
1287
1545
|
}
|
1288
1546
|
|
1289
1547
|
if (!this->v_trans) {
|
1290
|
-
for (
|
1548
|
+
for (const auto & layer : layers) {
|
1549
|
+
const uint32_t il = layer.il;
|
1550
|
+
|
1291
1551
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
1292
1552
|
|
1293
1553
|
// Read type of value
|
1294
1554
|
int32_t v_type_i_ref;
|
1295
1555
|
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
1296
|
-
const int32_t v_type_i = (int32_t)
|
1556
|
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
1297
1557
|
if (v_type_i != v_type_i_ref) {
|
1298
1558
|
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
1299
1559
|
return false;
|
@@ -1302,7 +1562,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1302
1562
|
// Read row size of value
|
1303
1563
|
uint64_t v_size_row_ref;
|
1304
1564
|
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
1305
|
-
const size_t v_size_row = lm_ggml_row_size(
|
1565
|
+
const size_t v_size_row = lm_ggml_row_size(layer.v->type, n_embd_v_gqa);
|
1306
1566
|
if (v_size_row != v_size_row_ref) {
|
1307
1567
|
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
1308
1568
|
return false;
|
@@ -1310,18 +1570,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1310
1570
|
|
1311
1571
|
if (cell_count) {
|
1312
1572
|
// Read and set the values for the whole cell range
|
1313
|
-
lm_ggml_backend_tensor_set(
|
1573
|
+
lm_ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
1314
1574
|
}
|
1315
1575
|
}
|
1316
1576
|
} else {
|
1317
1577
|
// For each layer, read the values for each cell (transposed)
|
1318
|
-
for (
|
1578
|
+
for (const auto & layer : layers) {
|
1579
|
+
const uint32_t il = layer.il;
|
1580
|
+
|
1319
1581
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
1320
1582
|
|
1321
1583
|
// Read type of value
|
1322
1584
|
int32_t v_type_i_ref;
|
1323
1585
|
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
1324
|
-
const int32_t v_type_i = (int32_t)
|
1586
|
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
1325
1587
|
if (v_type_i != v_type_i_ref) {
|
1326
1588
|
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
1327
1589
|
return false;
|
@@ -1330,7 +1592,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1330
1592
|
// Read element size of value
|
1331
1593
|
uint32_t v_size_el_ref;
|
1332
1594
|
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
1333
|
-
const size_t v_size_el = lm_ggml_type_size(
|
1595
|
+
const size_t v_size_el = lm_ggml_type_size(layer.v->type);
|
1334
1596
|
if (v_size_el != v_size_el_ref) {
|
1335
1597
|
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
1336
1598
|
return false;
|
@@ -1348,7 +1610,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1348
1610
|
// For each row in the transposed matrix, read the values for the whole cell range
|
1349
1611
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
1350
1612
|
const size_t dst_offset = (head + j * size) * v_size_el;
|
1351
|
-
lm_ggml_backend_tensor_set(
|
1613
|
+
lm_ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
1352
1614
|
}
|
1353
1615
|
}
|
1354
1616
|
}
|
@@ -1357,6 +1619,193 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
1357
1619
|
return true;
|
1358
1620
|
}
|
1359
1621
|
|
1622
|
+
//
|
1623
|
+
// llama_kv_cache_unified_iswa
|
1624
|
+
//
|
1625
|
+
|
1626
|
+
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
1627
|
+
const llama_model & model,
|
1628
|
+
lm_ggml_type type_k,
|
1629
|
+
lm_ggml_type type_v,
|
1630
|
+
bool v_trans,
|
1631
|
+
bool offload,
|
1632
|
+
bool swa_full,
|
1633
|
+
uint32_t kv_size,
|
1634
|
+
uint32_t n_seq_max,
|
1635
|
+
uint32_t n_batch,
|
1636
|
+
uint32_t n_pad) : hparams(model.hparams) {
|
1637
|
+
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
1638
|
+
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
1639
|
+
|
1640
|
+
const uint32_t size_base = kv_size;
|
1641
|
+
|
1642
|
+
uint32_t size_swa = std::min(size_base, LM_GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
|
1643
|
+
|
1644
|
+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
|
1645
|
+
if (swa_full) {
|
1646
|
+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
1647
|
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
1648
|
+
|
1649
|
+
size_swa = size_base;
|
1650
|
+
do_prune = false;
|
1651
|
+
}
|
1652
|
+
|
1653
|
+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
1654
|
+
|
1655
|
+
kv_base = std::make_unique<llama_kv_cache_unified>(
|
1656
|
+
model, std::move(filter_base), type_k, type_v,
|
1657
|
+
v_trans, offload, size_base, n_seq_max, n_pad,
|
1658
|
+
0, LLAMA_SWA_TYPE_NONE);
|
1659
|
+
|
1660
|
+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
1661
|
+
|
1662
|
+
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
1663
|
+
model, std::move(filter_swa), type_k, type_v,
|
1664
|
+
v_trans, offload, size_swa, n_seq_max, n_pad,
|
1665
|
+
hparams.n_swa, hparams.swa_type);
|
1666
|
+
}
|
1667
|
+
|
1668
|
+
void llama_kv_cache_unified_iswa::clear() {
|
1669
|
+
kv_base->clear();
|
1670
|
+
kv_swa ->clear();
|
1671
|
+
}
|
1672
|
+
|
1673
|
+
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
1674
|
+
bool res = true;
|
1675
|
+
|
1676
|
+
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
1677
|
+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
1678
|
+
|
1679
|
+
return res;
|
1680
|
+
}
|
1681
|
+
|
1682
|
+
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
1683
|
+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
1684
|
+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
1685
|
+
}
|
1686
|
+
|
1687
|
+
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
1688
|
+
kv_base->seq_keep(seq_id);
|
1689
|
+
kv_swa ->seq_keep(seq_id);
|
1690
|
+
}
|
1691
|
+
|
1692
|
+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
1693
|
+
kv_base->seq_add(seq_id, p0, p1, delta);
|
1694
|
+
kv_swa ->seq_add(seq_id, p0, p1, delta);
|
1695
|
+
}
|
1696
|
+
|
1697
|
+
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
1698
|
+
kv_base->seq_div(seq_id, p0, p1, d);
|
1699
|
+
kv_swa ->seq_div(seq_id, p0, p1, d);
|
1700
|
+
}
|
1701
|
+
|
1702
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
1703
|
+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
1704
|
+
return kv_swa->seq_pos_min(seq_id);
|
1705
|
+
}
|
1706
|
+
|
1707
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
1708
|
+
return kv_swa->seq_pos_max(seq_id);
|
1709
|
+
}
|
1710
|
+
|
1711
|
+
void llama_kv_cache_unified_iswa::restore() {
|
1712
|
+
kv_base->restore();
|
1713
|
+
kv_swa ->restore();
|
1714
|
+
}
|
1715
|
+
|
1716
|
+
void llama_kv_cache_unified_iswa::commit() {
|
1717
|
+
kv_base->commit();
|
1718
|
+
kv_swa ->commit();
|
1719
|
+
|
1720
|
+
// slide the attention window, forgetting/pruning old tokens that are outside the window
|
1721
|
+
if (do_prune) {
|
1722
|
+
for (const auto & [seq_id, entry] : pending.pos) {
|
1723
|
+
kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
|
1724
|
+
}
|
1725
|
+
|
1726
|
+
}
|
1727
|
+
|
1728
|
+
pending.clear();
|
1729
|
+
}
|
1730
|
+
|
1731
|
+
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
1732
|
+
bool res = true;
|
1733
|
+
|
1734
|
+
res = res & kv_base->update(lctx);
|
1735
|
+
res = res & kv_swa ->update(lctx);
|
1736
|
+
|
1737
|
+
return res;
|
1738
|
+
}
|
1739
|
+
|
1740
|
+
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
1741
|
+
kv_base->defrag_sched(thold);
|
1742
|
+
kv_swa ->defrag_sched(thold);
|
1743
|
+
}
|
1744
|
+
|
1745
|
+
void llama_kv_cache_unified_iswa::set_full() {
|
1746
|
+
kv_base->set_full();
|
1747
|
+
kv_swa ->set_full();
|
1748
|
+
}
|
1749
|
+
|
1750
|
+
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
|
1751
|
+
pending.clear();
|
1752
|
+
|
1753
|
+
if (do_prune) {
|
1754
|
+
for (int i = 0; i < batch.n_tokens; ++i) {
|
1755
|
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
1756
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
1757
|
+
const llama_pos pos = batch.pos[i];
|
1758
|
+
|
1759
|
+
if (pending.pos.find(seq_id) == pending.pos.end()) {
|
1760
|
+
pending.pos[seq_id].pmin = pos;
|
1761
|
+
pending.pos[seq_id].pmax = pos;
|
1762
|
+
} else {
|
1763
|
+
pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
|
1764
|
+
pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
|
1765
|
+
}
|
1766
|
+
}
|
1767
|
+
}
|
1768
|
+
}
|
1769
|
+
|
1770
|
+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
1771
|
+
}
|
1772
|
+
|
1773
|
+
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
|
1774
|
+
LM_GGML_UNUSED(embd_pooled);
|
1775
|
+
return sbatch.split_simple(n_ubatch);
|
1776
|
+
}
|
1777
|
+
|
1778
|
+
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
|
1779
|
+
bool res = true;
|
1780
|
+
|
1781
|
+
res = res & kv_base->find_slot(batch);
|
1782
|
+
res = res & kv_swa ->find_slot(batch);
|
1783
|
+
|
1784
|
+
return res;
|
1785
|
+
}
|
1786
|
+
|
1787
|
+
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
1788
|
+
return kv_base->get_size() == kv_swa->get_size();
|
1789
|
+
}
|
1790
|
+
|
1791
|
+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
1792
|
+
kv_base->state_write(io, seq_id);
|
1793
|
+
kv_swa ->state_write(io, seq_id);
|
1794
|
+
}
|
1795
|
+
|
1796
|
+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
1797
|
+
kv_base->state_read(io, seq_id);
|
1798
|
+
kv_swa ->state_read(io, seq_id);
|
1799
|
+
}
|
1800
|
+
|
1801
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
|
1802
|
+
return kv_base.get();
|
1803
|
+
}
|
1804
|
+
|
1805
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
|
1806
|
+
return kv_swa.get();
|
1807
|
+
}
|
1808
|
+
|
1360
1809
|
//
|
1361
1810
|
// llama_kv_cache_recurrent
|
1362
1811
|
//
|
@@ -1366,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
1366
1815
|
lm_ggml_type type_k,
|
1367
1816
|
lm_ggml_type type_v,
|
1368
1817
|
bool offload,
|
1369
|
-
uint32_t kv_size
|
1818
|
+
uint32_t kv_size,
|
1819
|
+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
1370
1820
|
const int32_t n_layer = hparams.n_layer;
|
1371
1821
|
|
1372
|
-
LLAMA_LOG_INFO("%s: kv_size = %
|
1373
|
-
__func__, kv_size, lm_ggml_type_name(type_k), lm_ggml_type_name(type_v), n_layer);
|
1822
|
+
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
1823
|
+
__func__, kv_size, n_seq_max, lm_ggml_type_name(type_k), lm_ggml_type_name(type_v), n_layer);
|
1374
1824
|
|
1375
1825
|
head = 0;
|
1376
1826
|
size = kv_size;
|
1377
1827
|
used = 0;
|
1378
1828
|
|
1379
|
-
this->type_k = type_k;
|
1380
|
-
this->type_v = type_v;
|
1381
|
-
|
1382
1829
|
cells.clear();
|
1383
1830
|
cells.resize(kv_size);
|
1384
1831
|
|
@@ -1676,8 +2123,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
1676
2123
|
}
|
1677
2124
|
}
|
1678
2125
|
|
2126
|
+
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
2127
|
+
llama_pos result = std::numeric_limits<llama_pos>::max();
|
2128
|
+
|
2129
|
+
for (uint32_t i = 0; i < size; ++i) {
|
2130
|
+
if (cells[i].has_seq_id(seq_id)) {
|
2131
|
+
result = std::min(result, cells[i].pos);
|
2132
|
+
}
|
2133
|
+
}
|
2134
|
+
|
2135
|
+
if (result == std::numeric_limits<llama_pos>::max()) {
|
2136
|
+
result = -1;
|
2137
|
+
}
|
2138
|
+
|
2139
|
+
return result;
|
2140
|
+
}
|
2141
|
+
|
1679
2142
|
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
1680
|
-
llama_pos result =
|
2143
|
+
llama_pos result = -1;
|
1681
2144
|
|
1682
2145
|
for (uint32_t i = 0; i < size; ++i) {
|
1683
2146
|
if (cells[i].has_seq_id(seq_id)) {
|
@@ -1700,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() {
|
|
1700
2163
|
pending.ranges.clear();
|
1701
2164
|
}
|
1702
2165
|
|
1703
|
-
bool llama_kv_cache_recurrent::update(llama_context &
|
1704
|
-
LM_GGML_UNUSED(
|
2166
|
+
bool llama_kv_cache_recurrent::update(llama_context & ctx) {
|
2167
|
+
LM_GGML_UNUSED(ctx);
|
1705
2168
|
return false;
|
1706
2169
|
}
|
1707
2170
|
|
@@ -1712,6 +2175,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
|
1712
2175
|
|
1713
2176
|
void llama_kv_cache_recurrent::set_full() {
|
1714
2177
|
n = size;
|
2178
|
+
head = 0;
|
1715
2179
|
}
|
1716
2180
|
|
1717
2181
|
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
|
@@ -1761,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot(
|
|
1761
2225
|
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
1762
2226
|
// too big seq_id
|
1763
2227
|
// TODO: would it be possible to resize the cache instead?
|
1764
|
-
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%
|
2228
|
+
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
|
1765
2229
|
return false;
|
1766
2230
|
}
|
1767
2231
|
if (j > 0) {
|
@@ -1904,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot(
|
|
1904
2368
|
return n >= n_seqs;
|
1905
2369
|
}
|
1906
2370
|
|
1907
|
-
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
|
1908
|
-
int32_t result = 0;
|
1909
|
-
|
1910
|
-
for (uint32_t i = 0; i < size; i++) {
|
1911
|
-
result += cells[i].seq_id.size();
|
1912
|
-
}
|
1913
|
-
|
1914
|
-
return result;
|
1915
|
-
}
|
1916
|
-
|
1917
|
-
int32_t llama_kv_cache_recurrent::get_used_cells() const {
|
1918
|
-
return used;
|
1919
|
-
}
|
1920
|
-
|
1921
|
-
llama_pos llama_kv_cache_recurrent::get_pos_max() const {
|
1922
|
-
llama_pos pos_max = -1;
|
1923
|
-
for (const auto & cell : cells) {
|
1924
|
-
pos_max = std::max(pos_max, cell.pos);
|
1925
|
-
}
|
1926
|
-
|
1927
|
-
return pos_max;
|
1928
|
-
}
|
1929
|
-
|
1930
2371
|
bool llama_kv_cache_recurrent::get_can_shift() const {
|
1931
2372
|
return false;
|
1932
2373
|
}
|
@@ -2055,6 +2496,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
2055
2496
|
io.read_to(&cell_count, sizeof(cell_count));
|
2056
2497
|
|
2057
2498
|
bool res = true;
|
2499
|
+
|
2058
2500
|
res = res && state_read_meta(io, cell_count, seq_id);
|
2059
2501
|
res = res && state_read_data(io, cell_count);
|
2060
2502
|
|
@@ -2383,104 +2825,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
2383
2825
|
|
2384
2826
|
return true;
|
2385
2827
|
}
|
2386
|
-
|
2387
|
-
//
|
2388
|
-
// kv cache view
|
2389
|
-
//
|
2390
|
-
|
2391
|
-
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
|
2392
|
-
llama_kv_cache_view result = {
|
2393
|
-
/*.n_cells = */ 0,
|
2394
|
-
/*.n_seq_max = */ n_seq_max,
|
2395
|
-
/*.token_count = */ 0,
|
2396
|
-
/*.used_cells = */ kv.get_used_cells(),
|
2397
|
-
/*.max_contiguous = */ 0,
|
2398
|
-
/*.max_contiguous_idx = */ -1,
|
2399
|
-
/*.cells = */ nullptr,
|
2400
|
-
/*.cells_sequences = */ nullptr,
|
2401
|
-
};
|
2402
|
-
|
2403
|
-
return result;
|
2404
|
-
}
|
2405
|
-
|
2406
|
-
void llama_kv_cache_view_free(llama_kv_cache_view * view) {
|
2407
|
-
if (view->cells != nullptr) {
|
2408
|
-
free(view->cells);
|
2409
|
-
view->cells = nullptr;
|
2410
|
-
}
|
2411
|
-
if (view->cells_sequences != nullptr) {
|
2412
|
-
free(view->cells_sequences);
|
2413
|
-
view->cells_sequences = nullptr;
|
2414
|
-
}
|
2415
|
-
}
|
2416
|
-
|
2417
|
-
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
|
2418
|
-
// TODO: rework this in the future, for now quick hack
|
2419
|
-
const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
|
2420
|
-
if (kvu == nullptr) {
|
2421
|
-
LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
|
2422
|
-
return;
|
2423
|
-
}
|
2424
|
-
|
2425
|
-
if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
|
2426
|
-
view->n_cells = int32_t(kvu->size);
|
2427
|
-
void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
|
2428
|
-
LM_GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
2429
|
-
view->cells = (llama_kv_cache_view_cell *)p;
|
2430
|
-
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
|
2431
|
-
LM_GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
2432
|
-
view->cells_sequences = (llama_seq_id *)p;
|
2433
|
-
}
|
2434
|
-
|
2435
|
-
const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
|
2436
|
-
llama_kv_cache_view_cell * c_curr = view->cells;
|
2437
|
-
llama_seq_id * cs_curr = view->cells_sequences;
|
2438
|
-
int32_t used_cells = 0;
|
2439
|
-
int32_t token_count = 0;
|
2440
|
-
int32_t curr_contig_idx = -1;
|
2441
|
-
uint32_t max_contig = 0;
|
2442
|
-
int32_t max_contig_idx = -1;
|
2443
|
-
|
2444
|
-
for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
|
2445
|
-
const size_t curr_size = kv_cells[i].seq_id.size();
|
2446
|
-
token_count += curr_size;
|
2447
|
-
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
2448
|
-
|
2449
|
-
if (curr_size > 0) {
|
2450
|
-
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
|
2451
|
-
max_contig = i - curr_contig_idx;
|
2452
|
-
max_contig_idx = curr_contig_idx;
|
2453
|
-
}
|
2454
|
-
curr_contig_idx = -1;
|
2455
|
-
} else if (curr_contig_idx < 0) {
|
2456
|
-
curr_contig_idx = i;
|
2457
|
-
}
|
2458
|
-
|
2459
|
-
int seq_idx = 0;
|
2460
|
-
for (const llama_seq_id it : kv_cells[i].seq_id) {
|
2461
|
-
if (seq_idx >= view->n_seq_max) {
|
2462
|
-
break;
|
2463
|
-
}
|
2464
|
-
cs_curr[seq_idx] = it;
|
2465
|
-
seq_idx++;
|
2466
|
-
}
|
2467
|
-
if (seq_idx != 0) {
|
2468
|
-
used_cells++;
|
2469
|
-
}
|
2470
|
-
for (; seq_idx < view->n_seq_max; seq_idx++) {
|
2471
|
-
cs_curr[seq_idx] = -1;
|
2472
|
-
}
|
2473
|
-
}
|
2474
|
-
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
|
2475
|
-
max_contig_idx = curr_contig_idx;
|
2476
|
-
max_contig = kv_cells.size() - curr_contig_idx;
|
2477
|
-
}
|
2478
|
-
view->max_contiguous = max_contig;
|
2479
|
-
view->max_contiguous_idx = max_contig_idx;
|
2480
|
-
view->token_count = token_count;
|
2481
|
-
view->used_cells = used_cells;
|
2482
|
-
if (uint32_t(used_cells) != kvu->used) {
|
2483
|
-
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
2484
|
-
__func__, kvu->used, used_cells);
|
2485
|
-
}
|
2486
|
-
}
|