cui-llama.rn 1.4.6 → 1.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +20 -20
- package/README.md +317 -319
- package/android/build.gradle +116 -116
- package/android/gradle.properties +5 -5
- package/android/src/main/AndroidManifest.xml +4 -4
- package/android/src/main/CMakeLists.txt +124 -117
- package/android/src/main/java/com/rnllama/LlamaContext.java +645 -645
- package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
- package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
- package/android/src/main/jni-utils.h +100 -100
- package/android/src/main/jni.cpp +1263 -1245
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
- package/cpp/README.md +4 -4
- package/cpp/binary-ops.cpp +158 -0
- package/cpp/binary-ops.h +16 -0
- package/cpp/chat.cpp +1769 -1779
- package/cpp/chat.h +9 -1
- package/cpp/common.cpp +20 -522
- package/cpp/common.h +13 -36
- package/cpp/cpu-common.h +72 -0
- package/cpp/ggml-common.h +12 -6
- package/cpp/ggml-cpu-aarch64.cpp +1557 -80
- package/cpp/ggml-cpu-impl.h +2 -21
- package/cpp/ggml-cpu-quants.c +904 -405
- package/cpp/ggml-cpu.c +909 -13237
- package/cpp/ggml-impl.h +50 -23
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +597 -523
- package/cpp/ggml-metal.m +798 -580
- package/cpp/ggml.c +92 -3
- package/cpp/ggml.h +30 -6
- package/cpp/gguf.cpp +1 -0
- package/cpp/llama-adapter.cpp +55 -20
- package/cpp/llama-adapter.h +11 -9
- package/cpp/llama-arch.cpp +217 -16
- package/cpp/llama-arch.h +25 -0
- package/cpp/llama-batch.h +2 -2
- package/cpp/llama-chat.cpp +54 -2
- package/cpp/llama-chat.h +3 -0
- package/cpp/llama-context.cpp +2294 -1238
- package/cpp/llama-context.h +214 -77
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +1695 -0
- package/cpp/llama-graph.h +592 -0
- package/cpp/llama-hparams.cpp +8 -0
- package/cpp/llama-hparams.h +17 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache.cpp +965 -303
- package/cpp/llama-kv-cache.h +145 -151
- package/cpp/llama-memory.cpp +1 -0
- package/cpp/llama-memory.h +21 -0
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +10 -5
- package/cpp/llama-model-loader.h +5 -3
- package/cpp/llama-model.cpp +9194 -201
- package/cpp/llama-model.h +40 -1
- package/cpp/llama-sampling.cpp +5 -0
- package/cpp/llama-vocab.cpp +36 -5
- package/cpp/llama.cpp +51 -9984
- package/cpp/llama.h +102 -22
- package/cpp/log.cpp +34 -0
- package/cpp/minja/chat-template.hpp +15 -7
- package/cpp/minja/minja.hpp +120 -94
- package/cpp/ops.cpp +8723 -0
- package/cpp/ops.h +128 -0
- package/cpp/rn-llama.cpp +873 -882
- package/cpp/rn-llama.h +138 -148
- package/cpp/sampling.cpp +3 -0
- package/cpp/sampling.h +107 -107
- package/cpp/sgemm.cpp +533 -88
- package/cpp/simd-mappings.h +888 -0
- package/cpp/speculative.cpp +4 -4
- package/cpp/unary-ops.cpp +186 -0
- package/cpp/unary-ops.h +28 -0
- package/cpp/unicode-data.cpp +7034 -7034
- package/cpp/unicode-data.h +20 -20
- package/cpp/unicode.cpp +849 -849
- package/cpp/unicode.h +66 -66
- package/cpp/vec.cpp +258 -0
- package/cpp/vec.h +802 -0
- package/ios/CMakeLists.txt +116 -105
- package/ios/RNLlama.h +7 -7
- package/ios/RNLlama.mm +418 -405
- package/ios/RNLlamaContext.h +57 -57
- package/ios/RNLlamaContext.mm +835 -819
- package/ios/rnllama.xcframework/Info.plist +74 -74
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
- package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/chat-template.hpp +15 -7
- package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/minja.hpp +120 -94
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +203 -203
- package/lib/commonjs/NativeRNLlama.js +1 -2
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/chat.js.map +1 -1
- package/lib/commonjs/grammar.js +12 -31
- package/lib/commonjs/grammar.js.map +1 -1
- package/lib/commonjs/index.js +47 -47
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/package.json +1 -0
- package/lib/module/NativeRNLlama.js +2 -0
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/chat.js +2 -0
- package/lib/module/chat.js.map +1 -1
- package/lib/module/grammar.js +14 -31
- package/lib/module/grammar.js.map +1 -1
- package/lib/module/index.js +47 -45
- package/lib/module/index.js.map +1 -1
- package/lib/module/package.json +1 -0
- package/lib/typescript/NativeRNLlama.d.ts +6 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +48 -48
- package/package.json +233 -233
- package/src/NativeRNLlama.ts +426 -424
- package/src/chat.ts +44 -44
- package/src/grammar.ts +854 -854
- package/src/index.ts +495 -485
package/cpp/llama-context.cpp
CHANGED
@@ -1,1404 +1,1729 @@
|
|
1
1
|
#include "llama-context.h"
|
2
2
|
|
3
3
|
#include "llama-impl.h"
|
4
|
+
#include "llama-io.h"
|
4
5
|
#include "llama-mmap.h"
|
6
|
+
#include "llama-model.h"
|
7
|
+
#include "llama-kv-cache.h"
|
5
8
|
|
6
9
|
#include <cassert>
|
7
|
-
#include <cmath>
|
8
10
|
#include <cstring>
|
9
11
|
#include <stdexcept>
|
12
|
+
#include <cinttypes>
|
10
13
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
assert(lm_ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
|
15
|
-
|
16
|
-
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
|
14
|
+
//
|
15
|
+
// llama_context
|
16
|
+
//
|
17
17
|
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
18
|
+
llama_context::llama_context(
|
19
|
+
const llama_model & model,
|
20
|
+
llama_context_params params) :
|
21
|
+
model(model) {
|
22
|
+
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
22
23
|
|
23
|
-
|
24
|
-
|
24
|
+
t_start_us = model.t_start_us;
|
25
|
+
t_load_us = model.t_load_us;
|
25
26
|
|
26
|
-
|
27
|
+
const auto & hparams = model.hparams;
|
27
28
|
|
28
|
-
|
29
|
+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
30
|
+
cparams.n_threads = params.n_threads;
|
31
|
+
cparams.n_threads_batch = params.n_threads_batch;
|
32
|
+
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
33
|
+
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
34
|
+
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
35
|
+
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
36
|
+
cparams.defrag_thold = params.defrag_thold;
|
37
|
+
cparams.embeddings = params.embeddings;
|
38
|
+
cparams.offload_kqv = params.offload_kqv;
|
39
|
+
cparams.flash_attn = params.flash_attn;
|
40
|
+
cparams.no_perf = params.no_perf;
|
41
|
+
cparams.pooling_type = params.pooling_type;
|
42
|
+
cparams.warmup = false;
|
29
43
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
}
|
44
|
+
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
45
|
+
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
46
|
+
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
34
47
|
|
35
|
-
|
48
|
+
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
49
|
+
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
|
50
|
+
hparams.n_ctx_train;
|
36
51
|
|
37
|
-
|
38
|
-
|
39
|
-
const int64_t max_distance = 128;
|
52
|
+
cparams.cb_eval = params.cb_eval;
|
53
|
+
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
40
54
|
|
41
|
-
|
42
|
-
|
55
|
+
auto rope_scaling_type = params.rope_scaling_type;
|
56
|
+
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
57
|
+
rope_scaling_type = hparams.rope_scaling_type_train;
|
43
58
|
}
|
44
59
|
|
45
|
-
|
46
|
-
|
47
|
-
int32_t relative_position = x - y;
|
48
|
-
int32_t relative_bucket = 0;
|
49
|
-
if (bidirectional) {
|
50
|
-
relative_bucket += (relative_position > 0) * n_buckets;
|
51
|
-
relative_position = abs(relative_position);
|
52
|
-
} else {
|
53
|
-
relative_position = -std::min<int32_t>(relative_position, 0);
|
60
|
+
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
|
61
|
+
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
|
54
62
|
}
|
55
|
-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
56
|
-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
57
|
-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
58
|
-
return relative_bucket;
|
59
|
-
}
|
60
|
-
|
61
|
-
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
|
62
|
-
//
|
63
|
-
// set input data
|
64
|
-
//
|
65
|
-
|
66
|
-
const auto & hparams = lctx.model.hparams;
|
67
|
-
const auto & cparams = lctx.cparams;
|
68
|
-
const auto & kv_self = lctx.kv_self;
|
69
|
-
|
70
|
-
if (ubatch.token) {
|
71
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
72
63
|
|
73
|
-
|
64
|
+
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
|
65
|
+
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
74
66
|
}
|
75
67
|
|
76
|
-
|
77
|
-
const int64_t n_embd = hparams.n_embd;
|
78
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
68
|
+
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
79
69
|
|
80
|
-
|
70
|
+
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
71
|
+
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
72
|
+
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
73
|
+
} else {
|
74
|
+
cparams.pooling_type = hparams.pooling_type;
|
75
|
+
}
|
81
76
|
}
|
82
77
|
|
83
|
-
if (
|
84
|
-
|
85
|
-
|
86
|
-
|
78
|
+
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
79
|
+
cparams.causal_attn = hparams.causal_attn;
|
80
|
+
} else {
|
81
|
+
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
87
82
|
}
|
88
83
|
|
89
|
-
|
90
|
-
|
84
|
+
// with causal attention, the batch size is limited by the context size
|
85
|
+
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
91
86
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
87
|
+
// the batch has to be at least LM_GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
88
|
+
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. lm_ggml_flash_attn_ext)
|
89
|
+
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
90
|
+
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
|
91
|
+
if (cparams.n_batch < LM_GGML_KQ_MASK_PAD) {
|
92
|
+
LLAMA_LOG_WARN("%s: n_batch is less than LM_GGML_KQ_MASK_PAD - increasing to %d\n", __func__, LM_GGML_KQ_MASK_PAD);
|
93
|
+
cparams.n_batch = LM_GGML_KQ_MASK_PAD;
|
94
|
+
}
|
96
95
|
|
97
|
-
|
98
|
-
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
96
|
+
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
99
97
|
|
100
|
-
|
101
|
-
for (int i = 0; i < n_tokens; ++i) {
|
102
|
-
data[i] = i;
|
103
|
-
}
|
104
|
-
} else if (ubatch.output) {
|
105
|
-
int32_t n_outputs = 0;
|
106
|
-
for (int i = 0; i < n_tokens; ++i) {
|
107
|
-
if (ubatch.output[i]) {
|
108
|
-
data[n_outputs++] = i;
|
109
|
-
}
|
110
|
-
}
|
111
|
-
// the graph needs to have been passed the correct number of outputs
|
112
|
-
LM_GGML_ASSERT(lctx.n_outputs == n_outputs);
|
113
|
-
} else if (lctx.n_outputs == 1) {
|
114
|
-
// only keep last output
|
115
|
-
data[0] = n_tokens - 1;
|
116
|
-
} else {
|
117
|
-
LM_GGML_ASSERT(lctx.n_outputs == 0);
|
118
|
-
}
|
119
|
-
}
|
120
|
-
}
|
98
|
+
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
121
99
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
);
|
100
|
+
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
101
|
+
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
102
|
+
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
103
|
+
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
104
|
+
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
105
|
+
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
106
|
+
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
107
|
+
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
108
|
+
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
128
109
|
|
129
|
-
if (
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
134
|
-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
135
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
110
|
+
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
111
|
+
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
112
|
+
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
113
|
+
}
|
136
114
|
|
115
|
+
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
116
|
+
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
117
|
+
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
118
|
+
}
|
137
119
|
|
138
|
-
|
139
|
-
float * data_swa = nullptr;
|
120
|
+
logits_all = params.logits_all;
|
140
121
|
|
141
|
-
|
142
|
-
|
143
|
-
|
122
|
+
if (!hparams.vocab_only) {
|
123
|
+
// GPU backends
|
124
|
+
for (auto * dev : model.devices) {
|
125
|
+
lm_ggml_backend_t backend = lm_ggml_backend_dev_init(dev, nullptr);
|
126
|
+
if (backend == nullptr) {
|
127
|
+
throw std::runtime_error(format("failed to initialize %s backend", lm_ggml_backend_dev_name(dev)));
|
144
128
|
}
|
129
|
+
backends.emplace_back(backend);
|
130
|
+
}
|
145
131
|
|
146
|
-
|
147
|
-
|
148
|
-
|
132
|
+
// add ACCEL backends (such as BLAS)
|
133
|
+
for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) {
|
134
|
+
lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i);
|
135
|
+
if (lm_ggml_backend_dev_type(dev) == LM_GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
136
|
+
lm_ggml_backend_t backend = lm_ggml_backend_dev_init(dev, nullptr);
|
137
|
+
if (backend == nullptr) {
|
138
|
+
throw std::runtime_error(format("failed to initialize %s backend", lm_ggml_backend_dev_name(dev)));
|
139
|
+
}
|
140
|
+
backends.emplace_back(backend);
|
149
141
|
}
|
142
|
+
}
|
150
143
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
if (hparams.use_alibi) {
|
167
|
-
f = -std::abs(kv_self.cells[i].pos - pos);
|
168
|
-
} else {
|
169
|
-
f = 0.0f;
|
170
|
-
}
|
171
|
-
}
|
172
|
-
|
173
|
-
if (data) {
|
174
|
-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
175
|
-
}
|
176
|
-
|
177
|
-
// may need to cut off old tokens for sliding window
|
178
|
-
if (data_swa) {
|
179
|
-
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
180
|
-
f = -INFINITY;
|
181
|
-
}
|
182
|
-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
183
|
-
}
|
184
|
-
}
|
185
|
-
}
|
144
|
+
// add CPU backend
|
145
|
+
backend_cpu = lm_ggml_backend_init_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
146
|
+
if (backend_cpu == nullptr) {
|
147
|
+
throw std::runtime_error("failed to initialize CPU backend");
|
148
|
+
}
|
149
|
+
backends.emplace_back(backend_cpu);
|
150
|
+
|
151
|
+
// create a list of the set_n_threads functions in the backends
|
152
|
+
for (auto & backend : backends) {
|
153
|
+
lm_ggml_backend_dev_t dev = lm_ggml_backend_get_device(backend.get());
|
154
|
+
lm_ggml_backend_reg_t reg = dev ? lm_ggml_backend_dev_backend_reg(dev) : nullptr;
|
155
|
+
if (reg) {
|
156
|
+
auto lm_ggml_backend_set_n_threads_fn = (lm_ggml_backend_set_n_threads_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_set_n_threads");
|
157
|
+
if (lm_ggml_backend_set_n_threads_fn) {
|
158
|
+
set_n_threads_fns.emplace_back(backend.get(), lm_ggml_backend_set_n_threads_fn);
|
186
159
|
}
|
160
|
+
}
|
161
|
+
}
|
187
162
|
|
188
|
-
|
189
|
-
for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
|
190
|
-
for (int j = 0; j < n_kv; ++j) {
|
191
|
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
192
|
-
}
|
193
|
-
}
|
194
|
-
}
|
163
|
+
llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
|
195
164
|
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
}
|
202
|
-
}
|
165
|
+
// graph outputs buffer
|
166
|
+
{
|
167
|
+
// resized during inference when a batch uses more outputs
|
168
|
+
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
|
169
|
+
throw std::runtime_error("failed to reserve initial output buffer");
|
203
170
|
}
|
204
|
-
} else {
|
205
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
206
|
-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
207
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
208
|
-
// when using kv cache, the mask needs to match the kv cache size
|
209
|
-
const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
|
210
|
-
|
211
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
212
|
-
|
213
|
-
float * data = (float *) lctx.inp_KQ_mask->data;
|
214
|
-
|
215
|
-
for (int h = 0; h < 1; ++h) {
|
216
|
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
217
|
-
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
|
218
|
-
|
219
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
220
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
221
|
-
|
222
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
223
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
224
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
225
|
-
float f = -INFINITY;
|
226
|
-
|
227
|
-
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
|
228
|
-
if (ubatch.seq_id[s0][s] == seq_id) {
|
229
|
-
if (hparams.use_alibi) {
|
230
|
-
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
|
231
|
-
} else {
|
232
|
-
f = 0.0f;
|
233
|
-
}
|
234
|
-
break;
|
235
|
-
}
|
236
|
-
}
|
237
|
-
|
238
|
-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
239
|
-
}
|
240
|
-
}
|
241
171
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
}
|
246
|
-
}
|
247
|
-
}
|
172
|
+
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
|
173
|
+
lm_ggml_backend_buffer_name (buf_output.get()),
|
174
|
+
lm_ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
|
248
175
|
}
|
249
176
|
}
|
250
177
|
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
178
|
+
// init the memory module
|
179
|
+
// TODO: for now, always create a unified KV cache
|
180
|
+
if (!hparams.vocab_only) {
|
181
|
+
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
|
255
182
|
|
256
|
-
|
257
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
183
|
+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
258
184
|
|
259
|
-
|
260
|
-
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * lm_ggml_element_size(lctx.inp_mean));
|
185
|
+
cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
|
261
186
|
|
262
|
-
|
187
|
+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
263
188
|
|
264
|
-
|
265
|
-
|
189
|
+
uint32_t kv_size = cparams.n_ctx;
|
190
|
+
lm_ggml_type type_k = params.type_k;
|
191
|
+
lm_ggml_type type_v = params.type_v;
|
266
192
|
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
193
|
+
if (llama_model_is_recurrent(&model)) {
|
194
|
+
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
195
|
+
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
196
|
+
// it's probably best to keep as much precision as possible for the states
|
197
|
+
type_k = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_conv for Mamba's conv_states
|
198
|
+
type_v = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_scan for Mamba's ssm_states
|
271
199
|
}
|
272
200
|
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
}
|
201
|
+
LM_GGML_ASSERT(hparams.n_embd_head_k % lm_ggml_blck_size(type_k) == 0);
|
202
|
+
LM_GGML_ASSERT(hparams.n_embd_head_v % lm_ggml_blck_size(type_v) == 0);
|
203
|
+
|
204
|
+
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
205
|
+
throw std::runtime_error("failed to initialize self-attention cache");
|
279
206
|
}
|
280
207
|
|
281
|
-
|
282
|
-
const
|
208
|
+
{
|
209
|
+
const size_t memory_size_k = kv_self->size_k_bytes();
|
210
|
+
const size_t memory_size_v = kv_self->size_v_bytes();
|
283
211
|
|
284
|
-
|
285
|
-
|
286
|
-
|
212
|
+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
213
|
+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
214
|
+
lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
215
|
+
lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
287
216
|
}
|
288
217
|
}
|
289
218
|
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
294
|
-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
295
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
219
|
+
// init backends
|
220
|
+
if (!hparams.vocab_only) {
|
221
|
+
LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
|
296
222
|
|
297
|
-
|
298
|
-
|
223
|
+
backend_buft.clear();
|
224
|
+
backend_ptrs.clear();
|
299
225
|
|
300
|
-
|
301
|
-
|
226
|
+
for (auto & backend : backends) {
|
227
|
+
auto * buft = lm_ggml_backend_get_default_buffer_type(backend.get());
|
228
|
+
auto backend_type = lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backend.get()));
|
302
229
|
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
310
|
-
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
|
311
|
-
|
312
|
-
if (pos == 0) {
|
313
|
-
data[seq_id] = s*n_seq_tokens + i;
|
230
|
+
if (backend_type == LM_GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
|
231
|
+
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
232
|
+
auto * dev = model.devices[0];
|
233
|
+
auto * host_buft = lm_ggml_backend_dev_host_buffer_type(dev);
|
234
|
+
if (host_buft) {
|
235
|
+
buft = host_buft;
|
314
236
|
}
|
315
237
|
}
|
316
|
-
}
|
317
|
-
}
|
318
|
-
|
319
|
-
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
320
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
321
|
-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
322
|
-
const int64_t n_seqs = ubatch.n_seqs;
|
323
238
|
|
324
|
-
|
325
|
-
|
239
|
+
backend_buft.push_back(buft);
|
240
|
+
backend_ptrs.push_back(backend.get());
|
241
|
+
}
|
326
242
|
|
327
|
-
|
328
|
-
memset(lctx.inp_cls->data, 0, n_tokens * lm_ggml_element_size(lctx.inp_cls));
|
243
|
+
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
329
244
|
|
330
|
-
|
331
|
-
std::vector<int> last_row(n_tokens, -1);
|
245
|
+
const size_t max_nodes = this->graph_max_nodes();
|
332
246
|
|
333
|
-
|
334
|
-
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
247
|
+
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
335
248
|
|
336
|
-
|
337
|
-
|
249
|
+
// buffer used to store the computation graph and the tensor meta data
|
250
|
+
buf_compute_meta.resize(lm_ggml_tensor_overhead()*max_nodes + lm_ggml_graph_overhead_custom(max_nodes, false));
|
338
251
|
|
339
|
-
|
340
|
-
|
252
|
+
// TODO: move these checks to lm_ggml_backend_sched
|
253
|
+
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
254
|
+
bool pipeline_parallel =
|
255
|
+
model.n_devices() > 1 &&
|
256
|
+
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
|
257
|
+
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
258
|
+
cparams.offload_kqv &&
|
259
|
+
!model.has_tensor_overrides();
|
341
260
|
|
342
|
-
|
343
|
-
|
344
|
-
|
261
|
+
// pipeline parallelism requires support for async compute and events in all devices
|
262
|
+
if (pipeline_parallel) {
|
263
|
+
for (auto & backend : backends) {
|
264
|
+
auto dev_type = lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backend.get()));
|
265
|
+
if (dev_type == LM_GGML_BACKEND_DEVICE_TYPE_CPU) {
|
266
|
+
// ignore CPU backend
|
267
|
+
continue;
|
268
|
+
}
|
269
|
+
auto * dev = lm_ggml_backend_get_device(backend.get());
|
270
|
+
lm_ggml_backend_dev_props props;
|
271
|
+
lm_ggml_backend_dev_get_props(dev, &props);
|
272
|
+
if (!props.caps.async || !props.caps.events) {
|
273
|
+
// device does not support async compute or events
|
274
|
+
pipeline_parallel = false;
|
275
|
+
break;
|
345
276
|
}
|
346
277
|
}
|
347
278
|
}
|
348
279
|
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
280
|
+
sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
281
|
+
|
282
|
+
if (pipeline_parallel) {
|
283
|
+
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(sched.get()));
|
353
284
|
}
|
354
285
|
}
|
355
286
|
|
356
|
-
|
357
|
-
|
287
|
+
// reserve worst-case graph
|
288
|
+
if (!hparams.vocab_only) {
|
289
|
+
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
290
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
358
291
|
|
359
|
-
|
360
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
361
|
-
float * data = (float *) lctx.inp_s_mask->data;
|
292
|
+
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
362
293
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
294
|
+
// restore later
|
295
|
+
// TODO: something cleaner
|
296
|
+
const auto n_outputs_save = n_outputs;
|
367
297
|
|
368
|
-
|
369
|
-
|
370
|
-
// only clear once
|
371
|
-
if (kv_cell.src < 0) {
|
372
|
-
kv_cell.src = cell_id;
|
373
|
-
}
|
374
|
-
}
|
375
|
-
}
|
298
|
+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
376
299
|
|
377
|
-
|
378
|
-
|
379
|
-
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
300
|
+
int n_splits_pp = -1;
|
301
|
+
int n_nodes_pp = -1;
|
380
302
|
|
381
|
-
|
382
|
-
|
383
|
-
const uint32_t cell_id = i + kv_self.head;
|
384
|
-
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
303
|
+
int n_splits_tg = -1;
|
304
|
+
int n_nodes_tg = -1;
|
385
305
|
|
386
|
-
|
387
|
-
|
388
|
-
kv_cell.src = cell_id;
|
389
|
-
}
|
306
|
+
// simulate full KV cache
|
307
|
+
kv_self->n = kv_self->size;
|
390
308
|
|
391
|
-
|
309
|
+
cross.v_embd.clear();
|
392
310
|
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
}
|
397
|
-
}
|
398
|
-
}
|
399
|
-
}
|
311
|
+
// reserve pp graph first so that buffers are only allocated once
|
312
|
+
{
|
313
|
+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
400
314
|
|
401
|
-
|
402
|
-
|
315
|
+
// max number of outputs
|
316
|
+
n_outputs = ubatch_pp.n_tokens;
|
403
317
|
|
404
|
-
|
405
|
-
LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
|
318
|
+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
406
319
|
|
407
|
-
|
320
|
+
auto * gf = graph_init();
|
321
|
+
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
408
322
|
|
409
|
-
|
410
|
-
|
411
|
-
for (int h = 0; h < 1; ++h) {
|
412
|
-
for (int j = 0; j < n_tokens; ++j) {
|
413
|
-
for (int i = 0; i < n_kv; ++i) {
|
414
|
-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
|
415
|
-
}
|
416
|
-
}
|
417
|
-
}
|
418
|
-
} else {
|
419
|
-
for (int h = 0; h < 1; ++h) {
|
420
|
-
for (int j = 0; j < n_tokens; ++j) {
|
421
|
-
for (int i = 0; i < n_tokens; ++i) {
|
422
|
-
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
|
423
|
-
}
|
424
|
-
}
|
323
|
+
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
324
|
+
throw std::runtime_error("failed to allocate compute pp buffers");
|
425
325
|
}
|
426
|
-
}
|
427
|
-
}
|
428
326
|
|
429
|
-
|
430
|
-
|
431
|
-
|
327
|
+
n_splits_pp = lm_ggml_backend_sched_get_n_splits(sched.get());
|
328
|
+
n_nodes_pp = lm_ggml_graph_n_nodes(gf);
|
329
|
+
}
|
432
330
|
|
433
|
-
|
434
|
-
|
331
|
+
// reserve with tg graph to get the number of splits and nodes
|
332
|
+
{
|
333
|
+
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
435
334
|
|
436
|
-
|
437
|
-
const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
|
438
|
-
const int64_t n_tokens = ubatch.n_tokens;
|
335
|
+
n_outputs = ubatch_tg.n_tokens;
|
439
336
|
|
440
|
-
|
441
|
-
LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
|
337
|
+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
442
338
|
|
443
|
-
|
339
|
+
auto * gf = graph_init();
|
340
|
+
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
444
341
|
|
445
|
-
|
446
|
-
|
447
|
-
for (int i = 0; i < n_output_enc; ++i) {
|
448
|
-
float f = -INFINITY;
|
449
|
-
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
|
450
|
-
const llama_seq_id seq_id = ubatch.seq_id[j][s];
|
451
|
-
if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
|
452
|
-
f = 0.0f;
|
453
|
-
}
|
454
|
-
}
|
455
|
-
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
|
456
|
-
}
|
342
|
+
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
343
|
+
throw std::runtime_error("failed to allocate compute tg buffers");
|
457
344
|
}
|
458
345
|
|
459
|
-
|
460
|
-
|
461
|
-
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
|
462
|
-
}
|
463
|
-
}
|
346
|
+
n_splits_tg = lm_ggml_backend_sched_get_n_splits(sched.get());
|
347
|
+
n_nodes_tg = lm_ggml_graph_n_nodes(gf);
|
464
348
|
}
|
465
|
-
}
|
466
|
-
}
|
467
349
|
|
468
|
-
//
|
350
|
+
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
351
|
+
{
|
352
|
+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
469
353
|
|
470
|
-
|
471
|
-
const auto & cparams = lctx.cparams;
|
472
|
-
const auto & hparams = lctx.model.hparams;
|
473
|
-
const auto & vocab = lctx.model.vocab;
|
354
|
+
n_outputs = ubatch_pp.n_tokens;
|
474
355
|
|
475
|
-
|
356
|
+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
476
357
|
|
477
|
-
|
478
|
-
|
479
|
-
const auto n_embd = hparams.n_embd;
|
480
|
-
|
481
|
-
// TODO: use a per-batch flag for logits presence instead
|
482
|
-
const bool has_logits = !cparams.embeddings;
|
483
|
-
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
484
|
-
|
485
|
-
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
486
|
-
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
358
|
+
auto * gf = graph_init();
|
359
|
+
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
487
360
|
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
361
|
+
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
362
|
+
throw std::runtime_error("failed to allocate compute pp buffers");
|
363
|
+
}
|
364
|
+
}
|
492
365
|
|
493
|
-
|
494
|
-
const size_t new_size = (logits_size + embd_size) * sizeof(float);
|
366
|
+
n_outputs = n_outputs_save;
|
495
367
|
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
lctx.logits = nullptr;
|
506
|
-
lctx.embd = nullptr;
|
368
|
+
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
369
|
+
lm_ggml_backend_t backend = backend_ptrs[i];
|
370
|
+
lm_ggml_backend_buffer_type_t buft = backend_buft[i];
|
371
|
+
size_t size = lm_ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
372
|
+
if (size > 1) {
|
373
|
+
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
374
|
+
lm_ggml_backend_buft_name(buft),
|
375
|
+
size / 1024.0 / 1024.0);
|
376
|
+
}
|
507
377
|
}
|
508
378
|
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
if (output_dev_host_buft) {
|
514
|
-
buft = output_dev_host_buft;
|
379
|
+
if (n_nodes_pp == n_nodes_tg) {
|
380
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
381
|
+
} else {
|
382
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
515
383
|
}
|
516
|
-
|
517
|
-
if (
|
518
|
-
|
519
|
-
|
384
|
+
|
385
|
+
if (n_splits_pp == n_splits_tg) {
|
386
|
+
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
387
|
+
} else {
|
388
|
+
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
520
389
|
}
|
521
390
|
}
|
522
|
-
|
523
|
-
float * output_base = (float *) lm_ggml_backend_buffer_get_base(lctx.buf_output.get());
|
524
|
-
|
525
|
-
lctx.logits = has_logits ? output_base : nullptr;
|
526
|
-
lctx.embd = has_embd ? output_base + logits_size : nullptr;
|
527
|
-
|
528
|
-
lctx.output_size = n_outputs_max;
|
529
|
-
lctx.logits_size = logits_size;
|
530
|
-
lctx.embd_size = embd_size;
|
531
|
-
|
532
|
-
// set all ids as invalid (negative)
|
533
|
-
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
|
534
|
-
|
535
|
-
lm_ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
|
536
|
-
|
537
|
-
lctx.n_outputs = 0;
|
538
|
-
|
539
|
-
return n_outputs_max;
|
540
391
|
}
|
541
392
|
|
542
|
-
|
543
|
-
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
|
544
|
-
if (!out_ids.empty()) {
|
545
|
-
const uint32_t n_vocab = ctx.model.vocab.n_tokens();
|
546
|
-
const uint32_t n_embd = ctx.model.hparams.n_embd;
|
393
|
+
llama_context::~llama_context() = default;
|
547
394
|
|
548
|
-
|
549
|
-
|
395
|
+
void llama_context::synchronize() {
|
396
|
+
lm_ggml_backend_sched_synchronize(sched.get());
|
550
397
|
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
}
|
560
|
-
if (j_min == i) { continue; }
|
561
|
-
std::swap(out_ids[i], out_ids[j_min]);
|
562
|
-
if (ctx.logits_size > 0) {
|
563
|
-
for (uint32_t k = 0; k < n_vocab; k++) {
|
564
|
-
std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
|
565
|
-
}
|
566
|
-
}
|
567
|
-
if (ctx.embd_size > 0) {
|
568
|
-
for (uint32_t k = 0; k < n_embd; k++) {
|
569
|
-
std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
|
570
|
-
}
|
571
|
-
}
|
398
|
+
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
399
|
+
// the stats will be added to the prompt evaluation stats
|
400
|
+
// this should only happen when using batch size 1 to evaluate a batch
|
401
|
+
|
402
|
+
// add the evaluation to the stats
|
403
|
+
if (n_queued_tokens == 1) {
|
404
|
+
if (!cparams.no_perf) {
|
405
|
+
t_eval_us += lm_ggml_time_us() - t_compute_start_us;
|
572
406
|
}
|
573
|
-
|
574
|
-
|
575
|
-
|
407
|
+
n_eval++;
|
408
|
+
} else if (n_queued_tokens > 1) {
|
409
|
+
if (!cparams.no_perf) {
|
410
|
+
t_p_eval_us += lm_ggml_time_us() - t_compute_start_us;
|
576
411
|
}
|
577
|
-
|
412
|
+
n_p_eval += n_queued_tokens;
|
578
413
|
}
|
579
|
-
}
|
580
414
|
|
581
|
-
//
|
582
|
-
|
583
|
-
|
415
|
+
// get a more accurate load time, upon first eval
|
416
|
+
if (n_queued_tokens > 0 && !has_evaluated_once) {
|
417
|
+
t_load_us = lm_ggml_time_us() - t_start_us;
|
418
|
+
has_evaluated_once = true;
|
419
|
+
}
|
584
420
|
|
585
|
-
|
586
|
-
|
421
|
+
n_queued_tokens = 0;
|
422
|
+
t_compute_start_us = 0;
|
587
423
|
}
|
588
424
|
|
589
|
-
|
590
|
-
return
|
425
|
+
const llama_model & llama_context::get_model() const {
|
426
|
+
return model;
|
591
427
|
}
|
592
428
|
|
593
|
-
uint32_t
|
594
|
-
return
|
429
|
+
uint32_t llama_context::n_ctx() const {
|
430
|
+
return cparams.n_ctx;
|
595
431
|
}
|
596
432
|
|
597
|
-
uint32_t
|
598
|
-
return
|
433
|
+
uint32_t llama_context::n_ctx_per_seq() const {
|
434
|
+
return cparams.n_ctx / cparams.n_seq_max;
|
599
435
|
}
|
600
436
|
|
601
|
-
uint32_t
|
602
|
-
return
|
437
|
+
uint32_t llama_context::n_batch() const {
|
438
|
+
return cparams.n_batch;
|
603
439
|
}
|
604
440
|
|
605
|
-
|
606
|
-
return
|
441
|
+
uint32_t llama_context::n_ubatch() const {
|
442
|
+
return cparams.n_ubatch;
|
607
443
|
}
|
608
444
|
|
609
|
-
|
610
|
-
return
|
445
|
+
uint32_t llama_context::n_seq_max() const {
|
446
|
+
return cparams.n_seq_max;
|
611
447
|
}
|
612
448
|
|
613
|
-
|
614
|
-
|
615
|
-
lm_ggml_threadpool_t threadpool,
|
616
|
-
lm_ggml_threadpool_t threadpool_batch) {
|
617
|
-
ctx->threadpool = threadpool;
|
618
|
-
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
449
|
+
uint32_t llama_context::n_threads() const {
|
450
|
+
return cparams.n_threads;
|
619
451
|
}
|
620
452
|
|
621
|
-
|
622
|
-
|
623
|
-
ctx->threadpool_batch = nullptr;
|
453
|
+
uint32_t llama_context::n_threads_batch() const {
|
454
|
+
return cparams.n_threads_batch;
|
624
455
|
}
|
625
456
|
|
626
|
-
|
627
|
-
|
628
|
-
ctx->cparams.n_threads_batch = n_threads_batch;
|
457
|
+
llama_kv_cache * llama_context::get_kv_self() {
|
458
|
+
return kv_self.get();
|
629
459
|
}
|
630
460
|
|
631
|
-
|
632
|
-
return
|
461
|
+
const llama_kv_cache * llama_context::get_kv_self() const {
|
462
|
+
return kv_self.get();
|
633
463
|
}
|
634
464
|
|
635
|
-
|
636
|
-
|
637
|
-
|
465
|
+
lm_ggml_tensor * llama_context::build_rope_shift(
|
466
|
+
lm_ggml_context * ctx0,
|
467
|
+
lm_ggml_tensor * cur,
|
468
|
+
lm_ggml_tensor * shift,
|
469
|
+
lm_ggml_tensor * factors,
|
470
|
+
float freq_base,
|
471
|
+
float freq_scale,
|
472
|
+
lm_ggml_backend_buffer * bbuf) const {
|
473
|
+
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
474
|
+
|
475
|
+
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
476
|
+
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
477
|
+
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
478
|
+
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
479
|
+
|
480
|
+
const auto & hparams = model.hparams;
|
481
|
+
|
482
|
+
const auto & n_rot = hparams.n_rot;
|
483
|
+
const auto & rope_type = hparams.rope_type;
|
484
|
+
|
485
|
+
lm_ggml_tensor * tmp;
|
486
|
+
|
487
|
+
if (lm_ggml_is_quantized(cur->type)) {
|
488
|
+
// dequantize to f32 -> RoPE -> quantize back
|
489
|
+
tmp = lm_ggml_cast(ctx0, cur, LM_GGML_TYPE_F32);
|
490
|
+
|
491
|
+
if (bbuf) {
|
492
|
+
for (const auto & backend : backends) {
|
493
|
+
// Figure out which backend KV cache belongs to
|
494
|
+
if (lm_ggml_backend_supports_buft(backend.get(), lm_ggml_backend_buffer_get_type(bbuf))) {
|
495
|
+
lm_ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
|
496
|
+
break;
|
497
|
+
}
|
498
|
+
}
|
499
|
+
}
|
638
500
|
|
639
|
-
|
640
|
-
|
641
|
-
|
501
|
+
tmp = lm_ggml_rope_ext_inplace(ctx0, tmp,
|
502
|
+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
503
|
+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
642
504
|
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
505
|
+
tmp = lm_ggml_cpy(ctx0, tmp, cur);
|
506
|
+
} else {
|
507
|
+
// we rotate only the first n_rot dimensions
|
508
|
+
tmp = lm_ggml_rope_ext_inplace(ctx0, cur,
|
509
|
+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
510
|
+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
649
511
|
}
|
650
|
-
}
|
651
512
|
|
652
|
-
|
653
|
-
ctx->cparams.embeddings = embeddings;
|
513
|
+
return tmp;
|
654
514
|
}
|
655
515
|
|
656
|
-
|
657
|
-
|
658
|
-
}
|
516
|
+
class llm_graph_input_k_shift : public llm_graph_input_i {
|
517
|
+
public:
|
518
|
+
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
519
|
+
virtual ~llm_graph_input_k_shift() = default;
|
659
520
|
|
660
|
-
void
|
661
|
-
lm_ggml_backend_sched_synchronize(ctx->sched.get());
|
521
|
+
void set_input(const llama_ubatch * ubatch) override;
|
662
522
|
|
663
|
-
|
664
|
-
// the stats will be added to the prompt evaluation stats
|
665
|
-
// this should only happen when using batch size 1 to evaluate a batch
|
523
|
+
lm_ggml_tensor * k_shift; // I32 [kv_size]
|
666
524
|
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
525
|
+
const llama_kv_cache_unified * kv_self;
|
526
|
+
};
|
527
|
+
|
528
|
+
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
529
|
+
LM_GGML_UNUSED(ubatch);
|
530
|
+
|
531
|
+
if (k_shift) {
|
532
|
+
assert(lm_ggml_backend_buffer_is_host(k_shift->buffer));
|
533
|
+
|
534
|
+
int32_t * data = (int32_t *) k_shift->data;
|
535
|
+
|
536
|
+
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
537
|
+
data[i] = kv_self->cells[i].delta;
|
676
538
|
}
|
677
|
-
ctx->n_p_eval += ctx->n_queued_tokens;
|
678
539
|
}
|
540
|
+
}
|
679
541
|
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
}
|
542
|
+
llm_graph_result_ptr llama_context::build_kv_self_shift(
|
543
|
+
lm_ggml_context * ctx0,
|
544
|
+
lm_ggml_cgraph * gf) const {
|
545
|
+
auto res = std::make_unique<llm_graph_result>();
|
685
546
|
|
686
|
-
|
687
|
-
ctx->t_compute_start_us = 0;
|
688
|
-
}
|
547
|
+
const auto & hparams = model.hparams;
|
689
548
|
|
690
|
-
|
691
|
-
llama_synchronize(ctx);
|
549
|
+
const auto & n_layer = hparams.n_layer;
|
692
550
|
|
693
|
-
|
694
|
-
|
695
|
-
llama_output_reorder(*ctx);
|
551
|
+
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
552
|
+
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
696
553
|
|
697
|
-
|
698
|
-
}
|
554
|
+
//LM_GGML_ASSERT(kv_self->size == n_ctx);
|
699
555
|
|
700
|
-
|
701
|
-
int32_t j = -1;
|
556
|
+
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
|
702
557
|
|
703
|
-
|
558
|
+
inp->k_shift = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, cparams.n_ctx);
|
559
|
+
lm_ggml_set_input(inp->k_shift);
|
704
560
|
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
}
|
561
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
562
|
+
const int64_t n_head_kv = hparams.n_head_kv(il);
|
563
|
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
709
564
|
|
710
|
-
|
711
|
-
j = ctx->n_outputs + i;
|
712
|
-
if (j < 0) {
|
713
|
-
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
714
|
-
}
|
715
|
-
} else if ((size_t) i >= ctx->output_ids.size()) {
|
716
|
-
throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
|
717
|
-
} else {
|
718
|
-
j = ctx->output_ids[i];
|
719
|
-
}
|
565
|
+
const bool is_swa = hparams.is_swa(il);
|
720
566
|
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
// This should not happen
|
726
|
-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
727
|
-
}
|
567
|
+
// note: the swa rope params could become part of the cparams in the future
|
568
|
+
// if we decide to make them configurable, like the non-sliding ones
|
569
|
+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
570
|
+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
728
571
|
|
729
|
-
|
730
|
-
} catch (const std::exception & err) {
|
731
|
-
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
732
|
-
#ifndef NDEBUG
|
733
|
-
LM_GGML_ABORT("fatal error");
|
734
|
-
#else
|
735
|
-
return nullptr;
|
736
|
-
#endif
|
737
|
-
}
|
738
|
-
}
|
572
|
+
lm_ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
739
573
|
|
740
|
-
|
741
|
-
|
574
|
+
lm_ggml_tensor * k =
|
575
|
+
lm_ggml_view_3d(ctx0, kv_self->k_l[il],
|
576
|
+
n_embd_head_k, n_head_kv, kv_self->size,
|
577
|
+
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
578
|
+
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
579
|
+
0);
|
742
580
|
|
743
|
-
|
744
|
-
// TODO: maybe deprecate this
|
745
|
-
llama_output_reorder(*ctx);
|
581
|
+
lm_ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
746
582
|
|
747
|
-
|
583
|
+
lm_ggml_build_forward_expand(gf, cur);
|
584
|
+
}
|
585
|
+
|
586
|
+
res->add_input(std::move(inp));
|
587
|
+
|
588
|
+
return res;
|
748
589
|
}
|
749
590
|
|
750
|
-
|
751
|
-
|
591
|
+
llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
592
|
+
lm_ggml_context * ctx0,
|
593
|
+
lm_ggml_cgraph * gf) const {
|
594
|
+
auto res = std::make_unique<llm_graph_result>();
|
752
595
|
|
753
|
-
|
596
|
+
const auto & hparams = model.hparams;
|
754
597
|
|
755
|
-
|
756
|
-
if (ctx->embd == nullptr) {
|
757
|
-
throw std::runtime_error("no embeddings");
|
758
|
-
}
|
598
|
+
const auto & ids = kv_self->defrag_info.ids;
|
759
599
|
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
}
|
600
|
+
#if 0
|
601
|
+
// CPU defrag
|
602
|
+
//
|
603
|
+
// TODO: optimizations are possible:
|
604
|
+
// - multiple threads
|
605
|
+
// - avoid copying to the host memory when already there
|
606
|
+
//
|
607
|
+
// likely not worth the effort, as we have lm_ggml_graph based defrag
|
608
|
+
//
|
770
609
|
|
771
|
-
|
772
|
-
|
773
|
-
}
|
774
|
-
if (j >= ctx->n_outputs) {
|
775
|
-
// This should not happen
|
776
|
-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
777
|
-
}
|
610
|
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
611
|
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
778
612
|
|
779
|
-
|
780
|
-
} catch (const std::exception & err) {
|
781
|
-
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
782
|
-
#ifndef NDEBUG
|
783
|
-
LM_GGML_ABORT("fatal error");
|
784
|
-
#else
|
785
|
-
return nullptr;
|
786
|
-
#endif
|
787
|
-
}
|
788
|
-
}
|
613
|
+
const uint32_t kv_size = size;
|
789
614
|
|
790
|
-
|
791
|
-
|
615
|
+
std::vector<uint8_t> buf_k;
|
616
|
+
std::vector<uint8_t> buf_v;
|
792
617
|
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
}
|
618
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
619
|
+
const size_t k_size_row = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
620
|
+
const size_t k_size = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
797
621
|
|
798
|
-
|
799
|
-
|
622
|
+
const size_t v_size_el = lm_ggml_type_size(v_l[il]->type);
|
623
|
+
const size_t v_size = lm_ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
800
624
|
|
801
|
-
|
625
|
+
buf_k.resize(k_size);
|
626
|
+
buf_v.resize(v_size);
|
802
627
|
|
803
|
-
|
804
|
-
|
805
|
-
return llama_state_get_size(ctx);
|
806
|
-
}
|
628
|
+
lm_ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
629
|
+
lm_ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
807
630
|
|
808
|
-
//
|
809
|
-
|
810
|
-
|
811
|
-
|
631
|
+
// batch move [i, i+nm) to [id, id+nm)
|
632
|
+
// note: cells can move only to a lower index
|
633
|
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
634
|
+
const uint32_t id = ids[i];
|
812
635
|
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
}
|
636
|
+
if (i == id || id == n_kv) {
|
637
|
+
continue;
|
638
|
+
}
|
817
639
|
|
818
|
-
|
819
|
-
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
820
|
-
return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
821
|
-
}
|
640
|
+
uint32_t nm = 1;
|
822
641
|
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
}
|
642
|
+
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
643
|
+
nm++;
|
644
|
+
}
|
827
645
|
|
828
|
-
//
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
virtual size_t get_size_written() = 0;
|
833
|
-
virtual ~llama_data_write() = default;
|
646
|
+
// move keys
|
647
|
+
{
|
648
|
+
const int64_t os = i*k_size_row;
|
649
|
+
const int64_t od = id*k_size_row;
|
834
650
|
|
835
|
-
|
836
|
-
|
651
|
+
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
652
|
+
}
|
837
653
|
|
838
|
-
|
839
|
-
|
840
|
-
|
654
|
+
// move values (note: they are transposed)
|
655
|
+
{
|
656
|
+
const int64_t os = i;
|
657
|
+
const int64_t od = id;
|
841
658
|
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
659
|
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
660
|
+
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
661
|
+
}
|
662
|
+
}
|
663
|
+
|
664
|
+
i += nm - 1;
|
665
|
+
}
|
666
|
+
|
667
|
+
lm_ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
668
|
+
lm_ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
846
669
|
}
|
670
|
+
#else
|
671
|
+
for (uint32_t i = 0; i < ids.size(); ++i) {
|
672
|
+
const uint32_t id = ids[i];
|
847
673
|
|
848
|
-
|
849
|
-
|
850
|
-
|
674
|
+
if (i == id || id == ids.size()) {
|
675
|
+
continue;
|
676
|
+
}
|
851
677
|
|
852
|
-
|
678
|
+
uint32_t nm = 1;
|
853
679
|
|
854
|
-
|
855
|
-
|
680
|
+
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
681
|
+
nm++;
|
682
|
+
}
|
856
683
|
|
857
|
-
|
858
|
-
|
684
|
+
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
|
685
|
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
686
|
+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
687
|
+
|
688
|
+
lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
|
689
|
+
n_embd_k_gqa, nm,
|
690
|
+
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
691
|
+
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
|
692
|
+
|
693
|
+
lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
|
694
|
+
n_embd_k_gqa, nm,
|
695
|
+
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
696
|
+
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
|
697
|
+
|
698
|
+
lm_ggml_tensor * view_v_src;
|
699
|
+
lm_ggml_tensor * view_v_dst;
|
700
|
+
|
701
|
+
if (cparams.flash_attn) {
|
702
|
+
// NOTE: the V cache is not transposed when using flash attention
|
703
|
+
view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
704
|
+
n_embd_v_gqa, nm,
|
705
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
706
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
|
707
|
+
|
708
|
+
view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
709
|
+
n_embd_v_gqa, nm,
|
710
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
711
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
|
712
|
+
} else {
|
713
|
+
view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
714
|
+
nm, n_embd_v_gqa,
|
715
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
716
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, i));
|
717
|
+
|
718
|
+
view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
719
|
+
nm, n_embd_v_gqa,
|
720
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
721
|
+
lm_ggml_row_size(kv_self->v_l[il]->type, id));
|
722
|
+
}
|
859
723
|
|
860
|
-
|
724
|
+
lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_k_src, view_k_dst));
|
725
|
+
lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_v_src, view_v_dst));
|
726
|
+
}
|
861
727
|
|
862
|
-
|
728
|
+
i += nm - 1;
|
729
|
+
}
|
863
730
|
|
864
|
-
|
865
|
-
|
731
|
+
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
732
|
+
#endif
|
866
733
|
|
867
|
-
|
734
|
+
return res;
|
735
|
+
}
|
868
736
|
|
869
|
-
|
737
|
+
void llama_context::kv_self_update() {
|
738
|
+
auto & kv = kv_self;
|
870
739
|
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
LM_GGML_ASSERT((uint32_t) pos < n_outputs);
|
877
|
-
output_pos[pos] = i;
|
878
|
-
}
|
740
|
+
bool need_reserve = false;
|
741
|
+
|
742
|
+
if (kv->has_shift) {
|
743
|
+
if (!kv->get_can_shift()) {
|
744
|
+
LM_GGML_ABORT("The current context does not support K-shift");
|
879
745
|
}
|
880
746
|
|
881
|
-
|
747
|
+
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
882
748
|
|
883
|
-
if
|
884
|
-
|
885
|
-
|
886
|
-
}
|
749
|
+
// apply K-shift if needed
|
750
|
+
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
751
|
+
lm_ggml_backend_sched_reset(sched.get());
|
887
752
|
|
888
|
-
|
889
|
-
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
|
753
|
+
auto * gf = graph_init();
|
890
754
|
|
891
|
-
|
755
|
+
auto res = build_kv_self_shift(ctx_compute.get(), gf);
|
892
756
|
|
893
|
-
|
894
|
-
write(ctx->logits, logits_size * sizeof(float));
|
895
|
-
}
|
896
|
-
}
|
757
|
+
lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
|
897
758
|
|
898
|
-
|
899
|
-
const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
|
759
|
+
res->set_inputs(nullptr);
|
900
760
|
|
901
|
-
|
761
|
+
graph_compute(gf, false);
|
902
762
|
|
903
|
-
|
904
|
-
write(ctx->embd, embeddings_size * sizeof(float));
|
763
|
+
need_reserve = true;
|
905
764
|
}
|
906
|
-
}
|
907
|
-
|
908
|
-
void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
|
909
|
-
for (const auto & range : cell_ranges) {
|
910
|
-
for (uint32_t i = range.first; i < range.second; ++i) {
|
911
|
-
const auto & cell = kv_self.cells[i];
|
912
|
-
const llama_pos pos = cell.pos;
|
913
|
-
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
914
765
|
|
915
|
-
|
916
|
-
|
766
|
+
{
|
767
|
+
kv->has_shift = false;
|
917
768
|
|
918
|
-
|
919
|
-
|
920
|
-
write(&seq_id, sizeof(seq_id));
|
921
|
-
}
|
922
|
-
}
|
769
|
+
for (uint32_t i = 0; i < kv->size; ++i) {
|
770
|
+
kv->cells[i].delta = 0;
|
923
771
|
}
|
924
772
|
}
|
925
773
|
}
|
926
774
|
|
927
|
-
|
928
|
-
|
929
|
-
|
775
|
+
// defragment the KV cache if needed
|
776
|
+
if (kv->do_defrag) {
|
777
|
+
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
930
778
|
|
931
|
-
|
932
|
-
|
779
|
+
if (kv->defrag_prepare(graph_max_nodes())) {
|
780
|
+
lm_ggml_backend_sched_reset(sched.get());
|
933
781
|
|
934
|
-
|
935
|
-
write(&n_layer, sizeof(n_layer));
|
782
|
+
auto * gf = graph_init();
|
936
783
|
|
937
|
-
|
784
|
+
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
|
938
785
|
|
939
|
-
|
940
|
-
// Get whole range at a time
|
941
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
942
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
786
|
+
lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
|
943
787
|
|
944
|
-
|
945
|
-
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
946
|
-
write(&k_type_i, sizeof(k_type_i));
|
788
|
+
res->set_inputs(nullptr);
|
947
789
|
|
948
|
-
|
949
|
-
const uint64_t k_size_row = lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
950
|
-
write(&k_size_row, sizeof(k_size_row));
|
790
|
+
graph_compute(gf, false);
|
951
791
|
|
952
|
-
|
953
|
-
for (const auto & range : cell_ranges) {
|
954
|
-
const size_t range_size = range.second - range.first;
|
955
|
-
const size_t buf_size = range_size * k_size_row;
|
956
|
-
write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
|
957
|
-
}
|
792
|
+
need_reserve = true;
|
958
793
|
}
|
959
794
|
|
960
|
-
|
961
|
-
|
962
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
795
|
+
kv->do_defrag = false;
|
796
|
+
}
|
963
797
|
|
964
|
-
|
965
|
-
|
966
|
-
|
798
|
+
// reserve a worst case graph if needed
|
799
|
+
if (need_reserve) {
|
800
|
+
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
967
801
|
|
968
|
-
|
969
|
-
|
970
|
-
|
802
|
+
// build worst-case graph
|
803
|
+
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
804
|
+
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
971
805
|
|
972
|
-
|
973
|
-
|
974
|
-
const size_t range_size = range.second - range.first;
|
975
|
-
const size_t buf_size = range_size * v_size_row;
|
976
|
-
write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
|
977
|
-
}
|
978
|
-
}
|
979
|
-
} else {
|
980
|
-
// When v is transposed, we also need the element size and get the element ranges from each row
|
981
|
-
const uint32_t kv_size = kv_self.size;
|
982
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
983
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
806
|
+
// simulate full KV cache
|
807
|
+
kv_self->n = kv_self->size;
|
984
808
|
|
985
|
-
|
986
|
-
|
987
|
-
write(&v_type_i, sizeof(v_type_i));
|
809
|
+
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
810
|
+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
988
811
|
|
989
|
-
|
990
|
-
|
991
|
-
write(&v_size_el, sizeof(v_size_el));
|
812
|
+
auto * gf = graph_init();
|
813
|
+
graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
992
814
|
|
993
|
-
|
994
|
-
|
815
|
+
// initialize scheduler with the worst-case graph
|
816
|
+
lm_ggml_backend_sched_reset(sched.get());
|
817
|
+
if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
|
818
|
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
819
|
+
}
|
820
|
+
}
|
821
|
+
}
|
995
822
|
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
823
|
+
enum llama_pooling_type llama_context::pooling_type() const {
|
824
|
+
return cparams.pooling_type;
|
825
|
+
}
|
826
|
+
|
827
|
+
float * llama_context::get_logits() {
|
828
|
+
// reorder logits for backward compatibility
|
829
|
+
output_reorder();
|
830
|
+
|
831
|
+
return logits;
|
832
|
+
}
|
833
|
+
|
834
|
+
float * llama_context::get_logits_ith(int32_t i) {
|
835
|
+
int32_t j = -1;
|
836
|
+
|
837
|
+
try {
|
838
|
+
if (logits == nullptr) {
|
839
|
+
throw std::runtime_error("no logits");
|
840
|
+
}
|
841
|
+
|
842
|
+
if (i < 0) {
|
843
|
+
j = n_outputs + i;
|
844
|
+
if (j < 0) {
|
845
|
+
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
1006
846
|
}
|
847
|
+
} else if ((size_t) i >= output_ids.size()) {
|
848
|
+
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
849
|
+
} else {
|
850
|
+
j = output_ids[i];
|
851
|
+
}
|
852
|
+
|
853
|
+
if (j < 0) {
|
854
|
+
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
855
|
+
}
|
856
|
+
if (j >= n_outputs) {
|
857
|
+
// This should not happen
|
858
|
+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
|
1007
859
|
}
|
860
|
+
|
861
|
+
return logits + j*model.vocab.n_tokens();
|
862
|
+
} catch (const std::exception & err) {
|
863
|
+
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
864
|
+
#ifndef NDEBUG
|
865
|
+
LM_GGML_ABORT("fatal error");
|
866
|
+
#else
|
867
|
+
return nullptr;
|
868
|
+
#endif
|
1008
869
|
}
|
870
|
+
}
|
1009
871
|
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
uint32_t cell_count = 0;
|
872
|
+
float * llama_context::get_embeddings() {
|
873
|
+
// reorder embeddings for backward compatibility
|
874
|
+
output_reorder();
|
1014
875
|
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
}
|
1025
|
-
} else {
|
1026
|
-
if (cell_range_begin != kv_self.size) {
|
1027
|
-
cell_ranges.emplace_back(cell_range_begin, i);
|
1028
|
-
cell_range_begin = kv_self.size;
|
1029
|
-
}
|
1030
|
-
}
|
876
|
+
return embd;
|
877
|
+
}
|
878
|
+
|
879
|
+
float * llama_context::get_embeddings_ith(int32_t i) {
|
880
|
+
int32_t j = -1;
|
881
|
+
|
882
|
+
try {
|
883
|
+
if (embd == nullptr) {
|
884
|
+
throw std::runtime_error("no embeddings");
|
1031
885
|
}
|
1032
|
-
|
1033
|
-
|
886
|
+
|
887
|
+
if (i < 0) {
|
888
|
+
j = n_outputs + i;
|
889
|
+
if (j < 0) {
|
890
|
+
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
891
|
+
}
|
892
|
+
} else if ((size_t) i >= output_ids.size()) {
|
893
|
+
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
894
|
+
} else {
|
895
|
+
j = output_ids[i];
|
1034
896
|
}
|
1035
897
|
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
898
|
+
if (j < 0) {
|
899
|
+
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
900
|
+
}
|
901
|
+
if (j >= n_outputs) {
|
902
|
+
// This should not happen
|
903
|
+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
|
1040
904
|
}
|
1041
|
-
LM_GGML_ASSERT(cell_count == cell_count_check);
|
1042
905
|
|
1043
|
-
|
906
|
+
return embd + j*model.hparams.n_embd;
|
907
|
+
} catch (const std::exception & err) {
|
908
|
+
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
909
|
+
#ifndef NDEBUG
|
910
|
+
LM_GGML_ABORT("fatal error");
|
911
|
+
#else
|
912
|
+
return nullptr;
|
913
|
+
#endif
|
914
|
+
}
|
915
|
+
}
|
1044
916
|
|
1045
|
-
|
1046
|
-
|
917
|
+
float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
918
|
+
auto it = embd_seq.find(seq_id);
|
919
|
+
if (it == embd_seq.end()) {
|
920
|
+
return nullptr;
|
1047
921
|
}
|
1048
|
-
};
|
1049
922
|
|
1050
|
-
|
1051
|
-
|
1052
|
-
virtual void read_to(void * dst, size_t size) = 0;
|
1053
|
-
virtual size_t get_size_read() = 0;
|
1054
|
-
virtual ~llama_data_read() = default;
|
923
|
+
return it->second.data();
|
924
|
+
}
|
1055
925
|
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
926
|
+
void llama_context::attach_threadpool(
|
927
|
+
lm_ggml_threadpool_t threadpool,
|
928
|
+
lm_ggml_threadpool_t threadpool_batch) {
|
929
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
1059
930
|
|
1060
|
-
|
1061
|
-
|
931
|
+
this->threadpool = threadpool;
|
932
|
+
this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
933
|
+
}
|
1062
934
|
|
1063
|
-
|
1064
|
-
|
1065
|
-
const std::string cur_arch_str = llm_arch_name(ctx->model.arch);
|
935
|
+
void llama_context::detach_threadpool() {
|
936
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
1066
937
|
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
938
|
+
this->threadpool = nullptr;
|
939
|
+
this->threadpool_batch = nullptr;
|
940
|
+
}
|
941
|
+
|
942
|
+
void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
|
943
|
+
LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
|
944
|
+
|
945
|
+
cparams.n_threads = n_threads;
|
946
|
+
cparams.n_threads_batch = n_threads_batch;
|
947
|
+
}
|
948
|
+
|
949
|
+
void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
|
950
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
951
|
+
|
952
|
+
this->abort_callback = abort_callback;
|
953
|
+
this->abort_callback_data = abort_callback_data;
|
954
|
+
|
955
|
+
for (auto & backend : backends) {
|
956
|
+
auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_get_device(backend.get()));
|
957
|
+
auto * set_abort_callback_fn = (lm_ggml_backend_set_abort_callback_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_set_abort_callback");
|
958
|
+
if (set_abort_callback_fn) {
|
959
|
+
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
|
1071
960
|
}
|
1072
|
-
// TODO: add more info which needs to be identical but which is not verified otherwise
|
1073
961
|
}
|
962
|
+
}
|
963
|
+
|
964
|
+
void llama_context::set_embeddings(bool value) {
|
965
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
1074
966
|
|
1075
|
-
|
1076
|
-
|
1077
|
-
// read_string(rng_str);
|
967
|
+
cparams.embeddings = value;
|
968
|
+
}
|
1078
969
|
|
1079
|
-
|
1080
|
-
|
970
|
+
void llama_context::set_causal_attn(bool value) {
|
971
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
1081
972
|
|
1082
|
-
|
1083
|
-
|
1084
|
-
// }
|
1085
|
-
//}
|
973
|
+
cparams.causal_attn = value;
|
974
|
+
}
|
1086
975
|
|
1087
|
-
|
1088
|
-
|
976
|
+
void llama_context::set_warmup(bool value) {
|
977
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
1089
978
|
|
1090
|
-
|
1091
|
-
|
979
|
+
cparams.warmup = value;
|
980
|
+
}
|
1092
981
|
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
982
|
+
void llama_context::set_adapter_lora(
|
983
|
+
llama_adapter_lora * adapter,
|
984
|
+
float scale) {
|
985
|
+
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
|
1096
986
|
|
1097
|
-
|
1098
|
-
|
1099
|
-
read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
987
|
+
loras[adapter] = scale;
|
988
|
+
}
|
1100
989
|
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
|
1105
|
-
}
|
1106
|
-
ctx->output_ids[id] = i;
|
1107
|
-
}
|
990
|
+
bool llama_context::rm_adapter_lora(
|
991
|
+
llama_adapter_lora * adapter) {
|
992
|
+
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
|
1108
993
|
|
1109
|
-
|
1110
|
-
|
994
|
+
auto pos = loras.find(adapter);
|
995
|
+
if (pos != loras.end()) {
|
996
|
+
loras.erase(pos);
|
997
|
+
return true;
|
1111
998
|
}
|
1112
999
|
|
1113
|
-
|
1114
|
-
|
1115
|
-
read_to(&logits_size, sizeof(logits_size));
|
1000
|
+
return false;
|
1001
|
+
}
|
1116
1002
|
|
1117
|
-
|
1118
|
-
|
1119
|
-
}
|
1003
|
+
void llama_context::clear_adapter_lora() {
|
1004
|
+
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
1120
1005
|
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1006
|
+
loras.clear();
|
1007
|
+
}
|
1008
|
+
|
1009
|
+
bool llama_context::apply_adapter_cvec(
|
1010
|
+
const float * data,
|
1011
|
+
size_t len,
|
1012
|
+
int32_t n_embd,
|
1013
|
+
int32_t il_start,
|
1014
|
+
int32_t il_end) {
|
1015
|
+
LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
|
1016
|
+
|
1017
|
+
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
1018
|
+
}
|
1019
|
+
|
1020
|
+
int llama_context::encode(llama_batch & inp_batch) {
|
1021
|
+
if (inp_batch.n_tokens == 0) {
|
1022
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
1023
|
+
return -1;
|
1124
1024
|
}
|
1125
1025
|
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1026
|
+
// temporary allocate memory for the input batch if needed
|
1027
|
+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
1028
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
1129
1029
|
|
1130
|
-
|
1131
|
-
|
1132
|
-
}
|
1030
|
+
const llama_batch & batch = batch_allocr.batch;
|
1031
|
+
const int32_t n_tokens = batch.n_tokens;
|
1133
1032
|
|
1134
|
-
|
1135
|
-
|
1033
|
+
const auto & hparams = model.hparams;
|
1034
|
+
|
1035
|
+
LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
1036
|
+
|
1037
|
+
if (batch.token) {
|
1038
|
+
for (int32_t i = 0; i < n_tokens; ++i) {
|
1039
|
+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
1040
|
+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
1041
|
+
return -1;
|
1042
|
+
}
|
1136
1043
|
}
|
1137
1044
|
}
|
1138
1045
|
|
1139
|
-
|
1140
|
-
|
1046
|
+
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
1047
|
+
LM_GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
1048
|
+
|
1049
|
+
if (t_compute_start_us == 0) {
|
1050
|
+
t_compute_start_us = lm_ggml_time_us();
|
1051
|
+
}
|
1052
|
+
|
1053
|
+
n_queued_tokens += n_tokens;
|
1141
1054
|
|
1142
|
-
|
1143
|
-
// single sequence
|
1055
|
+
const int64_t n_embd = hparams.n_embd;
|
1144
1056
|
|
1145
|
-
|
1057
|
+
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
1146
1058
|
|
1147
|
-
|
1148
|
-
batch.n_tokens = cell_count;
|
1149
|
-
batch.n_seq_tokens = cell_count;
|
1150
|
-
batch.n_seqs = 1;
|
1059
|
+
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
1151
1060
|
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1061
|
+
// reserve output buffer
|
1062
|
+
if (output_reserve(n_tokens) < n_tokens) {
|
1063
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
1064
|
+
return -2;
|
1065
|
+
};
|
1155
1066
|
|
1156
|
-
|
1157
|
-
|
1067
|
+
for (int32_t i = 0; i < n_tokens; ++i) {
|
1068
|
+
output_ids[i] = i;
|
1069
|
+
}
|
1158
1070
|
|
1159
|
-
|
1160
|
-
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
1161
|
-
return false;
|
1162
|
-
}
|
1071
|
+
n_outputs = n_tokens;
|
1163
1072
|
|
1164
|
-
|
1165
|
-
}
|
1166
|
-
batch.n_seq_id[0] = 1;
|
1167
|
-
batch.seq_id[0] = &dest_seq_id;
|
1168
|
-
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
1169
|
-
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
1170
|
-
return false;
|
1171
|
-
}
|
1073
|
+
//batch_manager->prepare(ubatch);
|
1172
1074
|
|
1173
|
-
|
1174
|
-
|
1175
|
-
LM_GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
|
1176
|
-
LM_GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
|
1177
|
-
LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
1178
|
-
LM_GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
1179
|
-
LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
1180
|
-
} else {
|
1181
|
-
// whole KV cache restore
|
1075
|
+
lm_ggml_backend_sched_reset(sched.get());
|
1076
|
+
lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
1182
1077
|
|
1183
|
-
|
1184
|
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
1185
|
-
return false;
|
1186
|
-
}
|
1078
|
+
const auto causal_attn_org = cparams.causal_attn;
|
1187
1079
|
|
1188
|
-
|
1080
|
+
// always use non-causal attention for encoder graphs
|
1081
|
+
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
1082
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
1083
|
+
cparams.causal_attn = false;
|
1189
1084
|
|
1190
|
-
|
1191
|
-
|
1085
|
+
auto * gf = graph_init();
|
1086
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
1192
1087
|
|
1193
|
-
|
1194
|
-
uint32_t n_seq_id;
|
1088
|
+
lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
|
1195
1089
|
|
1196
|
-
|
1197
|
-
read_to(&n_seq_id, sizeof(n_seq_id));
|
1090
|
+
res->set_inputs(&ubatch);
|
1198
1091
|
|
1199
|
-
|
1092
|
+
cparams.causal_attn = causal_attn_org;
|
1200
1093
|
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1094
|
+
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
1095
|
+
switch (compute_status) {
|
1096
|
+
case LM_GGML_STATUS_SUCCESS:
|
1097
|
+
break;
|
1098
|
+
case LM_GGML_STATUS_ABORTED:
|
1099
|
+
return 2;
|
1100
|
+
case LM_GGML_STATUS_ALLOC_FAILED:
|
1101
|
+
return -2;
|
1102
|
+
case LM_GGML_STATUS_FAILED:
|
1103
|
+
default:
|
1104
|
+
return -3;
|
1105
|
+
}
|
1204
1106
|
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1107
|
+
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
1108
|
+
|
1109
|
+
// extract embeddings
|
1110
|
+
if (t_embd) {
|
1111
|
+
lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
1112
|
+
LM_GGML_ASSERT(backend_embd != nullptr);
|
1113
|
+
|
1114
|
+
LM_GGML_ASSERT(embd != nullptr);
|
1209
1115
|
|
1210
|
-
|
1116
|
+
switch (cparams.pooling_type) {
|
1117
|
+
case LLAMA_POOLING_TYPE_NONE:
|
1118
|
+
{
|
1119
|
+
// extract token embeddings
|
1120
|
+
LM_GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
1121
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
|
1122
|
+
} break;
|
1123
|
+
case LLAMA_POOLING_TYPE_MEAN:
|
1124
|
+
case LLAMA_POOLING_TYPE_CLS:
|
1125
|
+
case LLAMA_POOLING_TYPE_LAST:
|
1126
|
+
{
|
1127
|
+
// extract sequence embeddings
|
1128
|
+
auto & embd_seq_out = embd_seq;
|
1129
|
+
embd_seq_out.clear();
|
1211
1130
|
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1131
|
+
LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
1132
|
+
|
1133
|
+
for (int32_t i = 0; i < n_tokens; i++) {
|
1134
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
1135
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
1136
|
+
continue;
|
1217
1137
|
}
|
1218
|
-
|
1138
|
+
embd_seq_out[seq_id].resize(n_embd);
|
1139
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
1219
1140
|
}
|
1141
|
+
} break;
|
1142
|
+
case LLAMA_POOLING_TYPE_RANK:
|
1143
|
+
{
|
1144
|
+
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
1145
|
+
// wait for an encoder model that requires this pooling type in order to test it
|
1146
|
+
// https://github.com/ggerganov/llama.cpp/pull/9510
|
1147
|
+
LM_GGML_ABORT("RANK pooling not implemented yet");
|
1148
|
+
}
|
1149
|
+
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
1150
|
+
{
|
1151
|
+
LM_GGML_ABORT("unknown pooling type");
|
1220
1152
|
}
|
1221
|
-
}
|
1222
|
-
|
1223
|
-
kv_self.head = 0;
|
1224
|
-
kv_self.used = cell_count;
|
1225
1153
|
}
|
1154
|
+
}
|
1155
|
+
|
1156
|
+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
1157
|
+
// overlap with device computation.
|
1158
|
+
lm_ggml_backend_sched_reset(sched.get());
|
1159
|
+
|
1160
|
+
// TODO: hacky solution
|
1161
|
+
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
1162
|
+
//cross.t_embd = t_embd;
|
1163
|
+
|
1164
|
+
synchronize();
|
1226
1165
|
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1166
|
+
cross.n_embd = t_embd->ne[0];
|
1167
|
+
cross.n_enc = t_embd->ne[1];
|
1168
|
+
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
1169
|
+
memcpy(cross.v_embd.data(), embd, lm_ggml_nbytes(t_embd));
|
1170
|
+
|
1171
|
+
// remember the sequence ids used during the encoding - needed for cross attention later
|
1172
|
+
cross.seq_ids_enc.resize(n_tokens);
|
1173
|
+
for (int32_t i = 0; i < n_tokens; i++) {
|
1174
|
+
cross.seq_ids_enc[i].clear();
|
1175
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
1176
|
+
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
1177
|
+
cross.seq_ids_enc[i].insert(seq_id);
|
1232
1178
|
}
|
1233
1179
|
}
|
1180
|
+
}
|
1234
1181
|
|
1235
|
-
|
1182
|
+
return 0;
|
1183
|
+
}
|
1184
|
+
|
1185
|
+
int llama_context::decode(llama_batch & inp_batch) {
|
1186
|
+
if (inp_batch.n_tokens == 0) {
|
1187
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
1188
|
+
return -1;
|
1236
1189
|
}
|
1237
1190
|
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
uint32_t v_trans;
|
1242
|
-
uint32_t n_layer;
|
1243
|
-
read_to(&v_trans, sizeof(v_trans));
|
1244
|
-
read_to(&n_layer, sizeof(n_layer));
|
1191
|
+
// temporary allocate memory for the input batch if needed
|
1192
|
+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
1193
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
1245
1194
|
|
1246
|
-
|
1247
|
-
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
1248
|
-
return false;
|
1249
|
-
}
|
1250
|
-
if (cell_count > kv_self.size) {
|
1251
|
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
|
1252
|
-
return false;
|
1253
|
-
}
|
1254
|
-
if (kv_self.v_trans != (bool) v_trans) {
|
1255
|
-
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
1256
|
-
return false;
|
1257
|
-
}
|
1195
|
+
const llama_batch & batch = batch_allocr.batch;
|
1258
1196
|
|
1259
|
-
|
1260
|
-
|
1261
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
1197
|
+
const auto & vocab = model.vocab;
|
1198
|
+
const auto & hparams = model.hparams;
|
1262
1199
|
|
1263
|
-
|
1264
|
-
int32_t k_type_i_ref;
|
1265
|
-
read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
1266
|
-
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
1267
|
-
if (k_type_i != k_type_i_ref) {
|
1268
|
-
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
1269
|
-
return false;
|
1270
|
-
}
|
1200
|
+
const int32_t n_vocab = vocab.n_tokens();
|
1271
1201
|
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
if (k_size_row != k_size_row_ref) {
|
1277
|
-
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
1278
|
-
return false;
|
1279
|
-
}
|
1202
|
+
const int64_t n_tokens_all = batch.n_tokens;
|
1203
|
+
const int64_t n_embd = hparams.n_embd;
|
1204
|
+
|
1205
|
+
llama_kv_cache_guard kv_guard(kv_self.get());
|
1280
1206
|
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1207
|
+
LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
1208
|
+
|
1209
|
+
if (batch.token) {
|
1210
|
+
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
1211
|
+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
1212
|
+
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
1213
|
+
throw std::runtime_error("invalid token");
|
1284
1214
|
}
|
1285
1215
|
}
|
1216
|
+
}
|
1286
1217
|
|
1287
|
-
|
1288
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
1289
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
1218
|
+
LM_GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
1290
1219
|
|
1291
|
-
|
1292
|
-
int32_t v_type_i_ref;
|
1293
|
-
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
1294
|
-
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
1295
|
-
if (v_type_i != v_type_i_ref) {
|
1296
|
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
1297
|
-
return false;
|
1298
|
-
}
|
1220
|
+
LM_GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
1299
1221
|
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
if (v_size_row != v_size_row_ref) {
|
1305
|
-
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
1306
|
-
return false;
|
1307
|
-
}
|
1222
|
+
if (t_compute_start_us == 0) {
|
1223
|
+
t_compute_start_us = lm_ggml_time_us();
|
1224
|
+
}
|
1225
|
+
n_queued_tokens += n_tokens_all;
|
1308
1226
|
|
1309
|
-
|
1310
|
-
|
1311
|
-
lm_ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
|
1312
|
-
}
|
1313
|
-
}
|
1314
|
-
} else {
|
1315
|
-
// For each layer, read the values for each cell (transposed)
|
1316
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
1317
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
1318
|
-
|
1319
|
-
// Read type of value
|
1320
|
-
int32_t v_type_i_ref;
|
1321
|
-
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
1322
|
-
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
1323
|
-
if (v_type_i != v_type_i_ref) {
|
1324
|
-
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
1325
|
-
return false;
|
1326
|
-
}
|
1227
|
+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
1228
|
+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
1327
1229
|
|
1328
|
-
|
1329
|
-
uint32_t v_size_el_ref;
|
1330
|
-
read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
1331
|
-
const size_t v_size_el = lm_ggml_type_size(kv_self.v_l[il]->type);
|
1332
|
-
if (v_size_el != v_size_el_ref) {
|
1333
|
-
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
1334
|
-
return false;
|
1335
|
-
}
|
1230
|
+
embd_seq.clear();
|
1336
1231
|
|
1337
|
-
|
1338
|
-
uint32_t n_embd_v_gqa_ref;
|
1339
|
-
read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
1340
|
-
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
1341
|
-
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
1342
|
-
return false;
|
1343
|
-
}
|
1232
|
+
int64_t n_outputs_all = 0;
|
1344
1233
|
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
lm_ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
1350
|
-
}
|
1351
|
-
}
|
1352
|
-
}
|
1234
|
+
// count outputs
|
1235
|
+
if (batch.logits && !embd_pooled) {
|
1236
|
+
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
1237
|
+
n_outputs_all += batch.logits[i] != 0;
|
1353
1238
|
}
|
1354
|
-
|
1239
|
+
} else if (logits_all || embd_pooled) {
|
1240
|
+
n_outputs_all = n_tokens_all;
|
1241
|
+
} else {
|
1242
|
+
// keep last output only
|
1243
|
+
n_outputs_all = 1;
|
1355
1244
|
}
|
1356
1245
|
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1246
|
+
const bool logits_all = n_outputs_all == n_tokens_all;
|
1247
|
+
|
1248
|
+
sbatch.from_batch(batch, n_embd,
|
1249
|
+
/* simple_split */ !kv_self->recurrent,
|
1250
|
+
/* logits_all */ logits_all);
|
1251
|
+
|
1252
|
+
// reserve output buffer
|
1253
|
+
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
1254
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
1255
|
+
return -2;
|
1256
|
+
};
|
1360
1257
|
|
1361
|
-
|
1258
|
+
// handle any pending defrags/shifts
|
1259
|
+
kv_self_update();
|
1362
1260
|
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1261
|
+
int64_t n_outputs_prev = 0;
|
1262
|
+
|
1263
|
+
while (sbatch.n_tokens > 0) {
|
1264
|
+
llama_ubatch ubatch = llama_ubatch();
|
1265
|
+
|
1266
|
+
const auto & n_ubatch = cparams.n_ubatch;
|
1267
|
+
|
1268
|
+
if (kv_self->recurrent) {
|
1269
|
+
if (embd_pooled) {
|
1270
|
+
// Pooled embeddings cannot be split across ubatches (yet)
|
1271
|
+
ubatch = sbatch.split_seq(cparams.n_ubatch);
|
1366
1272
|
} else {
|
1367
|
-
|
1273
|
+
// recurrent model architectures are easier to implement
|
1274
|
+
// with equal-length sequences
|
1275
|
+
ubatch = sbatch.split_equal(cparams.n_ubatch);
|
1368
1276
|
}
|
1369
|
-
|
1277
|
+
} else {
|
1278
|
+
ubatch = sbatch.split_simple(n_ubatch);
|
1370
1279
|
}
|
1371
|
-
}
|
1372
|
-
};
|
1373
1280
|
|
1374
|
-
|
1375
|
-
|
1281
|
+
// count the outputs in this u_batch
|
1282
|
+
{
|
1283
|
+
int32_t n_outputs_new = 0;
|
1376
1284
|
|
1377
|
-
|
1285
|
+
if (n_outputs_all == n_tokens_all) {
|
1286
|
+
n_outputs_new = ubatch.n_tokens;
|
1287
|
+
} else {
|
1288
|
+
LM_GGML_ASSERT(ubatch.output);
|
1289
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
1290
|
+
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
1291
|
+
}
|
1292
|
+
}
|
1378
1293
|
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1294
|
+
// needs to happen before the graph is built
|
1295
|
+
n_outputs = n_outputs_new;
|
1296
|
+
}
|
1382
1297
|
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1298
|
+
// find KV slot
|
1299
|
+
{
|
1300
|
+
if (!kv_self->find_slot(ubatch)) {
|
1301
|
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
1386
1302
|
|
1387
|
-
|
1388
|
-
|
1389
|
-
}
|
1390
|
-
};
|
1303
|
+
return 1;
|
1304
|
+
}
|
1391
1305
|
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1306
|
+
if (!kv_self->recurrent) {
|
1307
|
+
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
1308
|
+
// after enough generations, the benefit from this heuristic disappears
|
1309
|
+
// if we start defragmenting the cache, the benefit from this will be more important
|
1310
|
+
const uint32_t pad = kv_self->get_padding(cparams);
|
1311
|
+
kv_self->n = std::min(kv_self->size, std::max(pad, LM_GGML_PAD(kv_self->cell_max(), pad)));
|
1312
|
+
}
|
1313
|
+
}
|
1396
1314
|
|
1397
|
-
|
1315
|
+
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
1398
1316
|
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1317
|
+
lm_ggml_backend_sched_reset(sched.get());
|
1318
|
+
lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
1319
|
+
|
1320
|
+
auto * gf = graph_init();
|
1321
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
|
1322
|
+
|
1323
|
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
1324
|
+
|
1325
|
+
lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
|
1326
|
+
|
1327
|
+
res->set_inputs(&ubatch);
|
1328
|
+
|
1329
|
+
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
1330
|
+
if (compute_status != LM_GGML_STATUS_SUCCESS) {
|
1331
|
+
switch (compute_status) {
|
1332
|
+
case LM_GGML_STATUS_ABORTED:
|
1333
|
+
return 2;
|
1334
|
+
case LM_GGML_STATUS_ALLOC_FAILED:
|
1335
|
+
return -2;
|
1336
|
+
case LM_GGML_STATUS_FAILED:
|
1337
|
+
default:
|
1338
|
+
return -3;
|
1339
|
+
}
|
1340
|
+
}
|
1341
|
+
|
1342
|
+
// plot the computation graph in dot format (for debugging purposes)
|
1343
|
+
//if (n_past%100 == 0) {
|
1344
|
+
// lm_ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
1345
|
+
//}
|
1346
|
+
|
1347
|
+
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
1348
|
+
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
1349
|
+
|
1350
|
+
if (t_embd && res->get_embd_pooled()) {
|
1351
|
+
t_embd = res->get_embd_pooled();
|
1352
|
+
}
|
1353
|
+
|
1354
|
+
// extract logits
|
1355
|
+
if (t_logits && n_outputs > 0) {
|
1356
|
+
lm_ggml_backend_t backend_res = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
1357
|
+
LM_GGML_ASSERT(backend_res != nullptr);
|
1358
|
+
LM_GGML_ASSERT(logits != nullptr);
|
1359
|
+
|
1360
|
+
float * logits_out = logits + n_outputs_prev*n_vocab;
|
1361
|
+
|
1362
|
+
if (n_outputs) {
|
1363
|
+
LM_GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
1364
|
+
LM_GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
1365
|
+
lm_ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
1366
|
+
}
|
1367
|
+
}
|
1368
|
+
|
1369
|
+
// extract embeddings
|
1370
|
+
if (t_embd && n_outputs > 0) {
|
1371
|
+
lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
1372
|
+
LM_GGML_ASSERT(backend_embd != nullptr);
|
1373
|
+
|
1374
|
+
switch (cparams.pooling_type) {
|
1375
|
+
case LLAMA_POOLING_TYPE_NONE:
|
1376
|
+
{
|
1377
|
+
// extract token embeddings
|
1378
|
+
LM_GGML_ASSERT(embd != nullptr);
|
1379
|
+
float * embd_out = embd + n_outputs_prev*n_embd;
|
1380
|
+
|
1381
|
+
if (n_outputs) {
|
1382
|
+
LM_GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
1383
|
+
LM_GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
1384
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
1385
|
+
}
|
1386
|
+
} break;
|
1387
|
+
case LLAMA_POOLING_TYPE_MEAN:
|
1388
|
+
case LLAMA_POOLING_TYPE_CLS:
|
1389
|
+
case LLAMA_POOLING_TYPE_LAST:
|
1390
|
+
{
|
1391
|
+
// extract sequence embeddings (cleared before processing each batch)
|
1392
|
+
auto & embd_seq_out = embd_seq;
|
1393
|
+
|
1394
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
1395
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
1396
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
1397
|
+
continue;
|
1398
|
+
}
|
1399
|
+
embd_seq_out[seq_id].resize(n_embd);
|
1400
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
1401
|
+
}
|
1402
|
+
} break;
|
1403
|
+
case LLAMA_POOLING_TYPE_RANK:
|
1404
|
+
{
|
1405
|
+
// extract the rerank score - a single float per sequence
|
1406
|
+
auto & embd_seq_out = embd_seq;
|
1407
|
+
|
1408
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
1409
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
1410
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
1411
|
+
continue;
|
1412
|
+
}
|
1413
|
+
embd_seq_out[seq_id].resize(1);
|
1414
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
1415
|
+
}
|
1416
|
+
} break;
|
1417
|
+
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
1418
|
+
{
|
1419
|
+
LM_GGML_ABORT("unknown pooling type");
|
1420
|
+
}
|
1421
|
+
}
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
n_outputs_prev += n_outputs;
|
1425
|
+
}
|
1426
|
+
|
1427
|
+
// finalize the batch processing
|
1428
|
+
kv_guard.commit();
|
1429
|
+
|
1430
|
+
// set output mappings
|
1431
|
+
{
|
1432
|
+
bool sorted_output = true;
|
1433
|
+
|
1434
|
+
LM_GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
1435
|
+
|
1436
|
+
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
1437
|
+
int64_t out_id = sbatch.out_ids[i];
|
1438
|
+
output_ids[out_id] = i;
|
1439
|
+
if (out_id != i) {
|
1440
|
+
sorted_output = false;
|
1441
|
+
}
|
1442
|
+
}
|
1443
|
+
|
1444
|
+
if (sorted_output) {
|
1445
|
+
sbatch.out_ids.clear();
|
1446
|
+
}
|
1447
|
+
}
|
1448
|
+
|
1449
|
+
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
1450
|
+
n_outputs = n_outputs_all;
|
1451
|
+
|
1452
|
+
// wait for the computation to finish (automatically done when obtaining the model output)
|
1453
|
+
//synchronize();
|
1454
|
+
|
1455
|
+
// decide if we need to defrag the kv cache
|
1456
|
+
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
1457
|
+
// - do not defrag small contexts (i.e. < 2048 tokens)
|
1458
|
+
// - count the padding towards the number of used tokens
|
1459
|
+
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
|
1460
|
+
|
1461
|
+
// queue defragmentation for next llama_kv_cache_update
|
1462
|
+
if (fragmentation > cparams.defrag_thold) {
|
1463
|
+
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
1464
|
+
|
1465
|
+
kv_self->defrag();
|
1466
|
+
}
|
1467
|
+
}
|
1468
|
+
|
1469
|
+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
1470
|
+
// overlap with device computation.
|
1471
|
+
lm_ggml_backend_sched_reset(sched.get());
|
1472
|
+
|
1473
|
+
return 0;
|
1474
|
+
}
|
1475
|
+
|
1476
|
+
//
|
1477
|
+
// output
|
1478
|
+
//
|
1479
|
+
|
1480
|
+
int32_t llama_context::output_reserve(int32_t n_outputs) {
|
1481
|
+
const auto & hparams = model.hparams;
|
1482
|
+
const auto & vocab = model.vocab;
|
1483
|
+
|
1484
|
+
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
|
1485
|
+
|
1486
|
+
const auto n_batch = cparams.n_batch;
|
1487
|
+
const auto n_vocab = vocab.n_tokens();
|
1488
|
+
const auto n_embd = hparams.n_embd;
|
1489
|
+
|
1490
|
+
// TODO: use a per-batch flag for logits presence instead
|
1491
|
+
bool has_logits = !cparams.embeddings;
|
1492
|
+
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
1493
|
+
|
1494
|
+
// TODO: hacky enc-dec support
|
1495
|
+
if (model.arch == LLM_ARCH_T5) {
|
1496
|
+
has_logits = true;
|
1497
|
+
has_embd = true;
|
1498
|
+
}
|
1499
|
+
|
1500
|
+
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
1501
|
+
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
1502
|
+
|
1503
|
+
if (output_ids.empty()) {
|
1504
|
+
// init, never resized afterwards
|
1505
|
+
output_ids.resize(n_batch);
|
1506
|
+
}
|
1507
|
+
|
1508
|
+
const size_t prev_size = buf_output ? lm_ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
1509
|
+
const size_t new_size = (logits_size + embd_size) * sizeof(float);
|
1510
|
+
|
1511
|
+
// alloc only when more than the current capacity is required
|
1512
|
+
// TODO: also consider shrinking the buffer
|
1513
|
+
if (!buf_output || prev_size < new_size) {
|
1514
|
+
if (buf_output) {
|
1515
|
+
#ifndef NDEBUG
|
1516
|
+
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
1517
|
+
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
1518
|
+
#endif
|
1519
|
+
buf_output = nullptr;
|
1520
|
+
logits = nullptr;
|
1521
|
+
embd = nullptr;
|
1522
|
+
}
|
1523
|
+
|
1524
|
+
auto * buft = lm_ggml_backend_cpu_buffer_type();
|
1525
|
+
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
|
1526
|
+
auto * output_dev = model.dev_output();
|
1527
|
+
auto * output_dev_host_buft = output_dev ? lm_ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
|
1528
|
+
if (output_dev_host_buft) {
|
1529
|
+
buft = output_dev_host_buft;
|
1530
|
+
}
|
1531
|
+
buf_output.reset(lm_ggml_backend_buft_alloc_buffer(buft, new_size));
|
1532
|
+
if (buf_output == nullptr) {
|
1533
|
+
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
|
1534
|
+
return 0;
|
1535
|
+
}
|
1536
|
+
}
|
1537
|
+
|
1538
|
+
float * output_base = (float *) lm_ggml_backend_buffer_get_base(buf_output.get());
|
1539
|
+
|
1540
|
+
logits = has_logits ? output_base : nullptr;
|
1541
|
+
embd = has_embd ? output_base + logits_size : nullptr;
|
1542
|
+
|
1543
|
+
// set all ids as invalid (negative)
|
1544
|
+
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1545
|
+
|
1546
|
+
lm_ggml_backend_buffer_clear(buf_output.get(), 0);
|
1547
|
+
|
1548
|
+
this->n_outputs = 0;
|
1549
|
+
this->n_outputs_max = n_outputs_max;
|
1550
|
+
|
1551
|
+
return n_outputs_max;
|
1552
|
+
}
|
1553
|
+
|
1554
|
+
void llama_context::output_reorder() {
|
1555
|
+
auto & out_ids = sbatch.out_ids;
|
1556
|
+
if (!out_ids.empty()) {
|
1557
|
+
const uint32_t n_vocab = model.vocab.n_tokens();
|
1558
|
+
const uint32_t n_embd = model.hparams.n_embd;
|
1559
|
+
|
1560
|
+
LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
1561
|
+
|
1562
|
+
// TODO: is there something more efficient which also minimizes swaps?
|
1563
|
+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
1564
|
+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
1565
|
+
int32_t j_min = i;
|
1566
|
+
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
1567
|
+
if (out_ids[j] < out_ids[j_min]) {
|
1568
|
+
j_min = j;
|
1569
|
+
}
|
1570
|
+
}
|
1571
|
+
if (j_min == i) { continue; }
|
1572
|
+
std::swap(out_ids[i], out_ids[j_min]);
|
1573
|
+
if (logits_size > 0) {
|
1574
|
+
for (uint32_t k = 0; k < n_vocab; k++) {
|
1575
|
+
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
1576
|
+
}
|
1577
|
+
}
|
1578
|
+
if (embd_size > 0) {
|
1579
|
+
for (uint32_t k = 0; k < n_embd; k++) {
|
1580
|
+
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
1581
|
+
}
|
1582
|
+
}
|
1583
|
+
}
|
1584
|
+
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1585
|
+
for (int32_t i = 0; i < n_outputs; ++i) {
|
1586
|
+
output_ids[out_ids[i]] = i;
|
1587
|
+
}
|
1588
|
+
out_ids.clear();
|
1589
|
+
}
|
1590
|
+
}
|
1591
|
+
|
1592
|
+
//
|
1593
|
+
// graph
|
1594
|
+
//
|
1595
|
+
|
1596
|
+
int32_t llama_context::graph_max_nodes() const {
|
1597
|
+
return std::max<int32_t>(65536, 5*model.n_tensors());
|
1598
|
+
}
|
1599
|
+
|
1600
|
+
lm_ggml_cgraph * llama_context::graph_init() {
|
1601
|
+
lm_ggml_init_params params = {
|
1602
|
+
/*.mem_size =*/ buf_compute_meta.size(),
|
1603
|
+
/*.mem_buffer =*/ buf_compute_meta.data(),
|
1604
|
+
/*.no_alloc =*/ true,
|
1605
|
+
};
|
1606
|
+
|
1607
|
+
ctx_compute.reset(lm_ggml_init(params));
|
1608
|
+
|
1609
|
+
return lm_ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
1610
|
+
}
|
1611
|
+
|
1612
|
+
llm_graph_result_ptr llama_context::graph_build(
|
1613
|
+
lm_ggml_context * ctx,
|
1614
|
+
lm_ggml_cgraph * gf,
|
1615
|
+
const llama_ubatch & ubatch,
|
1616
|
+
llm_graph_type gtype) {
|
1617
|
+
return model.build_graph(
|
1618
|
+
{
|
1619
|
+
/*.ctx =*/ ctx,
|
1620
|
+
/*.arch =*/ model.arch,
|
1621
|
+
/*.hparams =*/ model.hparams,
|
1622
|
+
/*.cparams =*/ cparams,
|
1623
|
+
/*.ubatch =*/ ubatch,
|
1624
|
+
/*.sched =*/ sched.get(),
|
1625
|
+
/*.backend_cpu =*/ backend_cpu,
|
1626
|
+
/*.cvec =*/ &cvec,
|
1627
|
+
/*.loras =*/ &loras,
|
1628
|
+
/*.memory =*/ kv_self.get(),
|
1629
|
+
/*.cross =*/ &cross,
|
1630
|
+
/*.n_outputs =*/ n_outputs,
|
1631
|
+
/*.cb =*/ graph_get_cb(),
|
1632
|
+
}, gf, gtype);
|
1633
|
+
}
|
1634
|
+
|
1635
|
+
lm_ggml_status llama_context::graph_compute(
|
1636
|
+
lm_ggml_cgraph * gf,
|
1637
|
+
bool batched) {
|
1638
|
+
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
|
1639
|
+
lm_ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
|
1640
|
+
|
1641
|
+
if (backend_cpu != nullptr) {
|
1642
|
+
auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_get_device(backend_cpu));
|
1643
|
+
auto * set_threadpool_fn = (decltype(lm_ggml_backend_cpu_set_threadpool) *) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_cpu_set_threadpool");
|
1644
|
+
set_threadpool_fn(backend_cpu, tp);
|
1645
|
+
}
|
1646
|
+
|
1647
|
+
// set the number of threads for all the backends
|
1648
|
+
for (const auto & set_n_threads_fn : set_n_threads_fns) {
|
1649
|
+
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
1650
|
+
}
|
1651
|
+
|
1652
|
+
auto status = lm_ggml_backend_sched_graph_compute_async(sched.get(), gf);
|
1653
|
+
if (status != LM_GGML_STATUS_SUCCESS) {
|
1654
|
+
LLAMA_LOG_ERROR("%s: lm_ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
1655
|
+
}
|
1656
|
+
|
1657
|
+
// fprintf(stderr, "splits: %d\n", lm_ggml_backend_sched_get_n_splits(sched));
|
1658
|
+
|
1659
|
+
return status;
|
1660
|
+
}
|
1661
|
+
|
1662
|
+
llm_graph_cb llama_context::graph_get_cb() const {
|
1663
|
+
return [&](const llama_ubatch & ubatch, lm_ggml_tensor * cur, const char * name, int il) {
|
1664
|
+
if (il >= 0) {
|
1665
|
+
lm_ggml_format_name(cur, "%s-%d", name, il);
|
1666
|
+
} else {
|
1667
|
+
lm_ggml_set_name(cur, name);
|
1668
|
+
}
|
1669
|
+
|
1670
|
+
if (!cparams.offload_kqv) {
|
1671
|
+
if (strcmp(name, "kqv_merged_cont") == 0) {
|
1672
|
+
// all nodes between the KV store and the attention output are run on the CPU
|
1673
|
+
lm_ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
|
1674
|
+
}
|
1675
|
+
}
|
1676
|
+
|
1677
|
+
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
|
1678
|
+
// FIXME: fix in lm_ggml_backend_sched
|
1679
|
+
const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
|
1680
|
+
if (ubatch.n_tokens < 32 || full_offload) {
|
1681
|
+
if (il != -1 && strcmp(name, "norm") == 0) {
|
1682
|
+
const auto & dev_layer = model.dev_layer(il);
|
1683
|
+
for (const auto & backend : backends) {
|
1684
|
+
if (lm_ggml_backend_get_device(backend.get()) == dev_layer) {
|
1685
|
+
if (lm_ggml_backend_supports_op(backend.get(), cur)) {
|
1686
|
+
lm_ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
|
1687
|
+
}
|
1688
|
+
}
|
1689
|
+
}
|
1690
|
+
}
|
1691
|
+
}
|
1692
|
+
};
|
1693
|
+
}
|
1694
|
+
|
1695
|
+
//
|
1696
|
+
// state save/load
|
1697
|
+
//
|
1698
|
+
|
1699
|
+
class llama_io_write_dummy : public llama_io_write_i {
|
1700
|
+
public:
|
1701
|
+
llama_io_write_dummy() = default;
|
1702
|
+
|
1703
|
+
void write(const void * /* src */, size_t size) override {
|
1704
|
+
size_written += size;
|
1705
|
+
}
|
1706
|
+
|
1707
|
+
void write_tensor(const lm_ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
1708
|
+
size_written += size;
|
1709
|
+
}
|
1710
|
+
|
1711
|
+
size_t n_bytes() override {
|
1712
|
+
return size_written;
|
1713
|
+
}
|
1714
|
+
|
1715
|
+
private:
|
1716
|
+
size_t size_written = 0;
|
1717
|
+
};
|
1718
|
+
|
1719
|
+
class llama_io_write_buffer : public llama_io_write_i {
|
1720
|
+
public:
|
1721
|
+
llama_io_write_buffer(
|
1722
|
+
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
1723
|
+
|
1724
|
+
void write(const void * src, size_t size) override {
|
1725
|
+
if (size > buf_size) {
|
1726
|
+
throw std::runtime_error("unexpectedly reached end of buffer");
|
1402
1727
|
}
|
1403
1728
|
memcpy(ptr, src, size);
|
1404
1729
|
ptr += size;
|
@@ -1406,7 +1731,7 @@ struct llama_data_write_buffer : llama_data_write {
|
|
1406
1731
|
buf_size -= size;
|
1407
1732
|
}
|
1408
1733
|
|
1409
|
-
void
|
1734
|
+
void write_tensor(const lm_ggml_tensor * tensor, size_t offset, size_t size) override {
|
1410
1735
|
if (size > buf_size) {
|
1411
1736
|
throw std::runtime_error("unexpectedly reached end of buffer");
|
1412
1737
|
}
|
@@ -1416,17 +1741,19 @@ struct llama_data_write_buffer : llama_data_write {
|
|
1416
1741
|
buf_size -= size;
|
1417
1742
|
}
|
1418
1743
|
|
1419
|
-
size_t
|
1744
|
+
size_t n_bytes() override {
|
1420
1745
|
return size_written;
|
1421
1746
|
}
|
1422
|
-
};
|
1423
1747
|
|
1424
|
-
|
1425
|
-
|
1748
|
+
private:
|
1749
|
+
uint8_t * ptr;
|
1426
1750
|
size_t buf_size = 0;
|
1427
|
-
size_t
|
1751
|
+
size_t size_written = 0;
|
1752
|
+
};
|
1428
1753
|
|
1429
|
-
|
1754
|
+
class llama_io_read_buffer : public llama_io_read_i {
|
1755
|
+
public:
|
1756
|
+
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
1430
1757
|
|
1431
1758
|
const uint8_t * read(size_t size) override {
|
1432
1759
|
const uint8_t * base_ptr = ptr;
|
@@ -1443,40 +1770,44 @@ struct llama_data_read_buffer : llama_data_read {
|
|
1443
1770
|
memcpy(dst, read(size), size);
|
1444
1771
|
}
|
1445
1772
|
|
1446
|
-
size_t
|
1773
|
+
size_t n_bytes() override {
|
1447
1774
|
return size_read;
|
1448
1775
|
}
|
1449
|
-
};
|
1450
1776
|
|
1451
|
-
|
1452
|
-
|
1453
|
-
size_t
|
1454
|
-
|
1777
|
+
private:
|
1778
|
+
const uint8_t * ptr;
|
1779
|
+
size_t buf_size = 0;
|
1780
|
+
size_t size_read = 0;
|
1781
|
+
};
|
1455
1782
|
|
1456
|
-
|
1783
|
+
class llama_io_write_file : public llama_io_write_i {
|
1784
|
+
public:
|
1785
|
+
llama_io_write_file(llama_file * f) : file(f) {}
|
1457
1786
|
|
1458
1787
|
void write(const void * src, size_t size) override {
|
1459
1788
|
file->write_raw(src, size);
|
1460
1789
|
size_written += size;
|
1461
1790
|
}
|
1462
1791
|
|
1463
|
-
void
|
1792
|
+
void write_tensor(const lm_ggml_tensor * tensor, size_t offset, size_t size) override {
|
1464
1793
|
temp_buffer.resize(size);
|
1465
1794
|
lm_ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
1466
1795
|
write(temp_buffer.data(), temp_buffer.size());
|
1467
1796
|
}
|
1468
1797
|
|
1469
|
-
size_t
|
1798
|
+
size_t n_bytes() override {
|
1470
1799
|
return size_written;
|
1471
1800
|
}
|
1472
|
-
};
|
1473
1801
|
|
1474
|
-
|
1802
|
+
private:
|
1475
1803
|
llama_file * file;
|
1476
|
-
size_t
|
1804
|
+
size_t size_written = 0;
|
1477
1805
|
std::vector<uint8_t> temp_buffer;
|
1806
|
+
};
|
1478
1807
|
|
1479
|
-
|
1808
|
+
class llama_io_read_file : public llama_io_read_i {
|
1809
|
+
public:
|
1810
|
+
llama_io_read_file(llama_file * f) : file(f) {}
|
1480
1811
|
|
1481
1812
|
void read_to(void * dst, size_t size) override {
|
1482
1813
|
file->read_raw(dst, size);
|
@@ -1489,89 +1820,78 @@ struct llama_data_read_file : llama_data_read {
|
|
1489
1820
|
return temp_buffer.data();
|
1490
1821
|
}
|
1491
1822
|
|
1492
|
-
size_t
|
1823
|
+
size_t n_bytes() override {
|
1493
1824
|
return size_read;
|
1494
1825
|
}
|
1495
|
-
};
|
1496
|
-
|
1497
|
-
/** copy state data into either a buffer or file depending on the passed in context
|
1498
|
-
*
|
1499
|
-
* file context:
|
1500
|
-
* llama_file file("/path", "wb");
|
1501
|
-
* llama_data_write_file data_ctx(&file);
|
1502
|
-
* llama_state_get_data_internal(ctx, data_ctx);
|
1503
|
-
*
|
1504
|
-
* buffer context:
|
1505
|
-
* std::vector<uint8_t> buf(max_size, 0);
|
1506
|
-
* llama_data_write_buffer data_ctx(buf.data(), max_size);
|
1507
|
-
* llama_state_get_data_internal(ctx, data_ctx);
|
1508
|
-
*
|
1509
|
-
*/
|
1510
|
-
static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
|
1511
|
-
llama_synchronize(ctx);
|
1512
|
-
|
1513
|
-
data_ctx.write_model_info(ctx);
|
1514
|
-
|
1515
|
-
// copy outputs
|
1516
|
-
data_ctx.write_output_ids(ctx);
|
1517
|
-
data_ctx.write_logits(ctx);
|
1518
|
-
data_ctx.write_embeddings(ctx);
|
1519
1826
|
|
1520
|
-
|
1827
|
+
private:
|
1828
|
+
llama_file * file;
|
1829
|
+
size_t size_read = 0;
|
1830
|
+
std::vector<uint8_t> temp_buffer;
|
1831
|
+
};
|
1521
1832
|
|
1522
|
-
|
1833
|
+
size_t llama_context::state_get_size() {
|
1834
|
+
llama_io_write_dummy io;
|
1835
|
+
try {
|
1836
|
+
return state_write_data(io);
|
1837
|
+
} catch (const std::exception & err) {
|
1838
|
+
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
1839
|
+
return 0;
|
1840
|
+
}
|
1523
1841
|
}
|
1524
1842
|
|
1525
|
-
size_t
|
1526
|
-
|
1843
|
+
size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
|
1844
|
+
llama_io_write_buffer io(dst, size);
|
1527
1845
|
try {
|
1528
|
-
return
|
1846
|
+
return state_write_data(io);
|
1529
1847
|
} catch (const std::exception & err) {
|
1530
1848
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
1531
1849
|
return 0;
|
1532
1850
|
}
|
1533
1851
|
}
|
1534
1852
|
|
1535
|
-
|
1536
|
-
|
1537
|
-
size_t llama_state_get_size(struct llama_context * ctx) {
|
1538
|
-
llama_data_write_dummy data_ctx;
|
1853
|
+
size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
1854
|
+
llama_io_read_buffer io(src, size);
|
1539
1855
|
try {
|
1540
|
-
return
|
1856
|
+
return state_read_data(io);
|
1541
1857
|
} catch (const std::exception & err) {
|
1542
|
-
LLAMA_LOG_ERROR("%s: error
|
1858
|
+
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
1543
1859
|
return 0;
|
1544
1860
|
}
|
1545
1861
|
}
|
1546
1862
|
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
data_ctx.read_kv_cache(ctx);
|
1863
|
+
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
1864
|
+
llama_io_write_dummy io;
|
1865
|
+
try {
|
1866
|
+
return state_seq_write_data(io, seq_id);
|
1867
|
+
} catch (const std::exception & err) {
|
1868
|
+
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
1869
|
+
return 0;
|
1870
|
+
}
|
1871
|
+
}
|
1558
1872
|
|
1559
|
-
|
1873
|
+
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
1874
|
+
llama_io_write_buffer io(dst, size);
|
1875
|
+
try {
|
1876
|
+
return state_seq_write_data(io, seq_id);
|
1877
|
+
} catch (const std::exception & err) {
|
1878
|
+
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
1879
|
+
return 0;
|
1880
|
+
}
|
1560
1881
|
}
|
1561
1882
|
|
1562
|
-
|
1563
|
-
|
1564
|
-
llama_data_read_buffer data_ctx(src, size);
|
1883
|
+
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
1884
|
+
llama_io_read_buffer io(src, size);
|
1565
1885
|
try {
|
1566
|
-
return
|
1886
|
+
return state_seq_read_data(io, seq_id);
|
1567
1887
|
} catch (const std::exception & err) {
|
1568
1888
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
1569
1889
|
return 0;
|
1570
1890
|
}
|
1571
1891
|
}
|
1572
1892
|
|
1573
|
-
|
1574
|
-
llama_file file(
|
1893
|
+
bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
1894
|
+
llama_file file(filepath, "rb");
|
1575
1895
|
|
1576
1896
|
// sanity checks
|
1577
1897
|
{
|
@@ -1601,28 +1921,20 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
|
|
1601
1921
|
{
|
1602
1922
|
const size_t n_state_size_cur = file.size() - file.tell();
|
1603
1923
|
|
1604
|
-
|
1605
|
-
const size_t n_read =
|
1924
|
+
llama_io_read_file io( &file);
|
1925
|
+
const size_t n_read = state_read_data(io);
|
1606
1926
|
|
1607
1927
|
if (n_read != n_state_size_cur) {
|
1608
1928
|
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
|
1609
1929
|
return false;
|
1610
1930
|
}
|
1611
1931
|
}
|
1612
|
-
return true;
|
1613
|
-
}
|
1614
1932
|
|
1615
|
-
|
1616
|
-
try {
|
1617
|
-
return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
1618
|
-
} catch (const std::exception & err) {
|
1619
|
-
LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
|
1620
|
-
return false;
|
1621
|
-
}
|
1933
|
+
return true;
|
1622
1934
|
}
|
1623
1935
|
|
1624
|
-
|
1625
|
-
llama_file file(
|
1936
|
+
bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
1937
|
+
llama_file file(filepath, "wb");
|
1626
1938
|
|
1627
1939
|
file.write_u32(LLAMA_SESSION_MAGIC);
|
1628
1940
|
file.write_u32(LLAMA_SESSION_VERSION);
|
@@ -1632,63 +1944,56 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
|
|
1632
1944
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
1633
1945
|
|
1634
1946
|
// save the context state using stream saving
|
1635
|
-
|
1636
|
-
|
1947
|
+
llama_io_write_file io(&file);
|
1948
|
+
state_write_data(io);
|
1637
1949
|
|
1638
1950
|
return true;
|
1639
1951
|
}
|
1640
1952
|
|
1641
|
-
|
1642
|
-
|
1643
|
-
return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
|
1644
|
-
} catch (const std::exception & err) {
|
1645
|
-
LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
|
1646
|
-
return false;
|
1647
|
-
}
|
1648
|
-
}
|
1649
|
-
|
1650
|
-
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
1651
|
-
llama_synchronize(ctx);
|
1953
|
+
size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
1954
|
+
llama_file file(filepath, "rb");
|
1652
1955
|
|
1653
|
-
|
1956
|
+
// version checks
|
1957
|
+
{
|
1958
|
+
const uint32_t magic = file.read_u32();
|
1959
|
+
const uint32_t version = file.read_u32();
|
1654
1960
|
|
1655
|
-
|
1656
|
-
|
1961
|
+
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
1962
|
+
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
1963
|
+
return 0;
|
1964
|
+
}
|
1965
|
+
}
|
1657
1966
|
|
1658
|
-
|
1659
|
-
|
1660
|
-
|
1661
|
-
}
|
1662
|
-
|
1663
|
-
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
1664
|
-
llama_data_write_buffer data_ctx(dst, size);
|
1665
|
-
try {
|
1666
|
-
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
1667
|
-
} catch (const std::exception & err) {
|
1668
|
-
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
|
1669
|
-
return 0;
|
1670
|
-
}
|
1671
|
-
}
|
1672
|
-
|
1673
|
-
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
1674
|
-
llama_synchronize(ctx);
|
1967
|
+
// load the prompt
|
1968
|
+
{
|
1969
|
+
const uint32_t n_token_count = file.read_u32();
|
1675
1970
|
|
1676
|
-
|
1971
|
+
if (n_token_count > n_token_capacity) {
|
1972
|
+
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
1973
|
+
return 0;
|
1974
|
+
}
|
1677
1975
|
|
1678
|
-
|
1679
|
-
|
1976
|
+
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
1977
|
+
*n_token_count_out = n_token_count;
|
1978
|
+
}
|
1680
1979
|
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1684
|
-
|
1685
|
-
|
1686
|
-
|
1687
|
-
|
1980
|
+
// restore the context state
|
1981
|
+
{
|
1982
|
+
const size_t state_size = file.size() - file.tell();
|
1983
|
+
llama_io_read_file io(&file);
|
1984
|
+
const size_t nread = state_seq_read_data(io, seq_id);
|
1985
|
+
if (!nread) {
|
1986
|
+
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
1987
|
+
return 0;
|
1988
|
+
}
|
1989
|
+
LM_GGML_ASSERT(nread <= state_size);
|
1990
|
+
LM_GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
1688
1991
|
}
|
1992
|
+
|
1993
|
+
return file.tell();
|
1689
1994
|
}
|
1690
1995
|
|
1691
|
-
|
1996
|
+
size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
1692
1997
|
llama_file file(filepath, "wb");
|
1693
1998
|
|
1694
1999
|
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
@@ -1699,77 +2004,828 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
|
|
1699
2004
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
1700
2005
|
|
1701
2006
|
// save the context state using stream saving
|
1702
|
-
|
1703
|
-
|
2007
|
+
llama_io_write_file io(&file);
|
2008
|
+
state_seq_write_data(io, seq_id);
|
1704
2009
|
|
1705
2010
|
const size_t res = file.tell();
|
1706
|
-
LM_GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count +
|
2011
|
+
LM_GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
2012
|
+
|
1707
2013
|
return res;
|
1708
2014
|
}
|
1709
2015
|
|
1710
|
-
|
1711
|
-
|
2016
|
+
size_t llama_context::state_write_data(llama_io_write_i & io) {
|
2017
|
+
LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
|
1712
2018
|
|
1713
|
-
//
|
2019
|
+
// write model info
|
1714
2020
|
{
|
1715
|
-
|
1716
|
-
const uint32_t version = file.read_u32();
|
2021
|
+
LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
|
1717
2022
|
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
2023
|
+
const std::string arch_str = llm_arch_name(model.arch);
|
2024
|
+
io.write_string(arch_str);
|
2025
|
+
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
2026
|
+
}
|
2027
|
+
|
2028
|
+
// write output ids
|
2029
|
+
{
|
2030
|
+
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
2031
|
+
|
2032
|
+
output_reorder();
|
2033
|
+
|
2034
|
+
const auto n_outputs = this->n_outputs;
|
2035
|
+
const auto & output_ids = this->output_ids;
|
2036
|
+
|
2037
|
+
std::vector<int32_t> w_output_pos;
|
2038
|
+
|
2039
|
+
LM_GGML_ASSERT(n_outputs <= n_outputs_max);
|
2040
|
+
|
2041
|
+
w_output_pos.resize(n_outputs);
|
2042
|
+
|
2043
|
+
// build a more compact representation of the output ids
|
2044
|
+
for (size_t i = 0; i < n_batch(); ++i) {
|
2045
|
+
// map an output id to a position in the batch
|
2046
|
+
int32_t pos = output_ids[i];
|
2047
|
+
if (pos >= 0) {
|
2048
|
+
LM_GGML_ASSERT(pos < n_outputs);
|
2049
|
+
w_output_pos[pos] = i;
|
2050
|
+
}
|
2051
|
+
}
|
2052
|
+
|
2053
|
+
io.write(&n_outputs, sizeof(n_outputs));
|
2054
|
+
|
2055
|
+
if (n_outputs) {
|
2056
|
+
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
1721
2057
|
}
|
1722
2058
|
}
|
1723
2059
|
|
1724
|
-
//
|
2060
|
+
// write logits
|
1725
2061
|
{
|
1726
|
-
|
2062
|
+
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
1727
2063
|
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
2064
|
+
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
2065
|
+
|
2066
|
+
io.write(&logits_size, sizeof(logits_size));
|
2067
|
+
|
2068
|
+
if (logits_size) {
|
2069
|
+
io.write(logits, logits_size * sizeof(float));
|
1731
2070
|
}
|
2071
|
+
}
|
1732
2072
|
|
1733
|
-
|
1734
|
-
|
2073
|
+
// write embeddings
|
2074
|
+
{
|
2075
|
+
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
2076
|
+
|
2077
|
+
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
2078
|
+
|
2079
|
+
io.write(&embd_size, sizeof(embd_size));
|
2080
|
+
|
2081
|
+
if (embd_size) {
|
2082
|
+
io.write(embd, embd_size * sizeof(float));
|
2083
|
+
}
|
1735
2084
|
}
|
1736
2085
|
|
1737
|
-
|
2086
|
+
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
2087
|
+
kv_self->state_write(io);
|
2088
|
+
|
2089
|
+
return io.n_bytes();
|
2090
|
+
}
|
2091
|
+
|
2092
|
+
size_t llama_context::state_read_data(llama_io_read_i & io) {
|
2093
|
+
LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
|
2094
|
+
|
2095
|
+
// read model info
|
1738
2096
|
{
|
1739
|
-
|
1740
|
-
|
1741
|
-
const
|
1742
|
-
|
1743
|
-
|
1744
|
-
|
2097
|
+
LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
|
2098
|
+
|
2099
|
+
const std::string cur_arch_str = llm_arch_name(model.arch);
|
2100
|
+
|
2101
|
+
std::string arch_str;
|
2102
|
+
io.read_string(arch_str);
|
2103
|
+
if (cur_arch_str != arch_str) {
|
2104
|
+
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
1745
2105
|
}
|
1746
|
-
|
1747
|
-
LM_GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
2106
|
+
// TODO: add more info which needs to be identical but which is not verified otherwise
|
1748
2107
|
}
|
1749
2108
|
|
1750
|
-
|
2109
|
+
// read output ids
|
2110
|
+
{
|
2111
|
+
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
|
2112
|
+
|
2113
|
+
auto n_outputs = this->n_outputs;
|
2114
|
+
io.read_to(&n_outputs, sizeof(n_outputs));
|
2115
|
+
|
2116
|
+
if (n_outputs > output_reserve(n_outputs)) {
|
2117
|
+
throw std::runtime_error("could not reserve outputs");
|
2118
|
+
}
|
2119
|
+
|
2120
|
+
std::vector<int32_t> output_pos;
|
2121
|
+
|
2122
|
+
if (n_outputs) {
|
2123
|
+
output_pos.resize(n_outputs);
|
2124
|
+
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
2125
|
+
|
2126
|
+
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
2127
|
+
int32_t id = output_pos[i];
|
2128
|
+
if ((uint32_t) id >= n_batch()) {
|
2129
|
+
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
2130
|
+
}
|
2131
|
+
this->output_ids[id] = i;
|
2132
|
+
}
|
2133
|
+
|
2134
|
+
this->n_outputs = n_outputs;
|
2135
|
+
}
|
2136
|
+
}
|
2137
|
+
|
2138
|
+
// read logits
|
2139
|
+
{
|
2140
|
+
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
|
2141
|
+
|
2142
|
+
uint64_t logits_size;
|
2143
|
+
io.read_to(&logits_size, sizeof(logits_size));
|
2144
|
+
|
2145
|
+
if (this->logits_size < logits_size) {
|
2146
|
+
throw std::runtime_error("logits buffer too small");
|
2147
|
+
}
|
2148
|
+
|
2149
|
+
if (logits_size) {
|
2150
|
+
io.read_to(this->logits, logits_size * sizeof(float));
|
2151
|
+
}
|
2152
|
+
}
|
2153
|
+
|
2154
|
+
// read embeddings
|
2155
|
+
{
|
2156
|
+
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
|
2157
|
+
|
2158
|
+
uint64_t embd_size;
|
2159
|
+
io.read_to(&embd_size, sizeof(embd_size));
|
2160
|
+
|
2161
|
+
if (this->embd_size < embd_size) {
|
2162
|
+
throw std::runtime_error("embeddings buffer too small");
|
2163
|
+
}
|
2164
|
+
|
2165
|
+
if (embd_size) {
|
2166
|
+
io.read_to(this->embd, embd_size * sizeof(float));
|
2167
|
+
}
|
2168
|
+
}
|
2169
|
+
|
2170
|
+
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
2171
|
+
kv_self->state_read(io);
|
2172
|
+
|
2173
|
+
return io.n_bytes();
|
2174
|
+
}
|
2175
|
+
|
2176
|
+
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
2177
|
+
LM_GGML_UNUSED(seq_id);
|
2178
|
+
|
2179
|
+
kv_self->state_write(io, seq_id);
|
2180
|
+
|
2181
|
+
return io.n_bytes();
|
2182
|
+
}
|
2183
|
+
|
2184
|
+
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
2185
|
+
LM_GGML_UNUSED(seq_id);
|
2186
|
+
|
2187
|
+
kv_self->state_read(io, seq_id);
|
2188
|
+
|
2189
|
+
return io.n_bytes();
|
2190
|
+
}
|
2191
|
+
|
2192
|
+
//
|
2193
|
+
// perf
|
2194
|
+
//
|
2195
|
+
|
2196
|
+
llama_perf_context_data llama_context::perf_get_data() const {
|
2197
|
+
llama_perf_context_data data = {};
|
2198
|
+
|
2199
|
+
data.t_start_ms = 1e-3 * t_start_us;
|
2200
|
+
data.t_load_ms = 1e-3 * t_load_us;
|
2201
|
+
data.t_p_eval_ms = 1e-3 * t_p_eval_us;
|
2202
|
+
data.t_eval_ms = 1e-3 * t_eval_us;
|
2203
|
+
data.n_p_eval = std::max(1, n_p_eval);
|
2204
|
+
data.n_eval = std::max(1, n_eval);
|
2205
|
+
|
2206
|
+
return data;
|
1751
2207
|
}
|
1752
2208
|
|
1753
|
-
|
2209
|
+
void llama_context::perf_reset() {
|
2210
|
+
t_start_us = lm_ggml_time_us();
|
2211
|
+
t_eval_us = n_eval = 0;
|
2212
|
+
t_p_eval_us = n_p_eval = 0;
|
2213
|
+
}
|
2214
|
+
|
2215
|
+
//
|
2216
|
+
// interface implementation
|
2217
|
+
//
|
2218
|
+
|
2219
|
+
llama_context_params llama_context_default_params() {
|
2220
|
+
llama_context_params result = {
|
2221
|
+
/*.n_ctx =*/ 512,
|
2222
|
+
/*.n_batch =*/ 2048,
|
2223
|
+
/*.n_ubatch =*/ 512,
|
2224
|
+
/*.n_seq_max =*/ 1,
|
2225
|
+
/*.n_threads =*/ LM_GGML_DEFAULT_N_THREADS, // TODO: better default
|
2226
|
+
/*.n_threads_batch =*/ LM_GGML_DEFAULT_N_THREADS,
|
2227
|
+
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
2228
|
+
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
2229
|
+
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
2230
|
+
/*.rope_freq_base =*/ 0.0f,
|
2231
|
+
/*.rope_freq_scale =*/ 0.0f,
|
2232
|
+
/*.yarn_ext_factor =*/ -1.0f,
|
2233
|
+
/*.yarn_attn_factor =*/ 1.0f,
|
2234
|
+
/*.yarn_beta_fast =*/ 32.0f,
|
2235
|
+
/*.yarn_beta_slow =*/ 1.0f,
|
2236
|
+
/*.yarn_orig_ctx =*/ 0,
|
2237
|
+
/*.defrag_thold =*/ -1.0f,
|
2238
|
+
/*.cb_eval =*/ nullptr,
|
2239
|
+
/*.cb_eval_user_data =*/ nullptr,
|
2240
|
+
/*.type_k =*/ LM_GGML_TYPE_F16,
|
2241
|
+
/*.type_v =*/ LM_GGML_TYPE_F16,
|
2242
|
+
/*.logits_all =*/ false,
|
2243
|
+
/*.embeddings =*/ false,
|
2244
|
+
/*.offload_kqv =*/ true,
|
2245
|
+
/*.flash_attn =*/ false,
|
2246
|
+
/*.no_perf =*/ true,
|
2247
|
+
/*.abort_callback =*/ nullptr,
|
2248
|
+
/*.abort_callback_data =*/ nullptr,
|
2249
|
+
};
|
2250
|
+
|
2251
|
+
return result;
|
2252
|
+
}
|
2253
|
+
|
2254
|
+
llama_context * llama_init_from_model(
|
2255
|
+
llama_model * model,
|
2256
|
+
llama_context_params params) {
|
2257
|
+
if (!model) {
|
2258
|
+
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
|
2259
|
+
return nullptr;
|
2260
|
+
}
|
2261
|
+
|
2262
|
+
if (params.n_batch == 0 && params.n_ubatch == 0) {
|
2263
|
+
LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
|
2264
|
+
return nullptr;
|
2265
|
+
}
|
2266
|
+
|
2267
|
+
if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
|
2268
|
+
LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
|
2269
|
+
return nullptr;
|
2270
|
+
}
|
2271
|
+
|
2272
|
+
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
|
2273
|
+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
2274
|
+
params.flash_attn = false;
|
2275
|
+
}
|
2276
|
+
|
2277
|
+
if (lm_ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
2278
|
+
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
2279
|
+
return nullptr;
|
2280
|
+
}
|
2281
|
+
|
1754
2282
|
try {
|
1755
|
-
|
2283
|
+
auto * ctx = new llama_context(*model, params);
|
2284
|
+
return ctx;
|
2285
|
+
} catch (const std::exception & err) {
|
2286
|
+
LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
|
2287
|
+
}
|
2288
|
+
|
2289
|
+
return nullptr;
|
2290
|
+
}
|
2291
|
+
|
2292
|
+
// deprecated
|
2293
|
+
llama_context * llama_new_context_with_model(
|
2294
|
+
llama_model * model,
|
2295
|
+
llama_context_params params) {
|
2296
|
+
return llama_init_from_model(model, params);
|
2297
|
+
}
|
2298
|
+
|
2299
|
+
void llama_free(llama_context * ctx) {
|
2300
|
+
delete ctx;
|
2301
|
+
}
|
2302
|
+
|
2303
|
+
uint32_t llama_n_ctx(const llama_context * ctx) {
|
2304
|
+
return ctx->n_ctx();
|
2305
|
+
}
|
2306
|
+
|
2307
|
+
uint32_t llama_n_batch(const llama_context * ctx) {
|
2308
|
+
return ctx->n_batch();
|
2309
|
+
}
|
2310
|
+
|
2311
|
+
uint32_t llama_n_ubatch(const llama_context * ctx) {
|
2312
|
+
return ctx->n_ubatch();
|
2313
|
+
}
|
2314
|
+
|
2315
|
+
uint32_t llama_n_seq_max(const llama_context * ctx) {
|
2316
|
+
return ctx->n_seq_max();
|
2317
|
+
}
|
2318
|
+
|
2319
|
+
const llama_model * llama_get_model(const llama_context * ctx) {
|
2320
|
+
return &ctx->get_model();
|
2321
|
+
}
|
2322
|
+
|
2323
|
+
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
2324
|
+
return ctx->get_kv_self();
|
2325
|
+
}
|
2326
|
+
|
2327
|
+
void llama_kv_self_update(llama_context * ctx) {
|
2328
|
+
ctx->kv_self_update();
|
2329
|
+
}
|
2330
|
+
|
2331
|
+
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
2332
|
+
return ctx->pooling_type();
|
2333
|
+
}
|
2334
|
+
|
2335
|
+
void llama_attach_threadpool(
|
2336
|
+
llama_context * ctx,
|
2337
|
+
lm_ggml_threadpool_t threadpool,
|
2338
|
+
lm_ggml_threadpool_t threadpool_batch) {
|
2339
|
+
ctx->attach_threadpool(threadpool, threadpool_batch);
|
2340
|
+
}
|
2341
|
+
|
2342
|
+
void llama_detach_threadpool(llama_context * ctx) {
|
2343
|
+
ctx->detach_threadpool();
|
2344
|
+
}
|
2345
|
+
|
2346
|
+
void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
|
2347
|
+
ctx->set_n_threads(n_threads, n_threads_batch);
|
2348
|
+
}
|
2349
|
+
|
2350
|
+
int32_t llama_n_threads(llama_context * ctx) {
|
2351
|
+
return ctx->n_threads();
|
2352
|
+
}
|
2353
|
+
|
2354
|
+
int32_t llama_n_threads_batch(llama_context * ctx) {
|
2355
|
+
return ctx->n_threads_batch();
|
2356
|
+
}
|
2357
|
+
|
2358
|
+
void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
|
2359
|
+
ctx->set_abort_callback(abort_callback, abort_callback_data);
|
2360
|
+
}
|
2361
|
+
|
2362
|
+
void llama_set_embeddings(llama_context * ctx, bool embeddings) {
|
2363
|
+
ctx->set_embeddings(embeddings);
|
2364
|
+
}
|
2365
|
+
|
2366
|
+
void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
|
2367
|
+
ctx->set_causal_attn(causal_attn);
|
2368
|
+
}
|
2369
|
+
|
2370
|
+
void llama_set_warmup(llama_context * ctx, bool warmup) {
|
2371
|
+
ctx->set_warmup(warmup);
|
2372
|
+
}
|
2373
|
+
|
2374
|
+
void llama_synchronize(llama_context * ctx) {
|
2375
|
+
ctx->synchronize();
|
2376
|
+
}
|
2377
|
+
|
2378
|
+
float * llama_get_logits(llama_context * ctx) {
|
2379
|
+
ctx->synchronize();
|
2380
|
+
|
2381
|
+
return ctx->get_logits();
|
2382
|
+
}
|
2383
|
+
|
2384
|
+
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
2385
|
+
ctx->synchronize();
|
2386
|
+
|
2387
|
+
return ctx->get_logits_ith(i);
|
2388
|
+
}
|
2389
|
+
|
2390
|
+
float * llama_get_embeddings(llama_context * ctx) {
|
2391
|
+
ctx->synchronize();
|
2392
|
+
|
2393
|
+
return ctx->get_embeddings();
|
2394
|
+
}
|
2395
|
+
|
2396
|
+
float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
|
2397
|
+
ctx->synchronize();
|
2398
|
+
|
2399
|
+
return ctx->get_embeddings_ith(i);
|
2400
|
+
}
|
2401
|
+
|
2402
|
+
float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
2403
|
+
ctx->synchronize();
|
2404
|
+
|
2405
|
+
return ctx->get_embeddings_seq(seq_id);
|
2406
|
+
}
|
2407
|
+
|
2408
|
+
// llama adapter API
|
2409
|
+
|
2410
|
+
int32_t llama_set_adapter_lora(
|
2411
|
+
llama_context * ctx,
|
2412
|
+
llama_adapter_lora * adapter,
|
2413
|
+
float scale) {
|
2414
|
+
ctx->set_adapter_lora(adapter, scale);
|
2415
|
+
|
2416
|
+
return 0;
|
2417
|
+
}
|
2418
|
+
|
2419
|
+
int32_t llama_rm_adapter_lora(
|
2420
|
+
llama_context * ctx,
|
2421
|
+
llama_adapter_lora * adapter) {
|
2422
|
+
bool res = ctx->rm_adapter_lora(adapter);
|
2423
|
+
|
2424
|
+
return res ? 0 : -1;
|
2425
|
+
}
|
2426
|
+
|
2427
|
+
void llama_clear_adapter_lora(llama_context * ctx) {
|
2428
|
+
ctx->clear_adapter_lora();
|
2429
|
+
}
|
2430
|
+
|
2431
|
+
int32_t llama_apply_adapter_cvec(
|
2432
|
+
llama_context * ctx,
|
2433
|
+
const float * data,
|
2434
|
+
size_t len,
|
2435
|
+
int32_t n_embd,
|
2436
|
+
int32_t il_start,
|
2437
|
+
int32_t il_end) {
|
2438
|
+
bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
|
2439
|
+
|
2440
|
+
return res ? 0 : -1;
|
2441
|
+
}
|
2442
|
+
|
2443
|
+
//
|
2444
|
+
// kv cache view
|
2445
|
+
//
|
2446
|
+
|
2447
|
+
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
2448
|
+
const auto * kv = ctx->get_kv_self();
|
2449
|
+
if (kv == nullptr) {
|
2450
|
+
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
2451
|
+
return {};
|
2452
|
+
}
|
2453
|
+
|
2454
|
+
return llama_kv_cache_view_init(*kv, n_seq_max);
|
2455
|
+
}
|
2456
|
+
|
2457
|
+
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
2458
|
+
const auto * kv = ctx->get_kv_self();
|
2459
|
+
if (kv == nullptr) {
|
2460
|
+
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
2461
|
+
return;
|
2462
|
+
}
|
2463
|
+
|
2464
|
+
llama_kv_cache_view_update(view, kv);
|
2465
|
+
}
|
2466
|
+
|
2467
|
+
//
|
2468
|
+
// kv cache
|
2469
|
+
//
|
2470
|
+
|
2471
|
+
// deprecated
|
2472
|
+
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
2473
|
+
return llama_kv_self_n_tokens(ctx);
|
2474
|
+
}
|
2475
|
+
|
2476
|
+
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
2477
|
+
const auto * kv = ctx->get_kv_self();
|
2478
|
+
if (!kv) {
|
2479
|
+
return 0;
|
2480
|
+
}
|
2481
|
+
|
2482
|
+
return kv->get_n_tokens();
|
2483
|
+
}
|
2484
|
+
|
2485
|
+
// deprecated
|
2486
|
+
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
2487
|
+
return llama_kv_self_used_cells(ctx);
|
2488
|
+
}
|
2489
|
+
|
2490
|
+
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
2491
|
+
const auto * kv = ctx->get_kv_self();
|
2492
|
+
if (!kv) {
|
2493
|
+
return 0;
|
2494
|
+
}
|
2495
|
+
|
2496
|
+
return kv->get_used_cells();
|
2497
|
+
}
|
2498
|
+
|
2499
|
+
// deprecated
|
2500
|
+
void llama_kv_cache_clear(llama_context * ctx) {
|
2501
|
+
llama_kv_self_clear(ctx);
|
2502
|
+
}
|
2503
|
+
|
2504
|
+
void llama_kv_self_clear(llama_context * ctx) {
|
2505
|
+
auto * kv = ctx->get_kv_self();
|
2506
|
+
if (!kv) {
|
2507
|
+
return;
|
2508
|
+
}
|
2509
|
+
|
2510
|
+
kv->clear();
|
2511
|
+
}
|
2512
|
+
|
2513
|
+
// deprecated
|
2514
|
+
bool llama_kv_cache_seq_rm(
|
2515
|
+
llama_context * ctx,
|
2516
|
+
llama_seq_id seq_id,
|
2517
|
+
llama_pos p0,
|
2518
|
+
llama_pos p1) {
|
2519
|
+
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
2520
|
+
}
|
2521
|
+
|
2522
|
+
bool llama_kv_self_seq_rm(
|
2523
|
+
llama_context * ctx,
|
2524
|
+
llama_seq_id seq_id,
|
2525
|
+
llama_pos p0,
|
2526
|
+
llama_pos p1) {
|
2527
|
+
auto * kv = ctx->get_kv_self();
|
2528
|
+
if (!kv) {
|
2529
|
+
return true;
|
2530
|
+
}
|
2531
|
+
|
2532
|
+
return kv->seq_rm(seq_id, p0, p1);
|
2533
|
+
}
|
2534
|
+
|
2535
|
+
// deprecated
|
2536
|
+
void llama_kv_cache_seq_cp(
|
2537
|
+
llama_context * ctx,
|
2538
|
+
llama_seq_id seq_id_src,
|
2539
|
+
llama_seq_id seq_id_dst,
|
2540
|
+
llama_pos p0,
|
2541
|
+
llama_pos p1) {
|
2542
|
+
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
2543
|
+
}
|
2544
|
+
|
2545
|
+
void llama_kv_self_seq_cp(
|
2546
|
+
llama_context * ctx,
|
2547
|
+
llama_seq_id seq_id_src,
|
2548
|
+
llama_seq_id seq_id_dst,
|
2549
|
+
llama_pos p0,
|
2550
|
+
llama_pos p1) {
|
2551
|
+
auto * kv = ctx->get_kv_self();
|
2552
|
+
if (!kv) {
|
2553
|
+
return;
|
2554
|
+
}
|
2555
|
+
|
2556
|
+
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
2557
|
+
}
|
2558
|
+
|
2559
|
+
// deprecated
|
2560
|
+
void llama_kv_cache_seq_keep(
|
2561
|
+
llama_context * ctx,
|
2562
|
+
llama_seq_id seq_id) {
|
2563
|
+
return llama_kv_self_seq_keep(ctx, seq_id);
|
2564
|
+
}
|
2565
|
+
|
2566
|
+
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
2567
|
+
auto * kv = ctx->get_kv_self();
|
2568
|
+
if (!kv) {
|
2569
|
+
return;
|
2570
|
+
}
|
2571
|
+
|
2572
|
+
return kv->seq_keep(seq_id);
|
2573
|
+
}
|
2574
|
+
|
2575
|
+
// deprecated
|
2576
|
+
void llama_kv_cache_seq_add(
|
2577
|
+
llama_context * ctx,
|
2578
|
+
llama_seq_id seq_id,
|
2579
|
+
llama_pos p0,
|
2580
|
+
llama_pos p1,
|
2581
|
+
llama_pos delta) {
|
2582
|
+
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
2583
|
+
}
|
2584
|
+
|
2585
|
+
void llama_kv_self_seq_add(
|
2586
|
+
llama_context * ctx,
|
2587
|
+
llama_seq_id seq_id,
|
2588
|
+
llama_pos p0,
|
2589
|
+
llama_pos p1,
|
2590
|
+
llama_pos delta) {
|
2591
|
+
auto * kv = ctx->get_kv_self();
|
2592
|
+
if (!kv) {
|
2593
|
+
return;
|
2594
|
+
}
|
2595
|
+
|
2596
|
+
return kv->seq_add(seq_id, p0, p1, delta);
|
2597
|
+
}
|
2598
|
+
|
2599
|
+
// deprecated
|
2600
|
+
void llama_kv_cache_seq_div(
|
2601
|
+
llama_context * ctx,
|
2602
|
+
llama_seq_id seq_id,
|
2603
|
+
llama_pos p0,
|
2604
|
+
llama_pos p1,
|
2605
|
+
int d) {
|
2606
|
+
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
2607
|
+
}
|
2608
|
+
|
2609
|
+
void llama_kv_self_seq_div(
|
2610
|
+
llama_context * ctx,
|
2611
|
+
llama_seq_id seq_id,
|
2612
|
+
llama_pos p0,
|
2613
|
+
llama_pos p1,
|
2614
|
+
int d) {
|
2615
|
+
auto * kv = ctx->get_kv_self();
|
2616
|
+
if (!kv) {
|
2617
|
+
return;
|
2618
|
+
}
|
2619
|
+
|
2620
|
+
return kv->seq_div(seq_id, p0, p1, d);
|
2621
|
+
}
|
2622
|
+
|
2623
|
+
// deprecated
|
2624
|
+
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
2625
|
+
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
2626
|
+
}
|
2627
|
+
|
2628
|
+
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
2629
|
+
const auto * kv = ctx->get_kv_self();
|
2630
|
+
if (!kv) {
|
2631
|
+
return 0;
|
2632
|
+
}
|
2633
|
+
|
2634
|
+
return kv->seq_pos_max(seq_id);
|
2635
|
+
}
|
2636
|
+
|
2637
|
+
// deprecated
|
2638
|
+
void llama_kv_cache_defrag(llama_context * ctx) {
|
2639
|
+
return llama_kv_self_defrag(ctx);
|
2640
|
+
}
|
2641
|
+
|
2642
|
+
void llama_kv_self_defrag(llama_context * ctx) {
|
2643
|
+
auto * kv = ctx->get_kv_self();
|
2644
|
+
if (!kv) {
|
2645
|
+
return;
|
2646
|
+
}
|
2647
|
+
|
2648
|
+
return kv->defrag();
|
2649
|
+
}
|
2650
|
+
|
2651
|
+
// deprecated
|
2652
|
+
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
2653
|
+
return llama_kv_self_can_shift(ctx);
|
2654
|
+
}
|
2655
|
+
|
2656
|
+
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
2657
|
+
const auto * kv = ctx->get_kv_self();
|
2658
|
+
if (!kv) {
|
2659
|
+
return false;
|
2660
|
+
}
|
2661
|
+
|
2662
|
+
return kv->get_can_shift();
|
2663
|
+
}
|
2664
|
+
|
2665
|
+
// deprecated
|
2666
|
+
void llama_kv_cache_update(llama_context * ctx) {
|
2667
|
+
llama_kv_self_update(ctx);
|
2668
|
+
}
|
2669
|
+
|
2670
|
+
// llama state API
|
2671
|
+
|
2672
|
+
// deprecated
|
2673
|
+
size_t llama_get_state_size(llama_context * ctx) {
|
2674
|
+
return llama_state_get_size(ctx);
|
2675
|
+
}
|
2676
|
+
|
2677
|
+
// deprecated
|
2678
|
+
size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
|
2679
|
+
return llama_state_get_data(ctx, dst, -1);
|
2680
|
+
}
|
2681
|
+
|
2682
|
+
// deprecated
|
2683
|
+
size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
|
2684
|
+
return llama_state_set_data(ctx, src, -1);
|
2685
|
+
}
|
2686
|
+
|
2687
|
+
// deprecated
|
2688
|
+
bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
2689
|
+
return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
2690
|
+
}
|
2691
|
+
|
2692
|
+
// deprecated
|
2693
|
+
bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
2694
|
+
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
2695
|
+
}
|
2696
|
+
|
2697
|
+
// Returns the *actual* size of the state.
|
2698
|
+
// Intended to be used when saving to state to a buffer.
|
2699
|
+
size_t llama_state_get_size(llama_context * ctx) {
|
2700
|
+
return ctx->state_get_size();
|
2701
|
+
}
|
2702
|
+
|
2703
|
+
size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
|
2704
|
+
ctx->synchronize();
|
2705
|
+
|
2706
|
+
return ctx->state_get_data(dst, size);
|
2707
|
+
}
|
2708
|
+
|
2709
|
+
// Sets the state reading from the specified source address
|
2710
|
+
size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
|
2711
|
+
ctx->synchronize();
|
2712
|
+
|
2713
|
+
return ctx->state_set_data(src, size);
|
2714
|
+
}
|
2715
|
+
|
2716
|
+
bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
2717
|
+
ctx->synchronize();
|
2718
|
+
|
2719
|
+
try {
|
2720
|
+
return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
|
2721
|
+
} catch (const std::exception & err) {
|
2722
|
+
LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
|
2723
|
+
return false;
|
2724
|
+
}
|
2725
|
+
}
|
2726
|
+
|
2727
|
+
bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
2728
|
+
ctx->synchronize();
|
2729
|
+
|
2730
|
+
try {
|
2731
|
+
return ctx->state_save_file(path_session, tokens, n_token_count);
|
2732
|
+
} catch (const std::exception & err) {
|
2733
|
+
LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
|
2734
|
+
return false;
|
2735
|
+
}
|
2736
|
+
}
|
2737
|
+
|
2738
|
+
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
2739
|
+
return ctx->state_seq_get_size(seq_id);
|
2740
|
+
}
|
2741
|
+
|
2742
|
+
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
2743
|
+
ctx->synchronize();
|
2744
|
+
|
2745
|
+
return ctx->state_seq_get_data(seq_id, dst, size);
|
2746
|
+
}
|
2747
|
+
|
2748
|
+
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
2749
|
+
ctx->synchronize();
|
2750
|
+
|
2751
|
+
return ctx->state_seq_set_data(seq_id, src, size);
|
2752
|
+
}
|
2753
|
+
|
2754
|
+
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
2755
|
+
ctx->synchronize();
|
2756
|
+
|
2757
|
+
try {
|
2758
|
+
return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
|
1756
2759
|
} catch (const std::exception & err) {
|
1757
2760
|
LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
|
1758
2761
|
return 0;
|
1759
2762
|
}
|
1760
2763
|
}
|
1761
2764
|
|
1762
|
-
size_t llama_state_seq_load_file(
|
2765
|
+
size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
2766
|
+
ctx->synchronize();
|
2767
|
+
|
1763
2768
|
try {
|
1764
|
-
return
|
2769
|
+
return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
|
1765
2770
|
} catch (const std::exception & err) {
|
1766
2771
|
LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
|
1767
2772
|
return 0;
|
1768
2773
|
}
|
1769
2774
|
}
|
1770
2775
|
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
2776
|
+
///
|
2777
|
+
|
2778
|
+
int32_t llama_encode(
|
2779
|
+
llama_context * ctx,
|
2780
|
+
llama_batch batch) {
|
2781
|
+
const int ret = ctx->encode(batch);
|
2782
|
+
if (ret != 0) {
|
2783
|
+
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
2784
|
+
}
|
2785
|
+
|
2786
|
+
return ret;
|
2787
|
+
}
|
2788
|
+
|
2789
|
+
int32_t llama_decode(
|
2790
|
+
llama_context * ctx,
|
2791
|
+
llama_batch batch) {
|
2792
|
+
const int ret = ctx->decode(batch);
|
2793
|
+
if (ret != 0) {
|
2794
|
+
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
2795
|
+
}
|
2796
|
+
|
2797
|
+
return ret;
|
2798
|
+
}
|
2799
|
+
|
2800
|
+
//
|
2801
|
+
// perf
|
2802
|
+
//
|
2803
|
+
|
2804
|
+
llama_perf_context_data llama_perf_context(const llama_context * ctx) {
|
2805
|
+
llama_perf_context_data data = {};
|
2806
|
+
|
2807
|
+
if (ctx == nullptr) {
|
2808
|
+
return data;
|
2809
|
+
}
|
2810
|
+
|
2811
|
+
data = ctx->perf_get_data();
|
2812
|
+
|
2813
|
+
return data;
|
2814
|
+
}
|
2815
|
+
|
2816
|
+
void llama_perf_context_print(const llama_context * ctx) {
|
2817
|
+
const auto data = llama_perf_context(ctx);
|
2818
|
+
|
2819
|
+
const double t_end_ms = 1e-3 * lm_ggml_time_us();
|
2820
|
+
|
2821
|
+
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
2822
|
+
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
2823
|
+
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
2824
|
+
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
2825
|
+
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
2826
|
+
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
2827
|
+
}
|
2828
|
+
|
2829
|
+
void llama_perf_context_reset(llama_context * ctx) {
|
2830
|
+
ctx->perf_reset();
|
1775
2831
|
}
|