@novastera-oss/llamarn 0.2.5 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +3 -2
- package/android/CMakeLists.txt +6 -3
- package/android/src/main/cpp/include/llama.h +140 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +48 -67
- package/cpp/LlamaCppModel.h +8 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +33 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
- package/cpp/llama.cpp/common/arg.cpp +38 -12
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
- package/cpp/llama.cpp/common/chat-parser.h +4 -1
- package/cpp/llama.cpp/common/chat.cpp +16 -13
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +52 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/json-partial.cpp +5 -4
- package/cpp/llama.cpp/common/json-partial.h +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
- package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +140 -38
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
- package/cpp/llama.cpp/src/llama-batch.h +47 -17
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +488 -313
- package/cpp/llama.cpp/src/llama-context.h +38 -17
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
- package/cpp/llama.cpp/src/llama-graph.h +109 -52
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
- package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +89 -4
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +735 -143
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
- package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
- package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
- package/cpp/rn-completion.cpp +65 -10
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common/minja/chat-template.hpp +1 -1
- package/ios/include/common/minja/minja.hpp +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/json-schema-to-grammar.h +4 -4
- package/ios/include/llama.h +140 -38
- package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
- package/ios/libs/llama.xcframework/Info.plist +20 -20
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -17,10 +17,12 @@ struct ggml_tensor;
|
|
|
17
17
|
struct llama_ubatch;
|
|
18
18
|
struct llama_cparams;
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class
|
|
23
|
-
class
|
|
20
|
+
struct llama_memory_state_i;
|
|
21
|
+
|
|
22
|
+
class llama_kv_cache_unified_state;
|
|
23
|
+
class llama_kv_cache_unified_iswa_state;
|
|
24
|
+
class llama_memory_recurrent_state;
|
|
25
|
+
class llama_memory_hybrid_state;
|
|
24
26
|
|
|
25
27
|
// certain models (typically multi-modal) can produce different types of graphs
|
|
26
28
|
enum llm_graph_type {
|
|
@@ -35,6 +37,7 @@ enum llm_ffn_op_type {
|
|
|
35
37
|
LLM_FFN_RELU,
|
|
36
38
|
LLM_FFN_RELU_SQR,
|
|
37
39
|
LLM_FFN_SWIGLU,
|
|
40
|
+
LLM_FFN_GEGLU,
|
|
38
41
|
};
|
|
39
42
|
|
|
40
43
|
enum llm_ffn_gate_type {
|
|
@@ -133,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|
|
133
136
|
public:
|
|
134
137
|
llm_graph_input_pos_bucket_kv(
|
|
135
138
|
const llama_hparams & hparams,
|
|
136
|
-
const
|
|
139
|
+
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
|
137
140
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
|
138
141
|
|
|
139
142
|
void set_input(const llama_ubatch * ubatch) override;
|
|
@@ -141,7 +144,7 @@ public:
|
|
|
141
144
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
|
142
145
|
|
|
143
146
|
const llama_hparams & hparams;
|
|
144
|
-
const
|
|
147
|
+
const llama_kv_cache_unified_state * kv_state;
|
|
145
148
|
};
|
|
146
149
|
|
|
147
150
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
@@ -186,28 +189,16 @@ public:
|
|
|
186
189
|
const llama_cparams & cparams;
|
|
187
190
|
};
|
|
188
191
|
|
|
189
|
-
class
|
|
192
|
+
class llm_graph_input_rs : public llm_graph_input_i {
|
|
190
193
|
public:
|
|
191
|
-
|
|
192
|
-
virtual ~
|
|
194
|
+
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
|
|
195
|
+
virtual ~llm_graph_input_rs() = default;
|
|
193
196
|
|
|
194
197
|
void set_input(const llama_ubatch * ubatch) override;
|
|
195
198
|
|
|
196
199
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
197
200
|
|
|
198
|
-
const
|
|
199
|
-
};
|
|
200
|
-
|
|
201
|
-
class llm_graph_input_s_mask : public llm_graph_input_i {
|
|
202
|
-
public:
|
|
203
|
-
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
|
204
|
-
virtual ~llm_graph_input_s_mask() = default;
|
|
205
|
-
|
|
206
|
-
void set_input(const llama_ubatch * ubatch) override;
|
|
207
|
-
|
|
208
|
-
ggml_tensor * s_mask; // F32 [1, n_kv]
|
|
209
|
-
|
|
210
|
-
const llama_kv_cache_recurrent * kv_self;
|
|
201
|
+
const llama_memory_recurrent_state * mem_state;
|
|
211
202
|
};
|
|
212
203
|
|
|
213
204
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -247,10 +238,10 @@ public:
|
|
|
247
238
|
llm_graph_input_attn_kv_unified(
|
|
248
239
|
const llama_hparams & hparams,
|
|
249
240
|
const llama_cparams & cparams,
|
|
250
|
-
const
|
|
241
|
+
const llama_kv_cache_unified_state * kv_state) :
|
|
251
242
|
hparams(hparams),
|
|
252
243
|
cparams(cparams),
|
|
253
|
-
|
|
244
|
+
kv_state(kv_state) {
|
|
254
245
|
}
|
|
255
246
|
~llm_graph_input_attn_kv_unified() = default;
|
|
256
247
|
|
|
@@ -264,7 +255,7 @@ public:
|
|
|
264
255
|
const llama_hparams & hparams;
|
|
265
256
|
const llama_cparams & cparams;
|
|
266
257
|
|
|
267
|
-
const
|
|
258
|
+
const llama_kv_cache_unified_state * kv_state;
|
|
268
259
|
};
|
|
269
260
|
|
|
270
261
|
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
@@ -272,10 +263,10 @@ public:
|
|
|
272
263
|
llm_graph_input_attn_kv_unified_iswa(
|
|
273
264
|
const llama_hparams & hparams,
|
|
274
265
|
const llama_cparams & cparams,
|
|
275
|
-
const
|
|
266
|
+
const llama_kv_cache_unified_iswa_state * kv_state) :
|
|
276
267
|
hparams(hparams),
|
|
277
268
|
cparams(cparams),
|
|
278
|
-
|
|
269
|
+
kv_state(kv_state) {
|
|
279
270
|
}
|
|
280
271
|
~llm_graph_input_attn_kv_unified_iswa() = default;
|
|
281
272
|
|
|
@@ -292,7 +283,7 @@ public:
|
|
|
292
283
|
const llama_hparams & hparams;
|
|
293
284
|
const llama_cparams & cparams;
|
|
294
285
|
|
|
295
|
-
const
|
|
286
|
+
const llama_kv_cache_unified_iswa_state * kv_state;
|
|
296
287
|
};
|
|
297
288
|
|
|
298
289
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -310,6 +301,33 @@ public:
|
|
|
310
301
|
const llama_cross * cross = nullptr;
|
|
311
302
|
};
|
|
312
303
|
|
|
304
|
+
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
|
305
|
+
public:
|
|
306
|
+
llm_graph_input_mem_hybrid(
|
|
307
|
+
const llama_hparams & hparams,
|
|
308
|
+
const llama_cparams & cparams,
|
|
309
|
+
const llama_memory_hybrid_state * mem_state) :
|
|
310
|
+
hparams(hparams),
|
|
311
|
+
cparams(cparams),
|
|
312
|
+
mem_state(mem_state) {
|
|
313
|
+
}
|
|
314
|
+
virtual ~llm_graph_input_mem_hybrid() = default;
|
|
315
|
+
|
|
316
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
317
|
+
|
|
318
|
+
ggml_tensor * s_copy; // I32 [kv_size]
|
|
319
|
+
|
|
320
|
+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
321
|
+
|
|
322
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
|
323
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
|
324
|
+
|
|
325
|
+
const llama_hparams & hparams;
|
|
326
|
+
const llama_cparams & cparams;
|
|
327
|
+
|
|
328
|
+
const llama_memory_hybrid_state * mem_state;
|
|
329
|
+
};
|
|
330
|
+
|
|
313
331
|
//
|
|
314
332
|
// llm_graph_result
|
|
315
333
|
//
|
|
@@ -383,12 +401,12 @@ struct llm_graph_params {
|
|
|
383
401
|
ggml_backend_sched_t sched;
|
|
384
402
|
ggml_backend_t backend_cpu;
|
|
385
403
|
|
|
386
|
-
const llama_adapter_cvec
|
|
387
|
-
const llama_adapter_loras
|
|
388
|
-
const
|
|
389
|
-
const llama_cross
|
|
404
|
+
const llama_adapter_cvec * cvec;
|
|
405
|
+
const llama_adapter_loras * loras;
|
|
406
|
+
const llama_memory_state_i * mstate;
|
|
407
|
+
const llama_cross * cross;
|
|
390
408
|
|
|
391
|
-
|
|
409
|
+
uint32_t n_outputs;
|
|
392
410
|
|
|
393
411
|
const llm_graph_cb & cb;
|
|
394
412
|
};
|
|
@@ -422,8 +440,8 @@ struct llm_graph_context {
|
|
|
422
440
|
const float norm_eps;
|
|
423
441
|
const float norm_rms_eps;
|
|
424
442
|
|
|
425
|
-
const
|
|
426
|
-
const
|
|
443
|
+
const int64_t n_tokens;
|
|
444
|
+
const int64_t n_outputs;
|
|
427
445
|
const int32_t n_ctx_orig; // yarn
|
|
428
446
|
|
|
429
447
|
const enum llama_pooling_type pooling_type;
|
|
@@ -435,10 +453,10 @@ struct llm_graph_context {
|
|
|
435
453
|
|
|
436
454
|
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
437
455
|
|
|
438
|
-
const llama_adapter_cvec
|
|
439
|
-
const llama_adapter_loras
|
|
440
|
-
const
|
|
441
|
-
const llama_cross
|
|
456
|
+
const llama_adapter_cvec * cvec;
|
|
457
|
+
const llama_adapter_loras * loras;
|
|
458
|
+
const llama_memory_state_i * mstate;
|
|
459
|
+
const llama_cross * cross;
|
|
442
460
|
|
|
443
461
|
const llm_graph_cb & cb_func;
|
|
444
462
|
|
|
@@ -518,14 +536,14 @@ struct llm_graph_context {
|
|
|
518
536
|
ggml_tensor * build_inp_out_ids() const;
|
|
519
537
|
ggml_tensor * build_inp_mean() const;
|
|
520
538
|
ggml_tensor * build_inp_cls() const;
|
|
521
|
-
ggml_tensor * build_inp_s_copy() const;
|
|
522
|
-
ggml_tensor * build_inp_s_mask() const;
|
|
523
539
|
|
|
524
540
|
ggml_tensor * build_inp_cross_embd() const;
|
|
525
541
|
ggml_tensor * build_inp_pos_bucket_enc() const;
|
|
526
542
|
ggml_tensor * build_inp_pos_bucket_dec() const;
|
|
527
543
|
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
|
528
544
|
|
|
545
|
+
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
|
546
|
+
|
|
529
547
|
//
|
|
530
548
|
// attention
|
|
531
549
|
//
|
|
@@ -600,23 +618,62 @@ struct llm_graph_context {
|
|
|
600
618
|
float kq_scale,
|
|
601
619
|
int il) const;
|
|
602
620
|
|
|
621
|
+
ggml_tensor * build_attn(
|
|
622
|
+
llm_graph_input_mem_hybrid * inp,
|
|
623
|
+
ggml_cgraph * gf,
|
|
624
|
+
ggml_tensor * wo,
|
|
625
|
+
ggml_tensor * wo_b,
|
|
626
|
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
627
|
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
628
|
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
629
|
+
ggml_tensor * kq_b,
|
|
630
|
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
631
|
+
float kq_scale,
|
|
632
|
+
int il) const;
|
|
603
633
|
//
|
|
604
634
|
// recurrent
|
|
605
635
|
//
|
|
606
636
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
637
|
+
// TODO: avoid notion of "kv"
|
|
638
|
+
// TODO: move this implementation to llama_memory_recurrent.
|
|
639
|
+
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
|
640
|
+
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
|
641
|
+
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
|
642
|
+
// `llama_memory_recurrent`
|
|
643
|
+
ggml_tensor * build_rs(
|
|
644
|
+
ggml_cgraph * gf,
|
|
645
|
+
ggml_tensor * s,
|
|
646
|
+
ggml_tensor * state_copy,
|
|
647
|
+
int32_t state_size,
|
|
648
|
+
int32_t n_seqs,
|
|
649
|
+
uint32_t n_kv,
|
|
650
|
+
uint32_t kv_head,
|
|
651
|
+
uint32_t kv_size,
|
|
652
|
+
int32_t rs_zero,
|
|
653
|
+
bool avoid_copies = false) const;
|
|
654
|
+
|
|
655
|
+
llm_graph_input_rs * build_rs_inp() const;
|
|
656
|
+
|
|
657
|
+
ggml_tensor * build_rs(
|
|
658
|
+
llm_graph_input_rs * inp,
|
|
659
|
+
ggml_cgraph * gf,
|
|
660
|
+
ggml_tensor * s,
|
|
661
|
+
int32_t state_size,
|
|
662
|
+
int32_t n_seqs,
|
|
663
|
+
bool avoid_copies = false) const;
|
|
664
|
+
|
|
665
|
+
ggml_tensor * build_rs(
|
|
666
|
+
llm_graph_input_mem_hybrid * inp,
|
|
667
|
+
ggml_cgraph * gf,
|
|
668
|
+
ggml_tensor * s,
|
|
669
|
+
int32_t state_size,
|
|
670
|
+
int32_t n_seqs,
|
|
671
|
+
bool avoid_copies = false) const;
|
|
614
672
|
|
|
615
673
|
ggml_tensor * build_rwkv_token_shift_load(
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
const llama_ubatch & ubatch,
|
|
674
|
+
llm_graph_input_rs * inp,
|
|
675
|
+
ggml_cgraph * gf,
|
|
676
|
+
const llama_ubatch & ubatch,
|
|
620
677
|
int il) const;
|
|
621
678
|
|
|
622
679
|
ggml_tensor * build_rwkv_token_shift_store(
|
|
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|
|
65
65
|
return n_embd_head_v * n_head_kv;
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
-
uint32_t llama_hparams::
|
|
68
|
+
uint32_t llama_hparams::n_embd_r() const {
|
|
69
69
|
if (wkv_head_size != 0) {
|
|
70
70
|
// for RWKV models
|
|
71
71
|
return token_shift_count * n_embd;
|
|
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
|
|
76
76
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
|
77
77
|
}
|
|
78
78
|
|
|
79
|
-
uint32_t llama_hparams::
|
|
79
|
+
uint32_t llama_hparams::n_embd_s() const {
|
|
80
80
|
if (wkv_head_size != 0) {
|
|
81
81
|
// corresponds to RWKV's wkv_states size
|
|
82
82
|
return n_embd * wkv_head_size;
|
|
@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|
|
86
86
|
return ssm_d_state * ssm_d_inner;
|
|
87
87
|
}
|
|
88
88
|
|
|
89
|
+
bool llama_hparams::is_recurrent(uint32_t il) const {
|
|
90
|
+
return recurrent_layer_arr[il];
|
|
91
|
+
}
|
|
92
|
+
|
|
89
93
|
bool llama_hparams::is_swa(uint32_t il) const {
|
|
90
94
|
if (il < n_layer) {
|
|
91
95
|
return swa_layers[il];
|
|
@@ -115,6 +115,9 @@ struct llama_hparams {
|
|
|
115
115
|
uint32_t ssm_d_state = 0;
|
|
116
116
|
uint32_t ssm_dt_rank = 0;
|
|
117
117
|
|
|
118
|
+
// for hybrid state space models
|
|
119
|
+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
|
120
|
+
|
|
118
121
|
bool ssm_dt_b_c_rms = false;
|
|
119
122
|
|
|
120
123
|
float f_clamp_kqv = 0.0f;
|
|
@@ -181,10 +184,13 @@ struct llama_hparams {
|
|
|
181
184
|
|
|
182
185
|
// dimension of the rolling state embeddings
|
|
183
186
|
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
|
184
|
-
uint32_t
|
|
187
|
+
uint32_t n_embd_r() const;
|
|
185
188
|
|
|
186
189
|
// dimension of the recurrent state embeddings
|
|
187
|
-
uint32_t
|
|
190
|
+
uint32_t n_embd_s() const;
|
|
191
|
+
|
|
192
|
+
// whether or not the given layer is recurrent (for hybrid models)
|
|
193
|
+
bool is_recurrent(uint32_t il) const;
|
|
188
194
|
|
|
189
195
|
bool is_swa(uint32_t il) const;
|
|
190
196
|
};
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
#include "llama-kv-cache-unified-iswa.h"
|
|
2
|
+
|
|
3
|
+
#include "llama-impl.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
5
|
+
#include "llama-model.h"
|
|
6
|
+
|
|
7
|
+
#include <algorithm>
|
|
8
|
+
#include <cassert>
|
|
9
|
+
|
|
10
|
+
//
|
|
11
|
+
// llama_kv_cache_unified_iswa
|
|
12
|
+
//
|
|
13
|
+
|
|
14
|
+
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
15
|
+
const llama_model & model,
|
|
16
|
+
ggml_type type_k,
|
|
17
|
+
ggml_type type_v,
|
|
18
|
+
bool v_trans,
|
|
19
|
+
bool offload,
|
|
20
|
+
bool swa_full,
|
|
21
|
+
uint32_t kv_size,
|
|
22
|
+
uint32_t n_seq_max,
|
|
23
|
+
uint32_t n_ubatch,
|
|
24
|
+
uint32_t n_pad) : hparams(model.hparams) {
|
|
25
|
+
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
|
26
|
+
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
|
27
|
+
|
|
28
|
+
const uint32_t size_base = kv_size;
|
|
29
|
+
|
|
30
|
+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
|
31
|
+
|
|
32
|
+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
33
|
+
if (swa_full) {
|
|
34
|
+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
|
35
|
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
36
|
+
|
|
37
|
+
size_swa = size_base;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
|
41
|
+
|
|
42
|
+
kv_base = std::make_unique<llama_kv_cache_unified>(
|
|
43
|
+
model, std::move(filter_base), type_k, type_v,
|
|
44
|
+
v_trans, offload, size_base, n_seq_max, n_pad,
|
|
45
|
+
0, LLAMA_SWA_TYPE_NONE);
|
|
46
|
+
|
|
47
|
+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
48
|
+
|
|
49
|
+
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
|
50
|
+
model, std::move(filter_swa), type_k, type_v,
|
|
51
|
+
v_trans, offload, size_swa, n_seq_max, n_pad,
|
|
52
|
+
hparams.n_swa, hparams.swa_type);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
void llama_kv_cache_unified_iswa::clear(bool data) {
|
|
56
|
+
kv_base->clear(data);
|
|
57
|
+
kv_swa ->clear(data);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
61
|
+
bool res = true;
|
|
62
|
+
|
|
63
|
+
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
|
64
|
+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
|
65
|
+
|
|
66
|
+
return res;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
70
|
+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
71
|
+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
|
75
|
+
kv_base->seq_keep(seq_id);
|
|
76
|
+
kv_swa ->seq_keep(seq_id);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
80
|
+
kv_base->seq_add(seq_id, p0, p1, shift);
|
|
81
|
+
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
85
|
+
kv_base->seq_div(seq_id, p0, p1, d);
|
|
86
|
+
kv_swa ->seq_div(seq_id, p0, p1, d);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
|
90
|
+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
|
91
|
+
return kv_swa->seq_pos_min(seq_id);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
95
|
+
return kv_swa->seq_pos_max(seq_id);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
|
99
|
+
GGML_UNUSED(embd_all);
|
|
100
|
+
|
|
101
|
+
// first try simple split
|
|
102
|
+
do {
|
|
103
|
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
|
104
|
+
|
|
105
|
+
std::vector<llama_ubatch> ubatches;
|
|
106
|
+
|
|
107
|
+
while (sbatch.n_tokens > 0) {
|
|
108
|
+
auto ubatch = sbatch.split_simple(n_ubatch);
|
|
109
|
+
|
|
110
|
+
ubatches.push_back(ubatch);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
auto heads_base = kv_base->prepare(ubatches);
|
|
114
|
+
if (heads_base.empty()) {
|
|
115
|
+
break;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
auto heads_swa = kv_swa->prepare(ubatches);
|
|
119
|
+
if (heads_swa.empty()) {
|
|
120
|
+
break;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
assert(heads_base.size() == heads_swa.size());
|
|
124
|
+
|
|
125
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
|
126
|
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
127
|
+
} while (false);
|
|
128
|
+
|
|
129
|
+
// if it fails, try equal split
|
|
130
|
+
do {
|
|
131
|
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
|
132
|
+
|
|
133
|
+
std::vector<llama_ubatch> ubatches;
|
|
134
|
+
|
|
135
|
+
while (sbatch.n_tokens > 0) {
|
|
136
|
+
auto ubatch = sbatch.split_equal(n_ubatch);
|
|
137
|
+
|
|
138
|
+
ubatches.push_back(ubatch);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
auto heads_base = kv_base->prepare(ubatches);
|
|
142
|
+
if (heads_base.empty()) {
|
|
143
|
+
break;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
auto heads_swa = kv_swa->prepare(ubatches);
|
|
147
|
+
if (heads_swa.empty()) {
|
|
148
|
+
break;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
assert(heads_base.size() == heads_swa.size());
|
|
152
|
+
|
|
153
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
|
154
|
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
155
|
+
} while (false);
|
|
156
|
+
|
|
157
|
+
// TODO: if we fail again, we should attempt different splitting strategies
|
|
158
|
+
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
159
|
+
|
|
160
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
|
164
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
|
168
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
172
|
+
return kv_base->get_size() == kv_swa->get_size();
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
176
|
+
kv_base->state_write(io, seq_id);
|
|
177
|
+
kv_swa ->state_write(io, seq_id);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
181
|
+
kv_base->state_read(io, seq_id);
|
|
182
|
+
kv_swa ->state_read(io, seq_id);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
|
186
|
+
return kv_base.get();
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
190
|
+
return kv_swa.get();
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
//
|
|
194
|
+
// llama_kv_cache_unified_iswa_state
|
|
195
|
+
//
|
|
196
|
+
|
|
197
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
|
198
|
+
|
|
199
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
200
|
+
llama_kv_cache_unified_iswa * kv) :
|
|
201
|
+
state_base(kv->get_base()->init_full()),
|
|
202
|
+
state_swa (kv->get_swa ()->init_full()),
|
|
203
|
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
207
|
+
llama_kv_cache_unified_iswa * kv,
|
|
208
|
+
llama_context * lctx,
|
|
209
|
+
bool optimize) :
|
|
210
|
+
state_base(kv->get_base()->init_update(lctx, optimize)),
|
|
211
|
+
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
|
212
|
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
216
|
+
llama_kv_cache_unified_iswa * kv,
|
|
217
|
+
llama_sbatch sbatch,
|
|
218
|
+
std::vector<uint32_t> heads_base,
|
|
219
|
+
std::vector<uint32_t> heads_swa,
|
|
220
|
+
std::vector<llama_ubatch> ubatches) :
|
|
221
|
+
sbatch(std::move(sbatch)),
|
|
222
|
+
ubatches(std::move(ubatches)),
|
|
223
|
+
// note: here we copy the ubatches. not sure if this is ideal
|
|
224
|
+
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
|
|
225
|
+
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
|
|
226
|
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
|
230
|
+
|
|
231
|
+
bool llama_kv_cache_unified_iswa_state::next() {
|
|
232
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
233
|
+
|
|
234
|
+
state_base->next();
|
|
235
|
+
state_swa ->next();
|
|
236
|
+
|
|
237
|
+
if (++i_next >= ubatches.size()) {
|
|
238
|
+
return false;
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
return true;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
bool llama_kv_cache_unified_iswa_state::apply() {
|
|
245
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
246
|
+
|
|
247
|
+
bool res = true;
|
|
248
|
+
|
|
249
|
+
res = res & state_base->apply();
|
|
250
|
+
res = res & state_swa ->apply();
|
|
251
|
+
|
|
252
|
+
return res;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
|
256
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
257
|
+
|
|
258
|
+
return sbatch.out_ids;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
|
262
|
+
return status;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
|
266
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
267
|
+
|
|
268
|
+
return ubatches[i_next];
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
|
272
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
273
|
+
|
|
274
|
+
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
|
278
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
279
|
+
|
|
280
|
+
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
|
281
|
+
}
|