cui-llama.rn 1.5.0 → 1.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +20 -20
- package/README.md +317 -319
- package/android/build.gradle +116 -116
- package/android/gradle.properties +5 -5
- package/android/src/main/AndroidManifest.xml +4 -4
- package/android/src/main/CMakeLists.txt +124 -124
- package/android/src/main/java/com/rnllama/LlamaContext.java +645 -645
- package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
- package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
- package/android/src/main/jni-utils.h +100 -100
- package/android/src/main/jni.cpp +1263 -1263
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
- package/cpp/README.md +4 -4
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +597 -597
- package/cpp/ggml-metal.m +4 -0
- package/cpp/ggml.h +1 -1
- package/cpp/rn-llama.cpp +873 -873
- package/cpp/rn-llama.h +138 -138
- package/cpp/sampling.h +107 -107
- package/cpp/unicode-data.cpp +7034 -7034
- package/cpp/unicode-data.h +20 -20
- package/cpp/unicode.cpp +849 -849
- package/cpp/unicode.h +66 -66
- package/ios/CMakeLists.txt +116 -108
- package/ios/RNLlama.h +7 -7
- package/ios/RNLlama.mm +418 -405
- package/ios/RNLlamaContext.h +57 -57
- package/ios/RNLlamaContext.mm +835 -835
- package/ios/rnllama.xcframework/Info.plist +74 -74
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +203 -203
- package/lib/commonjs/NativeRNLlama.js +1 -2
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/chat.js.map +1 -1
- package/lib/commonjs/grammar.js +12 -31
- package/lib/commonjs/grammar.js.map +1 -1
- package/lib/commonjs/index.js +47 -47
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/package.json +1 -0
- package/lib/module/NativeRNLlama.js +2 -0
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/chat.js +2 -0
- package/lib/module/chat.js.map +1 -1
- package/lib/module/grammar.js +14 -31
- package/lib/module/grammar.js.map +1 -1
- package/lib/module/index.js +47 -45
- package/lib/module/index.js.map +1 -1
- package/lib/module/package.json +1 -0
- package/lib/typescript/NativeRNLlama.d.ts +6 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +48 -48
- package/package.json +233 -233
- package/src/NativeRNLlama.ts +426 -426
- package/src/chat.ts +44 -44
- package/src/grammar.ts +854 -854
- package/src/index.ts +495 -487
package/android/src/main/jni.cpp
CHANGED
@@ -1,1263 +1,1263 @@
|
|
1
|
-
#include <jni.h>
|
2
|
-
// #include <android/asset_manager.h>
|
3
|
-
// #include <android/asset_manager_jni.h>
|
4
|
-
#include <android/log.h>
|
5
|
-
#include <cstdlib>
|
6
|
-
#include <ctime>
|
7
|
-
#include <ctime>
|
8
|
-
#include <sys/sysinfo.h>
|
9
|
-
#include <string>
|
10
|
-
#include <thread>
|
11
|
-
#include <unordered_map>
|
12
|
-
#include "json-schema-to-grammar.h"
|
13
|
-
#include "llama.h"
|
14
|
-
#include "chat.h"
|
15
|
-
#include "llama-impl.h"
|
16
|
-
#include "ggml.h"
|
17
|
-
#include "rn-llama.h"
|
18
|
-
#include "jni-utils.h"
|
19
|
-
#define UNUSED(x) (void)(x)
|
20
|
-
#define TAG "RNLLAMA_ANDROID_JNI"
|
21
|
-
|
22
|
-
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
23
|
-
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
24
|
-
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
25
|
-
static inline int min(int a, int b) {
|
26
|
-
return (a < b) ? a : b;
|
27
|
-
}
|
28
|
-
|
29
|
-
static void rnllama_log_callback_default(lm_ggml_log_level level, const char * fmt, void * data) {
|
30
|
-
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
31
|
-
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
32
|
-
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
|
33
|
-
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
|
34
|
-
}
|
35
|
-
|
36
|
-
extern "C" {
|
37
|
-
|
38
|
-
// Method to create WritableMap
|
39
|
-
static inline jobject createWriteableMap(JNIEnv *env) {
|
40
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
|
41
|
-
jmethodID init = env->GetStaticMethodID(mapClass, "createMap", "()Lcom/facebook/react/bridge/WritableMap;");
|
42
|
-
jobject map = env->CallStaticObjectMethod(mapClass, init);
|
43
|
-
return map;
|
44
|
-
}
|
45
|
-
|
46
|
-
// Method to put string into WritableMap
|
47
|
-
static inline void putString(JNIEnv *env, jobject map, const char *key, const char *value) {
|
48
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
49
|
-
jmethodID putStringMethod = env->GetMethodID(mapClass, "putString", "(Ljava/lang/String;Ljava/lang/String;)V");
|
50
|
-
|
51
|
-
jstring jKey = env->NewStringUTF(key);
|
52
|
-
jstring jValue = env->NewStringUTF(value);
|
53
|
-
|
54
|
-
env->CallVoidMethod(map, putStringMethod, jKey, jValue);
|
55
|
-
}
|
56
|
-
|
57
|
-
// Method to put int into WritableMap
|
58
|
-
static inline void putInt(JNIEnv *env, jobject map, const char *key, int value) {
|
59
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
60
|
-
jmethodID putIntMethod = env->GetMethodID(mapClass, "putInt", "(Ljava/lang/String;I)V");
|
61
|
-
|
62
|
-
jstring jKey = env->NewStringUTF(key);
|
63
|
-
|
64
|
-
env->CallVoidMethod(map, putIntMethod, jKey, value);
|
65
|
-
}
|
66
|
-
|
67
|
-
// Method to put double into WritableMap
|
68
|
-
static inline void putDouble(JNIEnv *env, jobject map, const char *key, double value) {
|
69
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
70
|
-
jmethodID putDoubleMethod = env->GetMethodID(mapClass, "putDouble", "(Ljava/lang/String;D)V");
|
71
|
-
|
72
|
-
jstring jKey = env->NewStringUTF(key);
|
73
|
-
|
74
|
-
env->CallVoidMethod(map, putDoubleMethod, jKey, value);
|
75
|
-
}
|
76
|
-
|
77
|
-
// Method to put boolean into WritableMap
|
78
|
-
static inline void putBoolean(JNIEnv *env, jobject map, const char *key, bool value) {
|
79
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
80
|
-
jmethodID putBooleanMethod = env->GetMethodID(mapClass, "putBoolean", "(Ljava/lang/String;Z)V");
|
81
|
-
|
82
|
-
jstring jKey = env->NewStringUTF(key);
|
83
|
-
|
84
|
-
env->CallVoidMethod(map, putBooleanMethod, jKey, value);
|
85
|
-
}
|
86
|
-
|
87
|
-
// Method to put WriteableMap into WritableMap
|
88
|
-
static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
|
89
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
90
|
-
jmethodID putMapMethod = env->GetMethodID(mapClass, "putMap", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableMap;)V");
|
91
|
-
|
92
|
-
jstring jKey = env->NewStringUTF(key);
|
93
|
-
|
94
|
-
env->CallVoidMethod(map, putMapMethod, jKey, value);
|
95
|
-
}
|
96
|
-
|
97
|
-
// Method to create WritableArray
|
98
|
-
static inline jobject createWritableArray(JNIEnv *env) {
|
99
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
|
100
|
-
jmethodID init = env->GetStaticMethodID(mapClass, "createArray", "()Lcom/facebook/react/bridge/WritableArray;");
|
101
|
-
jobject map = env->CallStaticObjectMethod(mapClass, init);
|
102
|
-
return map;
|
103
|
-
}
|
104
|
-
|
105
|
-
// Method to push int into WritableArray
|
106
|
-
static inline void pushInt(JNIEnv *env, jobject arr, int value) {
|
107
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
108
|
-
jmethodID pushIntMethod = env->GetMethodID(mapClass, "pushInt", "(I)V");
|
109
|
-
|
110
|
-
env->CallVoidMethod(arr, pushIntMethod, value);
|
111
|
-
}
|
112
|
-
|
113
|
-
// Method to push double into WritableArray
|
114
|
-
static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
|
115
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
116
|
-
jmethodID pushDoubleMethod = env->GetMethodID(mapClass, "pushDouble", "(D)V");
|
117
|
-
|
118
|
-
env->CallVoidMethod(arr, pushDoubleMethod, value);
|
119
|
-
}
|
120
|
-
|
121
|
-
// Method to push string into WritableArray
|
122
|
-
static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
|
123
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
124
|
-
jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
|
125
|
-
|
126
|
-
jstring jValue = env->NewStringUTF(value);
|
127
|
-
env->CallVoidMethod(arr, pushStringMethod, jValue);
|
128
|
-
}
|
129
|
-
|
130
|
-
// Method to push WritableMap into WritableArray
|
131
|
-
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
|
132
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
133
|
-
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
|
134
|
-
|
135
|
-
env->CallVoidMethod(arr, pushMapMethod, value);
|
136
|
-
}
|
137
|
-
|
138
|
-
// Method to put WritableArray into WritableMap
|
139
|
-
static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject value) {
|
140
|
-
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
141
|
-
jmethodID putArrayMethod = env->GetMethodID(mapClass, "putArray", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableArray;)V");
|
142
|
-
|
143
|
-
jstring jKey = env->NewStringUTF(key);
|
144
|
-
|
145
|
-
env->CallVoidMethod(map, putArrayMethod, jKey, value);
|
146
|
-
}
|
147
|
-
|
148
|
-
JNIEXPORT jobject JNICALL
|
149
|
-
Java_com_rnllama_LlamaContext_modelInfo(
|
150
|
-
JNIEnv *env,
|
151
|
-
jobject thiz,
|
152
|
-
jstring model_path_str,
|
153
|
-
jobjectArray skip
|
154
|
-
) {
|
155
|
-
UNUSED(thiz);
|
156
|
-
|
157
|
-
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
158
|
-
|
159
|
-
std::vector<std::string> skip_vec;
|
160
|
-
int skip_len = env->GetArrayLength(skip);
|
161
|
-
for (int i = 0; i < skip_len; i++) {
|
162
|
-
jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
|
163
|
-
const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
|
164
|
-
skip_vec.push_back(skip_chars);
|
165
|
-
env->ReleaseStringUTFChars(skip_str, skip_chars);
|
166
|
-
}
|
167
|
-
|
168
|
-
struct lm_gguf_init_params params = {
|
169
|
-
/*.no_alloc = */ false,
|
170
|
-
/*.ctx = */ NULL,
|
171
|
-
};
|
172
|
-
struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
|
173
|
-
|
174
|
-
if (!ctx) {
|
175
|
-
LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
|
176
|
-
return nullptr;
|
177
|
-
}
|
178
|
-
|
179
|
-
auto info = createWriteableMap(env);
|
180
|
-
putInt(env, info, "version", lm_gguf_get_version(ctx));
|
181
|
-
putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
|
182
|
-
putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
|
183
|
-
{
|
184
|
-
const int n_kv = lm_gguf_get_n_kv(ctx);
|
185
|
-
|
186
|
-
for (int i = 0; i < n_kv; ++i) {
|
187
|
-
const char * key = lm_gguf_get_key(ctx, i);
|
188
|
-
|
189
|
-
bool skipped = false;
|
190
|
-
if (skip_len > 0) {
|
191
|
-
for (int j = 0; j < skip_len; j++) {
|
192
|
-
if (skip_vec[j] == key) {
|
193
|
-
skipped = true;
|
194
|
-
break;
|
195
|
-
}
|
196
|
-
}
|
197
|
-
}
|
198
|
-
|
199
|
-
if (skipped) {
|
200
|
-
continue;
|
201
|
-
}
|
202
|
-
|
203
|
-
const std::string value = lm_gguf_kv_to_str(ctx, i);
|
204
|
-
putString(env, info, key, value.c_str());
|
205
|
-
}
|
206
|
-
}
|
207
|
-
|
208
|
-
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
209
|
-
lm_gguf_free(ctx);
|
210
|
-
|
211
|
-
return reinterpret_cast<jobject>(info);
|
212
|
-
}
|
213
|
-
|
214
|
-
struct callback_context {
|
215
|
-
JNIEnv *env;
|
216
|
-
rnllama::llama_rn_context *llama;
|
217
|
-
jobject callback;
|
218
|
-
};
|
219
|
-
|
220
|
-
std::unordered_map<long, rnllama::llama_rn_context *> context_map;
|
221
|
-
|
222
|
-
struct CallbackContext {
|
223
|
-
JNIEnv * env;
|
224
|
-
jobject thiz;
|
225
|
-
jmethodID sendProgressMethod;
|
226
|
-
unsigned current;
|
227
|
-
};
|
228
|
-
|
229
|
-
JNIEXPORT jlong JNICALL
|
230
|
-
Java_com_rnllama_LlamaContext_initContext(
|
231
|
-
JNIEnv *env,
|
232
|
-
jobject thiz,
|
233
|
-
jstring model_path_str,
|
234
|
-
jstring chat_template,
|
235
|
-
jstring reasoning_format,
|
236
|
-
jboolean embedding,
|
237
|
-
jint embd_normalize,
|
238
|
-
jint n_ctx,
|
239
|
-
jint n_batch,
|
240
|
-
jint n_ubatch,
|
241
|
-
jint n_threads,
|
242
|
-
jint n_gpu_layers, // TODO: Support this
|
243
|
-
jboolean flash_attn,
|
244
|
-
jstring cache_type_k,
|
245
|
-
jstring cache_type_v,
|
246
|
-
jboolean use_mlock,
|
247
|
-
jboolean use_mmap,
|
248
|
-
jboolean vocab_only,
|
249
|
-
jstring lora_str,
|
250
|
-
jfloat lora_scaled,
|
251
|
-
jobject lora_list,
|
252
|
-
jfloat rope_freq_base,
|
253
|
-
jfloat rope_freq_scale,
|
254
|
-
jint pooling_type,
|
255
|
-
jobject load_progress_callback
|
256
|
-
) {
|
257
|
-
UNUSED(thiz);
|
258
|
-
|
259
|
-
common_params defaultParams;
|
260
|
-
|
261
|
-
defaultParams.vocab_only = vocab_only;
|
262
|
-
if(vocab_only) {
|
263
|
-
defaultParams.warmup = false;
|
264
|
-
}
|
265
|
-
|
266
|
-
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
267
|
-
defaultParams.model = { model_path_chars };
|
268
|
-
|
269
|
-
const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
|
270
|
-
defaultParams.chat_template = chat_template_chars;
|
271
|
-
|
272
|
-
const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
|
273
|
-
if (strcmp(reasoning_format_chars, "deepseek") == 0) {
|
274
|
-
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
275
|
-
} else {
|
276
|
-
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
277
|
-
}
|
278
|
-
|
279
|
-
defaultParams.n_ctx = n_ctx;
|
280
|
-
defaultParams.n_batch = n_batch;
|
281
|
-
defaultParams.n_ubatch = n_ubatch;
|
282
|
-
|
283
|
-
if (pooling_type != -1) {
|
284
|
-
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
|
285
|
-
}
|
286
|
-
|
287
|
-
defaultParams.embedding = embedding;
|
288
|
-
if (embd_normalize != -1) {
|
289
|
-
defaultParams.embd_normalize = embd_normalize;
|
290
|
-
}
|
291
|
-
if (embedding) {
|
292
|
-
// For non-causal models, batch size must be equal to ubatch size
|
293
|
-
defaultParams.n_ubatch = defaultParams.n_batch;
|
294
|
-
}
|
295
|
-
|
296
|
-
int max_threads = std::thread::hardware_concurrency();
|
297
|
-
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
298
|
-
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
299
|
-
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
300
|
-
|
301
|
-
// defaultParams.n_gpu_layers = n_gpu_layers;
|
302
|
-
defaultParams.flash_attn = flash_attn;
|
303
|
-
|
304
|
-
const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
|
305
|
-
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
|
306
|
-
defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
|
307
|
-
defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
|
308
|
-
|
309
|
-
defaultParams.use_mlock = use_mlock;
|
310
|
-
defaultParams.use_mmap = use_mmap;
|
311
|
-
|
312
|
-
defaultParams.rope_freq_base = rope_freq_base;
|
313
|
-
defaultParams.rope_freq_scale = rope_freq_scale;
|
314
|
-
|
315
|
-
auto llama = new rnllama::llama_rn_context();
|
316
|
-
llama->is_load_interrupted = false;
|
317
|
-
llama->loading_progress = 0;
|
318
|
-
|
319
|
-
if (load_progress_callback != nullptr) {
|
320
|
-
defaultParams.progress_callback = [](float progress, void * user_data) {
|
321
|
-
callback_context *cb_ctx = (callback_context *)user_data;
|
322
|
-
JNIEnv *env = cb_ctx->env;
|
323
|
-
auto llama = cb_ctx->llama;
|
324
|
-
jobject callback = cb_ctx->callback;
|
325
|
-
int percentage = (int) (100 * progress);
|
326
|
-
if (percentage > llama->loading_progress) {
|
327
|
-
llama->loading_progress = percentage;
|
328
|
-
jclass callback_class = env->GetObjectClass(callback);
|
329
|
-
jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
|
330
|
-
env->CallVoidMethod(callback, onLoadProgress, percentage);
|
331
|
-
}
|
332
|
-
return !llama->is_load_interrupted;
|
333
|
-
};
|
334
|
-
|
335
|
-
callback_context *cb_ctx = new callback_context;
|
336
|
-
cb_ctx->env = env;
|
337
|
-
cb_ctx->llama = llama;
|
338
|
-
cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
|
339
|
-
defaultParams.progress_callback_user_data = cb_ctx;
|
340
|
-
}
|
341
|
-
|
342
|
-
bool is_model_loaded = llama->loadModel(defaultParams);
|
343
|
-
|
344
|
-
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
345
|
-
env->ReleaseStringUTFChars(chat_template, chat_template_chars);
|
346
|
-
env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
|
347
|
-
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
|
348
|
-
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
349
|
-
|
350
|
-
LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
|
351
|
-
if (is_model_loaded) {
|
352
|
-
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
|
353
|
-
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
|
354
|
-
llama_free(llama->ctx);
|
355
|
-
return -1;
|
356
|
-
}
|
357
|
-
context_map[(long) llama->ctx] = llama;
|
358
|
-
} else {
|
359
|
-
llama_free(llama->ctx);
|
360
|
-
}
|
361
|
-
|
362
|
-
std::vector<common_adapter_lora_info> lora;
|
363
|
-
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
364
|
-
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
365
|
-
common_adapter_lora_info la;
|
366
|
-
la.path = lora_chars;
|
367
|
-
la.scale = lora_scaled;
|
368
|
-
lora.push_back(la);
|
369
|
-
}
|
370
|
-
|
371
|
-
if (lora_list != nullptr) {
|
372
|
-
// lora_adapters: ReadableArray<ReadableMap>
|
373
|
-
int lora_list_size = readablearray::size(env, lora_list);
|
374
|
-
for (int i = 0; i < lora_list_size; i++) {
|
375
|
-
jobject lora_adapter = readablearray::getMap(env, lora_list, i);
|
376
|
-
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
377
|
-
if (path != nullptr) {
|
378
|
-
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
379
|
-
common_adapter_lora_info la;
|
380
|
-
la.path = path_chars;
|
381
|
-
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
382
|
-
lora.push_back(la);
|
383
|
-
env->ReleaseStringUTFChars(path, path_chars);
|
384
|
-
}
|
385
|
-
}
|
386
|
-
}
|
387
|
-
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
388
|
-
int result = llama->applyLoraAdapters(lora);
|
389
|
-
if (result != 0) {
|
390
|
-
LOGI("[RNLlama] Failed to apply lora adapters");
|
391
|
-
llama_free(llama->ctx);
|
392
|
-
return -1;
|
393
|
-
}
|
394
|
-
|
395
|
-
return reinterpret_cast<jlong>(llama->ctx);
|
396
|
-
}
|
397
|
-
|
398
|
-
|
399
|
-
JNIEXPORT void JNICALL
|
400
|
-
Java_com_rnllama_LlamaContext_interruptLoad(
|
401
|
-
JNIEnv *env,
|
402
|
-
jobject thiz,
|
403
|
-
jlong context_ptr
|
404
|
-
) {
|
405
|
-
UNUSED(thiz);
|
406
|
-
auto llama = context_map[(long) context_ptr];
|
407
|
-
if (llama) {
|
408
|
-
llama->is_load_interrupted = true;
|
409
|
-
}
|
410
|
-
}
|
411
|
-
|
412
|
-
JNIEXPORT jobject JNICALL
|
413
|
-
Java_com_rnllama_LlamaContext_loadModelDetails(
|
414
|
-
JNIEnv *env,
|
415
|
-
jobject thiz,
|
416
|
-
jlong context_ptr
|
417
|
-
) {
|
418
|
-
UNUSED(thiz);
|
419
|
-
auto llama = context_map[(long) context_ptr];
|
420
|
-
|
421
|
-
int count = llama_model_meta_count(llama->model);
|
422
|
-
auto meta = createWriteableMap(env);
|
423
|
-
for (int i = 0; i < count; i++) {
|
424
|
-
char key[256];
|
425
|
-
llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
|
426
|
-
char val[4096];
|
427
|
-
llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
|
428
|
-
|
429
|
-
putString(env, meta, key, val);
|
430
|
-
}
|
431
|
-
|
432
|
-
auto result = createWriteableMap(env);
|
433
|
-
|
434
|
-
char desc[1024];
|
435
|
-
llama_model_desc(llama->model, desc, sizeof(desc));
|
436
|
-
|
437
|
-
putString(env, result, "desc", desc);
|
438
|
-
putDouble(env, result, "size", llama_model_size(llama->model));
|
439
|
-
putDouble(env, result, "nEmbd", llama_model_n_embd(llama->model));
|
440
|
-
putDouble(env, result, "nParams", llama_model_n_params(llama->model));
|
441
|
-
auto chat_templates = createWriteableMap(env);
|
442
|
-
putBoolean(env, chat_templates, "llamaChat", llama->validateModelChatTemplate(false, nullptr));
|
443
|
-
|
444
|
-
auto minja = createWriteableMap(env);
|
445
|
-
putBoolean(env, minja, "default", llama->validateModelChatTemplate(true, nullptr));
|
446
|
-
|
447
|
-
auto default_caps = createWriteableMap(env);
|
448
|
-
|
449
|
-
auto default_tmpl = llama->templates.get()->template_default.get();
|
450
|
-
auto default_tmpl_caps = default_tmpl->original_caps();
|
451
|
-
putBoolean(env, default_caps, "tools", default_tmpl_caps.supports_tools);
|
452
|
-
putBoolean(env, default_caps, "toolCalls", default_tmpl_caps.supports_tool_calls);
|
453
|
-
putBoolean(env, default_caps, "parallelToolCalls", default_tmpl_caps.supports_parallel_tool_calls);
|
454
|
-
putBoolean(env, default_caps, "toolResponses", default_tmpl_caps.supports_tool_responses);
|
455
|
-
putBoolean(env, default_caps, "systemRole", default_tmpl_caps.supports_system_role);
|
456
|
-
putBoolean(env, default_caps, "toolCallId", default_tmpl_caps.supports_tool_call_id);
|
457
|
-
putMap(env, minja, "defaultCaps", default_caps);
|
458
|
-
|
459
|
-
putBoolean(env, minja, "toolUse", llama->validateModelChatTemplate(true, "tool_use"));
|
460
|
-
auto tool_use_tmpl = llama->templates.get()->template_tool_use.get();
|
461
|
-
if (tool_use_tmpl != nullptr) {
|
462
|
-
auto tool_use_caps = createWriteableMap(env);
|
463
|
-
auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
|
464
|
-
putBoolean(env, tool_use_caps, "tools", tool_use_tmpl_caps.supports_tools);
|
465
|
-
putBoolean(env, tool_use_caps, "toolCalls", tool_use_tmpl_caps.supports_tool_calls);
|
466
|
-
putBoolean(env, tool_use_caps, "parallelToolCalls", tool_use_tmpl_caps.supports_parallel_tool_calls);
|
467
|
-
putBoolean(env, tool_use_caps, "systemRole", tool_use_tmpl_caps.supports_system_role);
|
468
|
-
putBoolean(env, tool_use_caps, "toolResponses", tool_use_tmpl_caps.supports_tool_responses);
|
469
|
-
putBoolean(env, tool_use_caps, "toolCallId", tool_use_tmpl_caps.supports_tool_call_id);
|
470
|
-
putMap(env, minja, "toolUseCaps", tool_use_caps);
|
471
|
-
}
|
472
|
-
|
473
|
-
putMap(env, chat_templates, "minja", minja);
|
474
|
-
putMap(env, result, "metadata", meta);
|
475
|
-
putMap(env, result, "chatTemplates", chat_templates);
|
476
|
-
|
477
|
-
// deprecated
|
478
|
-
putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate(false, nullptr));
|
479
|
-
|
480
|
-
return reinterpret_cast<jobject>(result);
|
481
|
-
}
|
482
|
-
|
483
|
-
JNIEXPORT jobject JNICALL
|
484
|
-
Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
|
485
|
-
JNIEnv *env,
|
486
|
-
jobject thiz,
|
487
|
-
jlong context_ptr,
|
488
|
-
jstring messages,
|
489
|
-
jstring chat_template,
|
490
|
-
jstring json_schema,
|
491
|
-
jstring tools,
|
492
|
-
jboolean parallel_tool_calls,
|
493
|
-
jstring tool_choice
|
494
|
-
) {
|
495
|
-
UNUSED(thiz);
|
496
|
-
auto llama = context_map[(long) context_ptr];
|
497
|
-
|
498
|
-
const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
|
499
|
-
const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
|
500
|
-
const char *json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
|
501
|
-
const char *tools_chars = env->GetStringUTFChars(tools, nullptr);
|
502
|
-
const char *tool_choice_chars = env->GetStringUTFChars(tool_choice, nullptr);
|
503
|
-
|
504
|
-
auto result = createWriteableMap(env);
|
505
|
-
try {
|
506
|
-
auto formatted = llama->getFormattedChatWithJinja(
|
507
|
-
messages_chars,
|
508
|
-
tmpl_chars,
|
509
|
-
json_schema_chars,
|
510
|
-
tools_chars,
|
511
|
-
parallel_tool_calls,
|
512
|
-
tool_choice_chars
|
513
|
-
);
|
514
|
-
putString(env, result, "prompt", formatted.prompt.c_str());
|
515
|
-
putInt(env, result, "chat_format", static_cast<int>(formatted.format));
|
516
|
-
putString(env, result, "grammar", formatted.grammar.c_str());
|
517
|
-
putBoolean(env, result, "grammar_lazy", formatted.grammar_lazy);
|
518
|
-
auto grammar_triggers = createWritableArray(env);
|
519
|
-
for (const auto &trigger : formatted.grammar_triggers) {
|
520
|
-
auto trigger_map = createWriteableMap(env);
|
521
|
-
putInt(env, trigger_map, "type", trigger.type);
|
522
|
-
putString(env, trigger_map, "value", trigger.value.c_str());
|
523
|
-
putInt(env, trigger_map, "token", trigger.token);
|
524
|
-
pushMap(env, grammar_triggers, trigger_map);
|
525
|
-
}
|
526
|
-
putArray(env, result, "grammar_triggers", grammar_triggers);
|
527
|
-
auto preserved_tokens = createWritableArray(env);
|
528
|
-
for (const auto &token : formatted.preserved_tokens) {
|
529
|
-
pushString(env, preserved_tokens, token.c_str());
|
530
|
-
}
|
531
|
-
putArray(env, result, "preserved_tokens", preserved_tokens);
|
532
|
-
auto additional_stops = createWritableArray(env);
|
533
|
-
for (const auto &stop : formatted.additional_stops) {
|
534
|
-
pushString(env, additional_stops, stop.c_str());
|
535
|
-
}
|
536
|
-
putArray(env, result, "additional_stops", additional_stops);
|
537
|
-
} catch (const std::runtime_error &e) {
|
538
|
-
LOGI("[RNLlama] Error: %s", e.what());
|
539
|
-
putString(env, result, "_error", e.what());
|
540
|
-
}
|
541
|
-
env->ReleaseStringUTFChars(tools, tools_chars);
|
542
|
-
env->ReleaseStringUTFChars(messages, messages_chars);
|
543
|
-
env->ReleaseStringUTFChars(chat_template, tmpl_chars);
|
544
|
-
env->ReleaseStringUTFChars(json_schema, json_schema_chars);
|
545
|
-
env->ReleaseStringUTFChars(tool_choice, tool_choice_chars);
|
546
|
-
return reinterpret_cast<jobject>(result);
|
547
|
-
}
|
548
|
-
|
549
|
-
JNIEXPORT jobject JNICALL
|
550
|
-
Java_com_rnllama_LlamaContext_getFormattedChat(
|
551
|
-
JNIEnv *env,
|
552
|
-
jobject thiz,
|
553
|
-
jlong context_ptr,
|
554
|
-
jstring messages,
|
555
|
-
jstring chat_template
|
556
|
-
) {
|
557
|
-
UNUSED(thiz);
|
558
|
-
auto llama = context_map[(long) context_ptr];
|
559
|
-
|
560
|
-
const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
|
561
|
-
const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
|
562
|
-
|
563
|
-
std::string formatted_chat = llama->getFormattedChat(messages_chars, tmpl_chars);
|
564
|
-
|
565
|
-
env->ReleaseStringUTFChars(messages, messages_chars);
|
566
|
-
env->ReleaseStringUTFChars(chat_template, tmpl_chars);
|
567
|
-
|
568
|
-
return env->NewStringUTF(formatted_chat.c_str());
|
569
|
-
}
|
570
|
-
|
571
|
-
JNIEXPORT jobject JNICALL
|
572
|
-
Java_com_rnllama_LlamaContext_loadSession(
|
573
|
-
JNIEnv *env,
|
574
|
-
jobject thiz,
|
575
|
-
jlong context_ptr,
|
576
|
-
jstring path
|
577
|
-
) {
|
578
|
-
UNUSED(thiz);
|
579
|
-
auto llama = context_map[(long) context_ptr];
|
580
|
-
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
581
|
-
|
582
|
-
auto result = createWriteableMap(env);
|
583
|
-
size_t n_token_count_out = 0;
|
584
|
-
llama->embd.resize(llama->params.n_ctx);
|
585
|
-
if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
|
586
|
-
env->ReleaseStringUTFChars(path, path_chars);
|
587
|
-
|
588
|
-
putString(env, result, "error", "Failed to load session");
|
589
|
-
return reinterpret_cast<jobject>(result);
|
590
|
-
}
|
591
|
-
llama->embd.resize(n_token_count_out);
|
592
|
-
env->ReleaseStringUTFChars(path, path_chars);
|
593
|
-
|
594
|
-
const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
|
595
|
-
putInt(env, result, "tokens_loaded", n_token_count_out);
|
596
|
-
putString(env, result, "prompt", text.c_str());
|
597
|
-
return reinterpret_cast<jobject>(result);
|
598
|
-
}
|
599
|
-
|
600
|
-
JNIEXPORT jint JNICALL
|
601
|
-
Java_com_rnllama_LlamaContext_saveSession(
|
602
|
-
JNIEnv *env,
|
603
|
-
jobject thiz,
|
604
|
-
jlong context_ptr,
|
605
|
-
jstring path,
|
606
|
-
jint size
|
607
|
-
) {
|
608
|
-
UNUSED(thiz);
|
609
|
-
auto llama = context_map[(long) context_ptr];
|
610
|
-
|
611
|
-
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
612
|
-
|
613
|
-
std::vector<llama_token> session_tokens = llama->embd;
|
614
|
-
int default_size = session_tokens.size();
|
615
|
-
int save_size = size > 0 && size <= default_size ? size : default_size;
|
616
|
-
if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
|
617
|
-
env->ReleaseStringUTFChars(path, path_chars);
|
618
|
-
return -1;
|
619
|
-
}
|
620
|
-
|
621
|
-
env->ReleaseStringUTFChars(path, path_chars);
|
622
|
-
return session_tokens.size();
|
623
|
-
}
|
624
|
-
|
625
|
-
static inline jobject tokenProbsToMap(
|
626
|
-
JNIEnv *env,
|
627
|
-
rnllama::llama_rn_context *llama,
|
628
|
-
std::vector<rnllama::completion_token_output> probs
|
629
|
-
) {
|
630
|
-
auto result = createWritableArray(env);
|
631
|
-
for (const auto &prob : probs) {
|
632
|
-
auto probsForToken = createWritableArray(env);
|
633
|
-
for (const auto &p : prob.probs) {
|
634
|
-
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
|
635
|
-
auto probResult = createWriteableMap(env);
|
636
|
-
putString(env, probResult, "tok_str", tokStr.c_str());
|
637
|
-
putDouble(env, probResult, "prob", p.prob);
|
638
|
-
pushMap(env, probsForToken, probResult);
|
639
|
-
}
|
640
|
-
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
|
641
|
-
auto tokenResult = createWriteableMap(env);
|
642
|
-
putString(env, tokenResult, "content", tokStr.c_str());
|
643
|
-
putArray(env, tokenResult, "probs", probsForToken);
|
644
|
-
pushMap(env, result, tokenResult);
|
645
|
-
}
|
646
|
-
return result;
|
647
|
-
}
|
648
|
-
|
649
|
-
JNIEXPORT jobject JNICALL
|
650
|
-
Java_com_rnllama_LlamaContext_doCompletion(
|
651
|
-
JNIEnv *env,
|
652
|
-
jobject thiz,
|
653
|
-
jlong context_ptr,
|
654
|
-
jstring prompt,
|
655
|
-
jint chat_format,
|
656
|
-
jstring grammar,
|
657
|
-
jstring json_schema,
|
658
|
-
jboolean grammar_lazy,
|
659
|
-
jobject grammar_triggers,
|
660
|
-
jobject preserved_tokens,
|
661
|
-
jfloat temperature,
|
662
|
-
jint n_threads,
|
663
|
-
jint n_predict,
|
664
|
-
jint n_probs,
|
665
|
-
jint penalty_last_n,
|
666
|
-
jfloat penalty_repeat,
|
667
|
-
jfloat penalty_freq,
|
668
|
-
jfloat penalty_present,
|
669
|
-
jfloat mirostat,
|
670
|
-
jfloat mirostat_tau,
|
671
|
-
jfloat mirostat_eta,
|
672
|
-
jint top_k,
|
673
|
-
jfloat top_p,
|
674
|
-
jfloat min_p,
|
675
|
-
jfloat xtc_threshold,
|
676
|
-
jfloat xtc_probability,
|
677
|
-
jfloat typical_p,
|
678
|
-
jint seed,
|
679
|
-
jobjectArray stop,
|
680
|
-
jboolean ignore_eos,
|
681
|
-
jobjectArray logit_bias,
|
682
|
-
jfloat dry_multiplier,
|
683
|
-
jfloat dry_base,
|
684
|
-
jint dry_allowed_length,
|
685
|
-
jint dry_penalty_last_n,
|
686
|
-
jfloat top_n_sigma,
|
687
|
-
jobjectArray dry_sequence_breakers,
|
688
|
-
jobject partial_completion_callback
|
689
|
-
) {
|
690
|
-
UNUSED(thiz);
|
691
|
-
auto llama = context_map[(long) context_ptr];
|
692
|
-
|
693
|
-
llama->rewind();
|
694
|
-
|
695
|
-
//llama_reset_timings(llama->ctx);
|
696
|
-
|
697
|
-
auto prompt_chars = env->GetStringUTFChars(prompt, nullptr);
|
698
|
-
llama->params.prompt = prompt_chars;
|
699
|
-
llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
|
700
|
-
|
701
|
-
int max_threads = std::thread::hardware_concurrency();
|
702
|
-
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
703
|
-
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
704
|
-
llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
705
|
-
|
706
|
-
llama->params.n_predict = n_predict;
|
707
|
-
llama->params.sampling.ignore_eos = ignore_eos;
|
708
|
-
|
709
|
-
auto & sparams = llama->params.sampling;
|
710
|
-
sparams.temp = temperature;
|
711
|
-
sparams.penalty_last_n = penalty_last_n;
|
712
|
-
sparams.penalty_repeat = penalty_repeat;
|
713
|
-
sparams.penalty_freq = penalty_freq;
|
714
|
-
sparams.penalty_present = penalty_present;
|
715
|
-
sparams.mirostat = mirostat;
|
716
|
-
sparams.mirostat_tau = mirostat_tau;
|
717
|
-
sparams.mirostat_eta = mirostat_eta;
|
718
|
-
sparams.top_k = top_k;
|
719
|
-
sparams.top_p = top_p;
|
720
|
-
sparams.min_p = min_p;
|
721
|
-
sparams.typ_p = typical_p;
|
722
|
-
sparams.n_probs = n_probs;
|
723
|
-
sparams.xtc_threshold = xtc_threshold;
|
724
|
-
sparams.xtc_probability = xtc_probability;
|
725
|
-
sparams.dry_multiplier = dry_multiplier;
|
726
|
-
sparams.dry_base = dry_base;
|
727
|
-
sparams.dry_allowed_length = dry_allowed_length;
|
728
|
-
sparams.dry_penalty_last_n = dry_penalty_last_n;
|
729
|
-
sparams.top_n_sigma = top_n_sigma;
|
730
|
-
|
731
|
-
// grammar
|
732
|
-
auto grammar_chars = env->GetStringUTFChars(grammar, nullptr);
|
733
|
-
if (grammar_chars && grammar_chars[0] != '\0') {
|
734
|
-
sparams.grammar = grammar_chars;
|
735
|
-
}
|
736
|
-
sparams.grammar_lazy = grammar_lazy;
|
737
|
-
|
738
|
-
if (preserved_tokens != nullptr) {
|
739
|
-
int preserved_tokens_size = readablearray::size(env, preserved_tokens);
|
740
|
-
for (int i = 0; i < preserved_tokens_size; i++) {
|
741
|
-
jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
|
742
|
-
auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
|
743
|
-
if (ids.size() == 1) {
|
744
|
-
sparams.preserved_tokens.insert(ids[0]);
|
745
|
-
} else {
|
746
|
-
LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
|
747
|
-
}
|
748
|
-
}
|
749
|
-
}
|
750
|
-
|
751
|
-
if (grammar_triggers != nullptr) {
|
752
|
-
int grammar_triggers_size = readablearray::size(env, grammar_triggers);
|
753
|
-
for (int i = 0; i < grammar_triggers_size; i++) {
|
754
|
-
auto trigger_map = readablearray::getMap(env, grammar_triggers, i);
|
755
|
-
const auto type = static_cast<common_grammar_trigger_type>(readablemap::getInt(env, trigger_map, "type", 0));
|
756
|
-
jstring trigger_word = readablemap::getString(env, trigger_map, "value", nullptr);
|
757
|
-
auto word = env->GetStringUTFChars(trigger_word, nullptr);
|
758
|
-
|
759
|
-
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
760
|
-
auto ids = common_tokenize(llama->ctx, word, /* add_special= */ false, /* parse_special= */ true);
|
761
|
-
if (ids.size() == 1) {
|
762
|
-
auto token = ids[0];
|
763
|
-
if (std::find(sparams.preserved_tokens.begin(), sparams.preserved_tokens.end(), (llama_token) token) == sparams.preserved_tokens.end()) {
|
764
|
-
throw std::runtime_error("Grammar trigger word should be marked as preserved token");
|
765
|
-
}
|
766
|
-
common_grammar_trigger trigger;
|
767
|
-
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
768
|
-
trigger.value = word;
|
769
|
-
trigger.token = token;
|
770
|
-
sparams.grammar_triggers.push_back(std::move(trigger));
|
771
|
-
} else {
|
772
|
-
sparams.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
773
|
-
}
|
774
|
-
} else {
|
775
|
-
common_grammar_trigger trigger;
|
776
|
-
trigger.type = type;
|
777
|
-
trigger.value = word;
|
778
|
-
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
779
|
-
const auto token = (llama_token) readablemap::getInt(env, trigger_map, "token", 0);
|
780
|
-
trigger.token = token;
|
781
|
-
}
|
782
|
-
sparams.grammar_triggers.push_back(std::move(trigger));
|
783
|
-
}
|
784
|
-
}
|
785
|
-
}
|
786
|
-
|
787
|
-
auto json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
|
788
|
-
if ((!grammar_chars || grammar_chars[0] == '\0') && json_schema_chars && json_schema_chars[0] != '\0') {
|
789
|
-
auto schema = json::parse(json_schema_chars);
|
790
|
-
sparams.grammar = json_schema_to_grammar(schema);
|
791
|
-
}
|
792
|
-
env->ReleaseStringUTFChars(json_schema, json_schema_chars);
|
793
|
-
|
794
|
-
|
795
|
-
const llama_model * model = llama_get_model(llama->ctx);
|
796
|
-
const llama_vocab * vocab = llama_model_get_vocab(model);
|
797
|
-
|
798
|
-
sparams.logit_bias.clear();
|
799
|
-
if (ignore_eos) {
|
800
|
-
sparams.logit_bias[llama_vocab_eos(vocab)].bias = -INFINITY;
|
801
|
-
}
|
802
|
-
|
803
|
-
// dry break seq
|
804
|
-
|
805
|
-
jint size = env->GetArrayLength(dry_sequence_breakers);
|
806
|
-
std::vector<std::string> dry_sequence_breakers_vector;
|
807
|
-
|
808
|
-
for (jint i = 0; i < size; i++) {
|
809
|
-
jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
|
810
|
-
const char *nativeString = env->GetStringUTFChars(javaString, 0);
|
811
|
-
dry_sequence_breakers_vector.push_back(std::string(nativeString));
|
812
|
-
env->ReleaseStringUTFChars(javaString, nativeString);
|
813
|
-
env->DeleteLocalRef(javaString);
|
814
|
-
}
|
815
|
-
|
816
|
-
sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
|
817
|
-
|
818
|
-
// logit bias
|
819
|
-
const int n_vocab = llama_vocab_n_tokens(vocab);
|
820
|
-
jsize logit_bias_len = env->GetArrayLength(logit_bias);
|
821
|
-
|
822
|
-
for (jsize i = 0; i < logit_bias_len; i++) {
|
823
|
-
jdoubleArray el = (jdoubleArray) env->GetObjectArrayElement(logit_bias, i);
|
824
|
-
if (el && env->GetArrayLength(el) == 2) {
|
825
|
-
jdouble* doubleArray = env->GetDoubleArrayElements(el, 0);
|
826
|
-
|
827
|
-
llama_token tok = static_cast<llama_token>(doubleArray[0]);
|
828
|
-
if (tok >= 0 && tok < n_vocab) {
|
829
|
-
if (doubleArray[1] != 0) { // If the second element is not false (0)
|
830
|
-
sparams.logit_bias[tok].bias = doubleArray[1];
|
831
|
-
} else {
|
832
|
-
sparams.logit_bias[tok].bias = -INFINITY;
|
833
|
-
}
|
834
|
-
}
|
835
|
-
|
836
|
-
env->ReleaseDoubleArrayElements(el, doubleArray, 0);
|
837
|
-
}
|
838
|
-
env->DeleteLocalRef(el);
|
839
|
-
}
|
840
|
-
|
841
|
-
llama->params.antiprompt.clear();
|
842
|
-
int stop_len = env->GetArrayLength(stop);
|
843
|
-
for (int i = 0; i < stop_len; i++) {
|
844
|
-
jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
|
845
|
-
const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
|
846
|
-
llama->params.antiprompt.push_back(stop_chars);
|
847
|
-
env->ReleaseStringUTFChars(stop_str, stop_chars);
|
848
|
-
}
|
849
|
-
|
850
|
-
if (!llama->initSampling()) {
|
851
|
-
auto result = createWriteableMap(env);
|
852
|
-
putString(env, result, "error", "Failed to initialize sampling");
|
853
|
-
return reinterpret_cast<jobject>(result);
|
854
|
-
}
|
855
|
-
llama->beginCompletion();
|
856
|
-
llama->loadPrompt();
|
857
|
-
|
858
|
-
size_t sent_count = 0;
|
859
|
-
size_t sent_token_probs_index = 0;
|
860
|
-
|
861
|
-
while (llama->has_next_token && !llama->is_interrupted) {
|
862
|
-
const rnllama::completion_token_output token_with_probs = llama->doCompletion();
|
863
|
-
if (token_with_probs.tok == -1 || llama->incomplete) {
|
864
|
-
continue;
|
865
|
-
}
|
866
|
-
const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
|
867
|
-
|
868
|
-
size_t pos = std::min(sent_count, llama->generated_text.size());
|
869
|
-
|
870
|
-
const std::string str_test = llama->generated_text.substr(pos);
|
871
|
-
bool is_stop_full = false;
|
872
|
-
size_t stop_pos =
|
873
|
-
llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
|
874
|
-
if (stop_pos != std::string::npos) {
|
875
|
-
is_stop_full = true;
|
876
|
-
llama->generated_text.erase(
|
877
|
-
llama->generated_text.begin() + pos + stop_pos,
|
878
|
-
llama->generated_text.end());
|
879
|
-
pos = std::min(sent_count, llama->generated_text.size());
|
880
|
-
} else {
|
881
|
-
is_stop_full = false;
|
882
|
-
stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
|
883
|
-
rnllama::STOP_PARTIAL);
|
884
|
-
}
|
885
|
-
|
886
|
-
if (
|
887
|
-
stop_pos == std::string::npos ||
|
888
|
-
// Send rest of the text if we are at the end of the generation
|
889
|
-
(!llama->has_next_token && !is_stop_full && stop_pos > 0)
|
890
|
-
) {
|
891
|
-
const std::string to_send = llama->generated_text.substr(pos, std::string::npos);
|
892
|
-
|
893
|
-
sent_count += to_send.size();
|
894
|
-
|
895
|
-
std::vector<rnllama::completion_token_output> probs_output = {};
|
896
|
-
|
897
|
-
auto tokenResult = createWriteableMap(env);
|
898
|
-
putString(env, tokenResult, "token", to_send.c_str());
|
899
|
-
|
900
|
-
if (llama->params.sampling.n_probs > 0) {
|
901
|
-
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
|
902
|
-
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
|
903
|
-
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
|
904
|
-
if (probs_pos < probs_stop_pos) {
|
905
|
-
probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
|
906
|
-
}
|
907
|
-
sent_token_probs_index = probs_stop_pos;
|
908
|
-
|
909
|
-
putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
|
910
|
-
}
|
911
|
-
|
912
|
-
jclass cb_class = env->GetObjectClass(partial_completion_callback);
|
913
|
-
jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
|
914
|
-
env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
|
915
|
-
}
|
916
|
-
}
|
917
|
-
|
918
|
-
env->ReleaseStringUTFChars(grammar, grammar_chars);
|
919
|
-
env->ReleaseStringUTFChars(prompt, prompt_chars);
|
920
|
-
llama_perf_context_print(llama->ctx);
|
921
|
-
llama->is_predicting = false;
|
922
|
-
|
923
|
-
auto toolCalls = createWritableArray(env);
|
924
|
-
std::string reasoningContent = "";
|
925
|
-
std::string content;
|
926
|
-
auto toolCallsSize = 0;
|
927
|
-
if (!llama->is_interrupted) {
|
928
|
-
try {
|
929
|
-
common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
|
930
|
-
if (!message.reasoning_content.empty()) {
|
931
|
-
reasoningContent = message.reasoning_content;
|
932
|
-
}
|
933
|
-
content = message.content;
|
934
|
-
for (const auto &tc : message.tool_calls) {
|
935
|
-
auto toolCall = createWriteableMap(env);
|
936
|
-
putString(env, toolCall, "type", "function");
|
937
|
-
auto functionMap = createWriteableMap(env);
|
938
|
-
putString(env, functionMap, "name", tc.name.c_str());
|
939
|
-
putString(env, functionMap, "arguments", tc.arguments.c_str());
|
940
|
-
putMap(env, toolCall, "function", functionMap);
|
941
|
-
if (!tc.id.empty()) {
|
942
|
-
putString(env, toolCall, "id", tc.id.c_str());
|
943
|
-
}
|
944
|
-
pushMap(env, toolCalls, toolCall);
|
945
|
-
toolCallsSize++;
|
946
|
-
}
|
947
|
-
} catch (const std::exception &e) {
|
948
|
-
// LOGI("Error parsing tool calls: %s", e.what());
|
949
|
-
}
|
950
|
-
}
|
951
|
-
|
952
|
-
auto result = createWriteableMap(env);
|
953
|
-
putString(env, result, "text", llama->generated_text.c_str());
|
954
|
-
if (!content.empty()) {
|
955
|
-
putString(env, result, "content", content.c_str());
|
956
|
-
}
|
957
|
-
if (!reasoningContent.empty()) {
|
958
|
-
putString(env, result, "reasoning_content", reasoningContent.c_str());
|
959
|
-
}
|
960
|
-
if (toolCallsSize > 0) {
|
961
|
-
putArray(env, result, "tool_calls", toolCalls);
|
962
|
-
}
|
963
|
-
putArray(env, result, "completion_probabilities", tokenProbsToMap(env, llama, llama->generated_token_probs));
|
964
|
-
putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
|
965
|
-
putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
|
966
|
-
putInt(env, result, "truncated", llama->truncated);
|
967
|
-
putInt(env, result, "stopped_eos", llama->stopped_eos);
|
968
|
-
putInt(env, result, "stopped_word", llama->stopped_word);
|
969
|
-
putInt(env, result, "stopped_limit", llama->stopped_limit);
|
970
|
-
putString(env, result, "stopping_word", llama->stopping_word.c_str());
|
971
|
-
putInt(env, result, "tokens_cached", llama->n_past);
|
972
|
-
|
973
|
-
const auto timings_token = llama_perf_context(llama -> ctx);
|
974
|
-
|
975
|
-
auto timingsResult = createWriteableMap(env);
|
976
|
-
putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
|
977
|
-
putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms);
|
978
|
-
putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval);
|
979
|
-
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval);
|
980
|
-
putInt(env, timingsResult, "predicted_n", timings_token.n_eval);
|
981
|
-
putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms);
|
982
|
-
putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval);
|
983
|
-
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval);
|
984
|
-
|
985
|
-
putMap(env, result, "timings", timingsResult);
|
986
|
-
|
987
|
-
return reinterpret_cast<jobject>(result);
|
988
|
-
}
|
989
|
-
|
990
|
-
JNIEXPORT void JNICALL
|
991
|
-
Java_com_rnllama_LlamaContext_stopCompletion(
|
992
|
-
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
993
|
-
UNUSED(env);
|
994
|
-
UNUSED(thiz);
|
995
|
-
auto llama = context_map[(long) context_ptr];
|
996
|
-
llama->is_interrupted = true;
|
997
|
-
}
|
998
|
-
|
999
|
-
JNIEXPORT jboolean JNICALL
|
1000
|
-
Java_com_rnllama_LlamaContext_isPredicting(
|
1001
|
-
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1002
|
-
UNUSED(env);
|
1003
|
-
UNUSED(thiz);
|
1004
|
-
auto llama = context_map[(long) context_ptr];
|
1005
|
-
return llama->is_predicting;
|
1006
|
-
}
|
1007
|
-
|
1008
|
-
JNIEXPORT jobject JNICALL
|
1009
|
-
Java_com_rnllama_LlamaContext_tokenize(
|
1010
|
-
JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
|
1011
|
-
UNUSED(thiz);
|
1012
|
-
auto llama = context_map[(long) context_ptr];
|
1013
|
-
|
1014
|
-
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
1015
|
-
|
1016
|
-
const std::vector<llama_token> toks = common_tokenize(
|
1017
|
-
llama->ctx,
|
1018
|
-
text_chars,
|
1019
|
-
false
|
1020
|
-
);
|
1021
|
-
|
1022
|
-
jobject result = createWritableArray(env);
|
1023
|
-
for (const auto &tok : toks) {
|
1024
|
-
pushInt(env, result, tok);
|
1025
|
-
}
|
1026
|
-
|
1027
|
-
env->ReleaseStringUTFChars(text, text_chars);
|
1028
|
-
return result;
|
1029
|
-
}
|
1030
|
-
|
1031
|
-
JNIEXPORT jstring JNICALL
|
1032
|
-
Java_com_rnllama_LlamaContext_detokenize(
|
1033
|
-
JNIEnv *env, jobject thiz, jlong context_ptr, jintArray tokens) {
|
1034
|
-
UNUSED(thiz);
|
1035
|
-
auto llama = context_map[(long) context_ptr];
|
1036
|
-
|
1037
|
-
jsize tokens_len = env->GetArrayLength(tokens);
|
1038
|
-
jint *tokens_ptr = env->GetIntArrayElements(tokens, 0);
|
1039
|
-
std::vector<llama_token> toks;
|
1040
|
-
for (int i = 0; i < tokens_len; i++) {
|
1041
|
-
toks.push_back(tokens_ptr[i]);
|
1042
|
-
}
|
1043
|
-
|
1044
|
-
auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
|
1045
|
-
|
1046
|
-
env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
|
1047
|
-
|
1048
|
-
return env->NewStringUTF(text.c_str());
|
1049
|
-
}
|
1050
|
-
|
1051
|
-
JNIEXPORT jboolean JNICALL
|
1052
|
-
Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
|
1053
|
-
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1054
|
-
UNUSED(env);
|
1055
|
-
UNUSED(thiz);
|
1056
|
-
auto llama = context_map[(long) context_ptr];
|
1057
|
-
return llama->params.embedding;
|
1058
|
-
}
|
1059
|
-
|
1060
|
-
JNIEXPORT jobject JNICALL
|
1061
|
-
Java_com_rnllama_LlamaContext_embedding(
|
1062
|
-
JNIEnv *env, jobject thiz,
|
1063
|
-
jlong context_ptr,
|
1064
|
-
jstring text,
|
1065
|
-
jint embd_normalize
|
1066
|
-
) {
|
1067
|
-
UNUSED(thiz);
|
1068
|
-
auto llama = context_map[(long) context_ptr];
|
1069
|
-
|
1070
|
-
common_params embdParams;
|
1071
|
-
embdParams.embedding = true;
|
1072
|
-
embdParams.embd_normalize = llama->params.embd_normalize;
|
1073
|
-
if (embd_normalize != -1) {
|
1074
|
-
embdParams.embd_normalize = embd_normalize;
|
1075
|
-
}
|
1076
|
-
|
1077
|
-
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
1078
|
-
|
1079
|
-
llama->rewind();
|
1080
|
-
|
1081
|
-
llama_perf_context_reset(llama->ctx);
|
1082
|
-
|
1083
|
-
llama->params.prompt = text_chars;
|
1084
|
-
|
1085
|
-
llama->params.n_predict = 0;
|
1086
|
-
|
1087
|
-
auto result = createWriteableMap(env);
|
1088
|
-
if (!llama->initSampling()) {
|
1089
|
-
putString(env, result, "error", "Failed to initialize sampling");
|
1090
|
-
return reinterpret_cast<jobject>(result);
|
1091
|
-
}
|
1092
|
-
|
1093
|
-
llama->beginCompletion();
|
1094
|
-
llama->loadPrompt();
|
1095
|
-
llama->doCompletion();
|
1096
|
-
|
1097
|
-
std::vector<float> embedding = llama->getEmbedding(embdParams);
|
1098
|
-
|
1099
|
-
auto embeddings = createWritableArray(env);
|
1100
|
-
for (const auto &val : embedding) {
|
1101
|
-
pushDouble(env, embeddings, (double) val);
|
1102
|
-
}
|
1103
|
-
putArray(env, result, "embedding", embeddings);
|
1104
|
-
|
1105
|
-
auto promptTokens = createWritableArray(env);
|
1106
|
-
for (const auto &tok : llama->embd) {
|
1107
|
-
pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
|
1108
|
-
}
|
1109
|
-
putArray(env, result, "prompt_tokens", promptTokens);
|
1110
|
-
|
1111
|
-
env->ReleaseStringUTFChars(text, text_chars);
|
1112
|
-
return result;
|
1113
|
-
}
|
1114
|
-
|
1115
|
-
JNIEXPORT jstring JNICALL
|
1116
|
-
Java_com_rnllama_LlamaContext_bench(
|
1117
|
-
JNIEnv *env,
|
1118
|
-
jobject thiz,
|
1119
|
-
jlong context_ptr,
|
1120
|
-
jint pp,
|
1121
|
-
jint tg,
|
1122
|
-
jint pl,
|
1123
|
-
jint nr
|
1124
|
-
) {
|
1125
|
-
UNUSED(thiz);
|
1126
|
-
auto llama = context_map[(long) context_ptr];
|
1127
|
-
std::string result = llama->bench(pp, tg, pl, nr);
|
1128
|
-
return env->NewStringUTF(result.c_str());
|
1129
|
-
}
|
1130
|
-
|
1131
|
-
JNIEXPORT jint JNICALL
|
1132
|
-
Java_com_rnllama_LlamaContext_applyLoraAdapters(
|
1133
|
-
JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
|
1134
|
-
UNUSED(thiz);
|
1135
|
-
auto llama = context_map[(long) context_ptr];
|
1136
|
-
|
1137
|
-
// lora_adapters: ReadableArray<ReadableMap>
|
1138
|
-
std::vector<common_adapter_lora_info> lora_adapters;
|
1139
|
-
int lora_adapters_size = readablearray::size(env, loraAdapters);
|
1140
|
-
for (int i = 0; i < lora_adapters_size; i++) {
|
1141
|
-
jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
|
1142
|
-
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
1143
|
-
if (path != nullptr) {
|
1144
|
-
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
1145
|
-
env->ReleaseStringUTFChars(path, path_chars);
|
1146
|
-
float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
1147
|
-
common_adapter_lora_info la;
|
1148
|
-
la.path = path_chars;
|
1149
|
-
la.scale = scaled;
|
1150
|
-
lora_adapters.push_back(la);
|
1151
|
-
}
|
1152
|
-
}
|
1153
|
-
return llama->applyLoraAdapters(lora_adapters);
|
1154
|
-
}
|
1155
|
-
|
1156
|
-
JNIEXPORT void JNICALL
|
1157
|
-
Java_com_rnllama_LlamaContext_removeLoraAdapters(
|
1158
|
-
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1159
|
-
UNUSED(env);
|
1160
|
-
UNUSED(thiz);
|
1161
|
-
auto llama = context_map[(long) context_ptr];
|
1162
|
-
llama->removeLoraAdapters();
|
1163
|
-
}
|
1164
|
-
|
1165
|
-
JNIEXPORT jobject JNICALL
|
1166
|
-
Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
|
1167
|
-
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1168
|
-
UNUSED(thiz);
|
1169
|
-
auto llama = context_map[(long) context_ptr];
|
1170
|
-
auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
|
1171
|
-
auto result = createWritableArray(env);
|
1172
|
-
for (common_adapter_lora_info &la : loaded_lora_adapters) {
|
1173
|
-
auto map = createWriteableMap(env);
|
1174
|
-
putString(env, map, "path", la.path.c_str());
|
1175
|
-
putDouble(env, map, "scaled", la.scale);
|
1176
|
-
pushMap(env, result, map);
|
1177
|
-
}
|
1178
|
-
return result;
|
1179
|
-
}
|
1180
|
-
|
1181
|
-
JNIEXPORT void JNICALL
|
1182
|
-
Java_com_rnllama_LlamaContext_freeContext(
|
1183
|
-
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1184
|
-
UNUSED(env);
|
1185
|
-
UNUSED(thiz);
|
1186
|
-
auto llama = context_map[(long) context_ptr];
|
1187
|
-
context_map.erase((long) llama->ctx);
|
1188
|
-
delete llama;
|
1189
|
-
}
|
1190
|
-
|
1191
|
-
struct log_callback_context {
|
1192
|
-
JavaVM *jvm;
|
1193
|
-
jobject callback;
|
1194
|
-
};
|
1195
|
-
|
1196
|
-
static void rnllama_log_callback_to_j(lm_ggml_log_level level, const char * text, void * data) {
|
1197
|
-
auto level_c = "";
|
1198
|
-
if (level == LM_GGML_LOG_LEVEL_ERROR) {
|
1199
|
-
__android_log_print(ANDROID_LOG_ERROR, TAG, text, nullptr);
|
1200
|
-
level_c = "error";
|
1201
|
-
} else if (level == LM_GGML_LOG_LEVEL_INFO) {
|
1202
|
-
__android_log_print(ANDROID_LOG_INFO, TAG, text, nullptr);
|
1203
|
-
level_c = "info";
|
1204
|
-
} else if (level == LM_GGML_LOG_LEVEL_WARN) {
|
1205
|
-
__android_log_print(ANDROID_LOG_WARN, TAG, text, nullptr);
|
1206
|
-
level_c = "warn";
|
1207
|
-
} else {
|
1208
|
-
__android_log_print(ANDROID_LOG_DEFAULT, TAG, text, nullptr);
|
1209
|
-
}
|
1210
|
-
|
1211
|
-
log_callback_context *cb_ctx = (log_callback_context *) data;
|
1212
|
-
|
1213
|
-
JNIEnv *env;
|
1214
|
-
bool need_detach = false;
|
1215
|
-
int getEnvResult = cb_ctx->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
|
1216
|
-
|
1217
|
-
if (getEnvResult == JNI_EDETACHED) {
|
1218
|
-
if (cb_ctx->jvm->AttachCurrentThread(&env, nullptr) == JNI_OK) {
|
1219
|
-
need_detach = true;
|
1220
|
-
} else {
|
1221
|
-
return;
|
1222
|
-
}
|
1223
|
-
} else if (getEnvResult != JNI_OK) {
|
1224
|
-
return;
|
1225
|
-
}
|
1226
|
-
|
1227
|
-
jobject callback = cb_ctx->callback;
|
1228
|
-
jclass cb_class = env->GetObjectClass(callback);
|
1229
|
-
jmethodID emitNativeLog = env->GetMethodID(cb_class, "emitNativeLog", "(Ljava/lang/String;Ljava/lang/String;)V");
|
1230
|
-
|
1231
|
-
jstring level_str = env->NewStringUTF(level_c);
|
1232
|
-
jstring text_str = env->NewStringUTF(text);
|
1233
|
-
env->CallVoidMethod(callback, emitNativeLog, level_str, text_str);
|
1234
|
-
env->DeleteLocalRef(level_str);
|
1235
|
-
env->DeleteLocalRef(text_str);
|
1236
|
-
|
1237
|
-
if (need_detach) {
|
1238
|
-
cb_ctx->jvm->DetachCurrentThread();
|
1239
|
-
}
|
1240
|
-
}
|
1241
|
-
|
1242
|
-
JNIEXPORT void JNICALL
|
1243
|
-
Java_com_rnllama_LlamaContext_setupLog(JNIEnv *env, jobject thiz, jobject logCallback) {
|
1244
|
-
UNUSED(thiz);
|
1245
|
-
|
1246
|
-
log_callback_context *cb_ctx = new log_callback_context;
|
1247
|
-
|
1248
|
-
JavaVM *jvm;
|
1249
|
-
env->GetJavaVM(&jvm);
|
1250
|
-
cb_ctx->jvm = jvm;
|
1251
|
-
cb_ctx->callback = env->NewGlobalRef(logCallback);
|
1252
|
-
|
1253
|
-
llama_log_set(rnllama_log_callback_to_j, cb_ctx);
|
1254
|
-
}
|
1255
|
-
|
1256
|
-
JNIEXPORT void JNICALL
|
1257
|
-
Java_com_rnllama_LlamaContext_unsetLog(JNIEnv *env, jobject thiz) {
|
1258
|
-
UNUSED(env);
|
1259
|
-
UNUSED(thiz);
|
1260
|
-
llama_log_set(rnllama_log_callback_default, NULL);
|
1261
|
-
}
|
1262
|
-
|
1263
|
-
} // extern "C"
|
1
|
+
#include <jni.h>
|
2
|
+
// #include <android/asset_manager.h>
|
3
|
+
// #include <android/asset_manager_jni.h>
|
4
|
+
#include <android/log.h>
|
5
|
+
#include <cstdlib>
|
6
|
+
#include <ctime>
|
7
|
+
#include <ctime>
|
8
|
+
#include <sys/sysinfo.h>
|
9
|
+
#include <string>
|
10
|
+
#include <thread>
|
11
|
+
#include <unordered_map>
|
12
|
+
#include "json-schema-to-grammar.h"
|
13
|
+
#include "llama.h"
|
14
|
+
#include "chat.h"
|
15
|
+
#include "llama-impl.h"
|
16
|
+
#include "ggml.h"
|
17
|
+
#include "rn-llama.h"
|
18
|
+
#include "jni-utils.h"
|
19
|
+
#define UNUSED(x) (void)(x)
|
20
|
+
#define TAG "RNLLAMA_ANDROID_JNI"
|
21
|
+
|
22
|
+
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
23
|
+
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
24
|
+
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
25
|
+
static inline int min(int a, int b) {
|
26
|
+
return (a < b) ? a : b;
|
27
|
+
}
|
28
|
+
|
29
|
+
static void rnllama_log_callback_default(lm_ggml_log_level level, const char * fmt, void * data) {
|
30
|
+
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
31
|
+
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
32
|
+
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
|
33
|
+
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
|
34
|
+
}
|
35
|
+
|
36
|
+
extern "C" {
|
37
|
+
|
38
|
+
// Method to create WritableMap
|
39
|
+
static inline jobject createWriteableMap(JNIEnv *env) {
|
40
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
|
41
|
+
jmethodID init = env->GetStaticMethodID(mapClass, "createMap", "()Lcom/facebook/react/bridge/WritableMap;");
|
42
|
+
jobject map = env->CallStaticObjectMethod(mapClass, init);
|
43
|
+
return map;
|
44
|
+
}
|
45
|
+
|
46
|
+
// Method to put string into WritableMap
|
47
|
+
static inline void putString(JNIEnv *env, jobject map, const char *key, const char *value) {
|
48
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
49
|
+
jmethodID putStringMethod = env->GetMethodID(mapClass, "putString", "(Ljava/lang/String;Ljava/lang/String;)V");
|
50
|
+
|
51
|
+
jstring jKey = env->NewStringUTF(key);
|
52
|
+
jstring jValue = env->NewStringUTF(value);
|
53
|
+
|
54
|
+
env->CallVoidMethod(map, putStringMethod, jKey, jValue);
|
55
|
+
}
|
56
|
+
|
57
|
+
// Method to put int into WritableMap
|
58
|
+
static inline void putInt(JNIEnv *env, jobject map, const char *key, int value) {
|
59
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
60
|
+
jmethodID putIntMethod = env->GetMethodID(mapClass, "putInt", "(Ljava/lang/String;I)V");
|
61
|
+
|
62
|
+
jstring jKey = env->NewStringUTF(key);
|
63
|
+
|
64
|
+
env->CallVoidMethod(map, putIntMethod, jKey, value);
|
65
|
+
}
|
66
|
+
|
67
|
+
// Method to put double into WritableMap
|
68
|
+
static inline void putDouble(JNIEnv *env, jobject map, const char *key, double value) {
|
69
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
70
|
+
jmethodID putDoubleMethod = env->GetMethodID(mapClass, "putDouble", "(Ljava/lang/String;D)V");
|
71
|
+
|
72
|
+
jstring jKey = env->NewStringUTF(key);
|
73
|
+
|
74
|
+
env->CallVoidMethod(map, putDoubleMethod, jKey, value);
|
75
|
+
}
|
76
|
+
|
77
|
+
// Method to put boolean into WritableMap
|
78
|
+
static inline void putBoolean(JNIEnv *env, jobject map, const char *key, bool value) {
|
79
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
80
|
+
jmethodID putBooleanMethod = env->GetMethodID(mapClass, "putBoolean", "(Ljava/lang/String;Z)V");
|
81
|
+
|
82
|
+
jstring jKey = env->NewStringUTF(key);
|
83
|
+
|
84
|
+
env->CallVoidMethod(map, putBooleanMethod, jKey, value);
|
85
|
+
}
|
86
|
+
|
87
|
+
// Method to put WriteableMap into WritableMap
|
88
|
+
static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
|
89
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
90
|
+
jmethodID putMapMethod = env->GetMethodID(mapClass, "putMap", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableMap;)V");
|
91
|
+
|
92
|
+
jstring jKey = env->NewStringUTF(key);
|
93
|
+
|
94
|
+
env->CallVoidMethod(map, putMapMethod, jKey, value);
|
95
|
+
}
|
96
|
+
|
97
|
+
// Method to create WritableArray
|
98
|
+
static inline jobject createWritableArray(JNIEnv *env) {
|
99
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
|
100
|
+
jmethodID init = env->GetStaticMethodID(mapClass, "createArray", "()Lcom/facebook/react/bridge/WritableArray;");
|
101
|
+
jobject map = env->CallStaticObjectMethod(mapClass, init);
|
102
|
+
return map;
|
103
|
+
}
|
104
|
+
|
105
|
+
// Method to push int into WritableArray
|
106
|
+
static inline void pushInt(JNIEnv *env, jobject arr, int value) {
|
107
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
108
|
+
jmethodID pushIntMethod = env->GetMethodID(mapClass, "pushInt", "(I)V");
|
109
|
+
|
110
|
+
env->CallVoidMethod(arr, pushIntMethod, value);
|
111
|
+
}
|
112
|
+
|
113
|
+
// Method to push double into WritableArray
|
114
|
+
static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
|
115
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
116
|
+
jmethodID pushDoubleMethod = env->GetMethodID(mapClass, "pushDouble", "(D)V");
|
117
|
+
|
118
|
+
env->CallVoidMethod(arr, pushDoubleMethod, value);
|
119
|
+
}
|
120
|
+
|
121
|
+
// Method to push string into WritableArray
|
122
|
+
static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
|
123
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
124
|
+
jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
|
125
|
+
|
126
|
+
jstring jValue = env->NewStringUTF(value);
|
127
|
+
env->CallVoidMethod(arr, pushStringMethod, jValue);
|
128
|
+
}
|
129
|
+
|
130
|
+
// Method to push WritableMap into WritableArray
|
131
|
+
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
|
132
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
133
|
+
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
|
134
|
+
|
135
|
+
env->CallVoidMethod(arr, pushMapMethod, value);
|
136
|
+
}
|
137
|
+
|
138
|
+
// Method to put WritableArray into WritableMap
|
139
|
+
static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject value) {
|
140
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
141
|
+
jmethodID putArrayMethod = env->GetMethodID(mapClass, "putArray", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableArray;)V");
|
142
|
+
|
143
|
+
jstring jKey = env->NewStringUTF(key);
|
144
|
+
|
145
|
+
env->CallVoidMethod(map, putArrayMethod, jKey, value);
|
146
|
+
}
|
147
|
+
|
148
|
+
JNIEXPORT jobject JNICALL
|
149
|
+
Java_com_rnllama_LlamaContext_modelInfo(
|
150
|
+
JNIEnv *env,
|
151
|
+
jobject thiz,
|
152
|
+
jstring model_path_str,
|
153
|
+
jobjectArray skip
|
154
|
+
) {
|
155
|
+
UNUSED(thiz);
|
156
|
+
|
157
|
+
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
158
|
+
|
159
|
+
std::vector<std::string> skip_vec;
|
160
|
+
int skip_len = env->GetArrayLength(skip);
|
161
|
+
for (int i = 0; i < skip_len; i++) {
|
162
|
+
jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
|
163
|
+
const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
|
164
|
+
skip_vec.push_back(skip_chars);
|
165
|
+
env->ReleaseStringUTFChars(skip_str, skip_chars);
|
166
|
+
}
|
167
|
+
|
168
|
+
struct lm_gguf_init_params params = {
|
169
|
+
/*.no_alloc = */ false,
|
170
|
+
/*.ctx = */ NULL,
|
171
|
+
};
|
172
|
+
struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
|
173
|
+
|
174
|
+
if (!ctx) {
|
175
|
+
LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
|
176
|
+
return nullptr;
|
177
|
+
}
|
178
|
+
|
179
|
+
auto info = createWriteableMap(env);
|
180
|
+
putInt(env, info, "version", lm_gguf_get_version(ctx));
|
181
|
+
putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
|
182
|
+
putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
|
183
|
+
{
|
184
|
+
const int n_kv = lm_gguf_get_n_kv(ctx);
|
185
|
+
|
186
|
+
for (int i = 0; i < n_kv; ++i) {
|
187
|
+
const char * key = lm_gguf_get_key(ctx, i);
|
188
|
+
|
189
|
+
bool skipped = false;
|
190
|
+
if (skip_len > 0) {
|
191
|
+
for (int j = 0; j < skip_len; j++) {
|
192
|
+
if (skip_vec[j] == key) {
|
193
|
+
skipped = true;
|
194
|
+
break;
|
195
|
+
}
|
196
|
+
}
|
197
|
+
}
|
198
|
+
|
199
|
+
if (skipped) {
|
200
|
+
continue;
|
201
|
+
}
|
202
|
+
|
203
|
+
const std::string value = lm_gguf_kv_to_str(ctx, i);
|
204
|
+
putString(env, info, key, value.c_str());
|
205
|
+
}
|
206
|
+
}
|
207
|
+
|
208
|
+
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
209
|
+
lm_gguf_free(ctx);
|
210
|
+
|
211
|
+
return reinterpret_cast<jobject>(info);
|
212
|
+
}
|
213
|
+
|
214
|
+
struct callback_context {
|
215
|
+
JNIEnv *env;
|
216
|
+
rnllama::llama_rn_context *llama;
|
217
|
+
jobject callback;
|
218
|
+
};
|
219
|
+
|
220
|
+
std::unordered_map<long, rnllama::llama_rn_context *> context_map;
|
221
|
+
|
222
|
+
struct CallbackContext {
|
223
|
+
JNIEnv * env;
|
224
|
+
jobject thiz;
|
225
|
+
jmethodID sendProgressMethod;
|
226
|
+
unsigned current;
|
227
|
+
};
|
228
|
+
|
229
|
+
JNIEXPORT jlong JNICALL
|
230
|
+
Java_com_rnllama_LlamaContext_initContext(
|
231
|
+
JNIEnv *env,
|
232
|
+
jobject thiz,
|
233
|
+
jstring model_path_str,
|
234
|
+
jstring chat_template,
|
235
|
+
jstring reasoning_format,
|
236
|
+
jboolean embedding,
|
237
|
+
jint embd_normalize,
|
238
|
+
jint n_ctx,
|
239
|
+
jint n_batch,
|
240
|
+
jint n_ubatch,
|
241
|
+
jint n_threads,
|
242
|
+
jint n_gpu_layers, // TODO: Support this
|
243
|
+
jboolean flash_attn,
|
244
|
+
jstring cache_type_k,
|
245
|
+
jstring cache_type_v,
|
246
|
+
jboolean use_mlock,
|
247
|
+
jboolean use_mmap,
|
248
|
+
jboolean vocab_only,
|
249
|
+
jstring lora_str,
|
250
|
+
jfloat lora_scaled,
|
251
|
+
jobject lora_list,
|
252
|
+
jfloat rope_freq_base,
|
253
|
+
jfloat rope_freq_scale,
|
254
|
+
jint pooling_type,
|
255
|
+
jobject load_progress_callback
|
256
|
+
) {
|
257
|
+
UNUSED(thiz);
|
258
|
+
|
259
|
+
common_params defaultParams;
|
260
|
+
|
261
|
+
defaultParams.vocab_only = vocab_only;
|
262
|
+
if(vocab_only) {
|
263
|
+
defaultParams.warmup = false;
|
264
|
+
}
|
265
|
+
|
266
|
+
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
267
|
+
defaultParams.model = { model_path_chars };
|
268
|
+
|
269
|
+
const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
|
270
|
+
defaultParams.chat_template = chat_template_chars;
|
271
|
+
|
272
|
+
const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
|
273
|
+
if (strcmp(reasoning_format_chars, "deepseek") == 0) {
|
274
|
+
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
275
|
+
} else {
|
276
|
+
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
277
|
+
}
|
278
|
+
|
279
|
+
defaultParams.n_ctx = n_ctx;
|
280
|
+
defaultParams.n_batch = n_batch;
|
281
|
+
defaultParams.n_ubatch = n_ubatch;
|
282
|
+
|
283
|
+
if (pooling_type != -1) {
|
284
|
+
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
|
285
|
+
}
|
286
|
+
|
287
|
+
defaultParams.embedding = embedding;
|
288
|
+
if (embd_normalize != -1) {
|
289
|
+
defaultParams.embd_normalize = embd_normalize;
|
290
|
+
}
|
291
|
+
if (embedding) {
|
292
|
+
// For non-causal models, batch size must be equal to ubatch size
|
293
|
+
defaultParams.n_ubatch = defaultParams.n_batch;
|
294
|
+
}
|
295
|
+
|
296
|
+
int max_threads = std::thread::hardware_concurrency();
|
297
|
+
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
298
|
+
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
299
|
+
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
300
|
+
|
301
|
+
// defaultParams.n_gpu_layers = n_gpu_layers;
|
302
|
+
defaultParams.flash_attn = flash_attn;
|
303
|
+
|
304
|
+
const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
|
305
|
+
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
|
306
|
+
defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
|
307
|
+
defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
|
308
|
+
|
309
|
+
defaultParams.use_mlock = use_mlock;
|
310
|
+
defaultParams.use_mmap = use_mmap;
|
311
|
+
|
312
|
+
defaultParams.rope_freq_base = rope_freq_base;
|
313
|
+
defaultParams.rope_freq_scale = rope_freq_scale;
|
314
|
+
|
315
|
+
auto llama = new rnllama::llama_rn_context();
|
316
|
+
llama->is_load_interrupted = false;
|
317
|
+
llama->loading_progress = 0;
|
318
|
+
|
319
|
+
if (load_progress_callback != nullptr) {
|
320
|
+
defaultParams.progress_callback = [](float progress, void * user_data) {
|
321
|
+
callback_context *cb_ctx = (callback_context *)user_data;
|
322
|
+
JNIEnv *env = cb_ctx->env;
|
323
|
+
auto llama = cb_ctx->llama;
|
324
|
+
jobject callback = cb_ctx->callback;
|
325
|
+
int percentage = (int) (100 * progress);
|
326
|
+
if (percentage > llama->loading_progress) {
|
327
|
+
llama->loading_progress = percentage;
|
328
|
+
jclass callback_class = env->GetObjectClass(callback);
|
329
|
+
jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
|
330
|
+
env->CallVoidMethod(callback, onLoadProgress, percentage);
|
331
|
+
}
|
332
|
+
return !llama->is_load_interrupted;
|
333
|
+
};
|
334
|
+
|
335
|
+
callback_context *cb_ctx = new callback_context;
|
336
|
+
cb_ctx->env = env;
|
337
|
+
cb_ctx->llama = llama;
|
338
|
+
cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
|
339
|
+
defaultParams.progress_callback_user_data = cb_ctx;
|
340
|
+
}
|
341
|
+
|
342
|
+
bool is_model_loaded = llama->loadModel(defaultParams);
|
343
|
+
|
344
|
+
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
345
|
+
env->ReleaseStringUTFChars(chat_template, chat_template_chars);
|
346
|
+
env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
|
347
|
+
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
|
348
|
+
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
349
|
+
|
350
|
+
LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
|
351
|
+
if (is_model_loaded) {
|
352
|
+
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
|
353
|
+
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
|
354
|
+
llama_free(llama->ctx);
|
355
|
+
return -1;
|
356
|
+
}
|
357
|
+
context_map[(long) llama->ctx] = llama;
|
358
|
+
} else {
|
359
|
+
llama_free(llama->ctx);
|
360
|
+
}
|
361
|
+
|
362
|
+
std::vector<common_adapter_lora_info> lora;
|
363
|
+
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
364
|
+
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
365
|
+
common_adapter_lora_info la;
|
366
|
+
la.path = lora_chars;
|
367
|
+
la.scale = lora_scaled;
|
368
|
+
lora.push_back(la);
|
369
|
+
}
|
370
|
+
|
371
|
+
if (lora_list != nullptr) {
|
372
|
+
// lora_adapters: ReadableArray<ReadableMap>
|
373
|
+
int lora_list_size = readablearray::size(env, lora_list);
|
374
|
+
for (int i = 0; i < lora_list_size; i++) {
|
375
|
+
jobject lora_adapter = readablearray::getMap(env, lora_list, i);
|
376
|
+
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
377
|
+
if (path != nullptr) {
|
378
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
379
|
+
common_adapter_lora_info la;
|
380
|
+
la.path = path_chars;
|
381
|
+
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
382
|
+
lora.push_back(la);
|
383
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
384
|
+
}
|
385
|
+
}
|
386
|
+
}
|
387
|
+
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
388
|
+
int result = llama->applyLoraAdapters(lora);
|
389
|
+
if (result != 0) {
|
390
|
+
LOGI("[RNLlama] Failed to apply lora adapters");
|
391
|
+
llama_free(llama->ctx);
|
392
|
+
return -1;
|
393
|
+
}
|
394
|
+
|
395
|
+
return reinterpret_cast<jlong>(llama->ctx);
|
396
|
+
}
|
397
|
+
|
398
|
+
|
399
|
+
JNIEXPORT void JNICALL
|
400
|
+
Java_com_rnllama_LlamaContext_interruptLoad(
|
401
|
+
JNIEnv *env,
|
402
|
+
jobject thiz,
|
403
|
+
jlong context_ptr
|
404
|
+
) {
|
405
|
+
UNUSED(thiz);
|
406
|
+
auto llama = context_map[(long) context_ptr];
|
407
|
+
if (llama) {
|
408
|
+
llama->is_load_interrupted = true;
|
409
|
+
}
|
410
|
+
}
|
411
|
+
|
412
|
+
JNIEXPORT jobject JNICALL
|
413
|
+
Java_com_rnllama_LlamaContext_loadModelDetails(
|
414
|
+
JNIEnv *env,
|
415
|
+
jobject thiz,
|
416
|
+
jlong context_ptr
|
417
|
+
) {
|
418
|
+
UNUSED(thiz);
|
419
|
+
auto llama = context_map[(long) context_ptr];
|
420
|
+
|
421
|
+
int count = llama_model_meta_count(llama->model);
|
422
|
+
auto meta = createWriteableMap(env);
|
423
|
+
for (int i = 0; i < count; i++) {
|
424
|
+
char key[256];
|
425
|
+
llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
|
426
|
+
char val[4096];
|
427
|
+
llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
|
428
|
+
|
429
|
+
putString(env, meta, key, val);
|
430
|
+
}
|
431
|
+
|
432
|
+
auto result = createWriteableMap(env);
|
433
|
+
|
434
|
+
char desc[1024];
|
435
|
+
llama_model_desc(llama->model, desc, sizeof(desc));
|
436
|
+
|
437
|
+
putString(env, result, "desc", desc);
|
438
|
+
putDouble(env, result, "size", llama_model_size(llama->model));
|
439
|
+
putDouble(env, result, "nEmbd", llama_model_n_embd(llama->model));
|
440
|
+
putDouble(env, result, "nParams", llama_model_n_params(llama->model));
|
441
|
+
auto chat_templates = createWriteableMap(env);
|
442
|
+
putBoolean(env, chat_templates, "llamaChat", llama->validateModelChatTemplate(false, nullptr));
|
443
|
+
|
444
|
+
auto minja = createWriteableMap(env);
|
445
|
+
putBoolean(env, minja, "default", llama->validateModelChatTemplate(true, nullptr));
|
446
|
+
|
447
|
+
auto default_caps = createWriteableMap(env);
|
448
|
+
|
449
|
+
auto default_tmpl = llama->templates.get()->template_default.get();
|
450
|
+
auto default_tmpl_caps = default_tmpl->original_caps();
|
451
|
+
putBoolean(env, default_caps, "tools", default_tmpl_caps.supports_tools);
|
452
|
+
putBoolean(env, default_caps, "toolCalls", default_tmpl_caps.supports_tool_calls);
|
453
|
+
putBoolean(env, default_caps, "parallelToolCalls", default_tmpl_caps.supports_parallel_tool_calls);
|
454
|
+
putBoolean(env, default_caps, "toolResponses", default_tmpl_caps.supports_tool_responses);
|
455
|
+
putBoolean(env, default_caps, "systemRole", default_tmpl_caps.supports_system_role);
|
456
|
+
putBoolean(env, default_caps, "toolCallId", default_tmpl_caps.supports_tool_call_id);
|
457
|
+
putMap(env, minja, "defaultCaps", default_caps);
|
458
|
+
|
459
|
+
putBoolean(env, minja, "toolUse", llama->validateModelChatTemplate(true, "tool_use"));
|
460
|
+
auto tool_use_tmpl = llama->templates.get()->template_tool_use.get();
|
461
|
+
if (tool_use_tmpl != nullptr) {
|
462
|
+
auto tool_use_caps = createWriteableMap(env);
|
463
|
+
auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
|
464
|
+
putBoolean(env, tool_use_caps, "tools", tool_use_tmpl_caps.supports_tools);
|
465
|
+
putBoolean(env, tool_use_caps, "toolCalls", tool_use_tmpl_caps.supports_tool_calls);
|
466
|
+
putBoolean(env, tool_use_caps, "parallelToolCalls", tool_use_tmpl_caps.supports_parallel_tool_calls);
|
467
|
+
putBoolean(env, tool_use_caps, "systemRole", tool_use_tmpl_caps.supports_system_role);
|
468
|
+
putBoolean(env, tool_use_caps, "toolResponses", tool_use_tmpl_caps.supports_tool_responses);
|
469
|
+
putBoolean(env, tool_use_caps, "toolCallId", tool_use_tmpl_caps.supports_tool_call_id);
|
470
|
+
putMap(env, minja, "toolUseCaps", tool_use_caps);
|
471
|
+
}
|
472
|
+
|
473
|
+
putMap(env, chat_templates, "minja", minja);
|
474
|
+
putMap(env, result, "metadata", meta);
|
475
|
+
putMap(env, result, "chatTemplates", chat_templates);
|
476
|
+
|
477
|
+
// deprecated
|
478
|
+
putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate(false, nullptr));
|
479
|
+
|
480
|
+
return reinterpret_cast<jobject>(result);
|
481
|
+
}
|
482
|
+
|
483
|
+
JNIEXPORT jobject JNICALL
|
484
|
+
Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
|
485
|
+
JNIEnv *env,
|
486
|
+
jobject thiz,
|
487
|
+
jlong context_ptr,
|
488
|
+
jstring messages,
|
489
|
+
jstring chat_template,
|
490
|
+
jstring json_schema,
|
491
|
+
jstring tools,
|
492
|
+
jboolean parallel_tool_calls,
|
493
|
+
jstring tool_choice
|
494
|
+
) {
|
495
|
+
UNUSED(thiz);
|
496
|
+
auto llama = context_map[(long) context_ptr];
|
497
|
+
|
498
|
+
const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
|
499
|
+
const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
|
500
|
+
const char *json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
|
501
|
+
const char *tools_chars = env->GetStringUTFChars(tools, nullptr);
|
502
|
+
const char *tool_choice_chars = env->GetStringUTFChars(tool_choice, nullptr);
|
503
|
+
|
504
|
+
auto result = createWriteableMap(env);
|
505
|
+
try {
|
506
|
+
auto formatted = llama->getFormattedChatWithJinja(
|
507
|
+
messages_chars,
|
508
|
+
tmpl_chars,
|
509
|
+
json_schema_chars,
|
510
|
+
tools_chars,
|
511
|
+
parallel_tool_calls,
|
512
|
+
tool_choice_chars
|
513
|
+
);
|
514
|
+
putString(env, result, "prompt", formatted.prompt.c_str());
|
515
|
+
putInt(env, result, "chat_format", static_cast<int>(formatted.format));
|
516
|
+
putString(env, result, "grammar", formatted.grammar.c_str());
|
517
|
+
putBoolean(env, result, "grammar_lazy", formatted.grammar_lazy);
|
518
|
+
auto grammar_triggers = createWritableArray(env);
|
519
|
+
for (const auto &trigger : formatted.grammar_triggers) {
|
520
|
+
auto trigger_map = createWriteableMap(env);
|
521
|
+
putInt(env, trigger_map, "type", trigger.type);
|
522
|
+
putString(env, trigger_map, "value", trigger.value.c_str());
|
523
|
+
putInt(env, trigger_map, "token", trigger.token);
|
524
|
+
pushMap(env, grammar_triggers, trigger_map);
|
525
|
+
}
|
526
|
+
putArray(env, result, "grammar_triggers", grammar_triggers);
|
527
|
+
auto preserved_tokens = createWritableArray(env);
|
528
|
+
for (const auto &token : formatted.preserved_tokens) {
|
529
|
+
pushString(env, preserved_tokens, token.c_str());
|
530
|
+
}
|
531
|
+
putArray(env, result, "preserved_tokens", preserved_tokens);
|
532
|
+
auto additional_stops = createWritableArray(env);
|
533
|
+
for (const auto &stop : formatted.additional_stops) {
|
534
|
+
pushString(env, additional_stops, stop.c_str());
|
535
|
+
}
|
536
|
+
putArray(env, result, "additional_stops", additional_stops);
|
537
|
+
} catch (const std::runtime_error &e) {
|
538
|
+
LOGI("[RNLlama] Error: %s", e.what());
|
539
|
+
putString(env, result, "_error", e.what());
|
540
|
+
}
|
541
|
+
env->ReleaseStringUTFChars(tools, tools_chars);
|
542
|
+
env->ReleaseStringUTFChars(messages, messages_chars);
|
543
|
+
env->ReleaseStringUTFChars(chat_template, tmpl_chars);
|
544
|
+
env->ReleaseStringUTFChars(json_schema, json_schema_chars);
|
545
|
+
env->ReleaseStringUTFChars(tool_choice, tool_choice_chars);
|
546
|
+
return reinterpret_cast<jobject>(result);
|
547
|
+
}
|
548
|
+
|
549
|
+
JNIEXPORT jobject JNICALL
|
550
|
+
Java_com_rnllama_LlamaContext_getFormattedChat(
|
551
|
+
JNIEnv *env,
|
552
|
+
jobject thiz,
|
553
|
+
jlong context_ptr,
|
554
|
+
jstring messages,
|
555
|
+
jstring chat_template
|
556
|
+
) {
|
557
|
+
UNUSED(thiz);
|
558
|
+
auto llama = context_map[(long) context_ptr];
|
559
|
+
|
560
|
+
const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
|
561
|
+
const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
|
562
|
+
|
563
|
+
std::string formatted_chat = llama->getFormattedChat(messages_chars, tmpl_chars);
|
564
|
+
|
565
|
+
env->ReleaseStringUTFChars(messages, messages_chars);
|
566
|
+
env->ReleaseStringUTFChars(chat_template, tmpl_chars);
|
567
|
+
|
568
|
+
return env->NewStringUTF(formatted_chat.c_str());
|
569
|
+
}
|
570
|
+
|
571
|
+
JNIEXPORT jobject JNICALL
|
572
|
+
Java_com_rnllama_LlamaContext_loadSession(
|
573
|
+
JNIEnv *env,
|
574
|
+
jobject thiz,
|
575
|
+
jlong context_ptr,
|
576
|
+
jstring path
|
577
|
+
) {
|
578
|
+
UNUSED(thiz);
|
579
|
+
auto llama = context_map[(long) context_ptr];
|
580
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
581
|
+
|
582
|
+
auto result = createWriteableMap(env);
|
583
|
+
size_t n_token_count_out = 0;
|
584
|
+
llama->embd.resize(llama->params.n_ctx);
|
585
|
+
if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
|
586
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
587
|
+
|
588
|
+
putString(env, result, "error", "Failed to load session");
|
589
|
+
return reinterpret_cast<jobject>(result);
|
590
|
+
}
|
591
|
+
llama->embd.resize(n_token_count_out);
|
592
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
593
|
+
|
594
|
+
const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
|
595
|
+
putInt(env, result, "tokens_loaded", n_token_count_out);
|
596
|
+
putString(env, result, "prompt", text.c_str());
|
597
|
+
return reinterpret_cast<jobject>(result);
|
598
|
+
}
|
599
|
+
|
600
|
+
JNIEXPORT jint JNICALL
|
601
|
+
Java_com_rnllama_LlamaContext_saveSession(
|
602
|
+
JNIEnv *env,
|
603
|
+
jobject thiz,
|
604
|
+
jlong context_ptr,
|
605
|
+
jstring path,
|
606
|
+
jint size
|
607
|
+
) {
|
608
|
+
UNUSED(thiz);
|
609
|
+
auto llama = context_map[(long) context_ptr];
|
610
|
+
|
611
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
612
|
+
|
613
|
+
std::vector<llama_token> session_tokens = llama->embd;
|
614
|
+
int default_size = session_tokens.size();
|
615
|
+
int save_size = size > 0 && size <= default_size ? size : default_size;
|
616
|
+
if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
|
617
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
618
|
+
return -1;
|
619
|
+
}
|
620
|
+
|
621
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
622
|
+
return session_tokens.size();
|
623
|
+
}
|
624
|
+
|
625
|
+
static inline jobject tokenProbsToMap(
|
626
|
+
JNIEnv *env,
|
627
|
+
rnllama::llama_rn_context *llama,
|
628
|
+
std::vector<rnllama::completion_token_output> probs
|
629
|
+
) {
|
630
|
+
auto result = createWritableArray(env);
|
631
|
+
for (const auto &prob : probs) {
|
632
|
+
auto probsForToken = createWritableArray(env);
|
633
|
+
for (const auto &p : prob.probs) {
|
634
|
+
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
|
635
|
+
auto probResult = createWriteableMap(env);
|
636
|
+
putString(env, probResult, "tok_str", tokStr.c_str());
|
637
|
+
putDouble(env, probResult, "prob", p.prob);
|
638
|
+
pushMap(env, probsForToken, probResult);
|
639
|
+
}
|
640
|
+
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
|
641
|
+
auto tokenResult = createWriteableMap(env);
|
642
|
+
putString(env, tokenResult, "content", tokStr.c_str());
|
643
|
+
putArray(env, tokenResult, "probs", probsForToken);
|
644
|
+
pushMap(env, result, tokenResult);
|
645
|
+
}
|
646
|
+
return result;
|
647
|
+
}
|
648
|
+
|
649
|
+
JNIEXPORT jobject JNICALL
|
650
|
+
Java_com_rnllama_LlamaContext_doCompletion(
|
651
|
+
JNIEnv *env,
|
652
|
+
jobject thiz,
|
653
|
+
jlong context_ptr,
|
654
|
+
jstring prompt,
|
655
|
+
jint chat_format,
|
656
|
+
jstring grammar,
|
657
|
+
jstring json_schema,
|
658
|
+
jboolean grammar_lazy,
|
659
|
+
jobject grammar_triggers,
|
660
|
+
jobject preserved_tokens,
|
661
|
+
jfloat temperature,
|
662
|
+
jint n_threads,
|
663
|
+
jint n_predict,
|
664
|
+
jint n_probs,
|
665
|
+
jint penalty_last_n,
|
666
|
+
jfloat penalty_repeat,
|
667
|
+
jfloat penalty_freq,
|
668
|
+
jfloat penalty_present,
|
669
|
+
jfloat mirostat,
|
670
|
+
jfloat mirostat_tau,
|
671
|
+
jfloat mirostat_eta,
|
672
|
+
jint top_k,
|
673
|
+
jfloat top_p,
|
674
|
+
jfloat min_p,
|
675
|
+
jfloat xtc_threshold,
|
676
|
+
jfloat xtc_probability,
|
677
|
+
jfloat typical_p,
|
678
|
+
jint seed,
|
679
|
+
jobjectArray stop,
|
680
|
+
jboolean ignore_eos,
|
681
|
+
jobjectArray logit_bias,
|
682
|
+
jfloat dry_multiplier,
|
683
|
+
jfloat dry_base,
|
684
|
+
jint dry_allowed_length,
|
685
|
+
jint dry_penalty_last_n,
|
686
|
+
jfloat top_n_sigma,
|
687
|
+
jobjectArray dry_sequence_breakers,
|
688
|
+
jobject partial_completion_callback
|
689
|
+
) {
|
690
|
+
UNUSED(thiz);
|
691
|
+
auto llama = context_map[(long) context_ptr];
|
692
|
+
|
693
|
+
llama->rewind();
|
694
|
+
|
695
|
+
//llama_reset_timings(llama->ctx);
|
696
|
+
|
697
|
+
auto prompt_chars = env->GetStringUTFChars(prompt, nullptr);
|
698
|
+
llama->params.prompt = prompt_chars;
|
699
|
+
llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
|
700
|
+
|
701
|
+
int max_threads = std::thread::hardware_concurrency();
|
702
|
+
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
703
|
+
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
704
|
+
llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
705
|
+
|
706
|
+
llama->params.n_predict = n_predict;
|
707
|
+
llama->params.sampling.ignore_eos = ignore_eos;
|
708
|
+
|
709
|
+
auto & sparams = llama->params.sampling;
|
710
|
+
sparams.temp = temperature;
|
711
|
+
sparams.penalty_last_n = penalty_last_n;
|
712
|
+
sparams.penalty_repeat = penalty_repeat;
|
713
|
+
sparams.penalty_freq = penalty_freq;
|
714
|
+
sparams.penalty_present = penalty_present;
|
715
|
+
sparams.mirostat = mirostat;
|
716
|
+
sparams.mirostat_tau = mirostat_tau;
|
717
|
+
sparams.mirostat_eta = mirostat_eta;
|
718
|
+
sparams.top_k = top_k;
|
719
|
+
sparams.top_p = top_p;
|
720
|
+
sparams.min_p = min_p;
|
721
|
+
sparams.typ_p = typical_p;
|
722
|
+
sparams.n_probs = n_probs;
|
723
|
+
sparams.xtc_threshold = xtc_threshold;
|
724
|
+
sparams.xtc_probability = xtc_probability;
|
725
|
+
sparams.dry_multiplier = dry_multiplier;
|
726
|
+
sparams.dry_base = dry_base;
|
727
|
+
sparams.dry_allowed_length = dry_allowed_length;
|
728
|
+
sparams.dry_penalty_last_n = dry_penalty_last_n;
|
729
|
+
sparams.top_n_sigma = top_n_sigma;
|
730
|
+
|
731
|
+
// grammar
|
732
|
+
auto grammar_chars = env->GetStringUTFChars(grammar, nullptr);
|
733
|
+
if (grammar_chars && grammar_chars[0] != '\0') {
|
734
|
+
sparams.grammar = grammar_chars;
|
735
|
+
}
|
736
|
+
sparams.grammar_lazy = grammar_lazy;
|
737
|
+
|
738
|
+
if (preserved_tokens != nullptr) {
|
739
|
+
int preserved_tokens_size = readablearray::size(env, preserved_tokens);
|
740
|
+
for (int i = 0; i < preserved_tokens_size; i++) {
|
741
|
+
jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
|
742
|
+
auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
|
743
|
+
if (ids.size() == 1) {
|
744
|
+
sparams.preserved_tokens.insert(ids[0]);
|
745
|
+
} else {
|
746
|
+
LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
|
747
|
+
}
|
748
|
+
}
|
749
|
+
}
|
750
|
+
|
751
|
+
if (grammar_triggers != nullptr) {
|
752
|
+
int grammar_triggers_size = readablearray::size(env, grammar_triggers);
|
753
|
+
for (int i = 0; i < grammar_triggers_size; i++) {
|
754
|
+
auto trigger_map = readablearray::getMap(env, grammar_triggers, i);
|
755
|
+
const auto type = static_cast<common_grammar_trigger_type>(readablemap::getInt(env, trigger_map, "type", 0));
|
756
|
+
jstring trigger_word = readablemap::getString(env, trigger_map, "value", nullptr);
|
757
|
+
auto word = env->GetStringUTFChars(trigger_word, nullptr);
|
758
|
+
|
759
|
+
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
760
|
+
auto ids = common_tokenize(llama->ctx, word, /* add_special= */ false, /* parse_special= */ true);
|
761
|
+
if (ids.size() == 1) {
|
762
|
+
auto token = ids[0];
|
763
|
+
if (std::find(sparams.preserved_tokens.begin(), sparams.preserved_tokens.end(), (llama_token) token) == sparams.preserved_tokens.end()) {
|
764
|
+
throw std::runtime_error("Grammar trigger word should be marked as preserved token");
|
765
|
+
}
|
766
|
+
common_grammar_trigger trigger;
|
767
|
+
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
768
|
+
trigger.value = word;
|
769
|
+
trigger.token = token;
|
770
|
+
sparams.grammar_triggers.push_back(std::move(trigger));
|
771
|
+
} else {
|
772
|
+
sparams.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
773
|
+
}
|
774
|
+
} else {
|
775
|
+
common_grammar_trigger trigger;
|
776
|
+
trigger.type = type;
|
777
|
+
trigger.value = word;
|
778
|
+
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
779
|
+
const auto token = (llama_token) readablemap::getInt(env, trigger_map, "token", 0);
|
780
|
+
trigger.token = token;
|
781
|
+
}
|
782
|
+
sparams.grammar_triggers.push_back(std::move(trigger));
|
783
|
+
}
|
784
|
+
}
|
785
|
+
}
|
786
|
+
|
787
|
+
auto json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
|
788
|
+
if ((!grammar_chars || grammar_chars[0] == '\0') && json_schema_chars && json_schema_chars[0] != '\0') {
|
789
|
+
auto schema = json::parse(json_schema_chars);
|
790
|
+
sparams.grammar = json_schema_to_grammar(schema);
|
791
|
+
}
|
792
|
+
env->ReleaseStringUTFChars(json_schema, json_schema_chars);
|
793
|
+
|
794
|
+
|
795
|
+
const llama_model * model = llama_get_model(llama->ctx);
|
796
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
797
|
+
|
798
|
+
sparams.logit_bias.clear();
|
799
|
+
if (ignore_eos) {
|
800
|
+
sparams.logit_bias[llama_vocab_eos(vocab)].bias = -INFINITY;
|
801
|
+
}
|
802
|
+
|
803
|
+
// dry break seq
|
804
|
+
|
805
|
+
jint size = env->GetArrayLength(dry_sequence_breakers);
|
806
|
+
std::vector<std::string> dry_sequence_breakers_vector;
|
807
|
+
|
808
|
+
for (jint i = 0; i < size; i++) {
|
809
|
+
jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
|
810
|
+
const char *nativeString = env->GetStringUTFChars(javaString, 0);
|
811
|
+
dry_sequence_breakers_vector.push_back(std::string(nativeString));
|
812
|
+
env->ReleaseStringUTFChars(javaString, nativeString);
|
813
|
+
env->DeleteLocalRef(javaString);
|
814
|
+
}
|
815
|
+
|
816
|
+
sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
|
817
|
+
|
818
|
+
// logit bias
|
819
|
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
820
|
+
jsize logit_bias_len = env->GetArrayLength(logit_bias);
|
821
|
+
|
822
|
+
for (jsize i = 0; i < logit_bias_len; i++) {
|
823
|
+
jdoubleArray el = (jdoubleArray) env->GetObjectArrayElement(logit_bias, i);
|
824
|
+
if (el && env->GetArrayLength(el) == 2) {
|
825
|
+
jdouble* doubleArray = env->GetDoubleArrayElements(el, 0);
|
826
|
+
|
827
|
+
llama_token tok = static_cast<llama_token>(doubleArray[0]);
|
828
|
+
if (tok >= 0 && tok < n_vocab) {
|
829
|
+
if (doubleArray[1] != 0) { // If the second element is not false (0)
|
830
|
+
sparams.logit_bias[tok].bias = doubleArray[1];
|
831
|
+
} else {
|
832
|
+
sparams.logit_bias[tok].bias = -INFINITY;
|
833
|
+
}
|
834
|
+
}
|
835
|
+
|
836
|
+
env->ReleaseDoubleArrayElements(el, doubleArray, 0);
|
837
|
+
}
|
838
|
+
env->DeleteLocalRef(el);
|
839
|
+
}
|
840
|
+
|
841
|
+
llama->params.antiprompt.clear();
|
842
|
+
int stop_len = env->GetArrayLength(stop);
|
843
|
+
for (int i = 0; i < stop_len; i++) {
|
844
|
+
jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
|
845
|
+
const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
|
846
|
+
llama->params.antiprompt.push_back(stop_chars);
|
847
|
+
env->ReleaseStringUTFChars(stop_str, stop_chars);
|
848
|
+
}
|
849
|
+
|
850
|
+
if (!llama->initSampling()) {
|
851
|
+
auto result = createWriteableMap(env);
|
852
|
+
putString(env, result, "error", "Failed to initialize sampling");
|
853
|
+
return reinterpret_cast<jobject>(result);
|
854
|
+
}
|
855
|
+
llama->beginCompletion();
|
856
|
+
llama->loadPrompt();
|
857
|
+
|
858
|
+
size_t sent_count = 0;
|
859
|
+
size_t sent_token_probs_index = 0;
|
860
|
+
|
861
|
+
while (llama->has_next_token && !llama->is_interrupted) {
|
862
|
+
const rnllama::completion_token_output token_with_probs = llama->doCompletion();
|
863
|
+
if (token_with_probs.tok == -1 || llama->incomplete) {
|
864
|
+
continue;
|
865
|
+
}
|
866
|
+
const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
|
867
|
+
|
868
|
+
size_t pos = std::min(sent_count, llama->generated_text.size());
|
869
|
+
|
870
|
+
const std::string str_test = llama->generated_text.substr(pos);
|
871
|
+
bool is_stop_full = false;
|
872
|
+
size_t stop_pos =
|
873
|
+
llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
|
874
|
+
if (stop_pos != std::string::npos) {
|
875
|
+
is_stop_full = true;
|
876
|
+
llama->generated_text.erase(
|
877
|
+
llama->generated_text.begin() + pos + stop_pos,
|
878
|
+
llama->generated_text.end());
|
879
|
+
pos = std::min(sent_count, llama->generated_text.size());
|
880
|
+
} else {
|
881
|
+
is_stop_full = false;
|
882
|
+
stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
|
883
|
+
rnllama::STOP_PARTIAL);
|
884
|
+
}
|
885
|
+
|
886
|
+
if (
|
887
|
+
stop_pos == std::string::npos ||
|
888
|
+
// Send rest of the text if we are at the end of the generation
|
889
|
+
(!llama->has_next_token && !is_stop_full && stop_pos > 0)
|
890
|
+
) {
|
891
|
+
const std::string to_send = llama->generated_text.substr(pos, std::string::npos);
|
892
|
+
|
893
|
+
sent_count += to_send.size();
|
894
|
+
|
895
|
+
std::vector<rnllama::completion_token_output> probs_output = {};
|
896
|
+
|
897
|
+
auto tokenResult = createWriteableMap(env);
|
898
|
+
putString(env, tokenResult, "token", to_send.c_str());
|
899
|
+
|
900
|
+
if (llama->params.sampling.n_probs > 0) {
|
901
|
+
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
|
902
|
+
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
|
903
|
+
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
|
904
|
+
if (probs_pos < probs_stop_pos) {
|
905
|
+
probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
|
906
|
+
}
|
907
|
+
sent_token_probs_index = probs_stop_pos;
|
908
|
+
|
909
|
+
putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
|
910
|
+
}
|
911
|
+
|
912
|
+
jclass cb_class = env->GetObjectClass(partial_completion_callback);
|
913
|
+
jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
|
914
|
+
env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
|
915
|
+
}
|
916
|
+
}
|
917
|
+
|
918
|
+
env->ReleaseStringUTFChars(grammar, grammar_chars);
|
919
|
+
env->ReleaseStringUTFChars(prompt, prompt_chars);
|
920
|
+
llama_perf_context_print(llama->ctx);
|
921
|
+
llama->is_predicting = false;
|
922
|
+
|
923
|
+
auto toolCalls = createWritableArray(env);
|
924
|
+
std::string reasoningContent = "";
|
925
|
+
std::string content;
|
926
|
+
auto toolCallsSize = 0;
|
927
|
+
if (!llama->is_interrupted) {
|
928
|
+
try {
|
929
|
+
common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
|
930
|
+
if (!message.reasoning_content.empty()) {
|
931
|
+
reasoningContent = message.reasoning_content;
|
932
|
+
}
|
933
|
+
content = message.content;
|
934
|
+
for (const auto &tc : message.tool_calls) {
|
935
|
+
auto toolCall = createWriteableMap(env);
|
936
|
+
putString(env, toolCall, "type", "function");
|
937
|
+
auto functionMap = createWriteableMap(env);
|
938
|
+
putString(env, functionMap, "name", tc.name.c_str());
|
939
|
+
putString(env, functionMap, "arguments", tc.arguments.c_str());
|
940
|
+
putMap(env, toolCall, "function", functionMap);
|
941
|
+
if (!tc.id.empty()) {
|
942
|
+
putString(env, toolCall, "id", tc.id.c_str());
|
943
|
+
}
|
944
|
+
pushMap(env, toolCalls, toolCall);
|
945
|
+
toolCallsSize++;
|
946
|
+
}
|
947
|
+
} catch (const std::exception &e) {
|
948
|
+
// LOGI("Error parsing tool calls: %s", e.what());
|
949
|
+
}
|
950
|
+
}
|
951
|
+
|
952
|
+
auto result = createWriteableMap(env);
|
953
|
+
putString(env, result, "text", llama->generated_text.c_str());
|
954
|
+
if (!content.empty()) {
|
955
|
+
putString(env, result, "content", content.c_str());
|
956
|
+
}
|
957
|
+
if (!reasoningContent.empty()) {
|
958
|
+
putString(env, result, "reasoning_content", reasoningContent.c_str());
|
959
|
+
}
|
960
|
+
if (toolCallsSize > 0) {
|
961
|
+
putArray(env, result, "tool_calls", toolCalls);
|
962
|
+
}
|
963
|
+
putArray(env, result, "completion_probabilities", tokenProbsToMap(env, llama, llama->generated_token_probs));
|
964
|
+
putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
|
965
|
+
putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
|
966
|
+
putInt(env, result, "truncated", llama->truncated);
|
967
|
+
putInt(env, result, "stopped_eos", llama->stopped_eos);
|
968
|
+
putInt(env, result, "stopped_word", llama->stopped_word);
|
969
|
+
putInt(env, result, "stopped_limit", llama->stopped_limit);
|
970
|
+
putString(env, result, "stopping_word", llama->stopping_word.c_str());
|
971
|
+
putInt(env, result, "tokens_cached", llama->n_past);
|
972
|
+
|
973
|
+
const auto timings_token = llama_perf_context(llama -> ctx);
|
974
|
+
|
975
|
+
auto timingsResult = createWriteableMap(env);
|
976
|
+
putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
|
977
|
+
putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms);
|
978
|
+
putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval);
|
979
|
+
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval);
|
980
|
+
putInt(env, timingsResult, "predicted_n", timings_token.n_eval);
|
981
|
+
putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms);
|
982
|
+
putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval);
|
983
|
+
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval);
|
984
|
+
|
985
|
+
putMap(env, result, "timings", timingsResult);
|
986
|
+
|
987
|
+
return reinterpret_cast<jobject>(result);
|
988
|
+
}
|
989
|
+
|
990
|
+
JNIEXPORT void JNICALL
|
991
|
+
Java_com_rnllama_LlamaContext_stopCompletion(
|
992
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
993
|
+
UNUSED(env);
|
994
|
+
UNUSED(thiz);
|
995
|
+
auto llama = context_map[(long) context_ptr];
|
996
|
+
llama->is_interrupted = true;
|
997
|
+
}
|
998
|
+
|
999
|
+
JNIEXPORT jboolean JNICALL
|
1000
|
+
Java_com_rnllama_LlamaContext_isPredicting(
|
1001
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1002
|
+
UNUSED(env);
|
1003
|
+
UNUSED(thiz);
|
1004
|
+
auto llama = context_map[(long) context_ptr];
|
1005
|
+
return llama->is_predicting;
|
1006
|
+
}
|
1007
|
+
|
1008
|
+
JNIEXPORT jobject JNICALL
|
1009
|
+
Java_com_rnllama_LlamaContext_tokenize(
|
1010
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
|
1011
|
+
UNUSED(thiz);
|
1012
|
+
auto llama = context_map[(long) context_ptr];
|
1013
|
+
|
1014
|
+
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
1015
|
+
|
1016
|
+
const std::vector<llama_token> toks = common_tokenize(
|
1017
|
+
llama->ctx,
|
1018
|
+
text_chars,
|
1019
|
+
false
|
1020
|
+
);
|
1021
|
+
|
1022
|
+
jobject result = createWritableArray(env);
|
1023
|
+
for (const auto &tok : toks) {
|
1024
|
+
pushInt(env, result, tok);
|
1025
|
+
}
|
1026
|
+
|
1027
|
+
env->ReleaseStringUTFChars(text, text_chars);
|
1028
|
+
return result;
|
1029
|
+
}
|
1030
|
+
|
1031
|
+
JNIEXPORT jstring JNICALL
|
1032
|
+
Java_com_rnllama_LlamaContext_detokenize(
|
1033
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jintArray tokens) {
|
1034
|
+
UNUSED(thiz);
|
1035
|
+
auto llama = context_map[(long) context_ptr];
|
1036
|
+
|
1037
|
+
jsize tokens_len = env->GetArrayLength(tokens);
|
1038
|
+
jint *tokens_ptr = env->GetIntArrayElements(tokens, 0);
|
1039
|
+
std::vector<llama_token> toks;
|
1040
|
+
for (int i = 0; i < tokens_len; i++) {
|
1041
|
+
toks.push_back(tokens_ptr[i]);
|
1042
|
+
}
|
1043
|
+
|
1044
|
+
auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
|
1045
|
+
|
1046
|
+
env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
|
1047
|
+
|
1048
|
+
return env->NewStringUTF(text.c_str());
|
1049
|
+
}
|
1050
|
+
|
1051
|
+
JNIEXPORT jboolean JNICALL
|
1052
|
+
Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
|
1053
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1054
|
+
UNUSED(env);
|
1055
|
+
UNUSED(thiz);
|
1056
|
+
auto llama = context_map[(long) context_ptr];
|
1057
|
+
return llama->params.embedding;
|
1058
|
+
}
|
1059
|
+
|
1060
|
+
JNIEXPORT jobject JNICALL
|
1061
|
+
Java_com_rnllama_LlamaContext_embedding(
|
1062
|
+
JNIEnv *env, jobject thiz,
|
1063
|
+
jlong context_ptr,
|
1064
|
+
jstring text,
|
1065
|
+
jint embd_normalize
|
1066
|
+
) {
|
1067
|
+
UNUSED(thiz);
|
1068
|
+
auto llama = context_map[(long) context_ptr];
|
1069
|
+
|
1070
|
+
common_params embdParams;
|
1071
|
+
embdParams.embedding = true;
|
1072
|
+
embdParams.embd_normalize = llama->params.embd_normalize;
|
1073
|
+
if (embd_normalize != -1) {
|
1074
|
+
embdParams.embd_normalize = embd_normalize;
|
1075
|
+
}
|
1076
|
+
|
1077
|
+
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
1078
|
+
|
1079
|
+
llama->rewind();
|
1080
|
+
|
1081
|
+
llama_perf_context_reset(llama->ctx);
|
1082
|
+
|
1083
|
+
llama->params.prompt = text_chars;
|
1084
|
+
|
1085
|
+
llama->params.n_predict = 0;
|
1086
|
+
|
1087
|
+
auto result = createWriteableMap(env);
|
1088
|
+
if (!llama->initSampling()) {
|
1089
|
+
putString(env, result, "error", "Failed to initialize sampling");
|
1090
|
+
return reinterpret_cast<jobject>(result);
|
1091
|
+
}
|
1092
|
+
|
1093
|
+
llama->beginCompletion();
|
1094
|
+
llama->loadPrompt();
|
1095
|
+
llama->doCompletion();
|
1096
|
+
|
1097
|
+
std::vector<float> embedding = llama->getEmbedding(embdParams);
|
1098
|
+
|
1099
|
+
auto embeddings = createWritableArray(env);
|
1100
|
+
for (const auto &val : embedding) {
|
1101
|
+
pushDouble(env, embeddings, (double) val);
|
1102
|
+
}
|
1103
|
+
putArray(env, result, "embedding", embeddings);
|
1104
|
+
|
1105
|
+
auto promptTokens = createWritableArray(env);
|
1106
|
+
for (const auto &tok : llama->embd) {
|
1107
|
+
pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
|
1108
|
+
}
|
1109
|
+
putArray(env, result, "prompt_tokens", promptTokens);
|
1110
|
+
|
1111
|
+
env->ReleaseStringUTFChars(text, text_chars);
|
1112
|
+
return result;
|
1113
|
+
}
|
1114
|
+
|
1115
|
+
JNIEXPORT jstring JNICALL
|
1116
|
+
Java_com_rnllama_LlamaContext_bench(
|
1117
|
+
JNIEnv *env,
|
1118
|
+
jobject thiz,
|
1119
|
+
jlong context_ptr,
|
1120
|
+
jint pp,
|
1121
|
+
jint tg,
|
1122
|
+
jint pl,
|
1123
|
+
jint nr
|
1124
|
+
) {
|
1125
|
+
UNUSED(thiz);
|
1126
|
+
auto llama = context_map[(long) context_ptr];
|
1127
|
+
std::string result = llama->bench(pp, tg, pl, nr);
|
1128
|
+
return env->NewStringUTF(result.c_str());
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
JNIEXPORT jint JNICALL
|
1132
|
+
Java_com_rnllama_LlamaContext_applyLoraAdapters(
|
1133
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
|
1134
|
+
UNUSED(thiz);
|
1135
|
+
auto llama = context_map[(long) context_ptr];
|
1136
|
+
|
1137
|
+
// lora_adapters: ReadableArray<ReadableMap>
|
1138
|
+
std::vector<common_adapter_lora_info> lora_adapters;
|
1139
|
+
int lora_adapters_size = readablearray::size(env, loraAdapters);
|
1140
|
+
for (int i = 0; i < lora_adapters_size; i++) {
|
1141
|
+
jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
|
1142
|
+
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
1143
|
+
if (path != nullptr) {
|
1144
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
1145
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
1146
|
+
float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
1147
|
+
common_adapter_lora_info la;
|
1148
|
+
la.path = path_chars;
|
1149
|
+
la.scale = scaled;
|
1150
|
+
lora_adapters.push_back(la);
|
1151
|
+
}
|
1152
|
+
}
|
1153
|
+
return llama->applyLoraAdapters(lora_adapters);
|
1154
|
+
}
|
1155
|
+
|
1156
|
+
JNIEXPORT void JNICALL
|
1157
|
+
Java_com_rnllama_LlamaContext_removeLoraAdapters(
|
1158
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1159
|
+
UNUSED(env);
|
1160
|
+
UNUSED(thiz);
|
1161
|
+
auto llama = context_map[(long) context_ptr];
|
1162
|
+
llama->removeLoraAdapters();
|
1163
|
+
}
|
1164
|
+
|
1165
|
+
JNIEXPORT jobject JNICALL
|
1166
|
+
Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
|
1167
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1168
|
+
UNUSED(thiz);
|
1169
|
+
auto llama = context_map[(long) context_ptr];
|
1170
|
+
auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
|
1171
|
+
auto result = createWritableArray(env);
|
1172
|
+
for (common_adapter_lora_info &la : loaded_lora_adapters) {
|
1173
|
+
auto map = createWriteableMap(env);
|
1174
|
+
putString(env, map, "path", la.path.c_str());
|
1175
|
+
putDouble(env, map, "scaled", la.scale);
|
1176
|
+
pushMap(env, result, map);
|
1177
|
+
}
|
1178
|
+
return result;
|
1179
|
+
}
|
1180
|
+
|
1181
|
+
JNIEXPORT void JNICALL
|
1182
|
+
Java_com_rnllama_LlamaContext_freeContext(
|
1183
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
1184
|
+
UNUSED(env);
|
1185
|
+
UNUSED(thiz);
|
1186
|
+
auto llama = context_map[(long) context_ptr];
|
1187
|
+
context_map.erase((long) llama->ctx);
|
1188
|
+
delete llama;
|
1189
|
+
}
|
1190
|
+
|
1191
|
+
struct log_callback_context {
|
1192
|
+
JavaVM *jvm;
|
1193
|
+
jobject callback;
|
1194
|
+
};
|
1195
|
+
|
1196
|
+
static void rnllama_log_callback_to_j(lm_ggml_log_level level, const char * text, void * data) {
|
1197
|
+
auto level_c = "";
|
1198
|
+
if (level == LM_GGML_LOG_LEVEL_ERROR) {
|
1199
|
+
__android_log_print(ANDROID_LOG_ERROR, TAG, text, nullptr);
|
1200
|
+
level_c = "error";
|
1201
|
+
} else if (level == LM_GGML_LOG_LEVEL_INFO) {
|
1202
|
+
__android_log_print(ANDROID_LOG_INFO, TAG, text, nullptr);
|
1203
|
+
level_c = "info";
|
1204
|
+
} else if (level == LM_GGML_LOG_LEVEL_WARN) {
|
1205
|
+
__android_log_print(ANDROID_LOG_WARN, TAG, text, nullptr);
|
1206
|
+
level_c = "warn";
|
1207
|
+
} else {
|
1208
|
+
__android_log_print(ANDROID_LOG_DEFAULT, TAG, text, nullptr);
|
1209
|
+
}
|
1210
|
+
|
1211
|
+
log_callback_context *cb_ctx = (log_callback_context *) data;
|
1212
|
+
|
1213
|
+
JNIEnv *env;
|
1214
|
+
bool need_detach = false;
|
1215
|
+
int getEnvResult = cb_ctx->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
|
1216
|
+
|
1217
|
+
if (getEnvResult == JNI_EDETACHED) {
|
1218
|
+
if (cb_ctx->jvm->AttachCurrentThread(&env, nullptr) == JNI_OK) {
|
1219
|
+
need_detach = true;
|
1220
|
+
} else {
|
1221
|
+
return;
|
1222
|
+
}
|
1223
|
+
} else if (getEnvResult != JNI_OK) {
|
1224
|
+
return;
|
1225
|
+
}
|
1226
|
+
|
1227
|
+
jobject callback = cb_ctx->callback;
|
1228
|
+
jclass cb_class = env->GetObjectClass(callback);
|
1229
|
+
jmethodID emitNativeLog = env->GetMethodID(cb_class, "emitNativeLog", "(Ljava/lang/String;Ljava/lang/String;)V");
|
1230
|
+
|
1231
|
+
jstring level_str = env->NewStringUTF(level_c);
|
1232
|
+
jstring text_str = env->NewStringUTF(text);
|
1233
|
+
env->CallVoidMethod(callback, emitNativeLog, level_str, text_str);
|
1234
|
+
env->DeleteLocalRef(level_str);
|
1235
|
+
env->DeleteLocalRef(text_str);
|
1236
|
+
|
1237
|
+
if (need_detach) {
|
1238
|
+
cb_ctx->jvm->DetachCurrentThread();
|
1239
|
+
}
|
1240
|
+
}
|
1241
|
+
|
1242
|
+
JNIEXPORT void JNICALL
|
1243
|
+
Java_com_rnllama_LlamaContext_setupLog(JNIEnv *env, jobject thiz, jobject logCallback) {
|
1244
|
+
UNUSED(thiz);
|
1245
|
+
|
1246
|
+
log_callback_context *cb_ctx = new log_callback_context;
|
1247
|
+
|
1248
|
+
JavaVM *jvm;
|
1249
|
+
env->GetJavaVM(&jvm);
|
1250
|
+
cb_ctx->jvm = jvm;
|
1251
|
+
cb_ctx->callback = env->NewGlobalRef(logCallback);
|
1252
|
+
|
1253
|
+
llama_log_set(rnllama_log_callback_to_j, cb_ctx);
|
1254
|
+
}
|
1255
|
+
|
1256
|
+
JNIEXPORT void JNICALL
|
1257
|
+
Java_com_rnllama_LlamaContext_unsetLog(JNIEnv *env, jobject thiz) {
|
1258
|
+
UNUSED(env);
|
1259
|
+
UNUSED(thiz);
|
1260
|
+
llama_log_set(rnllama_log_callback_default, NULL);
|
1261
|
+
}
|
1262
|
+
|
1263
|
+
} // extern "C"
|