cui-llama.rn 1.7.4 → 1.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +217 -17
- package/android/src/main/CMakeLists.txt +34 -15
- package/android/src/main/java/com/rnllama/LlamaContext.java +79 -5
- package/android/src/main/java/com/rnllama/RNLlama.java +237 -0
- package/android/src/main/jni.cpp +213 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
- package/cpp/README.md +1 -1
- package/cpp/chat-parser.cpp +385 -0
- package/cpp/chat-parser.h +120 -0
- package/cpp/chat.cpp +726 -596
- package/cpp/chat.h +71 -6
- package/cpp/common.cpp +56 -38
- package/cpp/common.h +9 -3
- package/cpp/ggml-backend-reg.cpp +5 -0
- package/cpp/ggml-backend.cpp +10 -2
- package/cpp/ggml-common.h +4 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
- package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- package/cpp/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/ggml-cpu/common.h +4 -3
- package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
- package/cpp/ggml-cpu/ggml-cpu.c +123 -104
- package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
- package/cpp/ggml-cpu/ops.cpp +330 -148
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/quants.c +1158 -0
- package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/ggml-cpu/repack.cpp +1571 -0
- package/cpp/ggml-cpu/repack.h +98 -0
- package/cpp/ggml-cpu/simd-mappings.h +330 -38
- package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/ggml-cpu/vec.cpp +87 -18
- package/cpp/ggml-cpu/vec.h +249 -94
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +63 -183
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal.m +152 -45
- package/cpp/ggml-quants.c +0 -2
- package/cpp/ggml.c +61 -21
- package/cpp/ggml.h +22 -3
- package/cpp/gguf.cpp +24 -3
- package/cpp/json-partial.cpp +256 -0
- package/cpp/json-partial.h +38 -0
- package/cpp/json-schema-to-grammar.cpp +5 -47
- package/cpp/json-schema-to-grammar.h +4 -4
- package/cpp/llama-arch.cpp +153 -3
- package/cpp/llama-arch.h +27 -1
- package/cpp/llama-batch.cpp +741 -272
- package/cpp/llama-batch.h +112 -54
- package/cpp/llama-chat.cpp +30 -8
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +524 -339
- package/cpp/llama-context.h +38 -17
- package/cpp/llama-cparams.cpp +4 -0
- package/cpp/llama-cparams.h +2 -0
- package/cpp/llama-grammar.cpp +12 -2
- package/cpp/llama-graph.cpp +431 -356
- package/cpp/llama-graph.h +126 -58
- package/cpp/llama-hparams.cpp +10 -2
- package/cpp/llama-hparams.h +19 -2
- package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
- package/cpp/llama-kv-cache-unified-iswa.h +128 -0
- package/cpp/llama-kv-cache-unified.cpp +1841 -0
- package/cpp/llama-kv-cache-unified.h +303 -0
- package/cpp/llama-kv-cells.h +439 -0
- package/cpp/llama-memory-hybrid.cpp +246 -0
- package/cpp/llama-memory-hybrid.h +138 -0
- package/cpp/llama-memory-recurrent.cpp +1112 -0
- package/cpp/llama-memory-recurrent.h +183 -0
- package/cpp/llama-memory.cpp +41 -0
- package/cpp/llama-memory.h +86 -5
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +42 -17
- package/cpp/llama-model-saver.cpp +1 -0
- package/cpp/llama-model.cpp +1639 -513
- package/cpp/llama-model.h +26 -0
- package/cpp/llama-sampling.cpp +2 -2
- package/cpp/llama-vocab.cpp +65 -28
- package/cpp/llama-vocab.h +1 -0
- package/cpp/llama.cpp +11 -7
- package/cpp/llama.h +150 -42
- package/cpp/minja/chat-template.hpp +1 -1
- package/cpp/minja/minja.hpp +1 -1
- package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
- package/cpp/nlohmann/json_fwd.hpp +187 -0
- package/cpp/regex-partial.cpp +204 -0
- package/cpp/regex-partial.h +56 -0
- package/cpp/rn-llama.cpp +646 -35
- package/cpp/rn-llama.h +32 -1
- package/cpp/rn-tts.h +39 -0
- package/cpp/sampling.cpp +7 -8
- package/cpp/tools/mtmd/clip-impl.h +5 -0
- package/cpp/tools/mtmd/clip.cpp +572 -436
- package/cpp/tools/mtmd/clip.h +14 -4
- package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
- package/cpp/tools/mtmd/mtmd-audio.h +2 -17
- package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
- package/cpp/tools/mtmd/mtmd-helper.h +91 -0
- package/cpp/tools/mtmd/mtmd.cpp +368 -248
- package/cpp/tools/mtmd/mtmd.h +6 -70
- package/cpp/unicode.cpp +5 -0
- package/ios/CMakeLists.txt +26 -6
- package/ios/RNLlama.h +1 -1
- package/ios/RNLlama.mm +153 -3
- package/ios/RNLlamaContext.h +9 -1
- package/ios/RNLlamaContext.mm +112 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +24 -0
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +46 -2
- package/src/index.ts +105 -1
- package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
- package/cpp/ggml-cpu/sgemm.cpp +0 -3544
- package/cpp/ggml-cpu/sgemm.h +0 -14
- package/cpp/llama-kv-cache.cpp +0 -2827
- package/cpp/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
- /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
package/cpp/llama-context.cpp
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
#include "llama-context.h"
|
2
2
|
|
3
3
|
#include "llama-impl.h"
|
4
|
+
#include "llama-batch.h"
|
4
5
|
#include "llama-io.h"
|
6
|
+
#include "llama-memory.h"
|
5
7
|
#include "llama-mmap.h"
|
6
8
|
#include "llama-model.h"
|
7
|
-
#include "llama-kv-cache.h"
|
8
9
|
|
10
|
+
#include <cinttypes>
|
9
11
|
#include <cstring>
|
12
|
+
#include <limits>
|
10
13
|
#include <stdexcept>
|
11
|
-
#include <cinttypes>
|
12
14
|
|
13
15
|
//
|
14
16
|
// llama_context
|
@@ -17,7 +19,8 @@
|
|
17
19
|
llama_context::llama_context(
|
18
20
|
const llama_model & model,
|
19
21
|
llama_context_params params) :
|
20
|
-
model(model)
|
22
|
+
model(model),
|
23
|
+
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
21
24
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
22
25
|
|
23
26
|
t_start_us = model.t_start_us;
|
@@ -25,7 +28,11 @@ llama_context::llama_context(
|
|
25
28
|
|
26
29
|
const auto & hparams = model.hparams;
|
27
30
|
|
28
|
-
cparams.n_seq_max
|
31
|
+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
32
|
+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
33
|
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
34
|
+
}
|
35
|
+
|
29
36
|
cparams.n_threads = params.n_threads;
|
30
37
|
cparams.n_threads_batch = params.n_threads_batch;
|
31
38
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
@@ -118,6 +125,11 @@ llama_context::llama_context(
|
|
118
125
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
119
126
|
}
|
120
127
|
|
128
|
+
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
|
129
|
+
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
130
|
+
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
131
|
+
}
|
132
|
+
|
121
133
|
if (!hparams.vocab_only) {
|
122
134
|
// GPU backends
|
123
135
|
for (auto * dev : model.devices) {
|
@@ -255,15 +267,9 @@ llama_context::llama_context(
|
|
255
267
|
|
256
268
|
// reserve worst-case graph
|
257
269
|
if (!hparams.vocab_only && memory) {
|
258
|
-
const uint32_t n_seqs =
|
270
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
259
271
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
260
272
|
|
261
|
-
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
262
|
-
|
263
|
-
// restore later
|
264
|
-
// TODO: something cleaner
|
265
|
-
const auto n_outputs_save = n_outputs;
|
266
|
-
|
267
273
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
268
274
|
|
269
275
|
int n_splits_pp = -1;
|
@@ -273,25 +279,18 @@ llama_context::llama_context(
|
|
273
279
|
int n_nodes_tg = -1;
|
274
280
|
|
275
281
|
// simulate full KV cache
|
276
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
277
282
|
|
278
|
-
|
283
|
+
const auto mctx = memory->init_full();
|
284
|
+
if (!mctx) {
|
285
|
+
throw std::runtime_error("failed to initialize KV cache");
|
286
|
+
}
|
279
287
|
|
280
288
|
cross.v_embd.clear();
|
281
289
|
|
282
290
|
// reserve pp graph first so that buffers are only allocated once
|
283
291
|
{
|
284
|
-
|
285
|
-
|
286
|
-
// max number of outputs
|
287
|
-
n_outputs = ubatch_pp.n_tokens;
|
288
|
-
|
289
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
290
|
-
|
291
|
-
auto * gf = graph_init();
|
292
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
293
|
-
|
294
|
-
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
292
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
293
|
+
if (!gf) {
|
295
294
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
296
295
|
}
|
297
296
|
|
@@ -301,16 +300,8 @@ llama_context::llama_context(
|
|
301
300
|
|
302
301
|
// reserve with tg graph to get the number of splits and nodes
|
303
302
|
{
|
304
|
-
|
305
|
-
|
306
|
-
n_outputs = ubatch_tg.n_tokens;
|
307
|
-
|
308
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
309
|
-
|
310
|
-
auto * gf = graph_init();
|
311
|
-
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
312
|
-
|
313
|
-
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
303
|
+
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
304
|
+
if (!gf) {
|
314
305
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
315
306
|
}
|
316
307
|
|
@@ -320,22 +311,12 @@ llama_context::llama_context(
|
|
320
311
|
|
321
312
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
322
313
|
{
|
323
|
-
|
324
|
-
|
325
|
-
n_outputs = ubatch_pp.n_tokens;
|
326
|
-
|
327
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
328
|
-
|
329
|
-
auto * gf = graph_init();
|
330
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
331
|
-
|
332
|
-
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
314
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
315
|
+
if (!gf) {
|
333
316
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
334
317
|
}
|
335
318
|
}
|
336
319
|
|
337
|
-
n_outputs = n_outputs_save;
|
338
|
-
|
339
320
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
340
321
|
lm_ggml_backend_t backend = backend_ptrs[i];
|
341
322
|
lm_ggml_backend_buffer_type_t buft = backend_buft[i];
|
@@ -439,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
|
|
439
420
|
return cparams.n_threads_batch;
|
440
421
|
}
|
441
422
|
|
442
|
-
|
443
|
-
|
444
|
-
return kv_self;
|
445
|
-
}
|
446
|
-
|
447
|
-
const llama_kv_cache * llama_context::get_kv_self() const {
|
448
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
449
|
-
return kv_self;
|
423
|
+
llama_memory_t llama_context::get_memory() const {
|
424
|
+
return memory.get();
|
450
425
|
}
|
451
426
|
|
452
|
-
|
453
|
-
|
427
|
+
// deprecated
|
428
|
+
void llama_context::kv_self_defrag_sched() {
|
429
|
+
if (!memory) {
|
430
|
+
return;
|
431
|
+
}
|
454
432
|
|
455
|
-
|
433
|
+
memory_force_optimize = true;
|
434
|
+
}
|
456
435
|
|
457
|
-
|
436
|
+
// deprecated
|
437
|
+
bool llama_context::kv_self_update(bool optimize) {
|
438
|
+
if (!memory) {
|
439
|
+
return false;
|
440
|
+
}
|
458
441
|
|
459
|
-
|
460
|
-
|
461
|
-
|
442
|
+
{
|
443
|
+
// TODO: remove in the future
|
444
|
+
optimize |= memory_force_optimize;
|
445
|
+
memory_force_optimize = false;
|
462
446
|
|
463
|
-
|
464
|
-
|
465
|
-
|
447
|
+
const auto mctx = memory->init_update(this, optimize);
|
448
|
+
switch (mctx->get_status()) {
|
449
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
450
|
+
{
|
451
|
+
// noop
|
452
|
+
} break;
|
453
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
454
|
+
{
|
455
|
+
// no updates need to be performed
|
456
|
+
return false;
|
457
|
+
}
|
458
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
459
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
460
|
+
{
|
461
|
+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
462
|
+
return false;
|
463
|
+
}
|
464
|
+
}
|
466
465
|
|
467
|
-
|
468
|
-
|
466
|
+
if (!mctx->apply()) {
|
467
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
468
|
+
}
|
469
|
+
}
|
469
470
|
|
470
|
-
|
471
|
-
|
471
|
+
// if the memory module did any computation, we have to reserve a new worst-case graph
|
472
|
+
{
|
473
|
+
const auto mctx = memory->init_full();
|
474
|
+
if (!mctx) {
|
475
|
+
throw std::runtime_error("failed to initialize memory context");
|
476
|
+
}
|
472
477
|
|
473
|
-
|
474
|
-
|
478
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
479
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
475
480
|
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
481
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
482
|
+
if (!gf) {
|
483
|
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
480
484
|
}
|
481
485
|
}
|
486
|
+
|
487
|
+
return true;
|
482
488
|
}
|
483
489
|
|
484
490
|
enum llama_pooling_type llama_context::pooling_type() const {
|
@@ -490,7 +496,7 @@ float * llama_context::get_logits() {
|
|
490
496
|
}
|
491
497
|
|
492
498
|
float * llama_context::get_logits_ith(int32_t i) {
|
493
|
-
|
499
|
+
int64_t j = -1;
|
494
500
|
|
495
501
|
try {
|
496
502
|
if (logits == nullptr) {
|
@@ -513,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
513
519
|
}
|
514
520
|
if (j >= n_outputs) {
|
515
521
|
// This should not happen
|
516
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
522
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
517
523
|
}
|
518
524
|
|
519
525
|
return logits + j*model.vocab.n_tokens();
|
@@ -532,7 +538,7 @@ float * llama_context::get_embeddings() {
|
|
532
538
|
}
|
533
539
|
|
534
540
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
535
|
-
|
541
|
+
int64_t j = -1;
|
536
542
|
|
537
543
|
try {
|
538
544
|
if (embd == nullptr) {
|
@@ -555,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
555
561
|
}
|
556
562
|
if (j >= n_outputs) {
|
557
563
|
// This should not happen
|
558
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
564
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
559
565
|
}
|
560
566
|
|
561
567
|
return embd + j*model.hparams.n_embd;
|
@@ -672,63 +678,95 @@ bool llama_context::apply_adapter_cvec(
|
|
672
678
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
673
679
|
}
|
674
680
|
|
675
|
-
|
676
|
-
if (
|
677
|
-
LLAMA_LOG_ERROR("%s:
|
678
|
-
|
681
|
+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, lm_ggml_status & ret) {
|
682
|
+
if (mctx && !mctx->apply()) {
|
683
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
684
|
+
ret = LM_GGML_STATUS_FAILED;
|
685
|
+
return nullptr;
|
679
686
|
}
|
680
687
|
|
681
|
-
|
682
|
-
|
683
|
-
|
688
|
+
auto * gf = graph_init();
|
689
|
+
if (!gf) {
|
690
|
+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
691
|
+
ret = LM_GGML_STATUS_FAILED;
|
692
|
+
return nullptr;
|
693
|
+
}
|
694
|
+
|
695
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
696
|
+
if (!res) {
|
697
|
+
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
698
|
+
ret = LM_GGML_STATUS_FAILED;
|
699
|
+
return nullptr;
|
700
|
+
}
|
701
|
+
|
702
|
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
703
|
+
|
704
|
+
if (!lm_ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
705
|
+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
706
|
+
ret = LM_GGML_STATUS_ALLOC_FAILED;
|
707
|
+
return nullptr;
|
708
|
+
}
|
709
|
+
|
710
|
+
res->set_inputs(&ubatch);
|
711
|
+
|
712
|
+
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
713
|
+
if (status != LM_GGML_STATUS_SUCCESS) {
|
714
|
+
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
715
|
+
ret = status;
|
716
|
+
return nullptr;
|
717
|
+
}
|
718
|
+
|
719
|
+
ret = LM_GGML_STATUS_SUCCESS;
|
684
720
|
|
685
|
-
|
686
|
-
|
721
|
+
return res;
|
722
|
+
}
|
723
|
+
|
724
|
+
int llama_context::encode(const llama_batch & batch_inp) {
|
725
|
+
LM_GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
726
|
+
|
727
|
+
if (batch_inp.n_tokens == 0) {
|
728
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
729
|
+
return -1;
|
730
|
+
}
|
687
731
|
|
688
732
|
const auto & hparams = model.hparams;
|
689
733
|
|
690
|
-
|
734
|
+
const int64_t n_embd = hparams.n_embd;
|
691
735
|
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
return -1;
|
697
|
-
}
|
698
|
-
}
|
736
|
+
// note: during encode, we always pass the full sequence starting from pos = 0
|
737
|
+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
738
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
739
|
+
return -1;
|
699
740
|
}
|
700
741
|
|
742
|
+
const uint32_t n_tokens = balloc->get_n_tokens();
|
743
|
+
|
744
|
+
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
745
|
+
|
701
746
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
702
|
-
LM_GGML_ASSERT(cparams.n_ubatch >=
|
747
|
+
LM_GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
703
748
|
|
704
749
|
if (t_compute_start_us == 0) {
|
705
750
|
t_compute_start_us = lm_ggml_time_us();
|
706
751
|
}
|
707
752
|
|
753
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
708
754
|
embd_seq.clear();
|
709
755
|
|
710
756
|
n_queued_tokens += n_tokens;
|
711
757
|
|
712
|
-
const int64_t n_embd = hparams.n_embd;
|
713
|
-
|
714
|
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
715
|
-
|
716
|
-
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
717
|
-
|
718
758
|
// reserve output buffer
|
719
759
|
if (output_reserve(n_tokens) < n_tokens) {
|
720
760
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
721
761
|
return -2;
|
722
762
|
};
|
723
763
|
|
724
|
-
for (
|
764
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
725
765
|
output_ids[i] = i;
|
726
766
|
}
|
727
767
|
|
728
768
|
n_outputs = n_tokens;
|
729
769
|
|
730
|
-
//batch_manager->prepare(ubatch);
|
731
|
-
|
732
770
|
lm_ggml_backend_sched_reset(sched.get());
|
733
771
|
lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
734
772
|
|
@@ -739,26 +777,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
739
777
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
740
778
|
cparams.causal_attn = false;
|
741
779
|
|
742
|
-
|
743
|
-
auto res =
|
744
|
-
|
745
|
-
lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
|
746
|
-
|
747
|
-
res->set_inputs(&ubatch);
|
780
|
+
lm_ggml_status status;
|
781
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
748
782
|
|
749
783
|
cparams.causal_attn = causal_attn_org;
|
750
784
|
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
return -2;
|
759
|
-
case LM_GGML_STATUS_FAILED:
|
760
|
-
default:
|
761
|
-
return -3;
|
785
|
+
if (!res) {
|
786
|
+
switch (status) {
|
787
|
+
case LM_GGML_STATUS_ABORTED: return 2;
|
788
|
+
case LM_GGML_STATUS_ALLOC_FAILED: return -2;
|
789
|
+
case LM_GGML_STATUS_FAILED: return -3;
|
790
|
+
case LM_GGML_STATUS_SUCCESS: LM_GGML_ABORT("should not happen");
|
791
|
+
}
|
762
792
|
}
|
763
793
|
|
764
794
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
@@ -783,31 +813,28 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
783
813
|
{
|
784
814
|
// extract sequence embeddings
|
785
815
|
auto & embd_seq_out = embd_seq;
|
786
|
-
embd_seq_out.clear();
|
787
816
|
|
788
|
-
|
817
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
818
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
819
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
789
820
|
|
790
|
-
for (int32_t i = 0; i < n_tokens; i++) {
|
791
|
-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
792
|
-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
793
|
-
continue;
|
794
|
-
}
|
795
821
|
embd_seq_out[seq_id].resize(n_embd);
|
796
|
-
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
822
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
797
823
|
}
|
798
824
|
} break;
|
799
825
|
case LLAMA_POOLING_TYPE_RANK:
|
800
826
|
{
|
801
|
-
// extract the rerank score -
|
827
|
+
// extract the rerank score - n_cls_out floats per sequence
|
802
828
|
auto & embd_seq_out = embd_seq;
|
803
829
|
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
830
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
831
|
+
|
832
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
833
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
834
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
835
|
+
|
836
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
837
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
811
838
|
}
|
812
839
|
} break;
|
813
840
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
@@ -832,12 +859,16 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
832
859
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
833
860
|
memcpy(cross.v_embd.data(), embd, lm_ggml_nbytes(t_embd));
|
834
861
|
|
862
|
+
const auto & batch = balloc->get_batch();
|
863
|
+
|
835
864
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
836
865
|
cross.seq_ids_enc.resize(n_tokens);
|
837
|
-
for (
|
866
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
838
867
|
cross.seq_ids_enc[i].clear();
|
839
|
-
|
840
|
-
|
868
|
+
|
869
|
+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
870
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
871
|
+
|
841
872
|
cross.seq_ids_enc[i].insert(seq_id);
|
842
873
|
}
|
843
874
|
}
|
@@ -846,49 +877,42 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
846
877
|
return 0;
|
847
878
|
}
|
848
879
|
|
849
|
-
int llama_context::decode(llama_batch &
|
880
|
+
int llama_context::decode(const llama_batch & batch_inp) {
|
881
|
+
LM_GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
882
|
+
|
850
883
|
if (!memory) {
|
851
|
-
|
852
|
-
return encode(
|
884
|
+
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
885
|
+
return encode(batch_inp);
|
853
886
|
}
|
854
887
|
|
855
|
-
if (
|
888
|
+
if (batch_inp.n_tokens == 0) {
|
856
889
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
857
890
|
return -1;
|
858
891
|
}
|
859
892
|
|
860
|
-
if (!inp_batch.pos) {
|
861
|
-
if (inp_batch.seq_id) {
|
862
|
-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
863
|
-
return -1;
|
864
|
-
}
|
865
|
-
}
|
866
|
-
|
867
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
868
|
-
|
869
|
-
// temporary allocate memory for the input batch if needed
|
870
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
871
|
-
|
872
|
-
const llama_batch & batch = batch_allocr.batch;
|
873
|
-
|
874
893
|
const auto & vocab = model.vocab;
|
875
894
|
const auto & hparams = model.hparams;
|
876
895
|
|
877
896
|
const int32_t n_vocab = vocab.n_tokens();
|
897
|
+
const int64_t n_embd = hparams.n_embd;
|
878
898
|
|
879
|
-
|
880
|
-
const
|
899
|
+
// when computing embeddings, all tokens are output
|
900
|
+
const bool output_all = cparams.embeddings;
|
881
901
|
|
882
|
-
|
902
|
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
903
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
904
|
+
return -1;
|
905
|
+
}
|
883
906
|
|
884
|
-
|
907
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
908
|
+
const uint32_t n_outputs_all = balloc->get_n_outputs();
|
885
909
|
|
886
|
-
if (
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
910
|
+
if (output_all) {
|
911
|
+
// require that all tokens are output
|
912
|
+
if (n_outputs_all != n_tokens_all) {
|
913
|
+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
914
|
+
__func__, n_outputs_all, n_tokens_all);
|
915
|
+
return -1;
|
892
916
|
}
|
893
917
|
}
|
894
918
|
|
@@ -901,49 +925,77 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
901
925
|
}
|
902
926
|
n_queued_tokens += n_tokens_all;
|
903
927
|
|
904
|
-
// this
|
905
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
906
|
-
|
928
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
907
929
|
embd_seq.clear();
|
908
930
|
|
909
|
-
|
931
|
+
bool did_optimize = false;
|
932
|
+
|
933
|
+
// handle any pending defrags/shifts
|
934
|
+
kv_self_update(false);
|
935
|
+
|
936
|
+
llama_memory_context_ptr mctx;
|
910
937
|
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
938
|
+
while (true) {
|
939
|
+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
940
|
+
if (!mctx) {
|
941
|
+
return -2;
|
915
942
|
}
|
916
|
-
} else if (embd_pooled) {
|
917
|
-
n_outputs_all = n_tokens_all;
|
918
|
-
} else {
|
919
|
-
// keep last output only
|
920
|
-
n_outputs_all = 1;
|
921
|
-
}
|
922
943
|
|
923
|
-
|
944
|
+
switch (mctx->get_status()) {
|
945
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
946
|
+
{
|
947
|
+
} break;
|
948
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
949
|
+
{
|
950
|
+
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
951
|
+
|
952
|
+
return -2;
|
953
|
+
}
|
954
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
955
|
+
{
|
956
|
+
if (!did_optimize) {
|
957
|
+
did_optimize = true;
|
958
|
+
|
959
|
+
if (kv_self_update(true)) {
|
960
|
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
961
|
+
|
962
|
+
continue;
|
963
|
+
}
|
964
|
+
}
|
965
|
+
|
966
|
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
|
967
|
+
|
968
|
+
return 1;
|
969
|
+
}
|
970
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
971
|
+
{
|
972
|
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
|
973
|
+
|
974
|
+
return -2;
|
975
|
+
}
|
976
|
+
}
|
977
|
+
|
978
|
+
break;
|
979
|
+
}
|
924
980
|
|
925
981
|
// reserve output buffer
|
926
982
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
927
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
983
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
928
984
|
return -2;
|
929
985
|
};
|
930
986
|
|
931
|
-
// handle any pending defrags/shifts
|
932
|
-
kv_self_update();
|
933
|
-
|
934
987
|
int64_t n_outputs_prev = 0;
|
935
988
|
|
936
|
-
|
937
|
-
|
989
|
+
do {
|
990
|
+
const auto & ubatch = mctx->get_ubatch();
|
938
991
|
|
939
|
-
// count the outputs in this
|
992
|
+
// count the outputs in this ubatch
|
940
993
|
{
|
941
994
|
int32_t n_outputs_new = 0;
|
942
995
|
|
943
996
|
if (n_outputs_all == n_tokens_all) {
|
944
997
|
n_outputs_new = ubatch.n_tokens;
|
945
998
|
} else {
|
946
|
-
LM_GGML_ASSERT(ubatch.output);
|
947
999
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
948
1000
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
949
1001
|
}
|
@@ -953,33 +1005,40 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
953
1005
|
n_outputs = n_outputs_new;
|
954
1006
|
}
|
955
1007
|
|
956
|
-
// find KV slot
|
957
|
-
if (!kv_self->find_slot(ubatch)) {
|
958
|
-
return 1;
|
959
|
-
}
|
960
|
-
|
961
1008
|
lm_ggml_backend_sched_reset(sched.get());
|
962
1009
|
lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
963
1010
|
|
964
|
-
|
965
|
-
auto res =
|
1011
|
+
lm_ggml_status status;
|
1012
|
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
966
1013
|
|
967
|
-
|
1014
|
+
if (!res) {
|
1015
|
+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
1016
|
+
llama_pos pos_min[LLAMA_MAX_SEQ];
|
1017
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
1018
|
+
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
1019
|
+
}
|
968
1020
|
|
969
|
-
|
1021
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
1022
|
+
const auto & seq_id = ubatch.seq_id[i][0];
|
970
1023
|
|
971
|
-
|
1024
|
+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
1025
|
+
}
|
972
1026
|
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
1027
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
1028
|
+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
1029
|
+
continue;
|
1030
|
+
}
|
1031
|
+
|
1032
|
+
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
1033
|
+
|
1034
|
+
memory->seq_rm(s, pos_min[s], -1);
|
1035
|
+
}
|
1036
|
+
|
1037
|
+
switch (status) {
|
1038
|
+
case LM_GGML_STATUS_ABORTED: return 2;
|
1039
|
+
case LM_GGML_STATUS_ALLOC_FAILED: return -2;
|
1040
|
+
case LM_GGML_STATUS_FAILED: return -3;
|
1041
|
+
case LM_GGML_STATUS_SUCCESS: LM_GGML_ABORT("should not happen");
|
983
1042
|
}
|
984
1043
|
}
|
985
1044
|
|
@@ -988,7 +1047,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
988
1047
|
// lm_ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
989
1048
|
//}
|
990
1049
|
|
991
|
-
auto * t_logits =
|
1050
|
+
auto * t_logits = res->get_logits();
|
992
1051
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
993
1052
|
|
994
1053
|
if (t_embd && res->get_embd_pooled()) {
|
@@ -1035,27 +1094,27 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1035
1094
|
// extract sequence embeddings (cleared before processing each batch)
|
1036
1095
|
auto & embd_seq_out = embd_seq;
|
1037
1096
|
|
1038
|
-
for (uint32_t s = 0; s < ubatch.
|
1039
|
-
const llama_seq_id seq_id
|
1040
|
-
|
1041
|
-
|
1042
|
-
}
|
1097
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
1098
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
1099
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
1100
|
+
|
1043
1101
|
embd_seq_out[seq_id].resize(n_embd);
|
1044
|
-
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
1102
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
1045
1103
|
}
|
1046
1104
|
} break;
|
1047
1105
|
case LLAMA_POOLING_TYPE_RANK:
|
1048
1106
|
{
|
1049
|
-
// extract the rerank score -
|
1107
|
+
// extract the rerank score - n_cls_out floats per sequence
|
1050
1108
|
auto & embd_seq_out = embd_seq;
|
1051
1109
|
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1110
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
1111
|
+
|
1112
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
1113
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
1114
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
1115
|
+
|
1116
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
1117
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
1059
1118
|
}
|
1060
1119
|
} break;
|
1061
1120
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
@@ -1066,23 +1125,20 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1066
1125
|
}
|
1067
1126
|
|
1068
1127
|
n_outputs_prev += n_outputs;
|
1069
|
-
}
|
1070
|
-
|
1071
|
-
// finalize the batch processing
|
1072
|
-
kv_guard.commit();
|
1128
|
+
} while (mctx->next());
|
1073
1129
|
|
1074
1130
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
1075
1131
|
n_outputs = n_outputs_all;
|
1076
1132
|
|
1077
1133
|
// set output mappings
|
1078
|
-
{
|
1134
|
+
if (n_outputs > 0) {
|
1079
1135
|
bool sorted_output = true;
|
1080
1136
|
|
1081
|
-
auto & out_ids =
|
1137
|
+
auto & out_ids = balloc->get_out_ids();
|
1082
1138
|
|
1083
|
-
LM_GGML_ASSERT(out_ids.size() == (size_t)
|
1139
|
+
LM_GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
1084
1140
|
|
1085
|
-
for (int64_t i = 0; i <
|
1141
|
+
for (int64_t i = 0; i < n_outputs; ++i) {
|
1086
1142
|
int64_t out_id = out_ids[i];
|
1087
1143
|
output_ids[out_id] = i;
|
1088
1144
|
if (out_id != i) {
|
@@ -1094,20 +1150,22 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1094
1150
|
// note: this is mostly relevant for recurrent models atm
|
1095
1151
|
if (!sorted_output) {
|
1096
1152
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
1097
|
-
const
|
1153
|
+
const uint64_t n_embd = model.hparams.n_embd;
|
1098
1154
|
|
1099
1155
|
LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
1100
1156
|
|
1101
1157
|
// TODO: is there something more efficient which also minimizes swaps?
|
1102
1158
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
1103
|
-
for (
|
1104
|
-
|
1105
|
-
for (
|
1159
|
+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
1160
|
+
uint32_t j_min = i;
|
1161
|
+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
1106
1162
|
if (out_ids[j] < out_ids[j_min]) {
|
1107
1163
|
j_min = j;
|
1108
1164
|
}
|
1109
1165
|
}
|
1110
|
-
if (j_min == i) {
|
1166
|
+
if (j_min == i) {
|
1167
|
+
continue;
|
1168
|
+
}
|
1111
1169
|
std::swap(out_ids[i], out_ids[j_min]);
|
1112
1170
|
if (logits_size > 0) {
|
1113
1171
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
@@ -1120,8 +1178,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1120
1178
|
}
|
1121
1179
|
}
|
1122
1180
|
}
|
1181
|
+
|
1123
1182
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1124
|
-
|
1183
|
+
|
1184
|
+
for (uint32_t i = 0; i < n_outputs; ++i) {
|
1125
1185
|
output_ids[out_ids[i]] = i;
|
1126
1186
|
}
|
1127
1187
|
}
|
@@ -1130,11 +1190,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1130
1190
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
1131
1191
|
//synchronize();
|
1132
1192
|
|
1133
|
-
// decide if we need to defrag the kv cache
|
1134
|
-
if (cparams.defrag_thold > 0.0f) {
|
1135
|
-
kv_self->defrag_sched(cparams.defrag_thold);
|
1136
|
-
}
|
1137
|
-
|
1138
1193
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
1139
1194
|
// overlap with device computation.
|
1140
1195
|
lm_ggml_backend_sched_reset(sched.get());
|
@@ -1146,7 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1146
1201
|
// output
|
1147
1202
|
//
|
1148
1203
|
|
1149
|
-
|
1204
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
1150
1205
|
const auto & hparams = model.hparams;
|
1151
1206
|
const auto & vocab = model.vocab;
|
1152
1207
|
|
@@ -1156,9 +1211,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1156
1211
|
const auto n_vocab = vocab.n_tokens();
|
1157
1212
|
const auto n_embd = hparams.n_embd;
|
1158
1213
|
|
1159
|
-
|
1160
|
-
bool
|
1161
|
-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
1214
|
+
bool has_logits = true;
|
1215
|
+
bool has_embd = cparams.embeddings;
|
1162
1216
|
|
1163
1217
|
// TODO: hacky enc-dec support
|
1164
1218
|
if (model.arch == LLM_ARCH_T5) {
|
@@ -1212,8 +1266,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1212
1266
|
// set all ids as invalid (negative)
|
1213
1267
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1214
1268
|
|
1215
|
-
this->n_outputs
|
1216
|
-
this->n_outputs_max = n_outputs_max;
|
1269
|
+
this->n_outputs = 0;
|
1217
1270
|
|
1218
1271
|
return n_outputs_max;
|
1219
1272
|
}
|
@@ -1238,11 +1291,52 @@ lm_ggml_cgraph * llama_context::graph_init() {
|
|
1238
1291
|
return lm_ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
1239
1292
|
}
|
1240
1293
|
|
1294
|
+
lm_ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
1295
|
+
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
1296
|
+
|
1297
|
+
if (n_tokens % n_seqs != 0) {
|
1298
|
+
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
1299
|
+
n_outputs = std::min(n_outputs, n_tokens);
|
1300
|
+
|
1301
|
+
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
1302
|
+
}
|
1303
|
+
|
1304
|
+
// store the n_outputs as it is, and restore it afterwards
|
1305
|
+
// TODO: not sure if needed, might simplify in the future by removing this
|
1306
|
+
const auto save_n_outputs = this->n_outputs;
|
1307
|
+
|
1308
|
+
this->n_outputs = n_outputs;
|
1309
|
+
|
1310
|
+
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
1311
|
+
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
1312
|
+
|
1313
|
+
auto * gf = graph_init();
|
1314
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
1315
|
+
|
1316
|
+
this->n_outputs = save_n_outputs;
|
1317
|
+
|
1318
|
+
if (!res) {
|
1319
|
+
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
1320
|
+
return nullptr;
|
1321
|
+
}
|
1322
|
+
|
1323
|
+
lm_ggml_backend_sched_reset(sched.get());
|
1324
|
+
|
1325
|
+
// initialize scheduler with the specified graph
|
1326
|
+
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
1327
|
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
1328
|
+
return nullptr;
|
1329
|
+
}
|
1330
|
+
|
1331
|
+
return gf;
|
1332
|
+
}
|
1333
|
+
|
1241
1334
|
llm_graph_result_ptr llama_context::graph_build(
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1335
|
+
lm_ggml_context * ctx,
|
1336
|
+
lm_ggml_cgraph * gf,
|
1337
|
+
const llama_ubatch & ubatch,
|
1338
|
+
llm_graph_type gtype,
|
1339
|
+
const llama_memory_context_i * mctx) {
|
1246
1340
|
return model.build_graph(
|
1247
1341
|
{
|
1248
1342
|
/*.ctx =*/ ctx,
|
@@ -1254,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
1254
1348
|
/*.backend_cpu =*/ backend_cpu,
|
1255
1349
|
/*.cvec =*/ &cvec,
|
1256
1350
|
/*.loras =*/ &loras,
|
1257
|
-
/*.
|
1351
|
+
/*.mctx =*/ mctx,
|
1258
1352
|
/*.cross =*/ &cross,
|
1259
1353
|
/*.n_outputs =*/ n_outputs,
|
1260
1354
|
/*.cb =*/ graph_get_cb(),
|
@@ -1663,14 +1757,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
1663
1757
|
|
1664
1758
|
std::vector<int32_t> w_output_pos;
|
1665
1759
|
|
1666
|
-
LM_GGML_ASSERT(n_outputs <= n_outputs_max);
|
1667
|
-
|
1668
1760
|
w_output_pos.resize(n_outputs);
|
1669
1761
|
|
1670
1762
|
// build a more compact representation of the output ids
|
1671
1763
|
for (size_t i = 0; i < n_batch(); ++i) {
|
1672
1764
|
// map an output id to a position in the batch
|
1673
|
-
|
1765
|
+
int64_t pos = output_ids[i];
|
1674
1766
|
if (pos >= 0) {
|
1675
1767
|
LM_GGML_ASSERT(pos < n_outputs);
|
1676
1768
|
w_output_pos[pos] = i;
|
@@ -1710,11 +1802,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
1710
1802
|
}
|
1711
1803
|
}
|
1712
1804
|
|
1713
|
-
|
1714
|
-
|
1715
|
-
if (kv_self != nullptr) {
|
1805
|
+
if (memory != nullptr) {
|
1716
1806
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
1717
|
-
|
1807
|
+
memory->state_write(io);
|
1718
1808
|
}
|
1719
1809
|
|
1720
1810
|
return io.n_bytes();
|
@@ -1801,9 +1891,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
1801
1891
|
if (memory) {
|
1802
1892
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
1803
1893
|
|
1804
|
-
|
1805
|
-
|
1806
|
-
kv_self->state_read(io);
|
1894
|
+
memory->state_read(io);
|
1807
1895
|
}
|
1808
1896
|
|
1809
1897
|
return io.n_bytes();
|
@@ -1813,9 +1901,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
1813
1901
|
LM_GGML_UNUSED(seq_id);
|
1814
1902
|
|
1815
1903
|
if (memory) {
|
1816
|
-
|
1817
|
-
|
1818
|
-
kv_self->state_write(io, seq_id);
|
1904
|
+
memory->state_write(io, seq_id);
|
1819
1905
|
}
|
1820
1906
|
|
1821
1907
|
return io.n_bytes();
|
@@ -1825,9 +1911,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
|
1825
1911
|
LM_GGML_UNUSED(seq_id);
|
1826
1912
|
|
1827
1913
|
if (memory) {
|
1828
|
-
|
1829
|
-
|
1830
|
-
kv_self->state_read(io, seq_id);
|
1914
|
+
memory->state_read(io, seq_id);
|
1831
1915
|
}
|
1832
1916
|
|
1833
1917
|
return io.n_bytes();
|
@@ -1932,10 +2016,7 @@ void llama_context::opt_epoch_iter(
|
|
1932
2016
|
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
1933
2017
|
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
1934
2018
|
|
1935
|
-
|
1936
|
-
|
1937
|
-
kv_self->clear();
|
1938
|
-
llama_kv_cache_guard kv_guard(kv_self);
|
2019
|
+
memory->clear(true);
|
1939
2020
|
|
1940
2021
|
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
1941
2022
|
batch.n_tokens = n_batch;
|
@@ -1947,39 +2028,44 @@ void llama_context::opt_epoch_iter(
|
|
1947
2028
|
batch.logits [pos_batch] = true;
|
1948
2029
|
}
|
1949
2030
|
|
1950
|
-
|
2031
|
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
2032
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
2033
|
+
return;
|
2034
|
+
}
|
1951
2035
|
|
1952
|
-
|
2036
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
1953
2037
|
|
1954
|
-
|
1955
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
2038
|
+
n_queued_tokens += n_tokens_all;
|
1956
2039
|
|
1957
2040
|
embd_seq.clear();
|
1958
2041
|
|
1959
|
-
|
2042
|
+
uint32_t n_outputs_all = n_tokens_all;
|
1960
2043
|
|
1961
|
-
|
2044
|
+
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
2045
|
+
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
2046
|
+
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
2047
|
+
break;
|
2048
|
+
}
|
1962
2049
|
|
1963
2050
|
// reserve output buffer
|
1964
2051
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
1965
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
2052
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
1966
2053
|
LM_GGML_ABORT("TODO: handle this error");
|
1967
2054
|
};
|
1968
2055
|
|
1969
|
-
|
1970
|
-
|
2056
|
+
uint32_t pos_batch = 0;
|
2057
|
+
do {
|
2058
|
+
const auto & ubatch = mctx->get_ubatch();
|
1971
2059
|
|
1972
2060
|
n_outputs = ubatch.n_tokens;
|
1973
2061
|
|
1974
|
-
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
LM_GGML_ABORT("TODO: handle this error");
|
2062
|
+
if (!mctx->apply()) {
|
2063
|
+
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
2064
|
+
break;
|
1979
2065
|
}
|
1980
2066
|
|
1981
2067
|
auto * gf = graph_init();
|
1982
|
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
2068
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
1983
2069
|
|
1984
2070
|
struct lm_ggml_context * ctx_compute_opt;
|
1985
2071
|
{
|
@@ -1994,6 +2080,7 @@ void llama_context::opt_epoch_iter(
|
|
1994
2080
|
}
|
1995
2081
|
lm_ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
1996
2082
|
lm_ggml_opt_alloc(opt_ctx, train);
|
2083
|
+
|
1997
2084
|
res->set_inputs(&ubatch);
|
1998
2085
|
{
|
1999
2086
|
struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
|
@@ -2011,10 +2098,10 @@ void llama_context::opt_epoch_iter(
|
|
2011
2098
|
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
2012
2099
|
}
|
2013
2100
|
lm_ggml_free(ctx_compute_opt);
|
2014
|
-
}
|
2015
|
-
}
|
2016
2101
|
|
2017
|
-
|
2102
|
+
pos_batch += ubatch.n_tokens;
|
2103
|
+
} while (mctx->next());
|
2104
|
+
}
|
2018
2105
|
}
|
2019
2106
|
|
2020
2107
|
void llama_context::opt_epoch(
|
@@ -2174,12 +2261,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|
2174
2261
|
return &ctx->get_model();
|
2175
2262
|
}
|
2176
2263
|
|
2264
|
+
// deprecated
|
2177
2265
|
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
2178
|
-
return ctx->
|
2266
|
+
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
2179
2267
|
}
|
2180
2268
|
|
2269
|
+
// deprecated
|
2181
2270
|
void llama_kv_self_update(llama_context * ctx) {
|
2182
|
-
ctx->kv_self_update();
|
2271
|
+
ctx->kv_self_update(false);
|
2183
2272
|
}
|
2184
2273
|
|
2185
2274
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
@@ -2294,13 +2383,118 @@ int32_t llama_apply_adapter_cvec(
|
|
2294
2383
|
return res ? 0 : -1;
|
2295
2384
|
}
|
2296
2385
|
|
2386
|
+
//
|
2387
|
+
// memory
|
2388
|
+
//
|
2389
|
+
|
2390
|
+
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
2391
|
+
return ctx->get_memory();
|
2392
|
+
}
|
2393
|
+
|
2394
|
+
void llama_memory_clear(llama_memory_t mem, bool data) {
|
2395
|
+
if (!mem) {
|
2396
|
+
return;
|
2397
|
+
}
|
2398
|
+
|
2399
|
+
mem->clear(data);
|
2400
|
+
}
|
2401
|
+
|
2402
|
+
bool llama_memory_seq_rm(
|
2403
|
+
llama_memory_t mem,
|
2404
|
+
llama_seq_id seq_id,
|
2405
|
+
llama_pos p0,
|
2406
|
+
llama_pos p1) {
|
2407
|
+
if (!mem) {
|
2408
|
+
return true;
|
2409
|
+
}
|
2410
|
+
|
2411
|
+
return mem->seq_rm(seq_id, p0, p1);
|
2412
|
+
}
|
2413
|
+
|
2414
|
+
void llama_memory_seq_cp(
|
2415
|
+
llama_memory_t mem,
|
2416
|
+
llama_seq_id seq_id_src,
|
2417
|
+
llama_seq_id seq_id_dst,
|
2418
|
+
llama_pos p0,
|
2419
|
+
llama_pos p1) {
|
2420
|
+
if (!mem) {
|
2421
|
+
return;
|
2422
|
+
}
|
2423
|
+
|
2424
|
+
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
2425
|
+
}
|
2426
|
+
|
2427
|
+
void llama_memory_seq_keep(
|
2428
|
+
llama_memory_t mem,
|
2429
|
+
llama_seq_id seq_id) {
|
2430
|
+
if (!mem) {
|
2431
|
+
return;
|
2432
|
+
}
|
2433
|
+
|
2434
|
+
mem->seq_keep(seq_id);
|
2435
|
+
}
|
2436
|
+
|
2437
|
+
void llama_memory_seq_add(
|
2438
|
+
llama_memory_t mem,
|
2439
|
+
llama_seq_id seq_id,
|
2440
|
+
llama_pos p0,
|
2441
|
+
llama_pos p1,
|
2442
|
+
llama_pos delta) {
|
2443
|
+
if (!mem) {
|
2444
|
+
return;
|
2445
|
+
}
|
2446
|
+
|
2447
|
+
mem->seq_add(seq_id, p0, p1, delta);
|
2448
|
+
}
|
2449
|
+
|
2450
|
+
void llama_memory_seq_div(
|
2451
|
+
llama_memory_t mem,
|
2452
|
+
llama_seq_id seq_id,
|
2453
|
+
llama_pos p0,
|
2454
|
+
llama_pos p1,
|
2455
|
+
int d) {
|
2456
|
+
if (!mem) {
|
2457
|
+
return;
|
2458
|
+
}
|
2459
|
+
|
2460
|
+
mem->seq_div(seq_id, p0, p1, d);
|
2461
|
+
}
|
2462
|
+
|
2463
|
+
llama_pos llama_memory_seq_pos_min(
|
2464
|
+
llama_memory_t mem,
|
2465
|
+
llama_seq_id seq_id) {
|
2466
|
+
if (!mem) {
|
2467
|
+
return -1;
|
2468
|
+
}
|
2469
|
+
|
2470
|
+
return mem->seq_pos_min(seq_id);
|
2471
|
+
}
|
2472
|
+
|
2473
|
+
llama_pos llama_memory_seq_pos_max(
|
2474
|
+
llama_memory_t mem,
|
2475
|
+
llama_seq_id seq_id) {
|
2476
|
+
if (!mem) {
|
2477
|
+
return -1;
|
2478
|
+
}
|
2479
|
+
|
2480
|
+
return mem->seq_pos_max(seq_id);
|
2481
|
+
}
|
2482
|
+
|
2483
|
+
bool llama_memory_can_shift(llama_memory_t mem) {
|
2484
|
+
if (!mem) {
|
2485
|
+
return false;
|
2486
|
+
}
|
2487
|
+
|
2488
|
+
return mem->get_can_shift();
|
2489
|
+
}
|
2490
|
+
|
2297
2491
|
//
|
2298
2492
|
// kv cache
|
2299
2493
|
//
|
2300
2494
|
|
2301
2495
|
// deprecated
|
2302
2496
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
2303
|
-
const auto * kv = ctx
|
2497
|
+
const auto * kv = llama_get_memory(ctx);
|
2304
2498
|
if (!kv) {
|
2305
2499
|
return 0;
|
2306
2500
|
}
|
@@ -2322,7 +2516,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
2322
2516
|
// deprecated
|
2323
2517
|
// note: this is the same as above - will be removed anyway, so it's ok
|
2324
2518
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
2325
|
-
const auto * kv = ctx
|
2519
|
+
const auto * kv = llama_get_memory(ctx);
|
2326
2520
|
if (!kv) {
|
2327
2521
|
return 0;
|
2328
2522
|
}
|
@@ -2341,114 +2535,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
2341
2535
|
return res;
|
2342
2536
|
}
|
2343
2537
|
|
2538
|
+
// deprecated
|
2344
2539
|
void llama_kv_self_clear(llama_context * ctx) {
|
2345
|
-
auto * kv = ctx
|
2540
|
+
auto * kv = llama_get_memory(ctx);
|
2346
2541
|
if (!kv) {
|
2347
2542
|
return;
|
2348
2543
|
}
|
2349
2544
|
|
2350
|
-
kv
|
2545
|
+
llama_memory_clear(kv, true);
|
2351
2546
|
}
|
2352
2547
|
|
2548
|
+
// deprecated
|
2353
2549
|
bool llama_kv_self_seq_rm(
|
2354
2550
|
llama_context * ctx,
|
2355
2551
|
llama_seq_id seq_id,
|
2356
2552
|
llama_pos p0,
|
2357
2553
|
llama_pos p1) {
|
2358
|
-
auto * kv = ctx
|
2554
|
+
auto * kv = llama_get_memory(ctx);
|
2359
2555
|
if (!kv) {
|
2360
2556
|
return true;
|
2361
2557
|
}
|
2362
2558
|
|
2363
|
-
return kv
|
2559
|
+
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
2364
2560
|
}
|
2365
2561
|
|
2562
|
+
// deprecated
|
2366
2563
|
void llama_kv_self_seq_cp(
|
2367
2564
|
llama_context * ctx,
|
2368
2565
|
llama_seq_id seq_id_src,
|
2369
2566
|
llama_seq_id seq_id_dst,
|
2370
2567
|
llama_pos p0,
|
2371
2568
|
llama_pos p1) {
|
2372
|
-
auto * kv = ctx
|
2569
|
+
auto * kv = llama_get_memory(ctx);
|
2373
2570
|
if (!kv) {
|
2374
2571
|
return;
|
2375
2572
|
}
|
2376
2573
|
|
2377
|
-
kv
|
2574
|
+
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
2378
2575
|
}
|
2379
2576
|
|
2577
|
+
// deprecated
|
2380
2578
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
2381
|
-
auto * kv = ctx
|
2579
|
+
auto * kv = llama_get_memory(ctx);
|
2382
2580
|
if (!kv) {
|
2383
2581
|
return;
|
2384
2582
|
}
|
2385
2583
|
|
2386
|
-
kv
|
2584
|
+
llama_memory_seq_keep(kv, seq_id);
|
2387
2585
|
}
|
2388
2586
|
|
2587
|
+
// deprecated
|
2389
2588
|
void llama_kv_self_seq_add(
|
2390
2589
|
llama_context * ctx,
|
2391
2590
|
llama_seq_id seq_id,
|
2392
2591
|
llama_pos p0,
|
2393
2592
|
llama_pos p1,
|
2394
2593
|
llama_pos delta) {
|
2395
|
-
auto * kv = ctx
|
2594
|
+
auto * kv = llama_get_memory(ctx);
|
2396
2595
|
if (!kv) {
|
2397
2596
|
return;
|
2398
2597
|
}
|
2399
2598
|
|
2400
|
-
kv
|
2599
|
+
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
2401
2600
|
}
|
2402
2601
|
|
2602
|
+
// deprecated
|
2403
2603
|
void llama_kv_self_seq_div(
|
2404
2604
|
llama_context * ctx,
|
2405
2605
|
llama_seq_id seq_id,
|
2406
2606
|
llama_pos p0,
|
2407
2607
|
llama_pos p1,
|
2408
2608
|
int d) {
|
2409
|
-
auto * kv = ctx
|
2609
|
+
auto * kv = llama_get_memory(ctx);
|
2410
2610
|
if (!kv) {
|
2411
2611
|
return;
|
2412
2612
|
}
|
2413
2613
|
|
2414
|
-
kv
|
2614
|
+
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
2415
2615
|
}
|
2416
2616
|
|
2617
|
+
// deprecated
|
2417
2618
|
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
2418
|
-
|
2619
|
+
auto * kv = llama_get_memory(ctx);
|
2419
2620
|
if (!kv) {
|
2420
2621
|
return -1;
|
2421
2622
|
}
|
2422
2623
|
|
2423
|
-
return kv
|
2624
|
+
return llama_memory_seq_pos_min(kv, seq_id);
|
2424
2625
|
}
|
2425
2626
|
|
2627
|
+
// deprecated
|
2426
2628
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
2427
|
-
|
2629
|
+
auto * kv = llama_get_memory(ctx);
|
2428
2630
|
if (!kv) {
|
2429
2631
|
return -1;
|
2430
2632
|
}
|
2431
2633
|
|
2432
|
-
return kv
|
2634
|
+
return llama_memory_seq_pos_max(kv, seq_id);
|
2433
2635
|
}
|
2434
2636
|
|
2637
|
+
// deprecated
|
2435
2638
|
void llama_kv_self_defrag(llama_context * ctx) {
|
2436
|
-
auto * kv = ctx->get_kv_self();
|
2437
|
-
if (!kv) {
|
2438
|
-
return;
|
2439
|
-
}
|
2440
|
-
|
2441
2639
|
// force defrag
|
2442
|
-
|
2640
|
+
ctx->kv_self_defrag_sched();
|
2443
2641
|
}
|
2444
2642
|
|
2643
|
+
// deprecated
|
2445
2644
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
2446
|
-
|
2645
|
+
auto * kv = llama_get_memory(ctx);
|
2447
2646
|
if (!kv) {
|
2448
2647
|
return false;
|
2449
2648
|
}
|
2450
2649
|
|
2451
|
-
return kv
|
2650
|
+
return llama_memory_can_shift(kv);
|
2452
2651
|
}
|
2453
2652
|
|
2454
2653
|
// llama state API
|
@@ -2573,22 +2772,8 @@ int32_t llama_encode(
|
|
2573
2772
|
int32_t llama_decode(
|
2574
2773
|
llama_context * ctx,
|
2575
2774
|
llama_batch batch) {
|
2576
|
-
int ret = ctx->decode(batch);
|
2577
|
-
|
2578
|
-
// defrag and try again
|
2579
|
-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
2580
|
-
if (ret == 1) {
|
2581
|
-
llama_kv_self_defrag(ctx);
|
2582
|
-
ret = ctx->decode(batch);
|
2583
|
-
|
2584
|
-
if (ret == 1) {
|
2585
|
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
2586
|
-
|
2587
|
-
return ret;
|
2588
|
-
}
|
2589
|
-
}
|
2590
|
-
|
2591
|
-
if (ret != 0) {
|
2775
|
+
const int ret = ctx->decode(batch);
|
2776
|
+
if (ret != 0 && ret != 1) {
|
2592
2777
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
2593
2778
|
}
|
2594
2779
|
|