@novastera-oss/llamarn 0.2.7 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/common/arg.cpp +7 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +1 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
- package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -3
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
- package/cpp/llama.cpp/src/llama-batch.h +98 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
- package/cpp/llama.cpp/src/llama-graph.h +44 -32
- package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-hparams.h +8 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
- package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.h +18 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
- package/cpp/llama.cpp/src/llama-model.h +22 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/common.h +1 -0
- package/ios/include/llama.h +8 -3
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
|
@@ -32,7 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|
|
32
32
|
mem_attn(new llama_kv_cache_unified(
|
|
33
33
|
model,
|
|
34
34
|
filter_attn == nullptr ?
|
|
35
|
-
[&](int32_t il) { return !
|
|
35
|
+
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
|
36
36
|
: filter_attn,
|
|
37
37
|
type_k,
|
|
38
38
|
type_v,
|
|
@@ -47,7 +47,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|
|
47
47
|
mem_recr(new llama_memory_recurrent(
|
|
48
48
|
model,
|
|
49
49
|
filter_recr == nullptr ?
|
|
50
|
-
[&](int32_t il) { return
|
|
50
|
+
[&](int32_t il) { return hparams.is_recurrent(il); }
|
|
51
51
|
: filter_recr,
|
|
52
52
|
type_r,
|
|
53
53
|
type_s,
|
|
@@ -56,50 +56,57 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|
|
56
56
|
n_seq_max
|
|
57
57
|
)) {}
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
60
|
+
do {
|
|
61
|
+
balloc.split_reset();
|
|
60
62
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
+
// follow the recurrent pattern for creating the ubatch splits
|
|
64
|
+
std::vector<llama_ubatch> ubatches;
|
|
63
65
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
while (sbatch.n_tokens > 0) {
|
|
67
|
-
llama_ubatch ubatch;
|
|
66
|
+
while (true) {
|
|
67
|
+
llama_ubatch ubatch;
|
|
68
68
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
69
|
+
if (embd_all) {
|
|
70
|
+
// if all tokens are output, split by sequence
|
|
71
|
+
ubatch = balloc.split_seq(n_ubatch);
|
|
72
|
+
} else {
|
|
73
|
+
ubatch = balloc.split_equal(n_ubatch);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
if (ubatch.n_tokens == 0) {
|
|
77
|
+
break;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
74
81
|
}
|
|
75
82
|
|
|
76
|
-
|
|
77
|
-
|
|
83
|
+
// prepare the recurrent batches first
|
|
84
|
+
if (!mem_recr->prepare(ubatches)) {
|
|
85
|
+
// TODO: will the recurrent cache be in an undefined context at this point?
|
|
86
|
+
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
|
87
|
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
88
|
+
}
|
|
78
89
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
90
|
+
// prepare the attention cache
|
|
91
|
+
auto heads_attn = mem_attn->prepare(ubatches);
|
|
92
|
+
if (heads_attn.empty()) {
|
|
93
|
+
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
|
94
|
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
95
|
+
}
|
|
85
96
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
|
90
|
-
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
91
|
-
}
|
|
97
|
+
return std::make_unique<llama_memory_hybrid_context>(
|
|
98
|
+
this, std::move(heads_attn), std::move(ubatches));
|
|
99
|
+
} while(false);
|
|
92
100
|
|
|
93
|
-
return std::make_unique<
|
|
94
|
-
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
|
|
101
|
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
95
102
|
}
|
|
96
103
|
|
|
97
|
-
|
|
98
|
-
return std::make_unique<
|
|
104
|
+
llama_memory_context_ptr llama_memory_hybrid::init_full() {
|
|
105
|
+
return std::make_unique<llama_memory_hybrid_context>(this);
|
|
99
106
|
}
|
|
100
107
|
|
|
101
|
-
|
|
102
|
-
return std::make_unique<
|
|
108
|
+
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
|
109
|
+
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
|
|
103
110
|
}
|
|
104
111
|
|
|
105
112
|
bool llama_memory_hybrid::get_can_shift() const {
|
|
@@ -169,41 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
|
|
169
176
|
return mem_recr.get();
|
|
170
177
|
}
|
|
171
178
|
|
|
172
|
-
|
|
179
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
|
|
173
180
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
status(llama_memory_status_combine(
|
|
181
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
|
|
182
|
+
ctx_attn(mem->get_mem_attn()->init_full()),
|
|
183
|
+
ctx_recr(mem->get_mem_recr()->init_full()),
|
|
184
|
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
178
185
|
}
|
|
179
186
|
|
|
180
|
-
|
|
187
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|
181
188
|
llama_memory_hybrid * mem,
|
|
182
189
|
llama_context * lctx,
|
|
183
190
|
bool optimize) :
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
status(llama_memory_status_combine(
|
|
191
|
+
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
|
192
|
+
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
|
193
|
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
187
194
|
}
|
|
188
195
|
|
|
189
|
-
|
|
196
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|
190
197
|
llama_memory_hybrid * mem,
|
|
191
|
-
llama_sbatch sbatch,
|
|
192
198
|
std::vector<uint32_t> heads_attn,
|
|
193
199
|
std::vector<llama_ubatch> ubatches) :
|
|
194
|
-
sbatch(std::move(sbatch)),
|
|
195
200
|
ubatches(std::move(ubatches)),
|
|
196
201
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
status(
|
|
202
|
+
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
|
203
|
+
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
|
204
|
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
200
205
|
}
|
|
201
206
|
|
|
202
|
-
bool
|
|
207
|
+
bool llama_memory_hybrid_context::next() {
|
|
203
208
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
204
209
|
|
|
205
|
-
|
|
206
|
-
|
|
210
|
+
ctx_attn->next();
|
|
211
|
+
ctx_recr->next();
|
|
207
212
|
|
|
208
213
|
if (++i_next >= ubatches.size()) {
|
|
209
214
|
return false;
|
|
@@ -212,36 +217,30 @@ bool llama_memory_hybrid_state::next() {
|
|
|
212
217
|
return true;
|
|
213
218
|
}
|
|
214
219
|
|
|
215
|
-
bool
|
|
220
|
+
bool llama_memory_hybrid_context::apply() {
|
|
216
221
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
217
222
|
|
|
218
223
|
bool res = true;
|
|
219
224
|
|
|
220
|
-
res = res &
|
|
221
|
-
res = res &
|
|
225
|
+
res = res & ctx_attn->apply();
|
|
226
|
+
res = res & ctx_recr->apply();
|
|
222
227
|
|
|
223
228
|
return res;
|
|
224
229
|
}
|
|
225
230
|
|
|
226
|
-
|
|
227
|
-
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
228
|
-
|
|
229
|
-
return sbatch.out_ids;
|
|
230
|
-
}
|
|
231
|
-
|
|
232
|
-
llama_memory_status llama_memory_hybrid_state::get_status() const {
|
|
231
|
+
llama_memory_status llama_memory_hybrid_context::get_status() const {
|
|
233
232
|
return status;
|
|
234
233
|
}
|
|
235
234
|
|
|
236
|
-
const llama_ubatch &
|
|
235
|
+
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
|
237
236
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
238
237
|
return ubatches[i_next];
|
|
239
238
|
}
|
|
240
239
|
|
|
241
|
-
const
|
|
242
|
-
return static_cast<const
|
|
240
|
+
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
|
241
|
+
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
|
243
242
|
}
|
|
244
243
|
|
|
245
|
-
const
|
|
246
|
-
return static_cast<const
|
|
244
|
+
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
|
245
|
+
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
|
247
246
|
}
|
|
@@ -49,14 +49,14 @@ public:
|
|
|
49
49
|
// llama_memory_i
|
|
50
50
|
//
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
llama_memory_context_ptr init_batch(
|
|
53
|
+
llama_batch_allocr & balloc,
|
|
54
54
|
uint32_t n_ubatch,
|
|
55
|
-
bool
|
|
55
|
+
bool embd_all) override;
|
|
56
56
|
|
|
57
|
-
|
|
57
|
+
llama_memory_context_ptr init_full() override;
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
60
60
|
|
|
61
61
|
bool get_can_shift() const override;
|
|
62
62
|
|
|
@@ -90,54 +90,49 @@ private:
|
|
|
90
90
|
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
|
91
91
|
};
|
|
92
92
|
|
|
93
|
-
class
|
|
93
|
+
class llama_memory_hybrid_context : public llama_memory_context_i {
|
|
94
94
|
public:
|
|
95
95
|
// init failure
|
|
96
|
-
explicit
|
|
96
|
+
explicit llama_memory_hybrid_context(llama_memory_status status);
|
|
97
97
|
|
|
98
98
|
// init full
|
|
99
|
-
explicit
|
|
99
|
+
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
|
|
100
100
|
|
|
101
101
|
// init update
|
|
102
|
-
explicit
|
|
102
|
+
explicit llama_memory_hybrid_context(
|
|
103
103
|
llama_memory_hybrid * mem,
|
|
104
104
|
llama_context * lctx,
|
|
105
105
|
bool optimize);
|
|
106
106
|
|
|
107
107
|
// init success
|
|
108
|
-
|
|
108
|
+
llama_memory_hybrid_context(
|
|
109
109
|
llama_memory_hybrid * mem,
|
|
110
|
-
llama_sbatch sbatch,
|
|
111
110
|
std::vector<uint32_t> heads_attn,
|
|
112
111
|
std::vector<llama_ubatch> ubatches);
|
|
113
112
|
|
|
114
|
-
~
|
|
113
|
+
~llama_memory_hybrid_context() = default;
|
|
115
114
|
|
|
116
115
|
bool next() override;
|
|
117
116
|
bool apply() override;
|
|
118
117
|
|
|
119
|
-
std::vector<int64_t> & out_ids() override;
|
|
120
|
-
|
|
121
118
|
llama_memory_status get_status() const override;
|
|
122
119
|
const llama_ubatch & get_ubatch() const override;
|
|
123
120
|
|
|
124
121
|
//
|
|
125
|
-
//
|
|
122
|
+
// llama_memory_hybrid_context
|
|
126
123
|
//
|
|
127
124
|
|
|
128
|
-
const
|
|
129
|
-
const
|
|
125
|
+
const llama_kv_cache_unified_context * get_attn() const;
|
|
126
|
+
const llama_memory_recurrent_context * get_recr() const;
|
|
130
127
|
|
|
131
128
|
private:
|
|
132
|
-
llama_sbatch sbatch;
|
|
133
|
-
|
|
134
129
|
// the index of the next ubatch to process
|
|
135
130
|
size_t i_next = 0;
|
|
136
131
|
|
|
137
132
|
std::vector<llama_ubatch> ubatches;
|
|
138
133
|
|
|
139
|
-
const
|
|
140
|
-
const
|
|
134
|
+
const llama_memory_context_ptr ctx_attn;
|
|
135
|
+
const llama_memory_context_ptr ctx_recr;
|
|
141
136
|
|
|
142
137
|
const llama_memory_status status;
|
|
143
138
|
};
|
|
@@ -362,40 +362,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
362
362
|
return result;
|
|
363
363
|
}
|
|
364
364
|
|
|
365
|
-
|
|
366
|
-
|
|
365
|
+
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
366
|
+
do {
|
|
367
|
+
balloc.split_reset();
|
|
367
368
|
|
|
368
|
-
|
|
369
|
+
std::vector<llama_ubatch> ubatches;
|
|
370
|
+
while (true) {
|
|
371
|
+
llama_ubatch ubatch;
|
|
369
372
|
|
|
370
|
-
|
|
371
|
-
|
|
373
|
+
if (embd_all) {
|
|
374
|
+
// if all tokens are output, split by sequence
|
|
375
|
+
ubatch = balloc.split_seq(n_ubatch);
|
|
376
|
+
} else {
|
|
377
|
+
ubatch = balloc.split_equal(n_ubatch);
|
|
378
|
+
}
|
|
372
379
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
380
|
+
if (ubatch.n_tokens == 0) {
|
|
381
|
+
break;
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
378
385
|
}
|
|
379
386
|
|
|
380
|
-
ubatches
|
|
381
|
-
|
|
387
|
+
if (!prepare(ubatches)) {
|
|
388
|
+
break;
|
|
389
|
+
}
|
|
382
390
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
}
|
|
391
|
+
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
|
392
|
+
} while (false);
|
|
386
393
|
|
|
387
|
-
return std::make_unique<
|
|
394
|
+
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
388
395
|
}
|
|
389
396
|
|
|
390
|
-
|
|
391
|
-
return std::make_unique<
|
|
397
|
+
llama_memory_context_ptr llama_memory_recurrent::init_full() {
|
|
398
|
+
return std::make_unique<llama_memory_recurrent_context>(this);
|
|
392
399
|
}
|
|
393
400
|
|
|
394
|
-
|
|
401
|
+
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
|
395
402
|
GGML_UNUSED(lctx);
|
|
396
403
|
GGML_UNUSED(optimize);
|
|
397
404
|
|
|
398
|
-
return std::make_unique<
|
|
405
|
+
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
|
399
406
|
}
|
|
400
407
|
|
|
401
408
|
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
@@ -423,9 +430,8 @@ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches)
|
|
|
423
430
|
}
|
|
424
431
|
|
|
425
432
|
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
426
|
-
const uint32_t n_seqs = ubatch.n_seqs;
|
|
427
|
-
|
|
428
433
|
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
434
|
+
const uint32_t n_seqs = ubatch.n_seqs;
|
|
429
435
|
|
|
430
436
|
// if we have enough unused cells before the current head ->
|
|
431
437
|
// better to start searching from the beginning of the cache, hoping to fill it
|
|
@@ -445,9 +451,11 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
445
451
|
|
|
446
452
|
// everything should fit if all seq_ids are smaller than the max
|
|
447
453
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
448
|
-
const uint32_t
|
|
454
|
+
const uint32_t i = s*n_seq_tokens; // first token of sequence set s
|
|
455
|
+
const uint32_t n_seq_id = ubatch.n_seq_id[i];
|
|
456
|
+
|
|
449
457
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
|
450
|
-
const llama_seq_id seq_id = ubatch.seq_id[
|
|
458
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
|
451
459
|
|
|
452
460
|
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
|
453
461
|
// too big seq_id
|
|
@@ -506,7 +514,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
506
514
|
|
|
507
515
|
// find usable cell range
|
|
508
516
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
509
|
-
const
|
|
517
|
+
const uint32_t i = s*n_seq_tokens;
|
|
518
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
510
519
|
auto & seq_meta = cells[seq_id];
|
|
511
520
|
bool has_cell = false;
|
|
512
521
|
if (seq_meta.tail >= 0) {
|
|
@@ -530,7 +539,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
530
539
|
seq_meta.tail = next_empty_cell;
|
|
531
540
|
// find next empty cell
|
|
532
541
|
if (s + 1 < n_seqs) {
|
|
533
|
-
for (uint32_t
|
|
542
|
+
for (uint32_t j = 0; j < size; ++j) {
|
|
534
543
|
next_empty_cell += 1;
|
|
535
544
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
|
536
545
|
auto & cell = cells[next_empty_cell];
|
|
@@ -544,8 +553,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
544
553
|
|
|
545
554
|
// gather and re-order
|
|
546
555
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
556
|
+
const uint32_t i = s*n_seq_tokens;
|
|
547
557
|
const int32_t dst_id = s + min;
|
|
548
|
-
const int32_t src_id = cells[ubatch.seq_id[
|
|
558
|
+
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
|
|
549
559
|
if (dst_id != src_id) {
|
|
550
560
|
auto & dst_cell = cells[dst_id];
|
|
551
561
|
auto & src_cell = cells[src_id];
|
|
@@ -555,8 +565,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
555
565
|
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
|
556
566
|
|
|
557
567
|
// swap tails
|
|
558
|
-
for (uint32_t
|
|
559
|
-
int32_t & tail = cells[
|
|
568
|
+
for (uint32_t j = 0; j < size; ++j) {
|
|
569
|
+
int32_t & tail = cells[j].tail;
|
|
560
570
|
if (tail == src_id) {
|
|
561
571
|
tail = dst_id;
|
|
562
572
|
} else if (tail == dst_id) {
|
|
@@ -568,7 +578,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
568
578
|
|
|
569
579
|
// update the pos of the used seqs
|
|
570
580
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
571
|
-
const
|
|
581
|
+
const uint32_t i = s*n_seq_tokens;
|
|
582
|
+
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
|
|
572
583
|
const int32_t cell_id = s + min;
|
|
573
584
|
auto & cell = cells[cell_id];
|
|
574
585
|
|
|
@@ -576,12 +587,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
576
587
|
// What should happen when the pos backtracks or skips a value?
|
|
577
588
|
// Clearing the state mid-batch would require special-casing which isn't done.
|
|
578
589
|
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
|
579
|
-
__func__, last_pos, cell.pos, ubatch.seq_id[
|
|
590
|
+
__func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
|
|
580
591
|
}
|
|
581
592
|
cell.pos = last_pos;
|
|
582
593
|
cell.seq_id.clear();
|
|
583
|
-
for (int32_t j = 0; j < ubatch.n_seq_id[
|
|
584
|
-
const llama_seq_id seq_id = ubatch.seq_id[
|
|
594
|
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
|
|
595
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
|
585
596
|
cell.seq_id.insert(seq_id);
|
|
586
597
|
cells[seq_id].tail = cell_id;
|
|
587
598
|
}
|
|
@@ -827,12 +838,9 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
|
827
838
|
|
|
828
839
|
seq_rm(dest_seq_id, -1, -1);
|
|
829
840
|
|
|
830
|
-
|
|
831
|
-
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
|
841
|
+
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
|
832
842
|
|
|
833
|
-
|
|
834
|
-
batch.n_seq_tokens = cell_count;
|
|
835
|
-
batch.n_seqs = 1;
|
|
843
|
+
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
|
836
844
|
|
|
837
845
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
838
846
|
llama_pos pos;
|
|
@@ -846,12 +854,12 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
|
846
854
|
return false;
|
|
847
855
|
}
|
|
848
856
|
|
|
849
|
-
|
|
857
|
+
ubatch.pos[i] = pos;
|
|
850
858
|
}
|
|
851
|
-
|
|
852
|
-
|
|
859
|
+
ubatch.n_seq_id[0] = 1;
|
|
860
|
+
ubatch.seq_id[0] = &dest_seq_id;
|
|
853
861
|
|
|
854
|
-
if (!find_slot(
|
|
862
|
+
if (!find_slot(ubatch)) {
|
|
855
863
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
856
864
|
return false;
|
|
857
865
|
}
|
|
@@ -859,8 +867,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
|
859
867
|
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
|
860
868
|
// Assume that this is one contiguous block of cells
|
|
861
869
|
GGML_ASSERT(head + cell_count <= size);
|
|
862
|
-
GGML_ASSERT(cells[head].pos ==
|
|
863
|
-
GGML_ASSERT(cells[head + cell_count - 1].pos ==
|
|
870
|
+
GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
|
|
871
|
+
GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
|
|
864
872
|
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
|
865
873
|
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
|
866
874
|
} else {
|
|
@@ -1037,23 +1045,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
|
1037
1045
|
}
|
|
1038
1046
|
|
|
1039
1047
|
//
|
|
1040
|
-
//
|
|
1048
|
+
// llama_memory_recurrent_context
|
|
1041
1049
|
//
|
|
1042
1050
|
|
|
1043
|
-
|
|
1051
|
+
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
|
|
1044
1052
|
|
|
1045
|
-
|
|
1053
|
+
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
|
1046
1054
|
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
|
1047
1055
|
}
|
|
1048
1056
|
|
|
1049
|
-
|
|
1057
|
+
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
|
1050
1058
|
llama_memory_recurrent * mem,
|
|
1051
|
-
|
|
1052
|
-
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
|
1059
|
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
|
1053
1060
|
|
|
1054
|
-
|
|
1061
|
+
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
|
|
1055
1062
|
|
|
1056
|
-
bool
|
|
1063
|
+
bool llama_memory_recurrent_context::next() {
|
|
1057
1064
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1058
1065
|
|
|
1059
1066
|
if (++i_next >= ubatches.size()) {
|
|
@@ -1063,7 +1070,7 @@ bool llama_memory_recurrent_state::next() {
|
|
|
1063
1070
|
return true;
|
|
1064
1071
|
}
|
|
1065
1072
|
|
|
1066
|
-
bool
|
|
1073
|
+
bool llama_memory_recurrent_context::apply() {
|
|
1067
1074
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1068
1075
|
|
|
1069
1076
|
mem->find_slot(ubatches[i_next]);
|
|
@@ -1071,46 +1078,40 @@ bool llama_memory_recurrent_state::apply() {
|
|
|
1071
1078
|
return true;
|
|
1072
1079
|
}
|
|
1073
1080
|
|
|
1074
|
-
|
|
1075
|
-
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1076
|
-
|
|
1077
|
-
return sbatch.out_ids;
|
|
1078
|
-
}
|
|
1079
|
-
|
|
1080
|
-
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
|
1081
|
+
llama_memory_status llama_memory_recurrent_context::get_status() const {
|
|
1081
1082
|
return status;
|
|
1082
1083
|
}
|
|
1083
1084
|
|
|
1084
|
-
const llama_ubatch &
|
|
1085
|
+
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
|
|
1085
1086
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1086
1087
|
|
|
1087
1088
|
return ubatches[i_next];
|
|
1088
1089
|
}
|
|
1089
1090
|
|
|
1090
|
-
uint32_t
|
|
1091
|
+
uint32_t llama_memory_recurrent_context::get_n_rs() const {
|
|
1091
1092
|
return is_full ? mem->size : mem->n;
|
|
1092
1093
|
}
|
|
1093
1094
|
|
|
1094
|
-
uint32_t
|
|
1095
|
+
uint32_t llama_memory_recurrent_context::get_head() const {
|
|
1095
1096
|
return is_full ? 0 : mem->head;
|
|
1096
1097
|
}
|
|
1097
1098
|
|
|
1098
|
-
int32_t
|
|
1099
|
+
int32_t llama_memory_recurrent_context::get_rs_z() const {
|
|
1099
1100
|
return is_full ? 0 : mem->rs_z;
|
|
1100
1101
|
}
|
|
1101
1102
|
|
|
1102
|
-
uint32_t
|
|
1103
|
+
uint32_t llama_memory_recurrent_context::get_size() const {
|
|
1103
1104
|
return mem->size;
|
|
1104
1105
|
}
|
|
1105
1106
|
|
|
1106
|
-
ggml_tensor *
|
|
1107
|
+
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
|
|
1107
1108
|
return mem->r_l[il];
|
|
1108
1109
|
}
|
|
1109
1110
|
|
|
1110
|
-
ggml_tensor *
|
|
1111
|
+
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
|
1111
1112
|
return mem->s_l[il];
|
|
1112
1113
|
}
|
|
1113
1114
|
|
|
1114
|
-
int32_t
|
|
1115
|
+
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
|
1115
1116
|
return mem->cells[i + mem->head].src0;
|
|
1116
1117
|
}
|