@novastera-oss/llamarn 0.2.5 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +3 -2
- package/android/CMakeLists.txt +6 -3
- package/android/src/main/cpp/include/llama.h +140 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +48 -67
- package/cpp/LlamaCppModel.h +8 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +33 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
- package/cpp/llama.cpp/common/arg.cpp +38 -12
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
- package/cpp/llama.cpp/common/chat-parser.h +4 -1
- package/cpp/llama.cpp/common/chat.cpp +16 -13
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +52 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/json-partial.cpp +5 -4
- package/cpp/llama.cpp/common/json-partial.h +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
- package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +140 -38
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
- package/cpp/llama.cpp/src/llama-batch.h +47 -17
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +488 -313
- package/cpp/llama.cpp/src/llama-context.h +38 -17
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
- package/cpp/llama.cpp/src/llama-graph.h +109 -52
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
- package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +89 -4
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +735 -143
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
- package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
- package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
- package/cpp/rn-completion.cpp +65 -10
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common/minja/chat-template.hpp +1 -1
- package/ios/include/common/minja/minja.hpp +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/json-schema-to-grammar.h +4 -4
- package/ios/include/llama.h +140 -38
- package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
- package/ios/libs/llama.xcframework/Info.plist +20 -20
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
#include "llama-context.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
4
5
|
#include "llama-io.h"
|
|
6
|
+
#include "llama-memory.h"
|
|
5
7
|
#include "llama-mmap.h"
|
|
6
8
|
#include "llama-model.h"
|
|
7
|
-
#include "llama-kv-cache.h"
|
|
8
9
|
|
|
10
|
+
#include <cinttypes>
|
|
9
11
|
#include <cstring>
|
|
12
|
+
#include <limits>
|
|
10
13
|
#include <stdexcept>
|
|
11
|
-
#include <cinttypes>
|
|
12
14
|
|
|
13
15
|
//
|
|
14
16
|
// llama_context
|
|
@@ -17,7 +19,8 @@
|
|
|
17
19
|
llama_context::llama_context(
|
|
18
20
|
const llama_model & model,
|
|
19
21
|
llama_context_params params) :
|
|
20
|
-
model(model)
|
|
22
|
+
model(model),
|
|
23
|
+
batch_allocr(std::make_unique<llama_batch_allocr>()) {
|
|
21
24
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
|
22
25
|
|
|
23
26
|
t_start_us = model.t_start_us;
|
|
@@ -26,8 +29,8 @@ llama_context::llama_context(
|
|
|
26
29
|
const auto & hparams = model.hparams;
|
|
27
30
|
|
|
28
31
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
|
29
|
-
if (cparams.n_seq_max >
|
|
30
|
-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(
|
|
32
|
+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
|
33
|
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
|
31
34
|
}
|
|
32
35
|
|
|
33
36
|
cparams.n_threads = params.n_threads;
|
|
@@ -122,6 +125,11 @@ llama_context::llama_context(
|
|
|
122
125
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
123
126
|
}
|
|
124
127
|
|
|
128
|
+
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
|
|
129
|
+
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
|
130
|
+
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
|
131
|
+
}
|
|
132
|
+
|
|
125
133
|
if (!hparams.vocab_only) {
|
|
126
134
|
// GPU backends
|
|
127
135
|
for (auto * dev : model.devices) {
|
|
@@ -259,15 +267,9 @@ llama_context::llama_context(
|
|
|
259
267
|
|
|
260
268
|
// reserve worst-case graph
|
|
261
269
|
if (!hparams.vocab_only && memory) {
|
|
262
|
-
const uint32_t n_seqs =
|
|
270
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
263
271
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
264
272
|
|
|
265
|
-
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
266
|
-
|
|
267
|
-
// restore later
|
|
268
|
-
// TODO: something cleaner
|
|
269
|
-
const auto n_outputs_save = n_outputs;
|
|
270
|
-
|
|
271
273
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
272
274
|
|
|
273
275
|
int n_splits_pp = -1;
|
|
@@ -277,25 +279,18 @@ llama_context::llama_context(
|
|
|
277
279
|
int n_nodes_tg = -1;
|
|
278
280
|
|
|
279
281
|
// simulate full KV cache
|
|
280
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
281
282
|
|
|
282
|
-
|
|
283
|
+
const auto mstate = memory->init_full();
|
|
284
|
+
if (!mstate) {
|
|
285
|
+
throw std::runtime_error("failed to initialize KV cache");
|
|
286
|
+
}
|
|
283
287
|
|
|
284
288
|
cross.v_embd.clear();
|
|
285
289
|
|
|
286
290
|
// reserve pp graph first so that buffers are only allocated once
|
|
287
291
|
{
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
// max number of outputs
|
|
291
|
-
n_outputs = ubatch_pp.n_tokens;
|
|
292
|
-
|
|
293
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
294
|
-
|
|
295
|
-
auto * gf = graph_init();
|
|
296
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
297
|
-
|
|
298
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
292
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
|
293
|
+
if (!gf) {
|
|
299
294
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
300
295
|
}
|
|
301
296
|
|
|
@@ -305,16 +300,8 @@ llama_context::llama_context(
|
|
|
305
300
|
|
|
306
301
|
// reserve with tg graph to get the number of splits and nodes
|
|
307
302
|
{
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
n_outputs = ubatch_tg.n_tokens;
|
|
311
|
-
|
|
312
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
|
313
|
-
|
|
314
|
-
auto * gf = graph_init();
|
|
315
|
-
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
|
316
|
-
|
|
317
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
303
|
+
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
|
304
|
+
if (!gf) {
|
|
318
305
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
319
306
|
}
|
|
320
307
|
|
|
@@ -324,22 +311,12 @@ llama_context::llama_context(
|
|
|
324
311
|
|
|
325
312
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
326
313
|
{
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
n_outputs = ubatch_pp.n_tokens;
|
|
330
|
-
|
|
331
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
332
|
-
|
|
333
|
-
auto * gf = graph_init();
|
|
334
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
335
|
-
|
|
336
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
314
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
|
315
|
+
if (!gf) {
|
|
337
316
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
338
317
|
}
|
|
339
318
|
}
|
|
340
319
|
|
|
341
|
-
n_outputs = n_outputs_save;
|
|
342
|
-
|
|
343
320
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
344
321
|
ggml_backend_t backend = backend_ptrs[i];
|
|
345
322
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
@@ -443,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
|
|
|
443
420
|
return cparams.n_threads_batch;
|
|
444
421
|
}
|
|
445
422
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
return kv_self;
|
|
449
|
-
}
|
|
450
|
-
|
|
451
|
-
const llama_kv_cache * llama_context::get_kv_self() const {
|
|
452
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
453
|
-
return kv_self;
|
|
423
|
+
llama_memory_t llama_context::get_memory() const {
|
|
424
|
+
return memory.get();
|
|
454
425
|
}
|
|
455
426
|
|
|
456
|
-
|
|
457
|
-
|
|
427
|
+
// deprecated
|
|
428
|
+
void llama_context::kv_self_defrag_sched() {
|
|
429
|
+
if (!memory) {
|
|
430
|
+
return;
|
|
431
|
+
}
|
|
458
432
|
|
|
459
|
-
|
|
433
|
+
memory_force_optimize = true;
|
|
434
|
+
}
|
|
460
435
|
|
|
461
|
-
|
|
436
|
+
// deprecated
|
|
437
|
+
bool llama_context::kv_self_update(bool optimize) {
|
|
438
|
+
if (!memory) {
|
|
439
|
+
return false;
|
|
440
|
+
}
|
|
462
441
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
442
|
+
{
|
|
443
|
+
// TODO: remove in the future
|
|
444
|
+
optimize |= memory_force_optimize;
|
|
445
|
+
memory_force_optimize = false;
|
|
466
446
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
447
|
+
const auto mstate = memory->init_update(this, optimize);
|
|
448
|
+
switch (mstate->get_status()) {
|
|
449
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
450
|
+
{
|
|
451
|
+
// noop
|
|
452
|
+
} break;
|
|
453
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
|
454
|
+
{
|
|
455
|
+
// no updates need to be performed
|
|
456
|
+
return false;
|
|
457
|
+
}
|
|
458
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
459
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
460
|
+
{
|
|
461
|
+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
|
462
|
+
return false;
|
|
463
|
+
}
|
|
464
|
+
}
|
|
470
465
|
|
|
471
|
-
|
|
472
|
-
|
|
466
|
+
if (!mstate->apply()) {
|
|
467
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
|
468
|
+
}
|
|
469
|
+
}
|
|
473
470
|
|
|
474
|
-
|
|
475
|
-
|
|
471
|
+
// if the memory module did any computation, we have to reserve a new worst-case graph
|
|
472
|
+
{
|
|
473
|
+
const auto mstate = memory->init_full();
|
|
474
|
+
if (!mstate) {
|
|
475
|
+
throw std::runtime_error("failed to initialize memory state");
|
|
476
|
+
}
|
|
476
477
|
|
|
477
|
-
|
|
478
|
-
|
|
478
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
479
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
479
480
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
481
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
|
482
|
+
if (!gf) {
|
|
483
|
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
|
484
484
|
}
|
|
485
485
|
}
|
|
486
|
+
|
|
487
|
+
return true;
|
|
486
488
|
}
|
|
487
489
|
|
|
488
490
|
enum llama_pooling_type llama_context::pooling_type() const {
|
|
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
|
|
|
494
496
|
}
|
|
495
497
|
|
|
496
498
|
float * llama_context::get_logits_ith(int32_t i) {
|
|
497
|
-
|
|
499
|
+
int64_t j = -1;
|
|
498
500
|
|
|
499
501
|
try {
|
|
500
502
|
if (logits == nullptr) {
|
|
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
517
519
|
}
|
|
518
520
|
if (j >= n_outputs) {
|
|
519
521
|
// This should not happen
|
|
520
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
|
522
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
521
523
|
}
|
|
522
524
|
|
|
523
525
|
return logits + j*model.vocab.n_tokens();
|
|
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
|
|
|
536
538
|
}
|
|
537
539
|
|
|
538
540
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
539
|
-
|
|
541
|
+
int64_t j = -1;
|
|
540
542
|
|
|
541
543
|
try {
|
|
542
544
|
if (embd == nullptr) {
|
|
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
|
559
561
|
}
|
|
560
562
|
if (j >= n_outputs) {
|
|
561
563
|
// This should not happen
|
|
562
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
|
564
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
563
565
|
}
|
|
564
566
|
|
|
565
567
|
return embd + j*model.hparams.n_embd;
|
|
@@ -676,52 +678,84 @@ bool llama_context::apply_adapter_cvec(
|
|
|
676
678
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
|
677
679
|
}
|
|
678
680
|
|
|
679
|
-
|
|
680
|
-
if (
|
|
681
|
+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
|
682
|
+
if (mstate && !mstate->apply()) {
|
|
683
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
|
684
|
+
ret = GGML_STATUS_FAILED;
|
|
685
|
+
return nullptr;
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
auto * gf = graph_init();
|
|
689
|
+
if (!gf) {
|
|
690
|
+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
|
691
|
+
ret = GGML_STATUS_FAILED;
|
|
692
|
+
return nullptr;
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
|
696
|
+
if (!res) {
|
|
697
|
+
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
|
698
|
+
ret = GGML_STATUS_FAILED;
|
|
699
|
+
return nullptr;
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
|
703
|
+
|
|
704
|
+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
|
705
|
+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
|
706
|
+
ret = GGML_STATUS_ALLOC_FAILED;
|
|
707
|
+
return nullptr;
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
res->set_inputs(&ubatch);
|
|
711
|
+
|
|
712
|
+
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
|
713
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
714
|
+
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
|
715
|
+
ret = status;
|
|
716
|
+
return nullptr;
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
ret = GGML_STATUS_SUCCESS;
|
|
720
|
+
|
|
721
|
+
return res;
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
int llama_context::encode(const llama_batch & batch_inp) {
|
|
725
|
+
if (batch_inp.n_tokens == 0) {
|
|
681
726
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
682
727
|
return -1;
|
|
683
728
|
}
|
|
684
729
|
|
|
685
|
-
// temporary allocate memory for the input batch if needed
|
|
686
730
|
// note: during encode, we always pass the full sequence starting from pos = 0
|
|
687
|
-
|
|
731
|
+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
|
|
732
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
733
|
+
return -1;
|
|
734
|
+
}
|
|
688
735
|
|
|
689
|
-
const llama_batch & batch = batch_allocr
|
|
690
|
-
const int32_t n_tokens = batch.n_tokens;
|
|
736
|
+
const llama_batch & batch = batch_allocr->get_batch();
|
|
691
737
|
|
|
692
|
-
const
|
|
738
|
+
const uint32_t n_tokens = batch.n_tokens;
|
|
693
739
|
|
|
694
740
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
695
741
|
|
|
696
|
-
// TODO: move the validation to the llama_batch_allocr
|
|
697
|
-
if (batch.token) {
|
|
698
|
-
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
699
|
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
700
|
-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
701
|
-
return -1;
|
|
702
|
-
}
|
|
703
|
-
|
|
704
|
-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
|
705
|
-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
706
|
-
throw -1;
|
|
707
|
-
}
|
|
708
|
-
}
|
|
709
|
-
}
|
|
710
|
-
|
|
711
742
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
712
|
-
GGML_ASSERT(cparams.n_ubatch >=
|
|
743
|
+
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
713
744
|
|
|
714
745
|
if (t_compute_start_us == 0) {
|
|
715
746
|
t_compute_start_us = ggml_time_us();
|
|
716
747
|
}
|
|
717
748
|
|
|
749
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
718
750
|
embd_seq.clear();
|
|
719
751
|
|
|
720
752
|
n_queued_tokens += n_tokens;
|
|
721
753
|
|
|
754
|
+
const auto & hparams = model.hparams;
|
|
755
|
+
|
|
722
756
|
const int64_t n_embd = hparams.n_embd;
|
|
723
757
|
|
|
724
|
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true
|
|
758
|
+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
|
|
725
759
|
|
|
726
760
|
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
|
727
761
|
|
|
@@ -731,14 +765,12 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
731
765
|
return -2;
|
|
732
766
|
};
|
|
733
767
|
|
|
734
|
-
for (
|
|
768
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
|
735
769
|
output_ids[i] = i;
|
|
736
770
|
}
|
|
737
771
|
|
|
738
772
|
n_outputs = n_tokens;
|
|
739
773
|
|
|
740
|
-
//batch_manager->prepare(ubatch);
|
|
741
|
-
|
|
742
774
|
ggml_backend_sched_reset(sched.get());
|
|
743
775
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
744
776
|
|
|
@@ -749,26 +781,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
749
781
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
|
750
782
|
cparams.causal_attn = false;
|
|
751
783
|
|
|
752
|
-
|
|
753
|
-
auto res =
|
|
754
|
-
|
|
755
|
-
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
756
|
-
|
|
757
|
-
res->set_inputs(&ubatch);
|
|
784
|
+
ggml_status status;
|
|
785
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
|
758
786
|
|
|
759
787
|
cparams.causal_attn = causal_attn_org;
|
|
760
788
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
return -2;
|
|
769
|
-
case GGML_STATUS_FAILED:
|
|
770
|
-
default:
|
|
771
|
-
return -3;
|
|
789
|
+
if (!res) {
|
|
790
|
+
switch (status) {
|
|
791
|
+
case GGML_STATUS_ABORTED: return 2;
|
|
792
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
|
793
|
+
case GGML_STATUS_FAILED: return -3;
|
|
794
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
|
795
|
+
}
|
|
772
796
|
}
|
|
773
797
|
|
|
774
798
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
@@ -797,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
797
821
|
|
|
798
822
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
799
823
|
|
|
800
|
-
|
|
824
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
825
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
801
826
|
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
802
827
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
803
828
|
continue;
|
|
@@ -808,16 +833,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
808
833
|
} break;
|
|
809
834
|
case LLAMA_POOLING_TYPE_RANK:
|
|
810
835
|
{
|
|
811
|
-
// extract the rerank score -
|
|
836
|
+
// extract the rerank score - n_cls_out floats per sequence
|
|
812
837
|
auto & embd_seq_out = embd_seq;
|
|
838
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
|
813
839
|
|
|
840
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
814
841
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
815
842
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
816
843
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
817
844
|
continue;
|
|
818
845
|
}
|
|
819
|
-
embd_seq_out[seq_id].resize(
|
|
820
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
|
846
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
|
847
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
|
|
821
848
|
}
|
|
822
849
|
} break;
|
|
823
850
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
@@ -844,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
844
871
|
|
|
845
872
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
|
846
873
|
cross.seq_ids_enc.resize(n_tokens);
|
|
847
|
-
for (
|
|
874
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
848
875
|
cross.seq_ids_enc[i].clear();
|
|
849
|
-
for (int s = 0; s <
|
|
850
|
-
llama_seq_id seq_id =
|
|
876
|
+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
|
877
|
+
llama_seq_id seq_id = batch.seq_id[i][s];
|
|
851
878
|
cross.seq_ids_enc[i].insert(seq_id);
|
|
852
879
|
}
|
|
853
880
|
}
|
|
@@ -856,55 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
856
883
|
return 0;
|
|
857
884
|
}
|
|
858
885
|
|
|
859
|
-
int llama_context::decode(llama_batch &
|
|
886
|
+
int llama_context::decode(const llama_batch & batch_inp) {
|
|
860
887
|
if (!memory) {
|
|
861
888
|
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
|
862
|
-
return encode(
|
|
889
|
+
return encode(batch_inp);
|
|
863
890
|
}
|
|
864
891
|
|
|
865
|
-
if (
|
|
892
|
+
if (batch_inp.n_tokens == 0) {
|
|
866
893
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
867
894
|
return -1;
|
|
868
895
|
}
|
|
869
896
|
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
|
873
|
-
return -1;
|
|
874
|
-
}
|
|
875
|
-
}
|
|
876
|
-
|
|
877
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
897
|
+
// when computing embeddings, all tokens are output
|
|
898
|
+
const bool embd_all = cparams.embeddings;
|
|
878
899
|
|
|
879
|
-
|
|
880
|
-
|
|
900
|
+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
|
|
901
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
902
|
+
return -1;
|
|
903
|
+
}
|
|
881
904
|
|
|
882
|
-
const llama_batch & batch = batch_allocr
|
|
905
|
+
const llama_batch & batch = batch_allocr->get_batch();
|
|
883
906
|
|
|
884
907
|
const auto & vocab = model.vocab;
|
|
885
908
|
const auto & hparams = model.hparams;
|
|
886
909
|
|
|
887
910
|
const int32_t n_vocab = vocab.n_tokens();
|
|
911
|
+
const int64_t n_embd = hparams.n_embd;
|
|
888
912
|
|
|
889
|
-
const
|
|
890
|
-
const int64_t n_embd = hparams.n_embd;
|
|
891
|
-
|
|
892
|
-
llama_kv_cache_guard kv_guard(kv_self);
|
|
913
|
+
const uint32_t n_tokens_all = batch.n_tokens;
|
|
893
914
|
|
|
894
915
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
895
916
|
|
|
896
|
-
|
|
897
|
-
if (batch.token) {
|
|
898
|
-
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
|
899
|
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
900
|
-
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
|
901
|
-
return -1;
|
|
902
|
-
}
|
|
917
|
+
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
|
|
903
918
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
919
|
+
if (embd_all) {
|
|
920
|
+
// require that all tokens are output
|
|
921
|
+
if (n_outputs_all != n_tokens_all) {
|
|
922
|
+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
|
923
|
+
__func__, n_outputs_all, n_tokens_all);
|
|
924
|
+
return -1;
|
|
908
925
|
}
|
|
909
926
|
}
|
|
910
927
|
|
|
@@ -917,42 +934,71 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
917
934
|
}
|
|
918
935
|
n_queued_tokens += n_tokens_all;
|
|
919
936
|
|
|
920
|
-
// this
|
|
921
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
922
|
-
|
|
937
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
923
938
|
embd_seq.clear();
|
|
924
939
|
|
|
925
|
-
|
|
940
|
+
bool did_optimize = false;
|
|
941
|
+
|
|
942
|
+
// handle any pending defrags/shifts
|
|
943
|
+
kv_self_update(false);
|
|
944
|
+
|
|
945
|
+
llama_memory_state_ptr mstate;
|
|
926
946
|
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
947
|
+
while (true) {
|
|
948
|
+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
|
|
949
|
+
if (!mstate) {
|
|
950
|
+
return -2;
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
switch (mstate->get_status()) {
|
|
954
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
|
955
|
+
{
|
|
956
|
+
} break;
|
|
957
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
|
958
|
+
{
|
|
959
|
+
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
|
960
|
+
|
|
961
|
+
return -2;
|
|
962
|
+
}
|
|
963
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
|
964
|
+
{
|
|
965
|
+
if (!did_optimize) {
|
|
966
|
+
did_optimize = true;
|
|
967
|
+
|
|
968
|
+
if (kv_self_update(true)) {
|
|
969
|
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
|
|
970
|
+
|
|
971
|
+
continue;
|
|
972
|
+
}
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
|
|
976
|
+
|
|
977
|
+
return 1;
|
|
978
|
+
}
|
|
979
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
|
980
|
+
{
|
|
981
|
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
|
|
982
|
+
|
|
983
|
+
return -2;
|
|
984
|
+
}
|
|
931
985
|
}
|
|
932
|
-
} else if (embd_pooled) {
|
|
933
|
-
n_outputs_all = n_tokens_all;
|
|
934
|
-
} else {
|
|
935
|
-
// keep last output only
|
|
936
|
-
n_outputs_all = 1;
|
|
937
|
-
}
|
|
938
986
|
|
|
939
|
-
|
|
987
|
+
break;
|
|
988
|
+
}
|
|
940
989
|
|
|
941
990
|
// reserve output buffer
|
|
942
991
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
943
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
|
992
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
944
993
|
return -2;
|
|
945
994
|
};
|
|
946
995
|
|
|
947
|
-
// handle any pending defrags/shifts
|
|
948
|
-
kv_self_update();
|
|
949
|
-
|
|
950
996
|
int64_t n_outputs_prev = 0;
|
|
951
997
|
|
|
952
|
-
|
|
953
|
-
|
|
998
|
+
do {
|
|
999
|
+
const auto & ubatch = mstate->get_ubatch();
|
|
954
1000
|
|
|
955
|
-
// count the outputs in this
|
|
1001
|
+
// count the outputs in this ubatch
|
|
956
1002
|
{
|
|
957
1003
|
int32_t n_outputs_new = 0;
|
|
958
1004
|
|
|
@@ -969,33 +1015,41 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
969
1015
|
n_outputs = n_outputs_new;
|
|
970
1016
|
}
|
|
971
1017
|
|
|
972
|
-
// find KV slot
|
|
973
|
-
if (!kv_self->find_slot(ubatch)) {
|
|
974
|
-
return 1;
|
|
975
|
-
}
|
|
976
|
-
|
|
977
1018
|
ggml_backend_sched_reset(sched.get());
|
|
978
1019
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
979
1020
|
|
|
980
|
-
|
|
981
|
-
auto res =
|
|
1021
|
+
ggml_status status;
|
|
1022
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
|
982
1023
|
|
|
983
|
-
|
|
1024
|
+
if (!res) {
|
|
1025
|
+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
|
1026
|
+
llama_pos pos_min[LLAMA_MAX_SEQ];
|
|
1027
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
1028
|
+
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
|
1029
|
+
}
|
|
984
1030
|
|
|
985
|
-
|
|
1031
|
+
// TODO: fix sequence indexing
|
|
1032
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
1033
|
+
const auto & seq_id = ubatch.seq_id[i][0];
|
|
986
1034
|
|
|
987
|
-
|
|
1035
|
+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
|
1036
|
+
}
|
|
988
1037
|
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
1038
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
1039
|
+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
|
1040
|
+
continue;
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
|
1044
|
+
|
|
1045
|
+
memory->seq_rm(s, pos_min[s], -1);
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
switch (status) {
|
|
1049
|
+
case GGML_STATUS_ABORTED: return 2;
|
|
1050
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
|
1051
|
+
case GGML_STATUS_FAILED: return -3;
|
|
1052
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
|
999
1053
|
}
|
|
1000
1054
|
}
|
|
1001
1055
|
|
|
@@ -1004,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1004
1058
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
1005
1059
|
//}
|
|
1006
1060
|
|
|
1007
|
-
auto * t_logits =
|
|
1061
|
+
auto * t_logits = res->get_logits();
|
|
1008
1062
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
|
1009
1063
|
|
|
1010
1064
|
if (t_embd && res->get_embd_pooled()) {
|
|
@@ -1082,23 +1136,20 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1082
1136
|
}
|
|
1083
1137
|
|
|
1084
1138
|
n_outputs_prev += n_outputs;
|
|
1085
|
-
}
|
|
1086
|
-
|
|
1087
|
-
// finalize the batch processing
|
|
1088
|
-
kv_guard.commit();
|
|
1139
|
+
} while (mstate->next());
|
|
1089
1140
|
|
|
1090
1141
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
1091
1142
|
n_outputs = n_outputs_all;
|
|
1092
1143
|
|
|
1093
1144
|
// set output mappings
|
|
1094
|
-
{
|
|
1145
|
+
if (n_outputs > 0) {
|
|
1095
1146
|
bool sorted_output = true;
|
|
1096
1147
|
|
|
1097
|
-
auto & out_ids =
|
|
1148
|
+
auto & out_ids = mstate->out_ids();
|
|
1098
1149
|
|
|
1099
|
-
GGML_ASSERT(out_ids.size() == (size_t)
|
|
1150
|
+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
|
1100
1151
|
|
|
1101
|
-
for (int64_t i = 0; i <
|
|
1152
|
+
for (int64_t i = 0; i < n_outputs; ++i) {
|
|
1102
1153
|
int64_t out_id = out_ids[i];
|
|
1103
1154
|
output_ids[out_id] = i;
|
|
1104
1155
|
if (out_id != i) {
|
|
@@ -1110,20 +1161,22 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1110
1161
|
// note: this is mostly relevant for recurrent models atm
|
|
1111
1162
|
if (!sorted_output) {
|
|
1112
1163
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
1113
|
-
const
|
|
1164
|
+
const uint64_t n_embd = model.hparams.n_embd;
|
|
1114
1165
|
|
|
1115
1166
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1116
1167
|
|
|
1117
1168
|
// TODO: is there something more efficient which also minimizes swaps?
|
|
1118
1169
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
1119
|
-
for (
|
|
1120
|
-
|
|
1121
|
-
for (
|
|
1170
|
+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
|
1171
|
+
uint32_t j_min = i;
|
|
1172
|
+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
|
1122
1173
|
if (out_ids[j] < out_ids[j_min]) {
|
|
1123
1174
|
j_min = j;
|
|
1124
1175
|
}
|
|
1125
1176
|
}
|
|
1126
|
-
if (j_min == i) {
|
|
1177
|
+
if (j_min == i) {
|
|
1178
|
+
continue;
|
|
1179
|
+
}
|
|
1127
1180
|
std::swap(out_ids[i], out_ids[j_min]);
|
|
1128
1181
|
if (logits_size > 0) {
|
|
1129
1182
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
@@ -1136,8 +1189,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1136
1189
|
}
|
|
1137
1190
|
}
|
|
1138
1191
|
}
|
|
1192
|
+
|
|
1139
1193
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1140
|
-
|
|
1194
|
+
|
|
1195
|
+
for (uint32_t i = 0; i < n_outputs; ++i) {
|
|
1141
1196
|
output_ids[out_ids[i]] = i;
|
|
1142
1197
|
}
|
|
1143
1198
|
}
|
|
@@ -1146,11 +1201,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1146
1201
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
|
1147
1202
|
//synchronize();
|
|
1148
1203
|
|
|
1149
|
-
// decide if we need to defrag the kv cache
|
|
1150
|
-
if (cparams.defrag_thold > 0.0f) {
|
|
1151
|
-
kv_self->defrag_sched(cparams.defrag_thold);
|
|
1152
|
-
}
|
|
1153
|
-
|
|
1154
1204
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
1155
1205
|
// overlap with device computation.
|
|
1156
1206
|
ggml_backend_sched_reset(sched.get());
|
|
@@ -1162,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1162
1212
|
// output
|
|
1163
1213
|
//
|
|
1164
1214
|
|
|
1165
|
-
|
|
1215
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1166
1216
|
const auto & hparams = model.hparams;
|
|
1167
1217
|
const auto & vocab = model.vocab;
|
|
1168
1218
|
|
|
@@ -1172,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1172
1222
|
const auto n_vocab = vocab.n_tokens();
|
|
1173
1223
|
const auto n_embd = hparams.n_embd;
|
|
1174
1224
|
|
|
1175
|
-
|
|
1176
|
-
bool
|
|
1177
|
-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
1225
|
+
bool has_logits = true;
|
|
1226
|
+
bool has_embd = cparams.embeddings;
|
|
1178
1227
|
|
|
1179
1228
|
// TODO: hacky enc-dec support
|
|
1180
1229
|
if (model.arch == LLM_ARCH_T5) {
|
|
@@ -1228,8 +1277,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1228
1277
|
// set all ids as invalid (negative)
|
|
1229
1278
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1230
1279
|
|
|
1231
|
-
this->n_outputs
|
|
1232
|
-
this->n_outputs_max = n_outputs_max;
|
|
1280
|
+
this->n_outputs = 0;
|
|
1233
1281
|
|
|
1234
1282
|
return n_outputs_max;
|
|
1235
1283
|
}
|
|
@@ -1254,11 +1302,52 @@ ggml_cgraph * llama_context::graph_init() {
|
|
|
1254
1302
|
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
|
1255
1303
|
}
|
|
1256
1304
|
|
|
1305
|
+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
|
1306
|
+
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1307
|
+
|
|
1308
|
+
if (n_tokens % n_seqs != 0) {
|
|
1309
|
+
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
|
1310
|
+
n_outputs = std::min(n_outputs, n_tokens);
|
|
1311
|
+
|
|
1312
|
+
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1313
|
+
}
|
|
1314
|
+
|
|
1315
|
+
// store the n_outputs as it is, and restore it afterwards
|
|
1316
|
+
// TODO: not sure if needed, might simplify in the future by removing this
|
|
1317
|
+
const auto save_n_outputs = this->n_outputs;
|
|
1318
|
+
|
|
1319
|
+
this->n_outputs = n_outputs;
|
|
1320
|
+
|
|
1321
|
+
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
1322
|
+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
1323
|
+
|
|
1324
|
+
auto * gf = graph_init();
|
|
1325
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
|
1326
|
+
|
|
1327
|
+
this->n_outputs = save_n_outputs;
|
|
1328
|
+
|
|
1329
|
+
if (!res) {
|
|
1330
|
+
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
|
1331
|
+
return nullptr;
|
|
1332
|
+
}
|
|
1333
|
+
|
|
1334
|
+
ggml_backend_sched_reset(sched.get());
|
|
1335
|
+
|
|
1336
|
+
// initialize scheduler with the specified graph
|
|
1337
|
+
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
1338
|
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
1339
|
+
return nullptr;
|
|
1340
|
+
}
|
|
1341
|
+
|
|
1342
|
+
return gf;
|
|
1343
|
+
}
|
|
1344
|
+
|
|
1257
1345
|
llm_graph_result_ptr llama_context::graph_build(
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1346
|
+
ggml_context * ctx,
|
|
1347
|
+
ggml_cgraph * gf,
|
|
1348
|
+
const llama_ubatch & ubatch,
|
|
1349
|
+
llm_graph_type gtype,
|
|
1350
|
+
const llama_memory_state_i * mstate) {
|
|
1262
1351
|
return model.build_graph(
|
|
1263
1352
|
{
|
|
1264
1353
|
/*.ctx =*/ ctx,
|
|
@@ -1270,7 +1359,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
|
1270
1359
|
/*.backend_cpu =*/ backend_cpu,
|
|
1271
1360
|
/*.cvec =*/ &cvec,
|
|
1272
1361
|
/*.loras =*/ &loras,
|
|
1273
|
-
/*.
|
|
1362
|
+
/*.mstate =*/ mstate,
|
|
1274
1363
|
/*.cross =*/ &cross,
|
|
1275
1364
|
/*.n_outputs =*/ n_outputs,
|
|
1276
1365
|
/*.cb =*/ graph_get_cb(),
|
|
@@ -1679,14 +1768,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1679
1768
|
|
|
1680
1769
|
std::vector<int32_t> w_output_pos;
|
|
1681
1770
|
|
|
1682
|
-
GGML_ASSERT(n_outputs <= n_outputs_max);
|
|
1683
|
-
|
|
1684
1771
|
w_output_pos.resize(n_outputs);
|
|
1685
1772
|
|
|
1686
1773
|
// build a more compact representation of the output ids
|
|
1687
1774
|
for (size_t i = 0; i < n_batch(); ++i) {
|
|
1688
1775
|
// map an output id to a position in the batch
|
|
1689
|
-
|
|
1776
|
+
int64_t pos = output_ids[i];
|
|
1690
1777
|
if (pos >= 0) {
|
|
1691
1778
|
GGML_ASSERT(pos < n_outputs);
|
|
1692
1779
|
w_output_pos[pos] = i;
|
|
@@ -1726,11 +1813,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1726
1813
|
}
|
|
1727
1814
|
}
|
|
1728
1815
|
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
if (kv_self != nullptr) {
|
|
1816
|
+
if (memory != nullptr) {
|
|
1732
1817
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
|
1733
|
-
|
|
1818
|
+
memory->state_write(io);
|
|
1734
1819
|
}
|
|
1735
1820
|
|
|
1736
1821
|
return io.n_bytes();
|
|
@@ -1817,9 +1902,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
1817
1902
|
if (memory) {
|
|
1818
1903
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
|
1819
1904
|
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
kv_self->state_read(io);
|
|
1905
|
+
memory->state_read(io);
|
|
1823
1906
|
}
|
|
1824
1907
|
|
|
1825
1908
|
return io.n_bytes();
|
|
@@ -1829,9 +1912,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
|
1829
1912
|
GGML_UNUSED(seq_id);
|
|
1830
1913
|
|
|
1831
1914
|
if (memory) {
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
kv_self->state_write(io, seq_id);
|
|
1915
|
+
memory->state_write(io, seq_id);
|
|
1835
1916
|
}
|
|
1836
1917
|
|
|
1837
1918
|
return io.n_bytes();
|
|
@@ -1841,9 +1922,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
|
|
1841
1922
|
GGML_UNUSED(seq_id);
|
|
1842
1923
|
|
|
1843
1924
|
if (memory) {
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
kv_self->state_read(io, seq_id);
|
|
1925
|
+
memory->state_read(io, seq_id);
|
|
1847
1926
|
}
|
|
1848
1927
|
|
|
1849
1928
|
return io.n_bytes();
|
|
@@ -1948,10 +2027,7 @@ void llama_context::opt_epoch_iter(
|
|
|
1948
2027
|
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
|
1949
2028
|
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
|
1950
2029
|
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
kv_self->clear();
|
|
1954
|
-
llama_kv_cache_guard kv_guard(kv_self);
|
|
2030
|
+
memory->clear(true);
|
|
1955
2031
|
|
|
1956
2032
|
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
|
1957
2033
|
batch.n_tokens = n_batch;
|
|
@@ -1967,35 +2043,35 @@ void llama_context::opt_epoch_iter(
|
|
|
1967
2043
|
|
|
1968
2044
|
n_queued_tokens += n_tokens_all;
|
|
1969
2045
|
|
|
1970
|
-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
1971
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
1972
|
-
|
|
1973
2046
|
embd_seq.clear();
|
|
1974
2047
|
|
|
1975
|
-
|
|
2048
|
+
uint32_t n_outputs_all = n_tokens_all;
|
|
1976
2049
|
|
|
1977
|
-
|
|
2050
|
+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
|
|
2051
|
+
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
|
2052
|
+
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
|
2053
|
+
break;
|
|
2054
|
+
}
|
|
1978
2055
|
|
|
1979
2056
|
// reserve output buffer
|
|
1980
2057
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
1981
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
|
2058
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
1982
2059
|
GGML_ABORT("TODO: handle this error");
|
|
1983
2060
|
};
|
|
1984
2061
|
|
|
1985
|
-
|
|
1986
|
-
|
|
2062
|
+
uint32_t pos_batch = 0;
|
|
2063
|
+
do {
|
|
2064
|
+
const auto & ubatch = mstate->get_ubatch();
|
|
1987
2065
|
|
|
1988
2066
|
n_outputs = ubatch.n_tokens;
|
|
1989
2067
|
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
GGML_ABORT("TODO: handle this error");
|
|
2068
|
+
if (!mstate->apply()) {
|
|
2069
|
+
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
|
2070
|
+
break;
|
|
1995
2071
|
}
|
|
1996
2072
|
|
|
1997
2073
|
auto * gf = graph_init();
|
|
1998
|
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
|
2074
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
|
1999
2075
|
|
|
2000
2076
|
struct ggml_context * ctx_compute_opt;
|
|
2001
2077
|
{
|
|
@@ -2010,6 +2086,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2010
2086
|
}
|
|
2011
2087
|
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
|
2012
2088
|
ggml_opt_alloc(opt_ctx, train);
|
|
2089
|
+
|
|
2013
2090
|
res->set_inputs(&ubatch);
|
|
2014
2091
|
{
|
|
2015
2092
|
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
|
@@ -2027,10 +2104,10 @@ void llama_context::opt_epoch_iter(
|
|
|
2027
2104
|
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
|
2028
2105
|
}
|
|
2029
2106
|
ggml_free(ctx_compute_opt);
|
|
2030
|
-
}
|
|
2031
|
-
}
|
|
2032
2107
|
|
|
2033
|
-
|
|
2108
|
+
pos_batch += ubatch.n_tokens;
|
|
2109
|
+
} while (mstate->next());
|
|
2110
|
+
}
|
|
2034
2111
|
}
|
|
2035
2112
|
|
|
2036
2113
|
void llama_context::opt_epoch(
|
|
@@ -2190,12 +2267,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|
|
2190
2267
|
return &ctx->get_model();
|
|
2191
2268
|
}
|
|
2192
2269
|
|
|
2270
|
+
// deprecated
|
|
2193
2271
|
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|
2194
|
-
return ctx->
|
|
2272
|
+
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
|
2195
2273
|
}
|
|
2196
2274
|
|
|
2275
|
+
// deprecated
|
|
2197
2276
|
void llama_kv_self_update(llama_context * ctx) {
|
|
2198
|
-
ctx->kv_self_update();
|
|
2277
|
+
ctx->kv_self_update(false);
|
|
2199
2278
|
}
|
|
2200
2279
|
|
|
2201
2280
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
|
@@ -2310,13 +2389,118 @@ int32_t llama_apply_adapter_cvec(
|
|
|
2310
2389
|
return res ? 0 : -1;
|
|
2311
2390
|
}
|
|
2312
2391
|
|
|
2392
|
+
//
|
|
2393
|
+
// memory
|
|
2394
|
+
//
|
|
2395
|
+
|
|
2396
|
+
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
|
2397
|
+
return ctx->get_memory();
|
|
2398
|
+
}
|
|
2399
|
+
|
|
2400
|
+
void llama_memory_clear(llama_memory_t mem, bool data) {
|
|
2401
|
+
if (!mem) {
|
|
2402
|
+
return;
|
|
2403
|
+
}
|
|
2404
|
+
|
|
2405
|
+
mem->clear(data);
|
|
2406
|
+
}
|
|
2407
|
+
|
|
2408
|
+
bool llama_memory_seq_rm(
|
|
2409
|
+
llama_memory_t mem,
|
|
2410
|
+
llama_seq_id seq_id,
|
|
2411
|
+
llama_pos p0,
|
|
2412
|
+
llama_pos p1) {
|
|
2413
|
+
if (!mem) {
|
|
2414
|
+
return true;
|
|
2415
|
+
}
|
|
2416
|
+
|
|
2417
|
+
return mem->seq_rm(seq_id, p0, p1);
|
|
2418
|
+
}
|
|
2419
|
+
|
|
2420
|
+
void llama_memory_seq_cp(
|
|
2421
|
+
llama_memory_t mem,
|
|
2422
|
+
llama_seq_id seq_id_src,
|
|
2423
|
+
llama_seq_id seq_id_dst,
|
|
2424
|
+
llama_pos p0,
|
|
2425
|
+
llama_pos p1) {
|
|
2426
|
+
if (!mem) {
|
|
2427
|
+
return;
|
|
2428
|
+
}
|
|
2429
|
+
|
|
2430
|
+
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
2431
|
+
}
|
|
2432
|
+
|
|
2433
|
+
void llama_memory_seq_keep(
|
|
2434
|
+
llama_memory_t mem,
|
|
2435
|
+
llama_seq_id seq_id) {
|
|
2436
|
+
if (!mem) {
|
|
2437
|
+
return;
|
|
2438
|
+
}
|
|
2439
|
+
|
|
2440
|
+
mem->seq_keep(seq_id);
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2443
|
+
void llama_memory_seq_add(
|
|
2444
|
+
llama_memory_t mem,
|
|
2445
|
+
llama_seq_id seq_id,
|
|
2446
|
+
llama_pos p0,
|
|
2447
|
+
llama_pos p1,
|
|
2448
|
+
llama_pos delta) {
|
|
2449
|
+
if (!mem) {
|
|
2450
|
+
return;
|
|
2451
|
+
}
|
|
2452
|
+
|
|
2453
|
+
mem->seq_add(seq_id, p0, p1, delta);
|
|
2454
|
+
}
|
|
2455
|
+
|
|
2456
|
+
void llama_memory_seq_div(
|
|
2457
|
+
llama_memory_t mem,
|
|
2458
|
+
llama_seq_id seq_id,
|
|
2459
|
+
llama_pos p0,
|
|
2460
|
+
llama_pos p1,
|
|
2461
|
+
int d) {
|
|
2462
|
+
if (!mem) {
|
|
2463
|
+
return;
|
|
2464
|
+
}
|
|
2465
|
+
|
|
2466
|
+
mem->seq_div(seq_id, p0, p1, d);
|
|
2467
|
+
}
|
|
2468
|
+
|
|
2469
|
+
llama_pos llama_memory_seq_pos_min(
|
|
2470
|
+
llama_memory_t mem,
|
|
2471
|
+
llama_seq_id seq_id) {
|
|
2472
|
+
if (!mem) {
|
|
2473
|
+
return -1;
|
|
2474
|
+
}
|
|
2475
|
+
|
|
2476
|
+
return mem->seq_pos_min(seq_id);
|
|
2477
|
+
}
|
|
2478
|
+
|
|
2479
|
+
llama_pos llama_memory_seq_pos_max(
|
|
2480
|
+
llama_memory_t mem,
|
|
2481
|
+
llama_seq_id seq_id) {
|
|
2482
|
+
if (!mem) {
|
|
2483
|
+
return -1;
|
|
2484
|
+
}
|
|
2485
|
+
|
|
2486
|
+
return mem->seq_pos_max(seq_id);
|
|
2487
|
+
}
|
|
2488
|
+
|
|
2489
|
+
bool llama_memory_can_shift(llama_memory_t mem) {
|
|
2490
|
+
if (!mem) {
|
|
2491
|
+
return false;
|
|
2492
|
+
}
|
|
2493
|
+
|
|
2494
|
+
return mem->get_can_shift();
|
|
2495
|
+
}
|
|
2496
|
+
|
|
2313
2497
|
//
|
|
2314
2498
|
// kv cache
|
|
2315
2499
|
//
|
|
2316
2500
|
|
|
2317
2501
|
// deprecated
|
|
2318
2502
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
2319
|
-
const auto * kv = ctx
|
|
2503
|
+
const auto * kv = llama_get_memory(ctx);
|
|
2320
2504
|
if (!kv) {
|
|
2321
2505
|
return 0;
|
|
2322
2506
|
}
|
|
@@ -2338,7 +2522,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
|
2338
2522
|
// deprecated
|
|
2339
2523
|
// note: this is the same as above - will be removed anyway, so it's ok
|
|
2340
2524
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
2341
|
-
const auto * kv = ctx
|
|
2525
|
+
const auto * kv = llama_get_memory(ctx);
|
|
2342
2526
|
if (!kv) {
|
|
2343
2527
|
return 0;
|
|
2344
2528
|
}
|
|
@@ -2357,114 +2541,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
|
2357
2541
|
return res;
|
|
2358
2542
|
}
|
|
2359
2543
|
|
|
2544
|
+
// deprecated
|
|
2360
2545
|
void llama_kv_self_clear(llama_context * ctx) {
|
|
2361
|
-
auto * kv = ctx
|
|
2546
|
+
auto * kv = llama_get_memory(ctx);
|
|
2362
2547
|
if (!kv) {
|
|
2363
2548
|
return;
|
|
2364
2549
|
}
|
|
2365
2550
|
|
|
2366
|
-
kv
|
|
2551
|
+
llama_memory_clear(kv, true);
|
|
2367
2552
|
}
|
|
2368
2553
|
|
|
2554
|
+
// deprecated
|
|
2369
2555
|
bool llama_kv_self_seq_rm(
|
|
2370
2556
|
llama_context * ctx,
|
|
2371
2557
|
llama_seq_id seq_id,
|
|
2372
2558
|
llama_pos p0,
|
|
2373
2559
|
llama_pos p1) {
|
|
2374
|
-
auto * kv = ctx
|
|
2560
|
+
auto * kv = llama_get_memory(ctx);
|
|
2375
2561
|
if (!kv) {
|
|
2376
2562
|
return true;
|
|
2377
2563
|
}
|
|
2378
2564
|
|
|
2379
|
-
return kv
|
|
2565
|
+
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
|
2380
2566
|
}
|
|
2381
2567
|
|
|
2568
|
+
// deprecated
|
|
2382
2569
|
void llama_kv_self_seq_cp(
|
|
2383
2570
|
llama_context * ctx,
|
|
2384
2571
|
llama_seq_id seq_id_src,
|
|
2385
2572
|
llama_seq_id seq_id_dst,
|
|
2386
2573
|
llama_pos p0,
|
|
2387
2574
|
llama_pos p1) {
|
|
2388
|
-
auto * kv = ctx
|
|
2575
|
+
auto * kv = llama_get_memory(ctx);
|
|
2389
2576
|
if (!kv) {
|
|
2390
2577
|
return;
|
|
2391
2578
|
}
|
|
2392
2579
|
|
|
2393
|
-
kv
|
|
2580
|
+
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
|
2394
2581
|
}
|
|
2395
2582
|
|
|
2583
|
+
// deprecated
|
|
2396
2584
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
2397
|
-
auto * kv = ctx
|
|
2585
|
+
auto * kv = llama_get_memory(ctx);
|
|
2398
2586
|
if (!kv) {
|
|
2399
2587
|
return;
|
|
2400
2588
|
}
|
|
2401
2589
|
|
|
2402
|
-
kv
|
|
2590
|
+
llama_memory_seq_keep(kv, seq_id);
|
|
2403
2591
|
}
|
|
2404
2592
|
|
|
2593
|
+
// deprecated
|
|
2405
2594
|
void llama_kv_self_seq_add(
|
|
2406
2595
|
llama_context * ctx,
|
|
2407
2596
|
llama_seq_id seq_id,
|
|
2408
2597
|
llama_pos p0,
|
|
2409
2598
|
llama_pos p1,
|
|
2410
2599
|
llama_pos delta) {
|
|
2411
|
-
auto * kv = ctx
|
|
2600
|
+
auto * kv = llama_get_memory(ctx);
|
|
2412
2601
|
if (!kv) {
|
|
2413
2602
|
return;
|
|
2414
2603
|
}
|
|
2415
2604
|
|
|
2416
|
-
kv
|
|
2605
|
+
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
|
2417
2606
|
}
|
|
2418
2607
|
|
|
2608
|
+
// deprecated
|
|
2419
2609
|
void llama_kv_self_seq_div(
|
|
2420
2610
|
llama_context * ctx,
|
|
2421
2611
|
llama_seq_id seq_id,
|
|
2422
2612
|
llama_pos p0,
|
|
2423
2613
|
llama_pos p1,
|
|
2424
2614
|
int d) {
|
|
2425
|
-
auto * kv = ctx
|
|
2615
|
+
auto * kv = llama_get_memory(ctx);
|
|
2426
2616
|
if (!kv) {
|
|
2427
2617
|
return;
|
|
2428
2618
|
}
|
|
2429
2619
|
|
|
2430
|
-
kv
|
|
2620
|
+
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
|
2431
2621
|
}
|
|
2432
2622
|
|
|
2623
|
+
// deprecated
|
|
2433
2624
|
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
|
2434
|
-
|
|
2625
|
+
auto * kv = llama_get_memory(ctx);
|
|
2435
2626
|
if (!kv) {
|
|
2436
2627
|
return -1;
|
|
2437
2628
|
}
|
|
2438
2629
|
|
|
2439
|
-
return kv
|
|
2630
|
+
return llama_memory_seq_pos_min(kv, seq_id);
|
|
2440
2631
|
}
|
|
2441
2632
|
|
|
2633
|
+
// deprecated
|
|
2442
2634
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
2443
|
-
|
|
2635
|
+
auto * kv = llama_get_memory(ctx);
|
|
2444
2636
|
if (!kv) {
|
|
2445
2637
|
return -1;
|
|
2446
2638
|
}
|
|
2447
2639
|
|
|
2448
|
-
return kv
|
|
2640
|
+
return llama_memory_seq_pos_max(kv, seq_id);
|
|
2449
2641
|
}
|
|
2450
2642
|
|
|
2643
|
+
// deprecated
|
|
2451
2644
|
void llama_kv_self_defrag(llama_context * ctx) {
|
|
2452
|
-
auto * kv = ctx->get_kv_self();
|
|
2453
|
-
if (!kv) {
|
|
2454
|
-
return;
|
|
2455
|
-
}
|
|
2456
|
-
|
|
2457
2645
|
// force defrag
|
|
2458
|
-
|
|
2646
|
+
ctx->kv_self_defrag_sched();
|
|
2459
2647
|
}
|
|
2460
2648
|
|
|
2649
|
+
// deprecated
|
|
2461
2650
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
2462
|
-
|
|
2651
|
+
auto * kv = llama_get_memory(ctx);
|
|
2463
2652
|
if (!kv) {
|
|
2464
2653
|
return false;
|
|
2465
2654
|
}
|
|
2466
2655
|
|
|
2467
|
-
return kv
|
|
2656
|
+
return llama_memory_can_shift(kv);
|
|
2468
2657
|
}
|
|
2469
2658
|
|
|
2470
2659
|
// llama state API
|
|
@@ -2589,22 +2778,8 @@ int32_t llama_encode(
|
|
|
2589
2778
|
int32_t llama_decode(
|
|
2590
2779
|
llama_context * ctx,
|
|
2591
2780
|
llama_batch batch) {
|
|
2592
|
-
int ret = ctx->decode(batch);
|
|
2593
|
-
|
|
2594
|
-
// defrag and try again
|
|
2595
|
-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
|
2596
|
-
if (ret == 1) {
|
|
2597
|
-
llama_kv_self_defrag(ctx);
|
|
2598
|
-
ret = ctx->decode(batch);
|
|
2599
|
-
|
|
2600
|
-
if (ret == 1) {
|
|
2601
|
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
|
2602
|
-
|
|
2603
|
-
return ret;
|
|
2604
|
-
}
|
|
2605
|
-
}
|
|
2606
|
-
|
|
2607
|
-
if (ret != 0) {
|
|
2781
|
+
const int ret = ctx->decode(batch);
|
|
2782
|
+
if (ret != 0 && ret != 1) {
|
|
2608
2783
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
2609
2784
|
}
|
|
2610
2785
|
|