cui-llama.rn 1.6.0 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -7
- package/android/src/main/CMakeLists.txt +22 -11
- package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
- package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
- package/android/src/main/jni.cpp +173 -18
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/LICENSE +21 -0
- package/cpp/chat.cpp +129 -107
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +58 -78
- package/cpp/common.h +29 -21
- package/cpp/ggml-alloc.c +4 -1
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpp.h +1 -1
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
- package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
- package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
- package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
- package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
- package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
- package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
- package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
- package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
- package/cpp/ggml-cpu.h +5 -0
- package/cpp/ggml-impl.h +16 -9
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +810 -176
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +227 -282
- package/cpp/ggml.h +82 -101
- package/cpp/gguf.cpp +33 -33
- package/cpp/json-schema-to-grammar.cpp +3 -0
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +49 -17
- package/cpp/llama-arch.h +9 -0
- package/cpp/llama-batch.cpp +8 -2
- package/cpp/llama-batch.h +2 -1
- package/cpp/llama-chat.cpp +39 -16
- package/cpp/llama-chat.h +4 -2
- package/cpp/llama-context.cpp +440 -611
- package/cpp/llama-context.h +44 -33
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +214 -291
- package/cpp/llama-graph.h +69 -21
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +39 -5
- package/cpp/llama-kv-cache.cpp +2067 -620
- package/cpp/llama-kv-cache.h +410 -108
- package/cpp/llama-memory.h +12 -1
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +1089 -359
- package/cpp/llama-model.h +19 -3
- package/cpp/llama-sampling.cpp +20 -7
- package/cpp/llama-vocab.cpp +54 -9
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +86 -142
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +602 -190
- package/cpp/rn-llama.h +34 -8
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +20 -10
- package/ios/RNLlama.h +6 -0
- package/ios/RNLlama.mm +82 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +131 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +54 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +72 -4
- package/src/index.ts +212 -38
- package/cpp/binary-ops.h +0 -16
- package/cpp/ops.h +0 -128
- package/cpp/simd-mappings.h +0 -888
- package/cpp/unary-ops.h +0 -28
- package/cpp/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
- /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
- /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
- /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
- /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
- /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
- /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
- /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
- /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
package/cpp/llama-context.cpp
CHANGED
@@ -6,7 +6,6 @@
|
|
6
6
|
#include "llama-model.h"
|
7
7
|
#include "llama-kv-cache.h"
|
8
8
|
|
9
|
-
#include <cassert>
|
10
9
|
#include <cstring>
|
11
10
|
#include <stdexcept>
|
12
11
|
#include <cinttypes>
|
@@ -95,6 +94,8 @@ llama_context::llama_context(
|
|
95
94
|
|
96
95
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
97
96
|
|
97
|
+
cparams.op_offload = params.op_offload;
|
98
|
+
|
98
99
|
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
99
100
|
|
100
101
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
@@ -113,12 +114,10 @@ llama_context::llama_context(
|
|
113
114
|
}
|
114
115
|
|
115
116
|
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
116
|
-
LLAMA_LOG_WARN("%s:
|
117
|
+
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
117
118
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
118
119
|
}
|
119
120
|
|
120
|
-
logits_all = params.logits_all;
|
121
|
-
|
122
121
|
if (!hparams.vocab_only) {
|
123
122
|
// GPU backends
|
124
123
|
for (auto * dev : model.devices) {
|
@@ -176,44 +175,14 @@ llama_context::llama_context(
|
|
176
175
|
}
|
177
176
|
|
178
177
|
// init the memory module
|
179
|
-
// TODO: for now, always create a unified KV cache
|
180
178
|
if (!hparams.vocab_only) {
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
188
|
-
|
189
|
-
uint32_t kv_size = cparams.n_ctx;
|
190
|
-
lm_ggml_type type_k = params.type_k;
|
191
|
-
lm_ggml_type type_v = params.type_v;
|
192
|
-
|
193
|
-
if (llama_model_is_recurrent(&model)) {
|
194
|
-
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
195
|
-
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
196
|
-
// it's probably best to keep as much precision as possible for the states
|
197
|
-
type_k = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_conv for Mamba's conv_states
|
198
|
-
type_v = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_scan for Mamba's ssm_states
|
199
|
-
}
|
200
|
-
|
201
|
-
LM_GGML_ASSERT(hparams.n_embd_head_k % lm_ggml_blck_size(type_k) == 0);
|
202
|
-
LM_GGML_ASSERT(hparams.n_embd_head_v % lm_ggml_blck_size(type_v) == 0);
|
203
|
-
|
204
|
-
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
205
|
-
throw std::runtime_error("failed to initialize self-attention cache");
|
206
|
-
}
|
207
|
-
|
208
|
-
{
|
209
|
-
const size_t memory_size_k = kv_self->size_k_bytes();
|
210
|
-
const size_t memory_size_v = kv_self->size_v_bytes();
|
179
|
+
llama_memory_params params_mem = {
|
180
|
+
/*.type_k =*/ params.type_k,
|
181
|
+
/*.type_v =*/ params.type_v,
|
182
|
+
/*.swa_full =*/ params.swa_full,
|
183
|
+
};
|
211
184
|
|
212
|
-
|
213
|
-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
214
|
-
lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
215
|
-
lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
216
|
-
}
|
185
|
+
memory.reset(model.create_memory(params_mem, cparams));
|
217
186
|
}
|
218
187
|
|
219
188
|
// init backends
|
@@ -277,7 +246,7 @@ llama_context::llama_context(
|
|
277
246
|
}
|
278
247
|
}
|
279
248
|
|
280
|
-
sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
249
|
+
sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
|
281
250
|
|
282
251
|
if (pipeline_parallel) {
|
283
252
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(sched.get()));
|
@@ -285,7 +254,7 @@ llama_context::llama_context(
|
|
285
254
|
}
|
286
255
|
|
287
256
|
// reserve worst-case graph
|
288
|
-
if (!hparams.vocab_only) {
|
257
|
+
if (!hparams.vocab_only && memory) {
|
289
258
|
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
290
259
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
291
260
|
|
@@ -304,7 +273,9 @@ llama_context::llama_context(
|
|
304
273
|
int n_nodes_tg = -1;
|
305
274
|
|
306
275
|
// simulate full KV cache
|
307
|
-
kv_self
|
276
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
277
|
+
|
278
|
+
kv_self->set_full();
|
308
279
|
|
309
280
|
cross.v_embd.clear();
|
310
281
|
|
@@ -390,7 +361,9 @@ llama_context::llama_context(
|
|
390
361
|
}
|
391
362
|
}
|
392
363
|
|
393
|
-
llama_context::~llama_context()
|
364
|
+
llama_context::~llama_context() {
|
365
|
+
lm_ggml_opt_free(opt_ctx);
|
366
|
+
}
|
394
367
|
|
395
368
|
void llama_context::synchronize() {
|
396
369
|
lm_ggml_backend_sched_synchronize(sched.get());
|
@@ -426,6 +399,18 @@ const llama_model & llama_context::get_model() const {
|
|
426
399
|
return model;
|
427
400
|
}
|
428
401
|
|
402
|
+
const llama_cparams & llama_context::get_cparams() const {
|
403
|
+
return cparams;
|
404
|
+
}
|
405
|
+
|
406
|
+
lm_ggml_backend_sched_t llama_context::get_sched() const {
|
407
|
+
return sched.get();
|
408
|
+
}
|
409
|
+
|
410
|
+
lm_ggml_context * llama_context::get_ctx_compute() const {
|
411
|
+
return ctx_compute.get();
|
412
|
+
}
|
413
|
+
|
429
414
|
uint32_t llama_context::n_ctx() const {
|
430
415
|
return cparams.n_ctx;
|
431
416
|
}
|
@@ -455,345 +440,21 @@ uint32_t llama_context::n_threads_batch() const {
|
|
455
440
|
}
|
456
441
|
|
457
442
|
llama_kv_cache * llama_context::get_kv_self() {
|
458
|
-
|
443
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
444
|
+
return kv_self;
|
459
445
|
}
|
460
446
|
|
461
447
|
const llama_kv_cache * llama_context::get_kv_self() const {
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
lm_ggml_tensor * llama_context::build_rope_shift(
|
466
|
-
lm_ggml_context * ctx0,
|
467
|
-
lm_ggml_tensor * cur,
|
468
|
-
lm_ggml_tensor * shift,
|
469
|
-
lm_ggml_tensor * factors,
|
470
|
-
float freq_base,
|
471
|
-
float freq_scale,
|
472
|
-
lm_ggml_backend_buffer * bbuf) const {
|
473
|
-
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
474
|
-
|
475
|
-
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
476
|
-
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
477
|
-
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
478
|
-
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
479
|
-
|
480
|
-
const auto & hparams = model.hparams;
|
481
|
-
|
482
|
-
const auto & n_rot = hparams.n_rot;
|
483
|
-
const auto & rope_type = hparams.rope_type;
|
484
|
-
|
485
|
-
lm_ggml_tensor * tmp;
|
486
|
-
|
487
|
-
if (lm_ggml_is_quantized(cur->type)) {
|
488
|
-
// dequantize to f32 -> RoPE -> quantize back
|
489
|
-
tmp = lm_ggml_cast(ctx0, cur, LM_GGML_TYPE_F32);
|
490
|
-
|
491
|
-
if (bbuf) {
|
492
|
-
for (const auto & backend : backends) {
|
493
|
-
// Figure out which backend KV cache belongs to
|
494
|
-
if (lm_ggml_backend_supports_buft(backend.get(), lm_ggml_backend_buffer_get_type(bbuf))) {
|
495
|
-
lm_ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
|
496
|
-
break;
|
497
|
-
}
|
498
|
-
}
|
499
|
-
}
|
500
|
-
|
501
|
-
tmp = lm_ggml_rope_ext_inplace(ctx0, tmp,
|
502
|
-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
503
|
-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
504
|
-
|
505
|
-
tmp = lm_ggml_cpy(ctx0, tmp, cur);
|
506
|
-
} else {
|
507
|
-
// we rotate only the first n_rot dimensions
|
508
|
-
tmp = lm_ggml_rope_ext_inplace(ctx0, cur,
|
509
|
-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
510
|
-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
511
|
-
}
|
512
|
-
|
513
|
-
return tmp;
|
514
|
-
}
|
515
|
-
|
516
|
-
class llm_graph_input_k_shift : public llm_graph_input_i {
|
517
|
-
public:
|
518
|
-
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
519
|
-
virtual ~llm_graph_input_k_shift() = default;
|
520
|
-
|
521
|
-
void set_input(const llama_ubatch * ubatch) override;
|
522
|
-
|
523
|
-
lm_ggml_tensor * k_shift; // I32 [kv_size]
|
524
|
-
|
525
|
-
const llama_kv_cache_unified * kv_self;
|
526
|
-
};
|
527
|
-
|
528
|
-
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
529
|
-
LM_GGML_UNUSED(ubatch);
|
530
|
-
|
531
|
-
if (k_shift) {
|
532
|
-
assert(lm_ggml_backend_buffer_is_host(k_shift->buffer));
|
533
|
-
|
534
|
-
int32_t * data = (int32_t *) k_shift->data;
|
535
|
-
|
536
|
-
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
537
|
-
data[i] = kv_self->cells[i].delta;
|
538
|
-
}
|
539
|
-
}
|
540
|
-
}
|
541
|
-
|
542
|
-
llm_graph_result_ptr llama_context::build_kv_self_shift(
|
543
|
-
lm_ggml_context * ctx0,
|
544
|
-
lm_ggml_cgraph * gf) const {
|
545
|
-
auto res = std::make_unique<llm_graph_result>();
|
546
|
-
|
547
|
-
const auto & hparams = model.hparams;
|
548
|
-
|
549
|
-
const auto & n_layer = hparams.n_layer;
|
550
|
-
|
551
|
-
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
552
|
-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
553
|
-
|
554
|
-
//LM_GGML_ASSERT(kv_self->size == n_ctx);
|
555
|
-
|
556
|
-
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
|
557
|
-
|
558
|
-
inp->k_shift = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, cparams.n_ctx);
|
559
|
-
lm_ggml_set_input(inp->k_shift);
|
560
|
-
|
561
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
562
|
-
const int64_t n_head_kv = hparams.n_head_kv(il);
|
563
|
-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
564
|
-
|
565
|
-
const bool is_swa = hparams.is_swa(il);
|
566
|
-
|
567
|
-
// note: the swa rope params could become part of the cparams in the future
|
568
|
-
// if we decide to make them configurable, like the non-sliding ones
|
569
|
-
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
570
|
-
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
571
|
-
|
572
|
-
lm_ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
573
|
-
|
574
|
-
lm_ggml_tensor * k =
|
575
|
-
lm_ggml_view_3d(ctx0, kv_self->k_l[il],
|
576
|
-
n_embd_head_k, n_head_kv, kv_self->size,
|
577
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
578
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
579
|
-
0);
|
580
|
-
|
581
|
-
lm_ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
582
|
-
|
583
|
-
lm_ggml_build_forward_expand(gf, cur);
|
584
|
-
}
|
585
|
-
|
586
|
-
res->add_input(std::move(inp));
|
587
|
-
|
588
|
-
return res;
|
589
|
-
}
|
590
|
-
|
591
|
-
llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
592
|
-
lm_ggml_context * ctx0,
|
593
|
-
lm_ggml_cgraph * gf) const {
|
594
|
-
auto res = std::make_unique<llm_graph_result>();
|
595
|
-
|
596
|
-
const auto & hparams = model.hparams;
|
597
|
-
|
598
|
-
const auto & ids = kv_self->defrag_info.ids;
|
599
|
-
|
600
|
-
#if 0
|
601
|
-
// CPU defrag
|
602
|
-
//
|
603
|
-
// TODO: optimizations are possible:
|
604
|
-
// - multiple threads
|
605
|
-
// - avoid copying to the host memory when already there
|
606
|
-
//
|
607
|
-
// likely not worth the effort, as we have lm_ggml_graph based defrag
|
608
|
-
//
|
609
|
-
|
610
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
611
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
612
|
-
|
613
|
-
const uint32_t kv_size = size;
|
614
|
-
|
615
|
-
std::vector<uint8_t> buf_k;
|
616
|
-
std::vector<uint8_t> buf_v;
|
617
|
-
|
618
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
619
|
-
const size_t k_size_row = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
620
|
-
const size_t k_size = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
621
|
-
|
622
|
-
const size_t v_size_el = lm_ggml_type_size(v_l[il]->type);
|
623
|
-
const size_t v_size = lm_ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
624
|
-
|
625
|
-
buf_k.resize(k_size);
|
626
|
-
buf_v.resize(v_size);
|
627
|
-
|
628
|
-
lm_ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
629
|
-
lm_ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
630
|
-
|
631
|
-
// batch move [i, i+nm) to [id, id+nm)
|
632
|
-
// note: cells can move only to a lower index
|
633
|
-
for (uint32_t i = 0; i < n_kv; ++i) {
|
634
|
-
const uint32_t id = ids[i];
|
635
|
-
|
636
|
-
if (i == id || id == n_kv) {
|
637
|
-
continue;
|
638
|
-
}
|
639
|
-
|
640
|
-
uint32_t nm = 1;
|
641
|
-
|
642
|
-
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
643
|
-
nm++;
|
644
|
-
}
|
645
|
-
|
646
|
-
// move keys
|
647
|
-
{
|
648
|
-
const int64_t os = i*k_size_row;
|
649
|
-
const int64_t od = id*k_size_row;
|
650
|
-
|
651
|
-
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
652
|
-
}
|
653
|
-
|
654
|
-
// move values (note: they are transposed)
|
655
|
-
{
|
656
|
-
const int64_t os = i;
|
657
|
-
const int64_t od = id;
|
658
|
-
|
659
|
-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
660
|
-
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
661
|
-
}
|
662
|
-
}
|
663
|
-
|
664
|
-
i += nm - 1;
|
665
|
-
}
|
666
|
-
|
667
|
-
lm_ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
668
|
-
lm_ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
669
|
-
}
|
670
|
-
#else
|
671
|
-
for (uint32_t i = 0; i < ids.size(); ++i) {
|
672
|
-
const uint32_t id = ids[i];
|
673
|
-
|
674
|
-
if (i == id || id == ids.size()) {
|
675
|
-
continue;
|
676
|
-
}
|
677
|
-
|
678
|
-
uint32_t nm = 1;
|
679
|
-
|
680
|
-
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
681
|
-
nm++;
|
682
|
-
}
|
683
|
-
|
684
|
-
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
|
685
|
-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
686
|
-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
687
|
-
|
688
|
-
lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
|
689
|
-
n_embd_k_gqa, nm,
|
690
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
691
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
|
692
|
-
|
693
|
-
lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
|
694
|
-
n_embd_k_gqa, nm,
|
695
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
696
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
|
697
|
-
|
698
|
-
lm_ggml_tensor * view_v_src;
|
699
|
-
lm_ggml_tensor * view_v_dst;
|
700
|
-
|
701
|
-
if (cparams.flash_attn) {
|
702
|
-
// NOTE: the V cache is not transposed when using flash attention
|
703
|
-
view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
704
|
-
n_embd_v_gqa, nm,
|
705
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
706
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
|
707
|
-
|
708
|
-
view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
709
|
-
n_embd_v_gqa, nm,
|
710
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
711
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
|
712
|
-
} else {
|
713
|
-
view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
714
|
-
nm, n_embd_v_gqa,
|
715
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
716
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, i));
|
717
|
-
|
718
|
-
view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
|
719
|
-
nm, n_embd_v_gqa,
|
720
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
721
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, id));
|
722
|
-
}
|
723
|
-
|
724
|
-
lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_k_src, view_k_dst));
|
725
|
-
lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_v_src, view_v_dst));
|
726
|
-
}
|
727
|
-
|
728
|
-
i += nm - 1;
|
729
|
-
}
|
730
|
-
|
731
|
-
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
732
|
-
#endif
|
733
|
-
|
734
|
-
return res;
|
448
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
449
|
+
return kv_self;
|
735
450
|
}
|
736
451
|
|
737
452
|
void llama_context::kv_self_update() {
|
738
|
-
auto & kv = kv_self;
|
739
|
-
|
740
453
|
bool need_reserve = false;
|
741
454
|
|
742
|
-
|
743
|
-
if (!kv->get_can_shift()) {
|
744
|
-
LM_GGML_ABORT("The current context does not support K-shift");
|
745
|
-
}
|
746
|
-
|
747
|
-
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
748
|
-
|
749
|
-
// apply K-shift if needed
|
750
|
-
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
751
|
-
lm_ggml_backend_sched_reset(sched.get());
|
752
|
-
|
753
|
-
auto * gf = graph_init();
|
754
|
-
|
755
|
-
auto res = build_kv_self_shift(ctx_compute.get(), gf);
|
756
|
-
|
757
|
-
lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
|
758
|
-
|
759
|
-
res->set_inputs(nullptr);
|
760
|
-
|
761
|
-
graph_compute(gf, false);
|
762
|
-
|
763
|
-
need_reserve = true;
|
764
|
-
}
|
765
|
-
|
766
|
-
{
|
767
|
-
kv->has_shift = false;
|
768
|
-
|
769
|
-
for (uint32_t i = 0; i < kv->size; ++i) {
|
770
|
-
kv->cells[i].delta = 0;
|
771
|
-
}
|
772
|
-
}
|
773
|
-
}
|
774
|
-
|
775
|
-
// defragment the KV cache if needed
|
776
|
-
if (kv->do_defrag) {
|
777
|
-
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
778
|
-
|
779
|
-
if (kv->defrag_prepare(graph_max_nodes())) {
|
780
|
-
lm_ggml_backend_sched_reset(sched.get());
|
781
|
-
|
782
|
-
auto * gf = graph_init();
|
783
|
-
|
784
|
-
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
|
785
|
-
|
786
|
-
lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
|
455
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
787
456
|
|
788
|
-
|
789
|
-
|
790
|
-
graph_compute(gf, false);
|
791
|
-
|
792
|
-
need_reserve = true;
|
793
|
-
}
|
794
|
-
|
795
|
-
kv->do_defrag = false;
|
796
|
-
}
|
457
|
+
need_reserve = kv_self->update(*this);
|
797
458
|
|
798
459
|
// reserve a worst case graph if needed
|
799
460
|
if (need_reserve) {
|
@@ -804,7 +465,7 @@ void llama_context::kv_self_update() {
|
|
804
465
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
805
466
|
|
806
467
|
// simulate full KV cache
|
807
|
-
kv_self->
|
468
|
+
kv_self->set_full();
|
808
469
|
|
809
470
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
810
471
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
@@ -825,9 +486,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
|
|
825
486
|
}
|
826
487
|
|
827
488
|
float * llama_context::get_logits() {
|
828
|
-
// reorder logits for backward compatibility
|
829
|
-
output_reorder();
|
830
|
-
|
831
489
|
return logits;
|
832
490
|
}
|
833
491
|
|
@@ -870,9 +528,6 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
870
528
|
}
|
871
529
|
|
872
530
|
float * llama_context::get_embeddings() {
|
873
|
-
// reorder embeddings for backward compatibility
|
874
|
-
output_reorder();
|
875
|
-
|
876
531
|
return embd;
|
877
532
|
}
|
878
533
|
|
@@ -1024,8 +679,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
1024
679
|
}
|
1025
680
|
|
1026
681
|
// temporary allocate memory for the input batch if needed
|
1027
|
-
//
|
1028
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 :
|
682
|
+
// note: during encode, we always pass the full sequence starting from pos = 0
|
683
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
|
1029
684
|
|
1030
685
|
const llama_batch & batch = batch_allocr.batch;
|
1031
686
|
const int32_t n_tokens = batch.n_tokens;
|
@@ -1050,11 +705,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
1050
705
|
t_compute_start_us = lm_ggml_time_us();
|
1051
706
|
}
|
1052
707
|
|
708
|
+
embd_seq.clear();
|
709
|
+
|
1053
710
|
n_queued_tokens += n_tokens;
|
1054
711
|
|
1055
712
|
const int64_t n_embd = hparams.n_embd;
|
1056
713
|
|
1057
|
-
sbatch
|
714
|
+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
1058
715
|
|
1059
716
|
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
1060
717
|
|
@@ -1111,12 +768,12 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
1111
768
|
lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
1112
769
|
LM_GGML_ASSERT(backend_embd != nullptr);
|
1113
770
|
|
1114
|
-
LM_GGML_ASSERT(embd != nullptr);
|
1115
|
-
|
1116
771
|
switch (cparams.pooling_type) {
|
1117
772
|
case LLAMA_POOLING_TYPE_NONE:
|
1118
773
|
{
|
1119
774
|
// extract token embeddings
|
775
|
+
LM_GGML_ASSERT(embd != nullptr);
|
776
|
+
|
1120
777
|
LM_GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
1121
778
|
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
|
1122
779
|
} break;
|
@@ -1141,11 +798,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
1141
798
|
} break;
|
1142
799
|
case LLAMA_POOLING_TYPE_RANK:
|
1143
800
|
{
|
1144
|
-
//
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
801
|
+
// extract the rerank score - a single float per sequence
|
802
|
+
auto & embd_seq_out = embd_seq;
|
803
|
+
|
804
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
805
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
806
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
807
|
+
continue;
|
808
|
+
}
|
809
|
+
embd_seq_out[seq_id].resize(1);
|
810
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
811
|
+
}
|
812
|
+
} break;
|
1149
813
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
1150
814
|
{
|
1151
815
|
LM_GGML_ABORT("unknown pooling type");
|
@@ -1183,14 +847,27 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
1183
847
|
}
|
1184
848
|
|
1185
849
|
int llama_context::decode(llama_batch & inp_batch) {
|
850
|
+
if (!memory) {
|
851
|
+
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
|
852
|
+
return encode(inp_batch);
|
853
|
+
}
|
854
|
+
|
1186
855
|
if (inp_batch.n_tokens == 0) {
|
1187
856
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
1188
857
|
return -1;
|
1189
858
|
}
|
1190
859
|
|
860
|
+
if (!inp_batch.pos) {
|
861
|
+
if (inp_batch.seq_id) {
|
862
|
+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
863
|
+
return -1;
|
864
|
+
}
|
865
|
+
}
|
866
|
+
|
867
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
868
|
+
|
1191
869
|
// temporary allocate memory for the input batch if needed
|
1192
|
-
|
1193
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
870
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
1194
871
|
|
1195
872
|
const llama_batch & batch = batch_allocr.batch;
|
1196
873
|
|
@@ -1202,7 +879,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1202
879
|
const int64_t n_tokens_all = batch.n_tokens;
|
1203
880
|
const int64_t n_embd = hparams.n_embd;
|
1204
881
|
|
1205
|
-
llama_kv_cache_guard kv_guard(kv_self
|
882
|
+
llama_kv_cache_guard kv_guard(kv_self);
|
1206
883
|
|
1207
884
|
LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
1208
885
|
|
@@ -1236,18 +913,14 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1236
913
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
1237
914
|
n_outputs_all += batch.logits[i] != 0;
|
1238
915
|
}
|
1239
|
-
} else if (
|
916
|
+
} else if (embd_pooled) {
|
1240
917
|
n_outputs_all = n_tokens_all;
|
1241
918
|
} else {
|
1242
919
|
// keep last output only
|
1243
920
|
n_outputs_all = 1;
|
1244
921
|
}
|
1245
922
|
|
1246
|
-
|
1247
|
-
|
1248
|
-
sbatch.from_batch(batch, n_embd,
|
1249
|
-
/* simple_split */ !kv_self->recurrent,
|
1250
|
-
/* logits_all */ logits_all);
|
923
|
+
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
|
1251
924
|
|
1252
925
|
// reserve output buffer
|
1253
926
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
@@ -1261,22 +934,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1261
934
|
int64_t n_outputs_prev = 0;
|
1262
935
|
|
1263
936
|
while (sbatch.n_tokens > 0) {
|
1264
|
-
llama_ubatch ubatch =
|
1265
|
-
|
1266
|
-
const auto & n_ubatch = cparams.n_ubatch;
|
1267
|
-
|
1268
|
-
if (kv_self->recurrent) {
|
1269
|
-
if (embd_pooled) {
|
1270
|
-
// Pooled embeddings cannot be split across ubatches (yet)
|
1271
|
-
ubatch = sbatch.split_seq(cparams.n_ubatch);
|
1272
|
-
} else {
|
1273
|
-
// recurrent model architectures are easier to implement
|
1274
|
-
// with equal-length sequences
|
1275
|
-
ubatch = sbatch.split_equal(cparams.n_ubatch);
|
1276
|
-
}
|
1277
|
-
} else {
|
1278
|
-
ubatch = sbatch.split_simple(n_ubatch);
|
1279
|
-
}
|
937
|
+
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
1280
938
|
|
1281
939
|
// count the outputs in this u_batch
|
1282
940
|
{
|
@@ -1296,24 +954,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1296
954
|
}
|
1297
955
|
|
1298
956
|
// find KV slot
|
1299
|
-
{
|
1300
|
-
|
1301
|
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
1302
|
-
|
1303
|
-
return 1;
|
1304
|
-
}
|
1305
|
-
|
1306
|
-
if (!kv_self->recurrent) {
|
1307
|
-
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
1308
|
-
// after enough generations, the benefit from this heuristic disappears
|
1309
|
-
// if we start defragmenting the cache, the benefit from this will be more important
|
1310
|
-
const uint32_t pad = kv_self->get_padding(cparams);
|
1311
|
-
kv_self->n = std::min(kv_self->size, std::max(pad, LM_GGML_PAD(kv_self->cell_max(), pad)));
|
1312
|
-
}
|
957
|
+
if (!kv_self->find_slot(ubatch)) {
|
958
|
+
return 1;
|
1313
959
|
}
|
1314
960
|
|
1315
|
-
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
1316
|
-
|
1317
961
|
lm_ggml_backend_sched_reset(sched.get());
|
1318
962
|
lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
1319
963
|
|
@@ -1427,43 +1071,68 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1427
1071
|
// finalize the batch processing
|
1428
1072
|
kv_guard.commit();
|
1429
1073
|
|
1074
|
+
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
1075
|
+
n_outputs = n_outputs_all;
|
1076
|
+
|
1430
1077
|
// set output mappings
|
1431
1078
|
{
|
1432
1079
|
bool sorted_output = true;
|
1433
1080
|
|
1434
|
-
|
1081
|
+
auto & out_ids = sbatch.out_ids;
|
1082
|
+
|
1083
|
+
LM_GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
1435
1084
|
|
1436
1085
|
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
1437
|
-
int64_t out_id =
|
1086
|
+
int64_t out_id = out_ids[i];
|
1438
1087
|
output_ids[out_id] = i;
|
1439
1088
|
if (out_id != i) {
|
1440
1089
|
sorted_output = false;
|
1441
1090
|
}
|
1442
1091
|
}
|
1443
1092
|
|
1444
|
-
|
1445
|
-
|
1093
|
+
// make the outputs have the same order they had in the user-provided batch
|
1094
|
+
// note: this is mostly relevant for recurrent models atm
|
1095
|
+
if (!sorted_output) {
|
1096
|
+
const uint32_t n_vocab = model.vocab.n_tokens();
|
1097
|
+
const uint32_t n_embd = model.hparams.n_embd;
|
1098
|
+
|
1099
|
+
LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
1100
|
+
|
1101
|
+
// TODO: is there something more efficient which also minimizes swaps?
|
1102
|
+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
1103
|
+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
1104
|
+
int32_t j_min = i;
|
1105
|
+
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
1106
|
+
if (out_ids[j] < out_ids[j_min]) {
|
1107
|
+
j_min = j;
|
1108
|
+
}
|
1109
|
+
}
|
1110
|
+
if (j_min == i) { continue; }
|
1111
|
+
std::swap(out_ids[i], out_ids[j_min]);
|
1112
|
+
if (logits_size > 0) {
|
1113
|
+
for (uint32_t k = 0; k < n_vocab; k++) {
|
1114
|
+
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
1115
|
+
}
|
1116
|
+
}
|
1117
|
+
if (embd_size > 0) {
|
1118
|
+
for (uint32_t k = 0; k < n_embd; k++) {
|
1119
|
+
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
1120
|
+
}
|
1121
|
+
}
|
1122
|
+
}
|
1123
|
+
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1124
|
+
for (int32_t i = 0; i < n_outputs; ++i) {
|
1125
|
+
output_ids[out_ids[i]] = i;
|
1126
|
+
}
|
1446
1127
|
}
|
1447
1128
|
}
|
1448
1129
|
|
1449
|
-
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
1450
|
-
n_outputs = n_outputs_all;
|
1451
|
-
|
1452
1130
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
1453
1131
|
//synchronize();
|
1454
1132
|
|
1455
1133
|
// decide if we need to defrag the kv cache
|
1456
|
-
if (cparams.
|
1457
|
-
|
1458
|
-
// - count the padding towards the number of used tokens
|
1459
|
-
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
|
1460
|
-
|
1461
|
-
// queue defragmentation for next llama_kv_cache_update
|
1462
|
-
if (fragmentation > cparams.defrag_thold) {
|
1463
|
-
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
1464
|
-
|
1465
|
-
kv_self->defrag();
|
1466
|
-
}
|
1134
|
+
if (cparams.defrag_thold > 0.0f) {
|
1135
|
+
kv_self->defrag_sched(cparams.defrag_thold);
|
1467
1136
|
}
|
1468
1137
|
|
1469
1138
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
@@ -1543,52 +1212,12 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1543
1212
|
// set all ids as invalid (negative)
|
1544
1213
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1545
1214
|
|
1546
|
-
lm_ggml_backend_buffer_clear(buf_output.get(), 0);
|
1547
|
-
|
1548
1215
|
this->n_outputs = 0;
|
1549
1216
|
this->n_outputs_max = n_outputs_max;
|
1550
1217
|
|
1551
1218
|
return n_outputs_max;
|
1552
1219
|
}
|
1553
1220
|
|
1554
|
-
void llama_context::output_reorder() {
|
1555
|
-
auto & out_ids = sbatch.out_ids;
|
1556
|
-
if (!out_ids.empty()) {
|
1557
|
-
const uint32_t n_vocab = model.vocab.n_tokens();
|
1558
|
-
const uint32_t n_embd = model.hparams.n_embd;
|
1559
|
-
|
1560
|
-
LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
1561
|
-
|
1562
|
-
// TODO: is there something more efficient which also minimizes swaps?
|
1563
|
-
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
1564
|
-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
1565
|
-
int32_t j_min = i;
|
1566
|
-
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
1567
|
-
if (out_ids[j] < out_ids[j_min]) {
|
1568
|
-
j_min = j;
|
1569
|
-
}
|
1570
|
-
}
|
1571
|
-
if (j_min == i) { continue; }
|
1572
|
-
std::swap(out_ids[i], out_ids[j_min]);
|
1573
|
-
if (logits_size > 0) {
|
1574
|
-
for (uint32_t k = 0; k < n_vocab; k++) {
|
1575
|
-
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
1576
|
-
}
|
1577
|
-
}
|
1578
|
-
if (embd_size > 0) {
|
1579
|
-
for (uint32_t k = 0; k < n_embd; k++) {
|
1580
|
-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
1581
|
-
}
|
1582
|
-
}
|
1583
|
-
}
|
1584
|
-
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1585
|
-
for (int32_t i = 0; i < n_outputs; ++i) {
|
1586
|
-
output_ids[out_ids[i]] = i;
|
1587
|
-
}
|
1588
|
-
out_ids.clear();
|
1589
|
-
}
|
1590
|
-
}
|
1591
|
-
|
1592
1221
|
//
|
1593
1222
|
// graph
|
1594
1223
|
//
|
@@ -1625,7 +1254,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
1625
1254
|
/*.backend_cpu =*/ backend_cpu,
|
1626
1255
|
/*.cvec =*/ &cvec,
|
1627
1256
|
/*.loras =*/ &loras,
|
1628
|
-
/*.memory =*/
|
1257
|
+
/*.memory =*/ memory.get(),
|
1629
1258
|
/*.cross =*/ &cross,
|
1630
1259
|
/*.n_outputs =*/ n_outputs,
|
1631
1260
|
/*.cb =*/ graph_get_cb(),
|
@@ -2029,8 +1658,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
2029
1658
|
{
|
2030
1659
|
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
2031
1660
|
|
2032
|
-
output_reorder();
|
2033
|
-
|
2034
1661
|
const auto n_outputs = this->n_outputs;
|
2035
1662
|
const auto & output_ids = this->output_ids;
|
2036
1663
|
|
@@ -2083,8 +1710,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
2083
1710
|
}
|
2084
1711
|
}
|
2085
1712
|
|
2086
|
-
|
2087
|
-
|
1713
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1714
|
+
|
1715
|
+
if (kv_self != nullptr) {
|
1716
|
+
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
1717
|
+
kv_self->state_write(io);
|
1718
|
+
}
|
2088
1719
|
|
2089
1720
|
return io.n_bytes();
|
2090
1721
|
}
|
@@ -2167,8 +1798,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
2167
1798
|
}
|
2168
1799
|
}
|
2169
1800
|
|
2170
|
-
|
2171
|
-
|
1801
|
+
if (memory) {
|
1802
|
+
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
1803
|
+
|
1804
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1805
|
+
|
1806
|
+
kv_self->state_read(io);
|
1807
|
+
}
|
2172
1808
|
|
2173
1809
|
return io.n_bytes();
|
2174
1810
|
}
|
@@ -2176,7 +1812,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
2176
1812
|
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
2177
1813
|
LM_GGML_UNUSED(seq_id);
|
2178
1814
|
|
2179
|
-
|
1815
|
+
if (memory) {
|
1816
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1817
|
+
|
1818
|
+
kv_self->state_write(io, seq_id);
|
1819
|
+
}
|
2180
1820
|
|
2181
1821
|
return io.n_bytes();
|
2182
1822
|
}
|
@@ -2184,7 +1824,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
2184
1824
|
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
2185
1825
|
LM_GGML_UNUSED(seq_id);
|
2186
1826
|
|
2187
|
-
|
1827
|
+
if (memory) {
|
1828
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1829
|
+
|
1830
|
+
kv_self->state_read(io, seq_id);
|
1831
|
+
}
|
2188
1832
|
|
2189
1833
|
return io.n_bytes();
|
2190
1834
|
}
|
@@ -2212,6 +1856,215 @@ void llama_context::perf_reset() {
|
|
2212
1856
|
t_p_eval_us = n_p_eval = 0;
|
2213
1857
|
}
|
2214
1858
|
|
1859
|
+
//
|
1860
|
+
// training
|
1861
|
+
//
|
1862
|
+
|
1863
|
+
static void llama_set_param(struct lm_ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
|
1864
|
+
if (!tensor || tensor->type != LM_GGML_TYPE_F32) {
|
1865
|
+
return;
|
1866
|
+
}
|
1867
|
+
if (!param_filter(tensor, userdata)) {
|
1868
|
+
return;
|
1869
|
+
}
|
1870
|
+
if (strcmp(tensor->name, "token_embd.weight") == 0) {
|
1871
|
+
return; // FIXME
|
1872
|
+
}
|
1873
|
+
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
|
1874
|
+
return; // FIXME
|
1875
|
+
}
|
1876
|
+
lm_ggml_set_param(tensor);
|
1877
|
+
}
|
1878
|
+
|
1879
|
+
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
|
1880
|
+
LM_GGML_ASSERT(!opt_ctx);
|
1881
|
+
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
|
1882
|
+
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
|
1883
|
+
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
1884
|
+
LM_GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
|
1885
|
+
LM_GGML_ASSERT(n_batch % n_ubatch == 0);
|
1886
|
+
|
1887
|
+
lm_ggml_opt_params opt_params = lm_ggml_opt_default_params(sched.get(), LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
|
1888
|
+
opt_params.opt_period = n_batch / n_ubatch;
|
1889
|
+
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
1890
|
+
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
1891
|
+
|
1892
|
+
opt_ctx = lm_ggml_opt_init(opt_params);
|
1893
|
+
|
1894
|
+
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
1895
|
+
void * param_filter_ud = lopt_params.param_filter_ud;
|
1896
|
+
|
1897
|
+
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
|
1898
|
+
llama_set_param(model->type_embd, param_filter, param_filter_ud);
|
1899
|
+
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
|
1900
|
+
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
|
1901
|
+
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
|
1902
|
+
llama_set_param(model->output_norm, param_filter, param_filter_ud);
|
1903
|
+
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
|
1904
|
+
llama_set_param(model->output, param_filter, param_filter_ud);
|
1905
|
+
llama_set_param(model->output_b, param_filter, param_filter_ud);
|
1906
|
+
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
|
1907
|
+
llama_set_param(model->cls, param_filter, param_filter_ud);
|
1908
|
+
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
1909
|
+
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
1910
|
+
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
1911
|
+
|
1912
|
+
for (struct llama_layer & layer : model->layers) {
|
1913
|
+
for (size_t i = 0; i < sizeof(layer)/sizeof(struct lm_ggml_tensor *); ++i) {
|
1914
|
+
llama_set_param(reinterpret_cast<struct lm_ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
|
1915
|
+
}
|
1916
|
+
}
|
1917
|
+
}
|
1918
|
+
|
1919
|
+
void llama_context::opt_epoch_iter(
|
1920
|
+
lm_ggml_opt_dataset_t dataset,
|
1921
|
+
lm_ggml_opt_result_t result,
|
1922
|
+
const std::vector<llama_token> & tokens,
|
1923
|
+
const std::vector<llama_token> & labels_sparse,
|
1924
|
+
llama_batch & batch,
|
1925
|
+
lm_ggml_opt_epoch_callback callback,
|
1926
|
+
bool train,
|
1927
|
+
int64_t idata_in_loop,
|
1928
|
+
int64_t ndata_in_loop,
|
1929
|
+
int64_t t_loop_start) {
|
1930
|
+
LM_GGML_ASSERT(opt_ctx);
|
1931
|
+
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
|
1932
|
+
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
1933
|
+
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
1934
|
+
|
1935
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1936
|
+
|
1937
|
+
kv_self->clear();
|
1938
|
+
llama_kv_cache_guard kv_guard(kv_self);
|
1939
|
+
|
1940
|
+
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
1941
|
+
batch.n_tokens = n_batch;
|
1942
|
+
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
|
1943
|
+
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
|
1944
|
+
batch.pos [pos_batch] = pos_ctx + pos_batch;
|
1945
|
+
batch.n_seq_id[pos_batch] = 1;
|
1946
|
+
batch.seq_id [pos_batch][0] = 0;
|
1947
|
+
batch.logits [pos_batch] = true;
|
1948
|
+
}
|
1949
|
+
|
1950
|
+
const auto n_tokens_all = batch.n_tokens;
|
1951
|
+
|
1952
|
+
n_queued_tokens += n_tokens_all;
|
1953
|
+
|
1954
|
+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
1955
|
+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
1956
|
+
|
1957
|
+
embd_seq.clear();
|
1958
|
+
|
1959
|
+
int64_t n_outputs_all = n_tokens_all;
|
1960
|
+
|
1961
|
+
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
1962
|
+
|
1963
|
+
// reserve output buffer
|
1964
|
+
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
1965
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
1966
|
+
LM_GGML_ABORT("TODO: handle this error");
|
1967
|
+
};
|
1968
|
+
|
1969
|
+
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
1970
|
+
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
1971
|
+
|
1972
|
+
n_outputs = ubatch.n_tokens;
|
1973
|
+
|
1974
|
+
// TODO: not sure if this is needed
|
1975
|
+
if (!kv_self->find_slot(ubatch)) {
|
1976
|
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
1977
|
+
|
1978
|
+
LM_GGML_ABORT("TODO: handle this error");
|
1979
|
+
}
|
1980
|
+
|
1981
|
+
auto * gf = graph_init();
|
1982
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
1983
|
+
|
1984
|
+
struct lm_ggml_context * ctx_compute_opt;
|
1985
|
+
{
|
1986
|
+
const size_t size_gf = lm_ggml_graph_size(gf);
|
1987
|
+
const size_t size_meta = 4*size_gf*lm_ggml_tensor_overhead() + 2*lm_ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
1988
|
+
struct lm_ggml_init_params params = {
|
1989
|
+
/*.mem_size =*/ size_meta,
|
1990
|
+
/*.mem_buffer =*/ nullptr,
|
1991
|
+
/*.no_alloc =*/ true,
|
1992
|
+
};
|
1993
|
+
ctx_compute_opt = lm_ggml_init(params);
|
1994
|
+
}
|
1995
|
+
lm_ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
1996
|
+
lm_ggml_opt_alloc(opt_ctx, train);
|
1997
|
+
res->set_inputs(&ubatch);
|
1998
|
+
{
|
1999
|
+
struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
|
2000
|
+
LM_GGML_ASSERT(labels->ne[1] == n_ubatch);
|
2001
|
+
lm_ggml_set_zero(labels);
|
2002
|
+
const float onef = 1.0f;
|
2003
|
+
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
|
2004
|
+
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
|
2005
|
+
LM_GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
|
2006
|
+
lm_ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
|
2007
|
+
}
|
2008
|
+
}
|
2009
|
+
lm_ggml_opt_eval(opt_ctx, result);
|
2010
|
+
if (callback) {
|
2011
|
+
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
2012
|
+
}
|
2013
|
+
lm_ggml_free(ctx_compute_opt);
|
2014
|
+
}
|
2015
|
+
}
|
2016
|
+
|
2017
|
+
kv_guard.commit();
|
2018
|
+
}
|
2019
|
+
|
2020
|
+
void llama_context::opt_epoch(
|
2021
|
+
lm_ggml_opt_dataset_t dataset,
|
2022
|
+
lm_ggml_opt_result_t result_train,
|
2023
|
+
lm_ggml_opt_result_t result_eval,
|
2024
|
+
int64_t idata_split,
|
2025
|
+
lm_ggml_opt_epoch_callback callback_train,
|
2026
|
+
lm_ggml_opt_epoch_callback callback_eval) {
|
2027
|
+
const uint32_t n_ctx = this->n_ctx();
|
2028
|
+
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
|
2029
|
+
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
|
2030
|
+
const int64_t ndata = lm_ggml_opt_dataset_ndata(dataset);
|
2031
|
+
|
2032
|
+
LM_GGML_ASSERT(idata_split >= 0);
|
2033
|
+
LM_GGML_ASSERT(idata_split <= ndata);
|
2034
|
+
|
2035
|
+
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
|
2036
|
+
|
2037
|
+
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
2038
|
+
std::vector<llama_token> tokens(n_ctx);
|
2039
|
+
std::vector<llama_token> labels_sparse(n_ctx);
|
2040
|
+
|
2041
|
+
int64_t idata = 0;
|
2042
|
+
|
2043
|
+
int64_t t_loop_start = lm_ggml_time_us();
|
2044
|
+
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
|
2045
|
+
for (; idata < idata_split; ++idata) {
|
2046
|
+
constexpr bool train = true;
|
2047
|
+
const int64_t idata_in_loop = idata*ubatch_per_ctx;
|
2048
|
+
|
2049
|
+
lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
2050
|
+
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
|
2051
|
+
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
2052
|
+
}
|
2053
|
+
|
2054
|
+
t_loop_start = lm_ggml_time_us();
|
2055
|
+
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
|
2056
|
+
for (; idata < ndata; ++idata) {
|
2057
|
+
constexpr bool train = false;
|
2058
|
+
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
|
2059
|
+
|
2060
|
+
lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
2061
|
+
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
|
2062
|
+
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
2063
|
+
}
|
2064
|
+
|
2065
|
+
llama_batch_free(batch);
|
2066
|
+
}
|
2067
|
+
|
2215
2068
|
//
|
2216
2069
|
// interface implementation
|
2217
2070
|
//
|
@@ -2239,13 +2092,14 @@ llama_context_params llama_context_default_params() {
|
|
2239
2092
|
/*.cb_eval_user_data =*/ nullptr,
|
2240
2093
|
/*.type_k =*/ LM_GGML_TYPE_F16,
|
2241
2094
|
/*.type_v =*/ LM_GGML_TYPE_F16,
|
2242
|
-
/*.
|
2095
|
+
/*.abort_callback =*/ nullptr,
|
2096
|
+
/*.abort_callback_data =*/ nullptr,
|
2243
2097
|
/*.embeddings =*/ false,
|
2244
2098
|
/*.offload_kqv =*/ true,
|
2245
2099
|
/*.flash_attn =*/ false,
|
2246
2100
|
/*.no_perf =*/ true,
|
2247
|
-
/*.
|
2248
|
-
/*.
|
2101
|
+
/*.op_offload =*/ true,
|
2102
|
+
/*.swa_full =*/ true,
|
2249
2103
|
};
|
2250
2104
|
|
2251
2105
|
return result;
|
@@ -2440,65 +2294,51 @@ int32_t llama_apply_adapter_cvec(
|
|
2440
2294
|
return res ? 0 : -1;
|
2441
2295
|
}
|
2442
2296
|
|
2443
|
-
//
|
2444
|
-
// kv cache view
|
2445
|
-
//
|
2446
|
-
|
2447
|
-
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
2448
|
-
const auto * kv = ctx->get_kv_self();
|
2449
|
-
if (kv == nullptr) {
|
2450
|
-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
2451
|
-
return {};
|
2452
|
-
}
|
2453
|
-
|
2454
|
-
return llama_kv_cache_view_init(*kv, n_seq_max);
|
2455
|
-
}
|
2456
|
-
|
2457
|
-
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
2458
|
-
const auto * kv = ctx->get_kv_self();
|
2459
|
-
if (kv == nullptr) {
|
2460
|
-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
2461
|
-
return;
|
2462
|
-
}
|
2463
|
-
|
2464
|
-
llama_kv_cache_view_update(view, kv);
|
2465
|
-
}
|
2466
|
-
|
2467
2297
|
//
|
2468
2298
|
// kv cache
|
2469
2299
|
//
|
2470
2300
|
|
2471
2301
|
// deprecated
|
2472
|
-
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
2473
|
-
return llama_kv_self_n_tokens(ctx);
|
2474
|
-
}
|
2475
|
-
|
2476
2302
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
2477
2303
|
const auto * kv = ctx->get_kv_self();
|
2478
2304
|
if (!kv) {
|
2479
2305
|
return 0;
|
2480
2306
|
}
|
2481
2307
|
|
2482
|
-
|
2483
|
-
}
|
2308
|
+
int32_t res = 0;
|
2484
2309
|
|
2485
|
-
|
2486
|
-
|
2487
|
-
|
2310
|
+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
2311
|
+
const llama_pos p0 = kv->seq_pos_min(s);
|
2312
|
+
const llama_pos p1 = kv->seq_pos_max(s);
|
2313
|
+
|
2314
|
+
if (p0 >= 0) {
|
2315
|
+
res += (p1 - p0) + 1;
|
2316
|
+
}
|
2317
|
+
}
|
2318
|
+
|
2319
|
+
return res;
|
2488
2320
|
}
|
2489
2321
|
|
2322
|
+
// deprecated
|
2323
|
+
// note: this is the same as above - will be removed anyway, so it's ok
|
2490
2324
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
2491
2325
|
const auto * kv = ctx->get_kv_self();
|
2492
2326
|
if (!kv) {
|
2493
2327
|
return 0;
|
2494
2328
|
}
|
2495
2329
|
|
2496
|
-
|
2497
|
-
}
|
2330
|
+
int32_t res = 0;
|
2498
2331
|
|
2499
|
-
|
2500
|
-
|
2501
|
-
|
2332
|
+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
2333
|
+
const llama_pos p0 = kv->seq_pos_min(s);
|
2334
|
+
const llama_pos p1 = kv->seq_pos_max(s);
|
2335
|
+
|
2336
|
+
if (p0 >= 0) {
|
2337
|
+
res += (p1 - p0) + 1;
|
2338
|
+
}
|
2339
|
+
}
|
2340
|
+
|
2341
|
+
return res;
|
2502
2342
|
}
|
2503
2343
|
|
2504
2344
|
void llama_kv_self_clear(llama_context * ctx) {
|
@@ -2510,15 +2350,6 @@ void llama_kv_self_clear(llama_context * ctx) {
|
|
2510
2350
|
kv->clear();
|
2511
2351
|
}
|
2512
2352
|
|
2513
|
-
// deprecated
|
2514
|
-
bool llama_kv_cache_seq_rm(
|
2515
|
-
llama_context * ctx,
|
2516
|
-
llama_seq_id seq_id,
|
2517
|
-
llama_pos p0,
|
2518
|
-
llama_pos p1) {
|
2519
|
-
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
2520
|
-
}
|
2521
|
-
|
2522
2353
|
bool llama_kv_self_seq_rm(
|
2523
2354
|
llama_context * ctx,
|
2524
2355
|
llama_seq_id seq_id,
|
@@ -2532,16 +2363,6 @@ bool llama_kv_self_seq_rm(
|
|
2532
2363
|
return kv->seq_rm(seq_id, p0, p1);
|
2533
2364
|
}
|
2534
2365
|
|
2535
|
-
// deprecated
|
2536
|
-
void llama_kv_cache_seq_cp(
|
2537
|
-
llama_context * ctx,
|
2538
|
-
llama_seq_id seq_id_src,
|
2539
|
-
llama_seq_id seq_id_dst,
|
2540
|
-
llama_pos p0,
|
2541
|
-
llama_pos p1) {
|
2542
|
-
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
2543
|
-
}
|
2544
|
-
|
2545
2366
|
void llama_kv_self_seq_cp(
|
2546
2367
|
llama_context * ctx,
|
2547
2368
|
llama_seq_id seq_id_src,
|
@@ -2553,14 +2374,7 @@ void llama_kv_self_seq_cp(
|
|
2553
2374
|
return;
|
2554
2375
|
}
|
2555
2376
|
|
2556
|
-
|
2557
|
-
}
|
2558
|
-
|
2559
|
-
// deprecated
|
2560
|
-
void llama_kv_cache_seq_keep(
|
2561
|
-
llama_context * ctx,
|
2562
|
-
llama_seq_id seq_id) {
|
2563
|
-
return llama_kv_self_seq_keep(ctx, seq_id);
|
2377
|
+
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
2564
2378
|
}
|
2565
2379
|
|
2566
2380
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
@@ -2569,17 +2383,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
2569
2383
|
return;
|
2570
2384
|
}
|
2571
2385
|
|
2572
|
-
|
2573
|
-
}
|
2574
|
-
|
2575
|
-
// deprecated
|
2576
|
-
void llama_kv_cache_seq_add(
|
2577
|
-
llama_context * ctx,
|
2578
|
-
llama_seq_id seq_id,
|
2579
|
-
llama_pos p0,
|
2580
|
-
llama_pos p1,
|
2581
|
-
llama_pos delta) {
|
2582
|
-
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
2386
|
+
kv->seq_keep(seq_id);
|
2583
2387
|
}
|
2584
2388
|
|
2585
2389
|
void llama_kv_self_seq_add(
|
@@ -2593,17 +2397,7 @@ void llama_kv_self_seq_add(
|
|
2593
2397
|
return;
|
2594
2398
|
}
|
2595
2399
|
|
2596
|
-
|
2597
|
-
}
|
2598
|
-
|
2599
|
-
// deprecated
|
2600
|
-
void llama_kv_cache_seq_div(
|
2601
|
-
llama_context * ctx,
|
2602
|
-
llama_seq_id seq_id,
|
2603
|
-
llama_pos p0,
|
2604
|
-
llama_pos p1,
|
2605
|
-
int d) {
|
2606
|
-
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
2400
|
+
kv->seq_add(seq_id, p0, p1, delta);
|
2607
2401
|
}
|
2608
2402
|
|
2609
2403
|
void llama_kv_self_seq_div(
|
@@ -2617,40 +2411,35 @@ void llama_kv_self_seq_div(
|
|
2617
2411
|
return;
|
2618
2412
|
}
|
2619
2413
|
|
2620
|
-
|
2414
|
+
kv->seq_div(seq_id, p0, p1, d);
|
2621
2415
|
}
|
2622
2416
|
|
2623
|
-
|
2624
|
-
|
2625
|
-
|
2417
|
+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
2418
|
+
const auto * kv = ctx->get_kv_self();
|
2419
|
+
if (!kv) {
|
2420
|
+
return -1;
|
2421
|
+
}
|
2422
|
+
|
2423
|
+
return kv->seq_pos_min(seq_id);
|
2626
2424
|
}
|
2627
2425
|
|
2628
2426
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
2629
2427
|
const auto * kv = ctx->get_kv_self();
|
2630
2428
|
if (!kv) {
|
2631
|
-
return
|
2429
|
+
return -1;
|
2632
2430
|
}
|
2633
2431
|
|
2634
2432
|
return kv->seq_pos_max(seq_id);
|
2635
2433
|
}
|
2636
2434
|
|
2637
|
-
// deprecated
|
2638
|
-
void llama_kv_cache_defrag(llama_context * ctx) {
|
2639
|
-
return llama_kv_self_defrag(ctx);
|
2640
|
-
}
|
2641
|
-
|
2642
2435
|
void llama_kv_self_defrag(llama_context * ctx) {
|
2643
2436
|
auto * kv = ctx->get_kv_self();
|
2644
2437
|
if (!kv) {
|
2645
2438
|
return;
|
2646
2439
|
}
|
2647
2440
|
|
2648
|
-
|
2649
|
-
|
2650
|
-
|
2651
|
-
// deprecated
|
2652
|
-
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
2653
|
-
return llama_kv_self_can_shift(ctx);
|
2441
|
+
// force defrag
|
2442
|
+
kv->defrag_sched(-1.0f);
|
2654
2443
|
}
|
2655
2444
|
|
2656
2445
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
@@ -2662,11 +2451,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
2662
2451
|
return kv->get_can_shift();
|
2663
2452
|
}
|
2664
2453
|
|
2665
|
-
// deprecated
|
2666
|
-
void llama_kv_cache_update(llama_context * ctx) {
|
2667
|
-
llama_kv_self_update(ctx);
|
2668
|
-
}
|
2669
|
-
|
2670
2454
|
// llama state API
|
2671
2455
|
|
2672
2456
|
// deprecated
|
@@ -2789,7 +2573,21 @@ int32_t llama_encode(
|
|
2789
2573
|
int32_t llama_decode(
|
2790
2574
|
llama_context * ctx,
|
2791
2575
|
llama_batch batch) {
|
2792
|
-
|
2576
|
+
int ret = ctx->decode(batch);
|
2577
|
+
|
2578
|
+
// defrag and try again
|
2579
|
+
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
2580
|
+
if (ret == 1) {
|
2581
|
+
llama_kv_self_defrag(ctx);
|
2582
|
+
ret = ctx->decode(batch);
|
2583
|
+
|
2584
|
+
if (ret == 1) {
|
2585
|
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
2586
|
+
|
2587
|
+
return ret;
|
2588
|
+
}
|
2589
|
+
}
|
2590
|
+
|
2793
2591
|
if (ret != 0) {
|
2794
2592
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
2795
2593
|
}
|
@@ -2829,3 +2627,34 @@ void llama_perf_context_print(const llama_context * ctx) {
|
|
2829
2627
|
void llama_perf_context_reset(llama_context * ctx) {
|
2830
2628
|
ctx->perf_reset();
|
2831
2629
|
}
|
2630
|
+
|
2631
|
+
//
|
2632
|
+
// training
|
2633
|
+
//
|
2634
|
+
|
2635
|
+
bool llama_opt_param_filter_all(const struct lm_ggml_tensor * tensor, void * userdata) {
|
2636
|
+
LM_GGML_UNUSED(tensor);
|
2637
|
+
LM_GGML_UNUSED(userdata);
|
2638
|
+
return true;
|
2639
|
+
}
|
2640
|
+
|
2641
|
+
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
|
2642
|
+
ctx->opt_init(model, lopt_params);
|
2643
|
+
}
|
2644
|
+
|
2645
|
+
void llama_opt_epoch(
|
2646
|
+
struct llama_context * ctx,
|
2647
|
+
lm_ggml_opt_dataset_t dataset,
|
2648
|
+
lm_ggml_opt_result_t result_train,
|
2649
|
+
lm_ggml_opt_result_t result_eval,
|
2650
|
+
int64_t idata_split,
|
2651
|
+
lm_ggml_opt_epoch_callback callback_train,
|
2652
|
+
lm_ggml_opt_epoch_callback callback_eval) {
|
2653
|
+
ctx->opt_epoch(
|
2654
|
+
dataset,
|
2655
|
+
result_train,
|
2656
|
+
result_eval,
|
2657
|
+
idata_split,
|
2658
|
+
callback_train,
|
2659
|
+
callback_eval);
|
2660
|
+
}
|