@novastera-oss/llamarn 0.2.7 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/common/arg.cpp +7 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +1 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
- package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -3
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
- package/cpp/llama.cpp/src/llama-batch.h +98 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
- package/cpp/llama.cpp/src/llama-graph.h +44 -32
- package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-hparams.h +8 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
- package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.h +18 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
- package/cpp/llama.cpp/src/llama-model.h +22 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/common.h +1 -0
- package/ios/include/llama.h +8 -3
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
|
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
|
33
33
|
|
|
34
34
|
GGML_ASSERT(kv_size % n_pad == 0);
|
|
35
35
|
|
|
36
|
+
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
|
37
|
+
auto n_layer_cache = hparams.n_layer;
|
|
38
|
+
if (model.arch == LLM_ARCH_GEMMA3N) {
|
|
39
|
+
n_layer_cache = 20;
|
|
40
|
+
}
|
|
41
|
+
|
|
36
42
|
// create a context for each buffer type
|
|
37
43
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
38
44
|
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
|
39
45
|
auto it = ctx_map.find(buft);
|
|
40
46
|
if (it == ctx_map.end()) {
|
|
41
47
|
ggml_init_params params = {
|
|
42
|
-
/*.mem_size =*/ size_t(2u*
|
|
48
|
+
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
|
|
43
49
|
/*.mem_buffer =*/ NULL,
|
|
44
50
|
/*.no_alloc =*/ true,
|
|
45
51
|
};
|
|
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
|
62
68
|
|
|
63
69
|
cells.resize(kv_size);
|
|
64
70
|
|
|
65
|
-
for (uint32_t il = 0; il <
|
|
71
|
+
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
|
66
72
|
if (filter && !filter(il)) {
|
|
67
73
|
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
|
68
74
|
continue;
|
|
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
|
102
108
|
layers.push_back({ il, k, v });
|
|
103
109
|
}
|
|
104
110
|
|
|
111
|
+
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
|
112
|
+
if (model.arch == LLM_ARCH_GEMMA3N) {
|
|
113
|
+
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
|
|
114
|
+
|
|
115
|
+
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
|
|
116
|
+
if (filter && !filter(il)) {
|
|
117
|
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
|
118
|
+
continue;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
const bool is_swa = hparams.is_swa(il);
|
|
122
|
+
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
|
|
123
|
+
|
|
124
|
+
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
|
|
125
|
+
map_layer_ids[il] = map_layer_ids[il_reuse];
|
|
126
|
+
|
|
127
|
+
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
105
131
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
106
132
|
for (auto it : ctx_map) {
|
|
107
133
|
auto * buft = it.first;
|
|
@@ -307,18 +333,24 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
307
333
|
return cells.seq_pos_max(seq_id);
|
|
308
334
|
}
|
|
309
335
|
|
|
310
|
-
|
|
311
|
-
|
|
336
|
+
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|
337
|
+
llama_batch_allocr & balloc,
|
|
312
338
|
uint32_t n_ubatch,
|
|
313
339
|
bool embd_all) {
|
|
314
340
|
GGML_UNUSED(embd_all);
|
|
315
341
|
|
|
316
342
|
do {
|
|
317
|
-
|
|
343
|
+
balloc.split_reset();
|
|
318
344
|
|
|
319
345
|
std::vector<llama_ubatch> ubatches;
|
|
320
|
-
while (
|
|
321
|
-
|
|
346
|
+
while (true) {
|
|
347
|
+
auto ubatch = balloc.split_simple(n_ubatch);
|
|
348
|
+
|
|
349
|
+
if (ubatch.n_tokens == 0) {
|
|
350
|
+
break;
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
322
354
|
}
|
|
323
355
|
|
|
324
356
|
auto heads = prepare(ubatches);
|
|
@@ -326,18 +358,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|
|
326
358
|
break;
|
|
327
359
|
}
|
|
328
360
|
|
|
329
|
-
return std::make_unique<
|
|
330
|
-
this, std::move(
|
|
361
|
+
return std::make_unique<llama_kv_cache_unified_context>(
|
|
362
|
+
this, std::move(heads), std::move(ubatches));
|
|
331
363
|
} while (false);
|
|
332
364
|
|
|
333
|
-
return std::make_unique<
|
|
365
|
+
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
334
366
|
}
|
|
335
367
|
|
|
336
|
-
|
|
337
|
-
return std::make_unique<
|
|
368
|
+
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
|
|
369
|
+
return std::make_unique<llama_kv_cache_unified_context>(this);
|
|
338
370
|
}
|
|
339
371
|
|
|
340
|
-
|
|
372
|
+
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
|
341
373
|
bool do_shift = get_has_shift();
|
|
342
374
|
|
|
343
375
|
defrag_info dinfo;
|
|
@@ -367,7 +399,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
|
|
|
367
399
|
}
|
|
368
400
|
}
|
|
369
401
|
|
|
370
|
-
return std::make_unique<
|
|
402
|
+
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
|
371
403
|
}
|
|
372
404
|
|
|
373
405
|
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
@@ -644,12 +676,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
644
676
|
}
|
|
645
677
|
|
|
646
678
|
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
|
647
|
-
if (debug > 0) {
|
|
648
|
-
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
|
|
649
|
-
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
|
|
650
|
-
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
|
|
651
|
-
}
|
|
652
|
-
|
|
653
679
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
|
654
680
|
// for non-SWA cache, this would be always empty
|
|
655
681
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
|
@@ -657,27 +683,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
|
657
683
|
seq_pos_max_rm[s] = -1;
|
|
658
684
|
}
|
|
659
685
|
|
|
660
|
-
for (uint32_t
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
if (!cells.is_empty(head_cur + idx)) {
|
|
665
|
-
assert(cells.seq_count(head_cur + idx) == 1);
|
|
686
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
687
|
+
if (!cells.is_empty(head_cur + i)) {
|
|
688
|
+
assert(cells.seq_count(head_cur + i) == 1);
|
|
666
689
|
|
|
667
|
-
|
|
668
|
-
|
|
690
|
+
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
|
691
|
+
const llama_pos pos = cells.pos_get(head_cur + i);
|
|
669
692
|
|
|
670
|
-
|
|
693
|
+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
|
671
694
|
|
|
672
|
-
|
|
673
|
-
|
|
695
|
+
cells.rm(head_cur + i);
|
|
696
|
+
}
|
|
674
697
|
|
|
675
|
-
|
|
698
|
+
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
|
676
699
|
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
|
|
680
|
-
}
|
|
700
|
+
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
701
|
+
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
|
|
681
702
|
}
|
|
682
703
|
}
|
|
683
704
|
|
|
@@ -696,6 +717,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
|
696
717
|
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
|
697
718
|
}
|
|
698
719
|
}
|
|
720
|
+
|
|
699
721
|
// move the head at the end of the slot
|
|
700
722
|
head = head_cur + ubatch.n_tokens;
|
|
701
723
|
}
|
|
@@ -792,9 +814,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|
|
792
814
|
}
|
|
793
815
|
|
|
794
816
|
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
|
795
|
-
const uint32_t n_tokens
|
|
796
|
-
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
797
|
-
const uint32_t n_seqs = ubatch->n_seqs;
|
|
817
|
+
const uint32_t n_tokens = ubatch->n_tokens;
|
|
798
818
|
|
|
799
819
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
|
800
820
|
float * data = (float *) dst->data;
|
|
@@ -814,52 +834,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
|
814
834
|
// xxxxx-----
|
|
815
835
|
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
|
816
836
|
for (uint32_t h = 0; h < 1; ++h) {
|
|
817
|
-
for (uint32_t
|
|
818
|
-
const llama_seq_id seq_id = ubatch->seq_id[
|
|
819
|
-
|
|
820
|
-
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
|
|
821
|
-
const uint32_t idx = s*n_seq_tokens + j;
|
|
837
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
|
838
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
822
839
|
|
|
823
|
-
|
|
840
|
+
const llama_pos p1 = ubatch->pos[i];
|
|
824
841
|
|
|
825
|
-
|
|
826
|
-
|
|
842
|
+
for (uint32_t j = 0; j < n_kv; ++j) {
|
|
843
|
+
float f = 0.0f;
|
|
827
844
|
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
if (cells.is_empty(i)) {
|
|
831
|
-
masked = true;
|
|
832
|
-
} else {
|
|
833
|
-
const llama_pos p0 = cells.pos_get(i);
|
|
845
|
+
bool masked = false;
|
|
834
846
|
|
|
835
|
-
|
|
836
|
-
|
|
847
|
+
if (cells.is_empty(j)) {
|
|
848
|
+
masked = true;
|
|
849
|
+
} else {
|
|
850
|
+
const llama_pos p0 = cells.pos_get(j);
|
|
837
851
|
|
|
838
|
-
|
|
839
|
-
|
|
852
|
+
// mask the token if not the same sequence
|
|
853
|
+
masked = masked || (!cells.seq_has(j, seq_id));
|
|
840
854
|
|
|
841
|
-
|
|
842
|
-
|
|
855
|
+
// mask future tokens
|
|
856
|
+
masked = masked || (causal_attn && p0 > p1);
|
|
843
857
|
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
}
|
|
847
|
-
}
|
|
858
|
+
// apply SWA if any
|
|
859
|
+
masked = masked || (is_masked_swa(p0, p1));
|
|
848
860
|
|
|
849
|
-
if (masked) {
|
|
850
|
-
f = -
|
|
861
|
+
if (!masked && hparams.use_alibi) {
|
|
862
|
+
f = -std::abs(p0 - p1);
|
|
851
863
|
}
|
|
864
|
+
}
|
|
852
865
|
|
|
853
|
-
|
|
866
|
+
if (masked) {
|
|
867
|
+
f = -INFINITY;
|
|
854
868
|
}
|
|
869
|
+
|
|
870
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
|
|
855
871
|
}
|
|
856
872
|
}
|
|
857
873
|
|
|
858
874
|
// mask padded tokens
|
|
859
875
|
if (data) {
|
|
860
|
-
for (uint32_t
|
|
861
|
-
for (uint32_t
|
|
862
|
-
data[h*(n_kv*n_tokens) +
|
|
876
|
+
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
877
|
+
for (uint32_t j = 0; j < n_kv; ++j) {
|
|
878
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
863
879
|
}
|
|
864
880
|
}
|
|
865
881
|
}
|
|
@@ -887,12 +903,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
|
|
887
903
|
const int32_t n_kv = dst->ne[0];
|
|
888
904
|
|
|
889
905
|
for (int h = 0; h < 1; ++h) {
|
|
890
|
-
for (int
|
|
891
|
-
for (int
|
|
906
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
907
|
+
for (int j = 0; j < n_kv; ++j) {
|
|
892
908
|
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
|
893
|
-
const llama_pos p0 = cells.is_empty(
|
|
909
|
+
const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
|
|
894
910
|
|
|
895
|
-
data[h*(n_kv*n_tokens) +
|
|
911
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
|
|
896
912
|
}
|
|
897
913
|
}
|
|
898
914
|
}
|
|
@@ -1509,12 +1525,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
|
1509
1525
|
|
|
1510
1526
|
seq_rm(dest_seq_id, -1, -1);
|
|
1511
1527
|
|
|
1512
|
-
|
|
1513
|
-
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
|
1528
|
+
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
|
1514
1529
|
|
|
1515
|
-
ubatch
|
|
1516
|
-
ubatch.n_seq_tokens = cell_count;
|
|
1517
|
-
ubatch.n_seqs = 1;
|
|
1530
|
+
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
|
1518
1531
|
|
|
1519
1532
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1520
1533
|
llama_pos pos;
|
|
@@ -1723,18 +1736,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
|
1723
1736
|
}
|
|
1724
1737
|
|
|
1725
1738
|
//
|
|
1726
|
-
//
|
|
1739
|
+
// llama_kv_cache_unified_context
|
|
1727
1740
|
//
|
|
1728
1741
|
|
|
1729
|
-
|
|
1742
|
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
|
|
1730
1743
|
|
|
1731
|
-
|
|
1744
|
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
1732
1745
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
|
1733
1746
|
n_kv = kv->get_size();
|
|
1734
1747
|
head = 0;
|
|
1735
1748
|
}
|
|
1736
1749
|
|
|
1737
|
-
|
|
1750
|
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
1738
1751
|
llama_kv_cache_unified * kv,
|
|
1739
1752
|
llama_context * lctx,
|
|
1740
1753
|
bool do_shift,
|
|
@@ -1744,16 +1757,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|
|
1744
1757
|
}
|
|
1745
1758
|
}
|
|
1746
1759
|
|
|
1747
|
-
|
|
1760
|
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
1748
1761
|
llama_kv_cache_unified * kv,
|
|
1749
|
-
llama_sbatch sbatch,
|
|
1750
1762
|
llama_kv_cache_unified::ubatch_heads heads,
|
|
1751
|
-
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv),
|
|
1763
|
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
|
1752
1764
|
}
|
|
1753
1765
|
|
|
1754
|
-
|
|
1766
|
+
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
|
1755
1767
|
|
|
1756
|
-
bool
|
|
1768
|
+
bool llama_kv_cache_unified_context::next() {
|
|
1757
1769
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1758
1770
|
|
|
1759
1771
|
if (++i_next >= ubatches.size()) {
|
|
@@ -1763,7 +1775,7 @@ bool llama_kv_cache_unified_state::next() {
|
|
|
1763
1775
|
return true;
|
|
1764
1776
|
}
|
|
1765
1777
|
|
|
1766
|
-
bool
|
|
1778
|
+
bool llama_kv_cache_unified_context::apply() {
|
|
1767
1779
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1768
1780
|
|
|
1769
1781
|
// no ubatches -> this is a KV cache update
|
|
@@ -1781,51 +1793,45 @@ bool llama_kv_cache_unified_state::apply() {
|
|
|
1781
1793
|
return true;
|
|
1782
1794
|
}
|
|
1783
1795
|
|
|
1784
|
-
|
|
1785
|
-
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1786
|
-
|
|
1787
|
-
return sbatch.out_ids;
|
|
1788
|
-
}
|
|
1789
|
-
|
|
1790
|
-
llama_memory_status llama_kv_cache_unified_state::get_status() const {
|
|
1796
|
+
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
|
1791
1797
|
return status;
|
|
1792
1798
|
}
|
|
1793
1799
|
|
|
1794
|
-
const llama_ubatch &
|
|
1800
|
+
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
|
1795
1801
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1796
1802
|
|
|
1797
1803
|
return ubatches[i_next];
|
|
1798
1804
|
}
|
|
1799
1805
|
|
|
1800
|
-
uint32_t
|
|
1806
|
+
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
|
1801
1807
|
return n_kv;
|
|
1802
1808
|
}
|
|
1803
1809
|
|
|
1804
|
-
ggml_tensor *
|
|
1810
|
+
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
|
1805
1811
|
return kv->get_k(ctx, il, n_kv);
|
|
1806
1812
|
}
|
|
1807
1813
|
|
|
1808
|
-
ggml_tensor *
|
|
1814
|
+
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
|
1809
1815
|
return kv->get_v(ctx, il, n_kv);
|
|
1810
1816
|
}
|
|
1811
1817
|
|
|
1812
|
-
ggml_tensor *
|
|
1818
|
+
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
|
1813
1819
|
return kv->cpy_k(ctx, k_cur, il, head);
|
|
1814
1820
|
}
|
|
1815
1821
|
|
|
1816
|
-
ggml_tensor *
|
|
1822
|
+
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
|
1817
1823
|
return kv->cpy_v(ctx, v_cur, il, head);
|
|
1818
1824
|
}
|
|
1819
1825
|
|
|
1820
|
-
void
|
|
1826
|
+
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
|
1821
1827
|
kv->set_input_k_shift(dst);
|
|
1822
1828
|
}
|
|
1823
1829
|
|
|
1824
|
-
void
|
|
1830
|
+
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
|
1825
1831
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
|
1826
1832
|
}
|
|
1827
1833
|
|
|
1828
|
-
void
|
|
1834
|
+
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
1829
1835
|
kv->set_input_pos_bucket(dst, ubatch);
|
|
1830
1836
|
}
|
|
1831
1837
|
|
|
@@ -56,14 +56,14 @@ public:
|
|
|
56
56
|
// llama_memory_i
|
|
57
57
|
//
|
|
58
58
|
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
llama_memory_context_ptr init_batch(
|
|
60
|
+
llama_batch_allocr & balloc,
|
|
61
61
|
uint32_t n_ubatch,
|
|
62
62
|
bool embd_all) override;
|
|
63
63
|
|
|
64
|
-
|
|
64
|
+
llama_memory_context_ptr init_full() override;
|
|
65
65
|
|
|
66
|
-
|
|
66
|
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
67
67
|
|
|
68
68
|
bool get_can_shift() const override;
|
|
69
69
|
|
|
@@ -208,49 +208,46 @@ private:
|
|
|
208
208
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
|
209
209
|
};
|
|
210
210
|
|
|
211
|
-
class
|
|
211
|
+
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
|
212
212
|
public:
|
|
213
213
|
// some shorthands
|
|
214
214
|
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
|
215
215
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
|
216
216
|
|
|
217
217
|
// used for errors
|
|
218
|
-
|
|
218
|
+
llama_kv_cache_unified_context(llama_memory_status status);
|
|
219
219
|
|
|
220
|
-
// used to create a full-cache
|
|
221
|
-
|
|
220
|
+
// used to create a full-cache context
|
|
221
|
+
llama_kv_cache_unified_context(
|
|
222
222
|
llama_kv_cache_unified * kv);
|
|
223
223
|
|
|
224
|
-
// used to create an update
|
|
225
|
-
|
|
224
|
+
// used to create an update context
|
|
225
|
+
llama_kv_cache_unified_context(
|
|
226
226
|
llama_kv_cache_unified * kv,
|
|
227
227
|
llama_context * lctx,
|
|
228
228
|
bool do_shift,
|
|
229
229
|
defrag_info dinfo);
|
|
230
230
|
|
|
231
|
-
// used to create a
|
|
232
|
-
|
|
231
|
+
// used to create a batch procesing context from a batch
|
|
232
|
+
llama_kv_cache_unified_context(
|
|
233
233
|
llama_kv_cache_unified * kv,
|
|
234
|
-
llama_sbatch sbatch,
|
|
235
234
|
ubatch_heads heads,
|
|
236
235
|
std::vector<llama_ubatch> ubatches);
|
|
237
236
|
|
|
238
|
-
virtual ~
|
|
237
|
+
virtual ~llama_kv_cache_unified_context();
|
|
239
238
|
|
|
240
239
|
//
|
|
241
|
-
//
|
|
240
|
+
// llama_memory_context_i
|
|
242
241
|
//
|
|
243
242
|
|
|
244
243
|
bool next() override;
|
|
245
244
|
bool apply() override;
|
|
246
245
|
|
|
247
|
-
std::vector<int64_t> & out_ids() override;
|
|
248
|
-
|
|
249
246
|
llama_memory_status get_status() const override;
|
|
250
247
|
const llama_ubatch & get_ubatch() const override;
|
|
251
248
|
|
|
252
249
|
//
|
|
253
|
-
//
|
|
250
|
+
// llama_kv_cache_unified_context specific API
|
|
254
251
|
//
|
|
255
252
|
|
|
256
253
|
uint32_t get_n_kv() const;
|
|
@@ -275,7 +272,7 @@ private:
|
|
|
275
272
|
llama_context * lctx;
|
|
276
273
|
|
|
277
274
|
//
|
|
278
|
-
// update
|
|
275
|
+
// update context
|
|
279
276
|
//
|
|
280
277
|
|
|
281
278
|
bool do_shift = false;
|
|
@@ -283,11 +280,9 @@ private:
|
|
|
283
280
|
defrag_info dinfo;
|
|
284
281
|
|
|
285
282
|
//
|
|
286
|
-
// batch processing
|
|
283
|
+
// batch processing context
|
|
287
284
|
//
|
|
288
285
|
|
|
289
|
-
llama_sbatch sbatch;
|
|
290
|
-
|
|
291
286
|
// the index of the next ubatch to process
|
|
292
287
|
size_t i_next = 0;
|
|
293
288
|
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
#include <cassert>
|
|
8
8
|
#include <vector>
|
|
9
9
|
#include <set>
|
|
10
|
+
#include <map>
|
|
10
11
|
|
|
11
12
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
|
12
13
|
// TODO: add unit tests
|
|
@@ -164,7 +165,7 @@ public:
|
|
|
164
165
|
assert(seq_id >= 0);
|
|
165
166
|
|
|
166
167
|
seq[i].reset(seq_id);
|
|
167
|
-
|
|
168
|
+
seq_pos_dec(seq_id, pos[i]);
|
|
168
169
|
|
|
169
170
|
if (seq[i].none()) {
|
|
170
171
|
pos[i] = -1;
|
|
@@ -187,7 +188,7 @@ public:
|
|
|
187
188
|
seq[i].reset();
|
|
188
189
|
|
|
189
190
|
seq[i].set(seq_id);
|
|
190
|
-
|
|
191
|
+
seq_pos_inc(seq_id, pos[i]);
|
|
191
192
|
|
|
192
193
|
return false;
|
|
193
194
|
}
|
|
@@ -232,7 +233,7 @@ public:
|
|
|
232
233
|
assert(!seq[i].test(seq_id));
|
|
233
234
|
|
|
234
235
|
seq[i].set(seq_id);
|
|
235
|
-
|
|
236
|
+
seq_pos_inc(seq_id, pos[i]);
|
|
236
237
|
}
|
|
237
238
|
|
|
238
239
|
// return the sequence id of this cell
|
|
@@ -259,7 +260,9 @@ public:
|
|
|
259
260
|
return -1;
|
|
260
261
|
}
|
|
261
262
|
|
|
262
|
-
|
|
263
|
+
assert(seq_pos[seq_id].begin()->second > 0);
|
|
264
|
+
|
|
265
|
+
return seq_pos[seq_id].begin()->first;
|
|
263
266
|
}
|
|
264
267
|
|
|
265
268
|
// the maximum position of sequence seq_id currently present in any of the cells
|
|
@@ -272,7 +275,9 @@ public:
|
|
|
272
275
|
return -1;
|
|
273
276
|
}
|
|
274
277
|
|
|
275
|
-
|
|
278
|
+
assert(seq_pos[seq_id].rbegin()->second > 0);
|
|
279
|
+
|
|
280
|
+
return seq_pos[seq_id].rbegin()->first;
|
|
276
281
|
}
|
|
277
282
|
|
|
278
283
|
// note: call only if the cell is not empty
|
|
@@ -384,22 +389,41 @@ private:
|
|
|
384
389
|
//
|
|
385
390
|
std::vector<llama_pos> shift;
|
|
386
391
|
|
|
387
|
-
using
|
|
392
|
+
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
|
388
393
|
|
|
389
394
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
|
390
|
-
std::vector<
|
|
395
|
+
std::vector<seq_set_t> seq;
|
|
391
396
|
|
|
392
|
-
// the set seq_pos[s] tells us
|
|
397
|
+
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
|
398
|
+
// if the position p is not present, seq_pos[s][p] is not set
|
|
393
399
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
|
394
|
-
|
|
400
|
+
//
|
|
401
|
+
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
|
402
|
+
// - during performing a cache reuse via (rm + add)
|
|
403
|
+
// - some vision models have input embeddings with repeating positions
|
|
404
|
+
//
|
|
405
|
+
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
|
395
406
|
|
|
396
407
|
// helper functions for updating `seq_pos`, once cell at a time:
|
|
397
408
|
|
|
409
|
+
void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
|
410
|
+
auto it = seq_pos[s].find(p);
|
|
411
|
+
assert(it != seq_pos[s].end());
|
|
412
|
+
|
|
413
|
+
if (--it->second == 0) {
|
|
414
|
+
seq_pos[s].erase(it);
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
|
419
|
+
seq_pos[s][p]++;
|
|
420
|
+
}
|
|
421
|
+
|
|
398
422
|
// remove cell i
|
|
399
423
|
void seq_pos_rm(uint32_t i) {
|
|
400
424
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
401
425
|
if (seq[i].test(s)) {
|
|
402
|
-
|
|
426
|
+
seq_pos_dec(s, pos[i]);
|
|
403
427
|
}
|
|
404
428
|
}
|
|
405
429
|
}
|
|
@@ -408,7 +432,7 @@ private:
|
|
|
408
432
|
void seq_pos_add(uint32_t i) {
|
|
409
433
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
410
434
|
if (seq[i].test(s)) {
|
|
411
|
-
|
|
435
|
+
seq_pos_inc(s, pos[i]);
|
|
412
436
|
}
|
|
413
437
|
}
|
|
414
438
|
}
|