cui-llama.rn 1.5.0 → 1.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (309) hide show
  1. package/LICENSE +20 -20
  2. package/README.md +317 -319
  3. package/android/build.gradle +116 -116
  4. package/android/gradle.properties +5 -5
  5. package/android/src/main/AndroidManifest.xml +4 -4
  6. package/android/src/main/CMakeLists.txt +124 -124
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +645 -645
  8. package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
  10. package/android/src/main/jni-utils.h +100 -100
  11. package/android/src/main/jni.cpp +1263 -1263
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  14. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  15. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  16. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  17. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  20. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
  21. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
  22. package/cpp/README.md +4 -4
  23. package/cpp/ggml-llama-sim.metallib +0 -0
  24. package/cpp/ggml-llama.metallib +0 -0
  25. package/cpp/ggml-metal-impl.h +597 -597
  26. package/cpp/ggml-metal.m +4 -0
  27. package/cpp/ggml.h +1 -1
  28. package/cpp/rn-llama.cpp +873 -873
  29. package/cpp/rn-llama.h +138 -138
  30. package/cpp/sampling.h +107 -107
  31. package/cpp/unicode-data.cpp +7034 -7034
  32. package/cpp/unicode-data.h +20 -20
  33. package/cpp/unicode.cpp +849 -849
  34. package/cpp/unicode.h +66 -66
  35. package/ios/CMakeLists.txt +116 -108
  36. package/ios/RNLlama.h +7 -7
  37. package/ios/RNLlama.mm +418 -405
  38. package/ios/RNLlamaContext.h +57 -57
  39. package/ios/RNLlamaContext.mm +835 -835
  40. package/ios/rnllama.xcframework/Info.plist +74 -74
  41. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  42. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
  43. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +677 -0
  44. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  45. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  46. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  47. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  48. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  49. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  50. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  51. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  52. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  53. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  54. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  55. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  56. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  57. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  58. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  59. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  60. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  61. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  62. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
  63. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  64. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  65. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  66. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1434 -0
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +128 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +802 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  103. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  104. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  129. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  130. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  132. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  133. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  134. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  162. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  163. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
  164. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +677 -0
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  175. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  176. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  177. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  178. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  179. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  180. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  181. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  182. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  183. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
  184. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  185. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  188. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  189. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  190. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1434 -0
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +128 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
  218. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +802 -0
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  222. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  223. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  224. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  259. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  260. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  271. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  274. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  275. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  276. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  277. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  278. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  279. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  280. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  281. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  282. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  283. package/jest/mock.js +203 -203
  284. package/lib/commonjs/NativeRNLlama.js +1 -2
  285. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  286. package/lib/commonjs/chat.js.map +1 -1
  287. package/lib/commonjs/grammar.js +12 -31
  288. package/lib/commonjs/grammar.js.map +1 -1
  289. package/lib/commonjs/index.js +47 -47
  290. package/lib/commonjs/index.js.map +1 -1
  291. package/lib/commonjs/package.json +1 -0
  292. package/lib/module/NativeRNLlama.js +2 -0
  293. package/lib/module/NativeRNLlama.js.map +1 -1
  294. package/lib/module/chat.js +2 -0
  295. package/lib/module/chat.js.map +1 -1
  296. package/lib/module/grammar.js +14 -31
  297. package/lib/module/grammar.js.map +1 -1
  298. package/lib/module/index.js +47 -45
  299. package/lib/module/index.js.map +1 -1
  300. package/lib/module/package.json +1 -0
  301. package/lib/typescript/NativeRNLlama.d.ts +6 -4
  302. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  303. package/lib/typescript/index.d.ts.map +1 -1
  304. package/llama-rn.podspec +48 -48
  305. package/package.json +233 -233
  306. package/src/NativeRNLlama.ts +426 -426
  307. package/src/chat.ts +44 -44
  308. package/src/grammar.ts +854 -854
  309. package/src/index.ts +495 -487
@@ -1,1263 +1,1263 @@
1
- #include <jni.h>
2
- // #include <android/asset_manager.h>
3
- // #include <android/asset_manager_jni.h>
4
- #include <android/log.h>
5
- #include <cstdlib>
6
- #include <ctime>
7
- #include <ctime>
8
- #include <sys/sysinfo.h>
9
- #include <string>
10
- #include <thread>
11
- #include <unordered_map>
12
- #include "json-schema-to-grammar.h"
13
- #include "llama.h"
14
- #include "chat.h"
15
- #include "llama-impl.h"
16
- #include "ggml.h"
17
- #include "rn-llama.h"
18
- #include "jni-utils.h"
19
- #define UNUSED(x) (void)(x)
20
- #define TAG "RNLLAMA_ANDROID_JNI"
21
-
22
- #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
23
- #define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
24
- #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
25
- static inline int min(int a, int b) {
26
- return (a < b) ? a : b;
27
- }
28
-
29
- static void rnllama_log_callback_default(lm_ggml_log_level level, const char * fmt, void * data) {
30
- if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
31
- else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
32
- else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
33
- else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
34
- }
35
-
36
- extern "C" {
37
-
38
- // Method to create WritableMap
39
- static inline jobject createWriteableMap(JNIEnv *env) {
40
- jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
41
- jmethodID init = env->GetStaticMethodID(mapClass, "createMap", "()Lcom/facebook/react/bridge/WritableMap;");
42
- jobject map = env->CallStaticObjectMethod(mapClass, init);
43
- return map;
44
- }
45
-
46
- // Method to put string into WritableMap
47
- static inline void putString(JNIEnv *env, jobject map, const char *key, const char *value) {
48
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
49
- jmethodID putStringMethod = env->GetMethodID(mapClass, "putString", "(Ljava/lang/String;Ljava/lang/String;)V");
50
-
51
- jstring jKey = env->NewStringUTF(key);
52
- jstring jValue = env->NewStringUTF(value);
53
-
54
- env->CallVoidMethod(map, putStringMethod, jKey, jValue);
55
- }
56
-
57
- // Method to put int into WritableMap
58
- static inline void putInt(JNIEnv *env, jobject map, const char *key, int value) {
59
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
60
- jmethodID putIntMethod = env->GetMethodID(mapClass, "putInt", "(Ljava/lang/String;I)V");
61
-
62
- jstring jKey = env->NewStringUTF(key);
63
-
64
- env->CallVoidMethod(map, putIntMethod, jKey, value);
65
- }
66
-
67
- // Method to put double into WritableMap
68
- static inline void putDouble(JNIEnv *env, jobject map, const char *key, double value) {
69
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
70
- jmethodID putDoubleMethod = env->GetMethodID(mapClass, "putDouble", "(Ljava/lang/String;D)V");
71
-
72
- jstring jKey = env->NewStringUTF(key);
73
-
74
- env->CallVoidMethod(map, putDoubleMethod, jKey, value);
75
- }
76
-
77
- // Method to put boolean into WritableMap
78
- static inline void putBoolean(JNIEnv *env, jobject map, const char *key, bool value) {
79
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
80
- jmethodID putBooleanMethod = env->GetMethodID(mapClass, "putBoolean", "(Ljava/lang/String;Z)V");
81
-
82
- jstring jKey = env->NewStringUTF(key);
83
-
84
- env->CallVoidMethod(map, putBooleanMethod, jKey, value);
85
- }
86
-
87
- // Method to put WriteableMap into WritableMap
88
- static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
89
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
90
- jmethodID putMapMethod = env->GetMethodID(mapClass, "putMap", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableMap;)V");
91
-
92
- jstring jKey = env->NewStringUTF(key);
93
-
94
- env->CallVoidMethod(map, putMapMethod, jKey, value);
95
- }
96
-
97
- // Method to create WritableArray
98
- static inline jobject createWritableArray(JNIEnv *env) {
99
- jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
100
- jmethodID init = env->GetStaticMethodID(mapClass, "createArray", "()Lcom/facebook/react/bridge/WritableArray;");
101
- jobject map = env->CallStaticObjectMethod(mapClass, init);
102
- return map;
103
- }
104
-
105
- // Method to push int into WritableArray
106
- static inline void pushInt(JNIEnv *env, jobject arr, int value) {
107
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
108
- jmethodID pushIntMethod = env->GetMethodID(mapClass, "pushInt", "(I)V");
109
-
110
- env->CallVoidMethod(arr, pushIntMethod, value);
111
- }
112
-
113
- // Method to push double into WritableArray
114
- static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
115
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
116
- jmethodID pushDoubleMethod = env->GetMethodID(mapClass, "pushDouble", "(D)V");
117
-
118
- env->CallVoidMethod(arr, pushDoubleMethod, value);
119
- }
120
-
121
- // Method to push string into WritableArray
122
- static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
123
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
124
- jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
125
-
126
- jstring jValue = env->NewStringUTF(value);
127
- env->CallVoidMethod(arr, pushStringMethod, jValue);
128
- }
129
-
130
- // Method to push WritableMap into WritableArray
131
- static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
132
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
133
- jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
134
-
135
- env->CallVoidMethod(arr, pushMapMethod, value);
136
- }
137
-
138
- // Method to put WritableArray into WritableMap
139
- static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject value) {
140
- jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
141
- jmethodID putArrayMethod = env->GetMethodID(mapClass, "putArray", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableArray;)V");
142
-
143
- jstring jKey = env->NewStringUTF(key);
144
-
145
- env->CallVoidMethod(map, putArrayMethod, jKey, value);
146
- }
147
-
148
- JNIEXPORT jobject JNICALL
149
- Java_com_rnllama_LlamaContext_modelInfo(
150
- JNIEnv *env,
151
- jobject thiz,
152
- jstring model_path_str,
153
- jobjectArray skip
154
- ) {
155
- UNUSED(thiz);
156
-
157
- const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
158
-
159
- std::vector<std::string> skip_vec;
160
- int skip_len = env->GetArrayLength(skip);
161
- for (int i = 0; i < skip_len; i++) {
162
- jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
163
- const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
164
- skip_vec.push_back(skip_chars);
165
- env->ReleaseStringUTFChars(skip_str, skip_chars);
166
- }
167
-
168
- struct lm_gguf_init_params params = {
169
- /*.no_alloc = */ false,
170
- /*.ctx = */ NULL,
171
- };
172
- struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
173
-
174
- if (!ctx) {
175
- LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
176
- return nullptr;
177
- }
178
-
179
- auto info = createWriteableMap(env);
180
- putInt(env, info, "version", lm_gguf_get_version(ctx));
181
- putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
182
- putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
183
- {
184
- const int n_kv = lm_gguf_get_n_kv(ctx);
185
-
186
- for (int i = 0; i < n_kv; ++i) {
187
- const char * key = lm_gguf_get_key(ctx, i);
188
-
189
- bool skipped = false;
190
- if (skip_len > 0) {
191
- for (int j = 0; j < skip_len; j++) {
192
- if (skip_vec[j] == key) {
193
- skipped = true;
194
- break;
195
- }
196
- }
197
- }
198
-
199
- if (skipped) {
200
- continue;
201
- }
202
-
203
- const std::string value = lm_gguf_kv_to_str(ctx, i);
204
- putString(env, info, key, value.c_str());
205
- }
206
- }
207
-
208
- env->ReleaseStringUTFChars(model_path_str, model_path_chars);
209
- lm_gguf_free(ctx);
210
-
211
- return reinterpret_cast<jobject>(info);
212
- }
213
-
214
- struct callback_context {
215
- JNIEnv *env;
216
- rnllama::llama_rn_context *llama;
217
- jobject callback;
218
- };
219
-
220
- std::unordered_map<long, rnllama::llama_rn_context *> context_map;
221
-
222
- struct CallbackContext {
223
- JNIEnv * env;
224
- jobject thiz;
225
- jmethodID sendProgressMethod;
226
- unsigned current;
227
- };
228
-
229
- JNIEXPORT jlong JNICALL
230
- Java_com_rnllama_LlamaContext_initContext(
231
- JNIEnv *env,
232
- jobject thiz,
233
- jstring model_path_str,
234
- jstring chat_template,
235
- jstring reasoning_format,
236
- jboolean embedding,
237
- jint embd_normalize,
238
- jint n_ctx,
239
- jint n_batch,
240
- jint n_ubatch,
241
- jint n_threads,
242
- jint n_gpu_layers, // TODO: Support this
243
- jboolean flash_attn,
244
- jstring cache_type_k,
245
- jstring cache_type_v,
246
- jboolean use_mlock,
247
- jboolean use_mmap,
248
- jboolean vocab_only,
249
- jstring lora_str,
250
- jfloat lora_scaled,
251
- jobject lora_list,
252
- jfloat rope_freq_base,
253
- jfloat rope_freq_scale,
254
- jint pooling_type,
255
- jobject load_progress_callback
256
- ) {
257
- UNUSED(thiz);
258
-
259
- common_params defaultParams;
260
-
261
- defaultParams.vocab_only = vocab_only;
262
- if(vocab_only) {
263
- defaultParams.warmup = false;
264
- }
265
-
266
- const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
267
- defaultParams.model = { model_path_chars };
268
-
269
- const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
270
- defaultParams.chat_template = chat_template_chars;
271
-
272
- const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
273
- if (strcmp(reasoning_format_chars, "deepseek") == 0) {
274
- defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
275
- } else {
276
- defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
277
- }
278
-
279
- defaultParams.n_ctx = n_ctx;
280
- defaultParams.n_batch = n_batch;
281
- defaultParams.n_ubatch = n_ubatch;
282
-
283
- if (pooling_type != -1) {
284
- defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
285
- }
286
-
287
- defaultParams.embedding = embedding;
288
- if (embd_normalize != -1) {
289
- defaultParams.embd_normalize = embd_normalize;
290
- }
291
- if (embedding) {
292
- // For non-causal models, batch size must be equal to ubatch size
293
- defaultParams.n_ubatch = defaultParams.n_batch;
294
- }
295
-
296
- int max_threads = std::thread::hardware_concurrency();
297
- // Use 2 threads by default on 4-core devices, 4 threads on more cores
298
- int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
299
- defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
300
-
301
- // defaultParams.n_gpu_layers = n_gpu_layers;
302
- defaultParams.flash_attn = flash_attn;
303
-
304
- const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
305
- const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
306
- defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
307
- defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
308
-
309
- defaultParams.use_mlock = use_mlock;
310
- defaultParams.use_mmap = use_mmap;
311
-
312
- defaultParams.rope_freq_base = rope_freq_base;
313
- defaultParams.rope_freq_scale = rope_freq_scale;
314
-
315
- auto llama = new rnllama::llama_rn_context();
316
- llama->is_load_interrupted = false;
317
- llama->loading_progress = 0;
318
-
319
- if (load_progress_callback != nullptr) {
320
- defaultParams.progress_callback = [](float progress, void * user_data) {
321
- callback_context *cb_ctx = (callback_context *)user_data;
322
- JNIEnv *env = cb_ctx->env;
323
- auto llama = cb_ctx->llama;
324
- jobject callback = cb_ctx->callback;
325
- int percentage = (int) (100 * progress);
326
- if (percentage > llama->loading_progress) {
327
- llama->loading_progress = percentage;
328
- jclass callback_class = env->GetObjectClass(callback);
329
- jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
330
- env->CallVoidMethod(callback, onLoadProgress, percentage);
331
- }
332
- return !llama->is_load_interrupted;
333
- };
334
-
335
- callback_context *cb_ctx = new callback_context;
336
- cb_ctx->env = env;
337
- cb_ctx->llama = llama;
338
- cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
339
- defaultParams.progress_callback_user_data = cb_ctx;
340
- }
341
-
342
- bool is_model_loaded = llama->loadModel(defaultParams);
343
-
344
- env->ReleaseStringUTFChars(model_path_str, model_path_chars);
345
- env->ReleaseStringUTFChars(chat_template, chat_template_chars);
346
- env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
347
- env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
348
- env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
349
-
350
- LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
351
- if (is_model_loaded) {
352
- if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
353
- LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
354
- llama_free(llama->ctx);
355
- return -1;
356
- }
357
- context_map[(long) llama->ctx] = llama;
358
- } else {
359
- llama_free(llama->ctx);
360
- }
361
-
362
- std::vector<common_adapter_lora_info> lora;
363
- const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
364
- if (lora_chars != nullptr && lora_chars[0] != '\0') {
365
- common_adapter_lora_info la;
366
- la.path = lora_chars;
367
- la.scale = lora_scaled;
368
- lora.push_back(la);
369
- }
370
-
371
- if (lora_list != nullptr) {
372
- // lora_adapters: ReadableArray<ReadableMap>
373
- int lora_list_size = readablearray::size(env, lora_list);
374
- for (int i = 0; i < lora_list_size; i++) {
375
- jobject lora_adapter = readablearray::getMap(env, lora_list, i);
376
- jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
377
- if (path != nullptr) {
378
- const char *path_chars = env->GetStringUTFChars(path, nullptr);
379
- common_adapter_lora_info la;
380
- la.path = path_chars;
381
- la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
382
- lora.push_back(la);
383
- env->ReleaseStringUTFChars(path, path_chars);
384
- }
385
- }
386
- }
387
- env->ReleaseStringUTFChars(lora_str, lora_chars);
388
- int result = llama->applyLoraAdapters(lora);
389
- if (result != 0) {
390
- LOGI("[RNLlama] Failed to apply lora adapters");
391
- llama_free(llama->ctx);
392
- return -1;
393
- }
394
-
395
- return reinterpret_cast<jlong>(llama->ctx);
396
- }
397
-
398
-
399
- JNIEXPORT void JNICALL
400
- Java_com_rnllama_LlamaContext_interruptLoad(
401
- JNIEnv *env,
402
- jobject thiz,
403
- jlong context_ptr
404
- ) {
405
- UNUSED(thiz);
406
- auto llama = context_map[(long) context_ptr];
407
- if (llama) {
408
- llama->is_load_interrupted = true;
409
- }
410
- }
411
-
412
- JNIEXPORT jobject JNICALL
413
- Java_com_rnllama_LlamaContext_loadModelDetails(
414
- JNIEnv *env,
415
- jobject thiz,
416
- jlong context_ptr
417
- ) {
418
- UNUSED(thiz);
419
- auto llama = context_map[(long) context_ptr];
420
-
421
- int count = llama_model_meta_count(llama->model);
422
- auto meta = createWriteableMap(env);
423
- for (int i = 0; i < count; i++) {
424
- char key[256];
425
- llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
426
- char val[4096];
427
- llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
428
-
429
- putString(env, meta, key, val);
430
- }
431
-
432
- auto result = createWriteableMap(env);
433
-
434
- char desc[1024];
435
- llama_model_desc(llama->model, desc, sizeof(desc));
436
-
437
- putString(env, result, "desc", desc);
438
- putDouble(env, result, "size", llama_model_size(llama->model));
439
- putDouble(env, result, "nEmbd", llama_model_n_embd(llama->model));
440
- putDouble(env, result, "nParams", llama_model_n_params(llama->model));
441
- auto chat_templates = createWriteableMap(env);
442
- putBoolean(env, chat_templates, "llamaChat", llama->validateModelChatTemplate(false, nullptr));
443
-
444
- auto minja = createWriteableMap(env);
445
- putBoolean(env, minja, "default", llama->validateModelChatTemplate(true, nullptr));
446
-
447
- auto default_caps = createWriteableMap(env);
448
-
449
- auto default_tmpl = llama->templates.get()->template_default.get();
450
- auto default_tmpl_caps = default_tmpl->original_caps();
451
- putBoolean(env, default_caps, "tools", default_tmpl_caps.supports_tools);
452
- putBoolean(env, default_caps, "toolCalls", default_tmpl_caps.supports_tool_calls);
453
- putBoolean(env, default_caps, "parallelToolCalls", default_tmpl_caps.supports_parallel_tool_calls);
454
- putBoolean(env, default_caps, "toolResponses", default_tmpl_caps.supports_tool_responses);
455
- putBoolean(env, default_caps, "systemRole", default_tmpl_caps.supports_system_role);
456
- putBoolean(env, default_caps, "toolCallId", default_tmpl_caps.supports_tool_call_id);
457
- putMap(env, minja, "defaultCaps", default_caps);
458
-
459
- putBoolean(env, minja, "toolUse", llama->validateModelChatTemplate(true, "tool_use"));
460
- auto tool_use_tmpl = llama->templates.get()->template_tool_use.get();
461
- if (tool_use_tmpl != nullptr) {
462
- auto tool_use_caps = createWriteableMap(env);
463
- auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
464
- putBoolean(env, tool_use_caps, "tools", tool_use_tmpl_caps.supports_tools);
465
- putBoolean(env, tool_use_caps, "toolCalls", tool_use_tmpl_caps.supports_tool_calls);
466
- putBoolean(env, tool_use_caps, "parallelToolCalls", tool_use_tmpl_caps.supports_parallel_tool_calls);
467
- putBoolean(env, tool_use_caps, "systemRole", tool_use_tmpl_caps.supports_system_role);
468
- putBoolean(env, tool_use_caps, "toolResponses", tool_use_tmpl_caps.supports_tool_responses);
469
- putBoolean(env, tool_use_caps, "toolCallId", tool_use_tmpl_caps.supports_tool_call_id);
470
- putMap(env, minja, "toolUseCaps", tool_use_caps);
471
- }
472
-
473
- putMap(env, chat_templates, "minja", minja);
474
- putMap(env, result, "metadata", meta);
475
- putMap(env, result, "chatTemplates", chat_templates);
476
-
477
- // deprecated
478
- putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate(false, nullptr));
479
-
480
- return reinterpret_cast<jobject>(result);
481
- }
482
-
483
- JNIEXPORT jobject JNICALL
484
- Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
485
- JNIEnv *env,
486
- jobject thiz,
487
- jlong context_ptr,
488
- jstring messages,
489
- jstring chat_template,
490
- jstring json_schema,
491
- jstring tools,
492
- jboolean parallel_tool_calls,
493
- jstring tool_choice
494
- ) {
495
- UNUSED(thiz);
496
- auto llama = context_map[(long) context_ptr];
497
-
498
- const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
499
- const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
500
- const char *json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
501
- const char *tools_chars = env->GetStringUTFChars(tools, nullptr);
502
- const char *tool_choice_chars = env->GetStringUTFChars(tool_choice, nullptr);
503
-
504
- auto result = createWriteableMap(env);
505
- try {
506
- auto formatted = llama->getFormattedChatWithJinja(
507
- messages_chars,
508
- tmpl_chars,
509
- json_schema_chars,
510
- tools_chars,
511
- parallel_tool_calls,
512
- tool_choice_chars
513
- );
514
- putString(env, result, "prompt", formatted.prompt.c_str());
515
- putInt(env, result, "chat_format", static_cast<int>(formatted.format));
516
- putString(env, result, "grammar", formatted.grammar.c_str());
517
- putBoolean(env, result, "grammar_lazy", formatted.grammar_lazy);
518
- auto grammar_triggers = createWritableArray(env);
519
- for (const auto &trigger : formatted.grammar_triggers) {
520
- auto trigger_map = createWriteableMap(env);
521
- putInt(env, trigger_map, "type", trigger.type);
522
- putString(env, trigger_map, "value", trigger.value.c_str());
523
- putInt(env, trigger_map, "token", trigger.token);
524
- pushMap(env, grammar_triggers, trigger_map);
525
- }
526
- putArray(env, result, "grammar_triggers", grammar_triggers);
527
- auto preserved_tokens = createWritableArray(env);
528
- for (const auto &token : formatted.preserved_tokens) {
529
- pushString(env, preserved_tokens, token.c_str());
530
- }
531
- putArray(env, result, "preserved_tokens", preserved_tokens);
532
- auto additional_stops = createWritableArray(env);
533
- for (const auto &stop : formatted.additional_stops) {
534
- pushString(env, additional_stops, stop.c_str());
535
- }
536
- putArray(env, result, "additional_stops", additional_stops);
537
- } catch (const std::runtime_error &e) {
538
- LOGI("[RNLlama] Error: %s", e.what());
539
- putString(env, result, "_error", e.what());
540
- }
541
- env->ReleaseStringUTFChars(tools, tools_chars);
542
- env->ReleaseStringUTFChars(messages, messages_chars);
543
- env->ReleaseStringUTFChars(chat_template, tmpl_chars);
544
- env->ReleaseStringUTFChars(json_schema, json_schema_chars);
545
- env->ReleaseStringUTFChars(tool_choice, tool_choice_chars);
546
- return reinterpret_cast<jobject>(result);
547
- }
548
-
549
- JNIEXPORT jobject JNICALL
550
- Java_com_rnllama_LlamaContext_getFormattedChat(
551
- JNIEnv *env,
552
- jobject thiz,
553
- jlong context_ptr,
554
- jstring messages,
555
- jstring chat_template
556
- ) {
557
- UNUSED(thiz);
558
- auto llama = context_map[(long) context_ptr];
559
-
560
- const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
561
- const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
562
-
563
- std::string formatted_chat = llama->getFormattedChat(messages_chars, tmpl_chars);
564
-
565
- env->ReleaseStringUTFChars(messages, messages_chars);
566
- env->ReleaseStringUTFChars(chat_template, tmpl_chars);
567
-
568
- return env->NewStringUTF(formatted_chat.c_str());
569
- }
570
-
571
- JNIEXPORT jobject JNICALL
572
- Java_com_rnllama_LlamaContext_loadSession(
573
- JNIEnv *env,
574
- jobject thiz,
575
- jlong context_ptr,
576
- jstring path
577
- ) {
578
- UNUSED(thiz);
579
- auto llama = context_map[(long) context_ptr];
580
- const char *path_chars = env->GetStringUTFChars(path, nullptr);
581
-
582
- auto result = createWriteableMap(env);
583
- size_t n_token_count_out = 0;
584
- llama->embd.resize(llama->params.n_ctx);
585
- if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
586
- env->ReleaseStringUTFChars(path, path_chars);
587
-
588
- putString(env, result, "error", "Failed to load session");
589
- return reinterpret_cast<jobject>(result);
590
- }
591
- llama->embd.resize(n_token_count_out);
592
- env->ReleaseStringUTFChars(path, path_chars);
593
-
594
- const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
595
- putInt(env, result, "tokens_loaded", n_token_count_out);
596
- putString(env, result, "prompt", text.c_str());
597
- return reinterpret_cast<jobject>(result);
598
- }
599
-
600
- JNIEXPORT jint JNICALL
601
- Java_com_rnllama_LlamaContext_saveSession(
602
- JNIEnv *env,
603
- jobject thiz,
604
- jlong context_ptr,
605
- jstring path,
606
- jint size
607
- ) {
608
- UNUSED(thiz);
609
- auto llama = context_map[(long) context_ptr];
610
-
611
- const char *path_chars = env->GetStringUTFChars(path, nullptr);
612
-
613
- std::vector<llama_token> session_tokens = llama->embd;
614
- int default_size = session_tokens.size();
615
- int save_size = size > 0 && size <= default_size ? size : default_size;
616
- if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
617
- env->ReleaseStringUTFChars(path, path_chars);
618
- return -1;
619
- }
620
-
621
- env->ReleaseStringUTFChars(path, path_chars);
622
- return session_tokens.size();
623
- }
624
-
625
- static inline jobject tokenProbsToMap(
626
- JNIEnv *env,
627
- rnllama::llama_rn_context *llama,
628
- std::vector<rnllama::completion_token_output> probs
629
- ) {
630
- auto result = createWritableArray(env);
631
- for (const auto &prob : probs) {
632
- auto probsForToken = createWritableArray(env);
633
- for (const auto &p : prob.probs) {
634
- std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
635
- auto probResult = createWriteableMap(env);
636
- putString(env, probResult, "tok_str", tokStr.c_str());
637
- putDouble(env, probResult, "prob", p.prob);
638
- pushMap(env, probsForToken, probResult);
639
- }
640
- std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
641
- auto tokenResult = createWriteableMap(env);
642
- putString(env, tokenResult, "content", tokStr.c_str());
643
- putArray(env, tokenResult, "probs", probsForToken);
644
- pushMap(env, result, tokenResult);
645
- }
646
- return result;
647
- }
648
-
649
- JNIEXPORT jobject JNICALL
650
- Java_com_rnllama_LlamaContext_doCompletion(
651
- JNIEnv *env,
652
- jobject thiz,
653
- jlong context_ptr,
654
- jstring prompt,
655
- jint chat_format,
656
- jstring grammar,
657
- jstring json_schema,
658
- jboolean grammar_lazy,
659
- jobject grammar_triggers,
660
- jobject preserved_tokens,
661
- jfloat temperature,
662
- jint n_threads,
663
- jint n_predict,
664
- jint n_probs,
665
- jint penalty_last_n,
666
- jfloat penalty_repeat,
667
- jfloat penalty_freq,
668
- jfloat penalty_present,
669
- jfloat mirostat,
670
- jfloat mirostat_tau,
671
- jfloat mirostat_eta,
672
- jint top_k,
673
- jfloat top_p,
674
- jfloat min_p,
675
- jfloat xtc_threshold,
676
- jfloat xtc_probability,
677
- jfloat typical_p,
678
- jint seed,
679
- jobjectArray stop,
680
- jboolean ignore_eos,
681
- jobjectArray logit_bias,
682
- jfloat dry_multiplier,
683
- jfloat dry_base,
684
- jint dry_allowed_length,
685
- jint dry_penalty_last_n,
686
- jfloat top_n_sigma,
687
- jobjectArray dry_sequence_breakers,
688
- jobject partial_completion_callback
689
- ) {
690
- UNUSED(thiz);
691
- auto llama = context_map[(long) context_ptr];
692
-
693
- llama->rewind();
694
-
695
- //llama_reset_timings(llama->ctx);
696
-
697
- auto prompt_chars = env->GetStringUTFChars(prompt, nullptr);
698
- llama->params.prompt = prompt_chars;
699
- llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
700
-
701
- int max_threads = std::thread::hardware_concurrency();
702
- // Use 2 threads by default on 4-core devices, 4 threads on more cores
703
- int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
704
- llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
705
-
706
- llama->params.n_predict = n_predict;
707
- llama->params.sampling.ignore_eos = ignore_eos;
708
-
709
- auto & sparams = llama->params.sampling;
710
- sparams.temp = temperature;
711
- sparams.penalty_last_n = penalty_last_n;
712
- sparams.penalty_repeat = penalty_repeat;
713
- sparams.penalty_freq = penalty_freq;
714
- sparams.penalty_present = penalty_present;
715
- sparams.mirostat = mirostat;
716
- sparams.mirostat_tau = mirostat_tau;
717
- sparams.mirostat_eta = mirostat_eta;
718
- sparams.top_k = top_k;
719
- sparams.top_p = top_p;
720
- sparams.min_p = min_p;
721
- sparams.typ_p = typical_p;
722
- sparams.n_probs = n_probs;
723
- sparams.xtc_threshold = xtc_threshold;
724
- sparams.xtc_probability = xtc_probability;
725
- sparams.dry_multiplier = dry_multiplier;
726
- sparams.dry_base = dry_base;
727
- sparams.dry_allowed_length = dry_allowed_length;
728
- sparams.dry_penalty_last_n = dry_penalty_last_n;
729
- sparams.top_n_sigma = top_n_sigma;
730
-
731
- // grammar
732
- auto grammar_chars = env->GetStringUTFChars(grammar, nullptr);
733
- if (grammar_chars && grammar_chars[0] != '\0') {
734
- sparams.grammar = grammar_chars;
735
- }
736
- sparams.grammar_lazy = grammar_lazy;
737
-
738
- if (preserved_tokens != nullptr) {
739
- int preserved_tokens_size = readablearray::size(env, preserved_tokens);
740
- for (int i = 0; i < preserved_tokens_size; i++) {
741
- jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
742
- auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
743
- if (ids.size() == 1) {
744
- sparams.preserved_tokens.insert(ids[0]);
745
- } else {
746
- LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
747
- }
748
- }
749
- }
750
-
751
- if (grammar_triggers != nullptr) {
752
- int grammar_triggers_size = readablearray::size(env, grammar_triggers);
753
- for (int i = 0; i < grammar_triggers_size; i++) {
754
- auto trigger_map = readablearray::getMap(env, grammar_triggers, i);
755
- const auto type = static_cast<common_grammar_trigger_type>(readablemap::getInt(env, trigger_map, "type", 0));
756
- jstring trigger_word = readablemap::getString(env, trigger_map, "value", nullptr);
757
- auto word = env->GetStringUTFChars(trigger_word, nullptr);
758
-
759
- if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
760
- auto ids = common_tokenize(llama->ctx, word, /* add_special= */ false, /* parse_special= */ true);
761
- if (ids.size() == 1) {
762
- auto token = ids[0];
763
- if (std::find(sparams.preserved_tokens.begin(), sparams.preserved_tokens.end(), (llama_token) token) == sparams.preserved_tokens.end()) {
764
- throw std::runtime_error("Grammar trigger word should be marked as preserved token");
765
- }
766
- common_grammar_trigger trigger;
767
- trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
768
- trigger.value = word;
769
- trigger.token = token;
770
- sparams.grammar_triggers.push_back(std::move(trigger));
771
- } else {
772
- sparams.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
773
- }
774
- } else {
775
- common_grammar_trigger trigger;
776
- trigger.type = type;
777
- trigger.value = word;
778
- if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
779
- const auto token = (llama_token) readablemap::getInt(env, trigger_map, "token", 0);
780
- trigger.token = token;
781
- }
782
- sparams.grammar_triggers.push_back(std::move(trigger));
783
- }
784
- }
785
- }
786
-
787
- auto json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
788
- if ((!grammar_chars || grammar_chars[0] == '\0') && json_schema_chars && json_schema_chars[0] != '\0') {
789
- auto schema = json::parse(json_schema_chars);
790
- sparams.grammar = json_schema_to_grammar(schema);
791
- }
792
- env->ReleaseStringUTFChars(json_schema, json_schema_chars);
793
-
794
-
795
- const llama_model * model = llama_get_model(llama->ctx);
796
- const llama_vocab * vocab = llama_model_get_vocab(model);
797
-
798
- sparams.logit_bias.clear();
799
- if (ignore_eos) {
800
- sparams.logit_bias[llama_vocab_eos(vocab)].bias = -INFINITY;
801
- }
802
-
803
- // dry break seq
804
-
805
- jint size = env->GetArrayLength(dry_sequence_breakers);
806
- std::vector<std::string> dry_sequence_breakers_vector;
807
-
808
- for (jint i = 0; i < size; i++) {
809
- jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
810
- const char *nativeString = env->GetStringUTFChars(javaString, 0);
811
- dry_sequence_breakers_vector.push_back(std::string(nativeString));
812
- env->ReleaseStringUTFChars(javaString, nativeString);
813
- env->DeleteLocalRef(javaString);
814
- }
815
-
816
- sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
817
-
818
- // logit bias
819
- const int n_vocab = llama_vocab_n_tokens(vocab);
820
- jsize logit_bias_len = env->GetArrayLength(logit_bias);
821
-
822
- for (jsize i = 0; i < logit_bias_len; i++) {
823
- jdoubleArray el = (jdoubleArray) env->GetObjectArrayElement(logit_bias, i);
824
- if (el && env->GetArrayLength(el) == 2) {
825
- jdouble* doubleArray = env->GetDoubleArrayElements(el, 0);
826
-
827
- llama_token tok = static_cast<llama_token>(doubleArray[0]);
828
- if (tok >= 0 && tok < n_vocab) {
829
- if (doubleArray[1] != 0) { // If the second element is not false (0)
830
- sparams.logit_bias[tok].bias = doubleArray[1];
831
- } else {
832
- sparams.logit_bias[tok].bias = -INFINITY;
833
- }
834
- }
835
-
836
- env->ReleaseDoubleArrayElements(el, doubleArray, 0);
837
- }
838
- env->DeleteLocalRef(el);
839
- }
840
-
841
- llama->params.antiprompt.clear();
842
- int stop_len = env->GetArrayLength(stop);
843
- for (int i = 0; i < stop_len; i++) {
844
- jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
845
- const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
846
- llama->params.antiprompt.push_back(stop_chars);
847
- env->ReleaseStringUTFChars(stop_str, stop_chars);
848
- }
849
-
850
- if (!llama->initSampling()) {
851
- auto result = createWriteableMap(env);
852
- putString(env, result, "error", "Failed to initialize sampling");
853
- return reinterpret_cast<jobject>(result);
854
- }
855
- llama->beginCompletion();
856
- llama->loadPrompt();
857
-
858
- size_t sent_count = 0;
859
- size_t sent_token_probs_index = 0;
860
-
861
- while (llama->has_next_token && !llama->is_interrupted) {
862
- const rnllama::completion_token_output token_with_probs = llama->doCompletion();
863
- if (token_with_probs.tok == -1 || llama->incomplete) {
864
- continue;
865
- }
866
- const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
867
-
868
- size_t pos = std::min(sent_count, llama->generated_text.size());
869
-
870
- const std::string str_test = llama->generated_text.substr(pos);
871
- bool is_stop_full = false;
872
- size_t stop_pos =
873
- llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
874
- if (stop_pos != std::string::npos) {
875
- is_stop_full = true;
876
- llama->generated_text.erase(
877
- llama->generated_text.begin() + pos + stop_pos,
878
- llama->generated_text.end());
879
- pos = std::min(sent_count, llama->generated_text.size());
880
- } else {
881
- is_stop_full = false;
882
- stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
883
- rnllama::STOP_PARTIAL);
884
- }
885
-
886
- if (
887
- stop_pos == std::string::npos ||
888
- // Send rest of the text if we are at the end of the generation
889
- (!llama->has_next_token && !is_stop_full && stop_pos > 0)
890
- ) {
891
- const std::string to_send = llama->generated_text.substr(pos, std::string::npos);
892
-
893
- sent_count += to_send.size();
894
-
895
- std::vector<rnllama::completion_token_output> probs_output = {};
896
-
897
- auto tokenResult = createWriteableMap(env);
898
- putString(env, tokenResult, "token", to_send.c_str());
899
-
900
- if (llama->params.sampling.n_probs > 0) {
901
- const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
902
- size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
903
- size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
904
- if (probs_pos < probs_stop_pos) {
905
- probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
906
- }
907
- sent_token_probs_index = probs_stop_pos;
908
-
909
- putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
910
- }
911
-
912
- jclass cb_class = env->GetObjectClass(partial_completion_callback);
913
- jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
914
- env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
915
- }
916
- }
917
-
918
- env->ReleaseStringUTFChars(grammar, grammar_chars);
919
- env->ReleaseStringUTFChars(prompt, prompt_chars);
920
- llama_perf_context_print(llama->ctx);
921
- llama->is_predicting = false;
922
-
923
- auto toolCalls = createWritableArray(env);
924
- std::string reasoningContent = "";
925
- std::string content;
926
- auto toolCallsSize = 0;
927
- if (!llama->is_interrupted) {
928
- try {
929
- common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
930
- if (!message.reasoning_content.empty()) {
931
- reasoningContent = message.reasoning_content;
932
- }
933
- content = message.content;
934
- for (const auto &tc : message.tool_calls) {
935
- auto toolCall = createWriteableMap(env);
936
- putString(env, toolCall, "type", "function");
937
- auto functionMap = createWriteableMap(env);
938
- putString(env, functionMap, "name", tc.name.c_str());
939
- putString(env, functionMap, "arguments", tc.arguments.c_str());
940
- putMap(env, toolCall, "function", functionMap);
941
- if (!tc.id.empty()) {
942
- putString(env, toolCall, "id", tc.id.c_str());
943
- }
944
- pushMap(env, toolCalls, toolCall);
945
- toolCallsSize++;
946
- }
947
- } catch (const std::exception &e) {
948
- // LOGI("Error parsing tool calls: %s", e.what());
949
- }
950
- }
951
-
952
- auto result = createWriteableMap(env);
953
- putString(env, result, "text", llama->generated_text.c_str());
954
- if (!content.empty()) {
955
- putString(env, result, "content", content.c_str());
956
- }
957
- if (!reasoningContent.empty()) {
958
- putString(env, result, "reasoning_content", reasoningContent.c_str());
959
- }
960
- if (toolCallsSize > 0) {
961
- putArray(env, result, "tool_calls", toolCalls);
962
- }
963
- putArray(env, result, "completion_probabilities", tokenProbsToMap(env, llama, llama->generated_token_probs));
964
- putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
965
- putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
966
- putInt(env, result, "truncated", llama->truncated);
967
- putInt(env, result, "stopped_eos", llama->stopped_eos);
968
- putInt(env, result, "stopped_word", llama->stopped_word);
969
- putInt(env, result, "stopped_limit", llama->stopped_limit);
970
- putString(env, result, "stopping_word", llama->stopping_word.c_str());
971
- putInt(env, result, "tokens_cached", llama->n_past);
972
-
973
- const auto timings_token = llama_perf_context(llama -> ctx);
974
-
975
- auto timingsResult = createWriteableMap(env);
976
- putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
977
- putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms);
978
- putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval);
979
- putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval);
980
- putInt(env, timingsResult, "predicted_n", timings_token.n_eval);
981
- putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms);
982
- putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval);
983
- putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval);
984
-
985
- putMap(env, result, "timings", timingsResult);
986
-
987
- return reinterpret_cast<jobject>(result);
988
- }
989
-
990
- JNIEXPORT void JNICALL
991
- Java_com_rnllama_LlamaContext_stopCompletion(
992
- JNIEnv *env, jobject thiz, jlong context_ptr) {
993
- UNUSED(env);
994
- UNUSED(thiz);
995
- auto llama = context_map[(long) context_ptr];
996
- llama->is_interrupted = true;
997
- }
998
-
999
- JNIEXPORT jboolean JNICALL
1000
- Java_com_rnllama_LlamaContext_isPredicting(
1001
- JNIEnv *env, jobject thiz, jlong context_ptr) {
1002
- UNUSED(env);
1003
- UNUSED(thiz);
1004
- auto llama = context_map[(long) context_ptr];
1005
- return llama->is_predicting;
1006
- }
1007
-
1008
- JNIEXPORT jobject JNICALL
1009
- Java_com_rnllama_LlamaContext_tokenize(
1010
- JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
1011
- UNUSED(thiz);
1012
- auto llama = context_map[(long) context_ptr];
1013
-
1014
- const char *text_chars = env->GetStringUTFChars(text, nullptr);
1015
-
1016
- const std::vector<llama_token> toks = common_tokenize(
1017
- llama->ctx,
1018
- text_chars,
1019
- false
1020
- );
1021
-
1022
- jobject result = createWritableArray(env);
1023
- for (const auto &tok : toks) {
1024
- pushInt(env, result, tok);
1025
- }
1026
-
1027
- env->ReleaseStringUTFChars(text, text_chars);
1028
- return result;
1029
- }
1030
-
1031
- JNIEXPORT jstring JNICALL
1032
- Java_com_rnllama_LlamaContext_detokenize(
1033
- JNIEnv *env, jobject thiz, jlong context_ptr, jintArray tokens) {
1034
- UNUSED(thiz);
1035
- auto llama = context_map[(long) context_ptr];
1036
-
1037
- jsize tokens_len = env->GetArrayLength(tokens);
1038
- jint *tokens_ptr = env->GetIntArrayElements(tokens, 0);
1039
- std::vector<llama_token> toks;
1040
- for (int i = 0; i < tokens_len; i++) {
1041
- toks.push_back(tokens_ptr[i]);
1042
- }
1043
-
1044
- auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
1045
-
1046
- env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
1047
-
1048
- return env->NewStringUTF(text.c_str());
1049
- }
1050
-
1051
- JNIEXPORT jboolean JNICALL
1052
- Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
1053
- JNIEnv *env, jobject thiz, jlong context_ptr) {
1054
- UNUSED(env);
1055
- UNUSED(thiz);
1056
- auto llama = context_map[(long) context_ptr];
1057
- return llama->params.embedding;
1058
- }
1059
-
1060
- JNIEXPORT jobject JNICALL
1061
- Java_com_rnllama_LlamaContext_embedding(
1062
- JNIEnv *env, jobject thiz,
1063
- jlong context_ptr,
1064
- jstring text,
1065
- jint embd_normalize
1066
- ) {
1067
- UNUSED(thiz);
1068
- auto llama = context_map[(long) context_ptr];
1069
-
1070
- common_params embdParams;
1071
- embdParams.embedding = true;
1072
- embdParams.embd_normalize = llama->params.embd_normalize;
1073
- if (embd_normalize != -1) {
1074
- embdParams.embd_normalize = embd_normalize;
1075
- }
1076
-
1077
- const char *text_chars = env->GetStringUTFChars(text, nullptr);
1078
-
1079
- llama->rewind();
1080
-
1081
- llama_perf_context_reset(llama->ctx);
1082
-
1083
- llama->params.prompt = text_chars;
1084
-
1085
- llama->params.n_predict = 0;
1086
-
1087
- auto result = createWriteableMap(env);
1088
- if (!llama->initSampling()) {
1089
- putString(env, result, "error", "Failed to initialize sampling");
1090
- return reinterpret_cast<jobject>(result);
1091
- }
1092
-
1093
- llama->beginCompletion();
1094
- llama->loadPrompt();
1095
- llama->doCompletion();
1096
-
1097
- std::vector<float> embedding = llama->getEmbedding(embdParams);
1098
-
1099
- auto embeddings = createWritableArray(env);
1100
- for (const auto &val : embedding) {
1101
- pushDouble(env, embeddings, (double) val);
1102
- }
1103
- putArray(env, result, "embedding", embeddings);
1104
-
1105
- auto promptTokens = createWritableArray(env);
1106
- for (const auto &tok : llama->embd) {
1107
- pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
1108
- }
1109
- putArray(env, result, "prompt_tokens", promptTokens);
1110
-
1111
- env->ReleaseStringUTFChars(text, text_chars);
1112
- return result;
1113
- }
1114
-
1115
- JNIEXPORT jstring JNICALL
1116
- Java_com_rnllama_LlamaContext_bench(
1117
- JNIEnv *env,
1118
- jobject thiz,
1119
- jlong context_ptr,
1120
- jint pp,
1121
- jint tg,
1122
- jint pl,
1123
- jint nr
1124
- ) {
1125
- UNUSED(thiz);
1126
- auto llama = context_map[(long) context_ptr];
1127
- std::string result = llama->bench(pp, tg, pl, nr);
1128
- return env->NewStringUTF(result.c_str());
1129
- }
1130
-
1131
- JNIEXPORT jint JNICALL
1132
- Java_com_rnllama_LlamaContext_applyLoraAdapters(
1133
- JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
1134
- UNUSED(thiz);
1135
- auto llama = context_map[(long) context_ptr];
1136
-
1137
- // lora_adapters: ReadableArray<ReadableMap>
1138
- std::vector<common_adapter_lora_info> lora_adapters;
1139
- int lora_adapters_size = readablearray::size(env, loraAdapters);
1140
- for (int i = 0; i < lora_adapters_size; i++) {
1141
- jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
1142
- jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
1143
- if (path != nullptr) {
1144
- const char *path_chars = env->GetStringUTFChars(path, nullptr);
1145
- env->ReleaseStringUTFChars(path, path_chars);
1146
- float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
1147
- common_adapter_lora_info la;
1148
- la.path = path_chars;
1149
- la.scale = scaled;
1150
- lora_adapters.push_back(la);
1151
- }
1152
- }
1153
- return llama->applyLoraAdapters(lora_adapters);
1154
- }
1155
-
1156
- JNIEXPORT void JNICALL
1157
- Java_com_rnllama_LlamaContext_removeLoraAdapters(
1158
- JNIEnv *env, jobject thiz, jlong context_ptr) {
1159
- UNUSED(env);
1160
- UNUSED(thiz);
1161
- auto llama = context_map[(long) context_ptr];
1162
- llama->removeLoraAdapters();
1163
- }
1164
-
1165
- JNIEXPORT jobject JNICALL
1166
- Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
1167
- JNIEnv *env, jobject thiz, jlong context_ptr) {
1168
- UNUSED(thiz);
1169
- auto llama = context_map[(long) context_ptr];
1170
- auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
1171
- auto result = createWritableArray(env);
1172
- for (common_adapter_lora_info &la : loaded_lora_adapters) {
1173
- auto map = createWriteableMap(env);
1174
- putString(env, map, "path", la.path.c_str());
1175
- putDouble(env, map, "scaled", la.scale);
1176
- pushMap(env, result, map);
1177
- }
1178
- return result;
1179
- }
1180
-
1181
- JNIEXPORT void JNICALL
1182
- Java_com_rnllama_LlamaContext_freeContext(
1183
- JNIEnv *env, jobject thiz, jlong context_ptr) {
1184
- UNUSED(env);
1185
- UNUSED(thiz);
1186
- auto llama = context_map[(long) context_ptr];
1187
- context_map.erase((long) llama->ctx);
1188
- delete llama;
1189
- }
1190
-
1191
- struct log_callback_context {
1192
- JavaVM *jvm;
1193
- jobject callback;
1194
- };
1195
-
1196
- static void rnllama_log_callback_to_j(lm_ggml_log_level level, const char * text, void * data) {
1197
- auto level_c = "";
1198
- if (level == LM_GGML_LOG_LEVEL_ERROR) {
1199
- __android_log_print(ANDROID_LOG_ERROR, TAG, text, nullptr);
1200
- level_c = "error";
1201
- } else if (level == LM_GGML_LOG_LEVEL_INFO) {
1202
- __android_log_print(ANDROID_LOG_INFO, TAG, text, nullptr);
1203
- level_c = "info";
1204
- } else if (level == LM_GGML_LOG_LEVEL_WARN) {
1205
- __android_log_print(ANDROID_LOG_WARN, TAG, text, nullptr);
1206
- level_c = "warn";
1207
- } else {
1208
- __android_log_print(ANDROID_LOG_DEFAULT, TAG, text, nullptr);
1209
- }
1210
-
1211
- log_callback_context *cb_ctx = (log_callback_context *) data;
1212
-
1213
- JNIEnv *env;
1214
- bool need_detach = false;
1215
- int getEnvResult = cb_ctx->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
1216
-
1217
- if (getEnvResult == JNI_EDETACHED) {
1218
- if (cb_ctx->jvm->AttachCurrentThread(&env, nullptr) == JNI_OK) {
1219
- need_detach = true;
1220
- } else {
1221
- return;
1222
- }
1223
- } else if (getEnvResult != JNI_OK) {
1224
- return;
1225
- }
1226
-
1227
- jobject callback = cb_ctx->callback;
1228
- jclass cb_class = env->GetObjectClass(callback);
1229
- jmethodID emitNativeLog = env->GetMethodID(cb_class, "emitNativeLog", "(Ljava/lang/String;Ljava/lang/String;)V");
1230
-
1231
- jstring level_str = env->NewStringUTF(level_c);
1232
- jstring text_str = env->NewStringUTF(text);
1233
- env->CallVoidMethod(callback, emitNativeLog, level_str, text_str);
1234
- env->DeleteLocalRef(level_str);
1235
- env->DeleteLocalRef(text_str);
1236
-
1237
- if (need_detach) {
1238
- cb_ctx->jvm->DetachCurrentThread();
1239
- }
1240
- }
1241
-
1242
- JNIEXPORT void JNICALL
1243
- Java_com_rnllama_LlamaContext_setupLog(JNIEnv *env, jobject thiz, jobject logCallback) {
1244
- UNUSED(thiz);
1245
-
1246
- log_callback_context *cb_ctx = new log_callback_context;
1247
-
1248
- JavaVM *jvm;
1249
- env->GetJavaVM(&jvm);
1250
- cb_ctx->jvm = jvm;
1251
- cb_ctx->callback = env->NewGlobalRef(logCallback);
1252
-
1253
- llama_log_set(rnllama_log_callback_to_j, cb_ctx);
1254
- }
1255
-
1256
- JNIEXPORT void JNICALL
1257
- Java_com_rnllama_LlamaContext_unsetLog(JNIEnv *env, jobject thiz) {
1258
- UNUSED(env);
1259
- UNUSED(thiz);
1260
- llama_log_set(rnllama_log_callback_default, NULL);
1261
- }
1262
-
1263
- } // extern "C"
1
+ #include <jni.h>
2
+ // #include <android/asset_manager.h>
3
+ // #include <android/asset_manager_jni.h>
4
+ #include <android/log.h>
5
+ #include <cstdlib>
6
+ #include <ctime>
7
+ #include <ctime>
8
+ #include <sys/sysinfo.h>
9
+ #include <string>
10
+ #include <thread>
11
+ #include <unordered_map>
12
+ #include "json-schema-to-grammar.h"
13
+ #include "llama.h"
14
+ #include "chat.h"
15
+ #include "llama-impl.h"
16
+ #include "ggml.h"
17
+ #include "rn-llama.h"
18
+ #include "jni-utils.h"
19
+ #define UNUSED(x) (void)(x)
20
+ #define TAG "RNLLAMA_ANDROID_JNI"
21
+
22
+ #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
23
+ #define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
24
+ #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
25
+ static inline int min(int a, int b) {
26
+ return (a < b) ? a : b;
27
+ }
28
+
29
+ static void rnllama_log_callback_default(lm_ggml_log_level level, const char * fmt, void * data) {
30
+ if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
31
+ else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
32
+ else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
33
+ else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
34
+ }
35
+
36
+ extern "C" {
37
+
38
+ // Method to create WritableMap
39
+ static inline jobject createWriteableMap(JNIEnv *env) {
40
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
41
+ jmethodID init = env->GetStaticMethodID(mapClass, "createMap", "()Lcom/facebook/react/bridge/WritableMap;");
42
+ jobject map = env->CallStaticObjectMethod(mapClass, init);
43
+ return map;
44
+ }
45
+
46
+ // Method to put string into WritableMap
47
+ static inline void putString(JNIEnv *env, jobject map, const char *key, const char *value) {
48
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
49
+ jmethodID putStringMethod = env->GetMethodID(mapClass, "putString", "(Ljava/lang/String;Ljava/lang/String;)V");
50
+
51
+ jstring jKey = env->NewStringUTF(key);
52
+ jstring jValue = env->NewStringUTF(value);
53
+
54
+ env->CallVoidMethod(map, putStringMethod, jKey, jValue);
55
+ }
56
+
57
+ // Method to put int into WritableMap
58
+ static inline void putInt(JNIEnv *env, jobject map, const char *key, int value) {
59
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
60
+ jmethodID putIntMethod = env->GetMethodID(mapClass, "putInt", "(Ljava/lang/String;I)V");
61
+
62
+ jstring jKey = env->NewStringUTF(key);
63
+
64
+ env->CallVoidMethod(map, putIntMethod, jKey, value);
65
+ }
66
+
67
+ // Method to put double into WritableMap
68
+ static inline void putDouble(JNIEnv *env, jobject map, const char *key, double value) {
69
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
70
+ jmethodID putDoubleMethod = env->GetMethodID(mapClass, "putDouble", "(Ljava/lang/String;D)V");
71
+
72
+ jstring jKey = env->NewStringUTF(key);
73
+
74
+ env->CallVoidMethod(map, putDoubleMethod, jKey, value);
75
+ }
76
+
77
+ // Method to put boolean into WritableMap
78
+ static inline void putBoolean(JNIEnv *env, jobject map, const char *key, bool value) {
79
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
80
+ jmethodID putBooleanMethod = env->GetMethodID(mapClass, "putBoolean", "(Ljava/lang/String;Z)V");
81
+
82
+ jstring jKey = env->NewStringUTF(key);
83
+
84
+ env->CallVoidMethod(map, putBooleanMethod, jKey, value);
85
+ }
86
+
87
+ // Method to put WriteableMap into WritableMap
88
+ static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
89
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
90
+ jmethodID putMapMethod = env->GetMethodID(mapClass, "putMap", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableMap;)V");
91
+
92
+ jstring jKey = env->NewStringUTF(key);
93
+
94
+ env->CallVoidMethod(map, putMapMethod, jKey, value);
95
+ }
96
+
97
+ // Method to create WritableArray
98
+ static inline jobject createWritableArray(JNIEnv *env) {
99
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
100
+ jmethodID init = env->GetStaticMethodID(mapClass, "createArray", "()Lcom/facebook/react/bridge/WritableArray;");
101
+ jobject map = env->CallStaticObjectMethod(mapClass, init);
102
+ return map;
103
+ }
104
+
105
+ // Method to push int into WritableArray
106
+ static inline void pushInt(JNIEnv *env, jobject arr, int value) {
107
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
108
+ jmethodID pushIntMethod = env->GetMethodID(mapClass, "pushInt", "(I)V");
109
+
110
+ env->CallVoidMethod(arr, pushIntMethod, value);
111
+ }
112
+
113
+ // Method to push double into WritableArray
114
+ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
115
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
116
+ jmethodID pushDoubleMethod = env->GetMethodID(mapClass, "pushDouble", "(D)V");
117
+
118
+ env->CallVoidMethod(arr, pushDoubleMethod, value);
119
+ }
120
+
121
+ // Method to push string into WritableArray
122
+ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
123
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
124
+ jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
125
+
126
+ jstring jValue = env->NewStringUTF(value);
127
+ env->CallVoidMethod(arr, pushStringMethod, jValue);
128
+ }
129
+
130
+ // Method to push WritableMap into WritableArray
131
+ static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
132
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
133
+ jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
134
+
135
+ env->CallVoidMethod(arr, pushMapMethod, value);
136
+ }
137
+
138
+ // Method to put WritableArray into WritableMap
139
+ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject value) {
140
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
141
+ jmethodID putArrayMethod = env->GetMethodID(mapClass, "putArray", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableArray;)V");
142
+
143
+ jstring jKey = env->NewStringUTF(key);
144
+
145
+ env->CallVoidMethod(map, putArrayMethod, jKey, value);
146
+ }
147
+
148
+ JNIEXPORT jobject JNICALL
149
+ Java_com_rnllama_LlamaContext_modelInfo(
150
+ JNIEnv *env,
151
+ jobject thiz,
152
+ jstring model_path_str,
153
+ jobjectArray skip
154
+ ) {
155
+ UNUSED(thiz);
156
+
157
+ const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
158
+
159
+ std::vector<std::string> skip_vec;
160
+ int skip_len = env->GetArrayLength(skip);
161
+ for (int i = 0; i < skip_len; i++) {
162
+ jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
163
+ const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
164
+ skip_vec.push_back(skip_chars);
165
+ env->ReleaseStringUTFChars(skip_str, skip_chars);
166
+ }
167
+
168
+ struct lm_gguf_init_params params = {
169
+ /*.no_alloc = */ false,
170
+ /*.ctx = */ NULL,
171
+ };
172
+ struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
173
+
174
+ if (!ctx) {
175
+ LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
176
+ return nullptr;
177
+ }
178
+
179
+ auto info = createWriteableMap(env);
180
+ putInt(env, info, "version", lm_gguf_get_version(ctx));
181
+ putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
182
+ putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
183
+ {
184
+ const int n_kv = lm_gguf_get_n_kv(ctx);
185
+
186
+ for (int i = 0; i < n_kv; ++i) {
187
+ const char * key = lm_gguf_get_key(ctx, i);
188
+
189
+ bool skipped = false;
190
+ if (skip_len > 0) {
191
+ for (int j = 0; j < skip_len; j++) {
192
+ if (skip_vec[j] == key) {
193
+ skipped = true;
194
+ break;
195
+ }
196
+ }
197
+ }
198
+
199
+ if (skipped) {
200
+ continue;
201
+ }
202
+
203
+ const std::string value = lm_gguf_kv_to_str(ctx, i);
204
+ putString(env, info, key, value.c_str());
205
+ }
206
+ }
207
+
208
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
209
+ lm_gguf_free(ctx);
210
+
211
+ return reinterpret_cast<jobject>(info);
212
+ }
213
+
214
+ struct callback_context {
215
+ JNIEnv *env;
216
+ rnllama::llama_rn_context *llama;
217
+ jobject callback;
218
+ };
219
+
220
+ std::unordered_map<long, rnllama::llama_rn_context *> context_map;
221
+
222
+ struct CallbackContext {
223
+ JNIEnv * env;
224
+ jobject thiz;
225
+ jmethodID sendProgressMethod;
226
+ unsigned current;
227
+ };
228
+
229
+ JNIEXPORT jlong JNICALL
230
+ Java_com_rnllama_LlamaContext_initContext(
231
+ JNIEnv *env,
232
+ jobject thiz,
233
+ jstring model_path_str,
234
+ jstring chat_template,
235
+ jstring reasoning_format,
236
+ jboolean embedding,
237
+ jint embd_normalize,
238
+ jint n_ctx,
239
+ jint n_batch,
240
+ jint n_ubatch,
241
+ jint n_threads,
242
+ jint n_gpu_layers, // TODO: Support this
243
+ jboolean flash_attn,
244
+ jstring cache_type_k,
245
+ jstring cache_type_v,
246
+ jboolean use_mlock,
247
+ jboolean use_mmap,
248
+ jboolean vocab_only,
249
+ jstring lora_str,
250
+ jfloat lora_scaled,
251
+ jobject lora_list,
252
+ jfloat rope_freq_base,
253
+ jfloat rope_freq_scale,
254
+ jint pooling_type,
255
+ jobject load_progress_callback
256
+ ) {
257
+ UNUSED(thiz);
258
+
259
+ common_params defaultParams;
260
+
261
+ defaultParams.vocab_only = vocab_only;
262
+ if(vocab_only) {
263
+ defaultParams.warmup = false;
264
+ }
265
+
266
+ const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
267
+ defaultParams.model = { model_path_chars };
268
+
269
+ const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
270
+ defaultParams.chat_template = chat_template_chars;
271
+
272
+ const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
273
+ if (strcmp(reasoning_format_chars, "deepseek") == 0) {
274
+ defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
275
+ } else {
276
+ defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
277
+ }
278
+
279
+ defaultParams.n_ctx = n_ctx;
280
+ defaultParams.n_batch = n_batch;
281
+ defaultParams.n_ubatch = n_ubatch;
282
+
283
+ if (pooling_type != -1) {
284
+ defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
285
+ }
286
+
287
+ defaultParams.embedding = embedding;
288
+ if (embd_normalize != -1) {
289
+ defaultParams.embd_normalize = embd_normalize;
290
+ }
291
+ if (embedding) {
292
+ // For non-causal models, batch size must be equal to ubatch size
293
+ defaultParams.n_ubatch = defaultParams.n_batch;
294
+ }
295
+
296
+ int max_threads = std::thread::hardware_concurrency();
297
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
298
+ int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
299
+ defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
300
+
301
+ // defaultParams.n_gpu_layers = n_gpu_layers;
302
+ defaultParams.flash_attn = flash_attn;
303
+
304
+ const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
305
+ const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
306
+ defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
307
+ defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
308
+
309
+ defaultParams.use_mlock = use_mlock;
310
+ defaultParams.use_mmap = use_mmap;
311
+
312
+ defaultParams.rope_freq_base = rope_freq_base;
313
+ defaultParams.rope_freq_scale = rope_freq_scale;
314
+
315
+ auto llama = new rnllama::llama_rn_context();
316
+ llama->is_load_interrupted = false;
317
+ llama->loading_progress = 0;
318
+
319
+ if (load_progress_callback != nullptr) {
320
+ defaultParams.progress_callback = [](float progress, void * user_data) {
321
+ callback_context *cb_ctx = (callback_context *)user_data;
322
+ JNIEnv *env = cb_ctx->env;
323
+ auto llama = cb_ctx->llama;
324
+ jobject callback = cb_ctx->callback;
325
+ int percentage = (int) (100 * progress);
326
+ if (percentage > llama->loading_progress) {
327
+ llama->loading_progress = percentage;
328
+ jclass callback_class = env->GetObjectClass(callback);
329
+ jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
330
+ env->CallVoidMethod(callback, onLoadProgress, percentage);
331
+ }
332
+ return !llama->is_load_interrupted;
333
+ };
334
+
335
+ callback_context *cb_ctx = new callback_context;
336
+ cb_ctx->env = env;
337
+ cb_ctx->llama = llama;
338
+ cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
339
+ defaultParams.progress_callback_user_data = cb_ctx;
340
+ }
341
+
342
+ bool is_model_loaded = llama->loadModel(defaultParams);
343
+
344
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
345
+ env->ReleaseStringUTFChars(chat_template, chat_template_chars);
346
+ env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
347
+ env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
348
+ env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
349
+
350
+ LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
351
+ if (is_model_loaded) {
352
+ if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
353
+ LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
354
+ llama_free(llama->ctx);
355
+ return -1;
356
+ }
357
+ context_map[(long) llama->ctx] = llama;
358
+ } else {
359
+ llama_free(llama->ctx);
360
+ }
361
+
362
+ std::vector<common_adapter_lora_info> lora;
363
+ const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
364
+ if (lora_chars != nullptr && lora_chars[0] != '\0') {
365
+ common_adapter_lora_info la;
366
+ la.path = lora_chars;
367
+ la.scale = lora_scaled;
368
+ lora.push_back(la);
369
+ }
370
+
371
+ if (lora_list != nullptr) {
372
+ // lora_adapters: ReadableArray<ReadableMap>
373
+ int lora_list_size = readablearray::size(env, lora_list);
374
+ for (int i = 0; i < lora_list_size; i++) {
375
+ jobject lora_adapter = readablearray::getMap(env, lora_list, i);
376
+ jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
377
+ if (path != nullptr) {
378
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
379
+ common_adapter_lora_info la;
380
+ la.path = path_chars;
381
+ la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
382
+ lora.push_back(la);
383
+ env->ReleaseStringUTFChars(path, path_chars);
384
+ }
385
+ }
386
+ }
387
+ env->ReleaseStringUTFChars(lora_str, lora_chars);
388
+ int result = llama->applyLoraAdapters(lora);
389
+ if (result != 0) {
390
+ LOGI("[RNLlama] Failed to apply lora adapters");
391
+ llama_free(llama->ctx);
392
+ return -1;
393
+ }
394
+
395
+ return reinterpret_cast<jlong>(llama->ctx);
396
+ }
397
+
398
+
399
+ JNIEXPORT void JNICALL
400
+ Java_com_rnllama_LlamaContext_interruptLoad(
401
+ JNIEnv *env,
402
+ jobject thiz,
403
+ jlong context_ptr
404
+ ) {
405
+ UNUSED(thiz);
406
+ auto llama = context_map[(long) context_ptr];
407
+ if (llama) {
408
+ llama->is_load_interrupted = true;
409
+ }
410
+ }
411
+
412
+ JNIEXPORT jobject JNICALL
413
+ Java_com_rnllama_LlamaContext_loadModelDetails(
414
+ JNIEnv *env,
415
+ jobject thiz,
416
+ jlong context_ptr
417
+ ) {
418
+ UNUSED(thiz);
419
+ auto llama = context_map[(long) context_ptr];
420
+
421
+ int count = llama_model_meta_count(llama->model);
422
+ auto meta = createWriteableMap(env);
423
+ for (int i = 0; i < count; i++) {
424
+ char key[256];
425
+ llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
426
+ char val[4096];
427
+ llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
428
+
429
+ putString(env, meta, key, val);
430
+ }
431
+
432
+ auto result = createWriteableMap(env);
433
+
434
+ char desc[1024];
435
+ llama_model_desc(llama->model, desc, sizeof(desc));
436
+
437
+ putString(env, result, "desc", desc);
438
+ putDouble(env, result, "size", llama_model_size(llama->model));
439
+ putDouble(env, result, "nEmbd", llama_model_n_embd(llama->model));
440
+ putDouble(env, result, "nParams", llama_model_n_params(llama->model));
441
+ auto chat_templates = createWriteableMap(env);
442
+ putBoolean(env, chat_templates, "llamaChat", llama->validateModelChatTemplate(false, nullptr));
443
+
444
+ auto minja = createWriteableMap(env);
445
+ putBoolean(env, minja, "default", llama->validateModelChatTemplate(true, nullptr));
446
+
447
+ auto default_caps = createWriteableMap(env);
448
+
449
+ auto default_tmpl = llama->templates.get()->template_default.get();
450
+ auto default_tmpl_caps = default_tmpl->original_caps();
451
+ putBoolean(env, default_caps, "tools", default_tmpl_caps.supports_tools);
452
+ putBoolean(env, default_caps, "toolCalls", default_tmpl_caps.supports_tool_calls);
453
+ putBoolean(env, default_caps, "parallelToolCalls", default_tmpl_caps.supports_parallel_tool_calls);
454
+ putBoolean(env, default_caps, "toolResponses", default_tmpl_caps.supports_tool_responses);
455
+ putBoolean(env, default_caps, "systemRole", default_tmpl_caps.supports_system_role);
456
+ putBoolean(env, default_caps, "toolCallId", default_tmpl_caps.supports_tool_call_id);
457
+ putMap(env, minja, "defaultCaps", default_caps);
458
+
459
+ putBoolean(env, minja, "toolUse", llama->validateModelChatTemplate(true, "tool_use"));
460
+ auto tool_use_tmpl = llama->templates.get()->template_tool_use.get();
461
+ if (tool_use_tmpl != nullptr) {
462
+ auto tool_use_caps = createWriteableMap(env);
463
+ auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
464
+ putBoolean(env, tool_use_caps, "tools", tool_use_tmpl_caps.supports_tools);
465
+ putBoolean(env, tool_use_caps, "toolCalls", tool_use_tmpl_caps.supports_tool_calls);
466
+ putBoolean(env, tool_use_caps, "parallelToolCalls", tool_use_tmpl_caps.supports_parallel_tool_calls);
467
+ putBoolean(env, tool_use_caps, "systemRole", tool_use_tmpl_caps.supports_system_role);
468
+ putBoolean(env, tool_use_caps, "toolResponses", tool_use_tmpl_caps.supports_tool_responses);
469
+ putBoolean(env, tool_use_caps, "toolCallId", tool_use_tmpl_caps.supports_tool_call_id);
470
+ putMap(env, minja, "toolUseCaps", tool_use_caps);
471
+ }
472
+
473
+ putMap(env, chat_templates, "minja", minja);
474
+ putMap(env, result, "metadata", meta);
475
+ putMap(env, result, "chatTemplates", chat_templates);
476
+
477
+ // deprecated
478
+ putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate(false, nullptr));
479
+
480
+ return reinterpret_cast<jobject>(result);
481
+ }
482
+
483
+ JNIEXPORT jobject JNICALL
484
+ Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
485
+ JNIEnv *env,
486
+ jobject thiz,
487
+ jlong context_ptr,
488
+ jstring messages,
489
+ jstring chat_template,
490
+ jstring json_schema,
491
+ jstring tools,
492
+ jboolean parallel_tool_calls,
493
+ jstring tool_choice
494
+ ) {
495
+ UNUSED(thiz);
496
+ auto llama = context_map[(long) context_ptr];
497
+
498
+ const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
499
+ const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
500
+ const char *json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
501
+ const char *tools_chars = env->GetStringUTFChars(tools, nullptr);
502
+ const char *tool_choice_chars = env->GetStringUTFChars(tool_choice, nullptr);
503
+
504
+ auto result = createWriteableMap(env);
505
+ try {
506
+ auto formatted = llama->getFormattedChatWithJinja(
507
+ messages_chars,
508
+ tmpl_chars,
509
+ json_schema_chars,
510
+ tools_chars,
511
+ parallel_tool_calls,
512
+ tool_choice_chars
513
+ );
514
+ putString(env, result, "prompt", formatted.prompt.c_str());
515
+ putInt(env, result, "chat_format", static_cast<int>(formatted.format));
516
+ putString(env, result, "grammar", formatted.grammar.c_str());
517
+ putBoolean(env, result, "grammar_lazy", formatted.grammar_lazy);
518
+ auto grammar_triggers = createWritableArray(env);
519
+ for (const auto &trigger : formatted.grammar_triggers) {
520
+ auto trigger_map = createWriteableMap(env);
521
+ putInt(env, trigger_map, "type", trigger.type);
522
+ putString(env, trigger_map, "value", trigger.value.c_str());
523
+ putInt(env, trigger_map, "token", trigger.token);
524
+ pushMap(env, grammar_triggers, trigger_map);
525
+ }
526
+ putArray(env, result, "grammar_triggers", grammar_triggers);
527
+ auto preserved_tokens = createWritableArray(env);
528
+ for (const auto &token : formatted.preserved_tokens) {
529
+ pushString(env, preserved_tokens, token.c_str());
530
+ }
531
+ putArray(env, result, "preserved_tokens", preserved_tokens);
532
+ auto additional_stops = createWritableArray(env);
533
+ for (const auto &stop : formatted.additional_stops) {
534
+ pushString(env, additional_stops, stop.c_str());
535
+ }
536
+ putArray(env, result, "additional_stops", additional_stops);
537
+ } catch (const std::runtime_error &e) {
538
+ LOGI("[RNLlama] Error: %s", e.what());
539
+ putString(env, result, "_error", e.what());
540
+ }
541
+ env->ReleaseStringUTFChars(tools, tools_chars);
542
+ env->ReleaseStringUTFChars(messages, messages_chars);
543
+ env->ReleaseStringUTFChars(chat_template, tmpl_chars);
544
+ env->ReleaseStringUTFChars(json_schema, json_schema_chars);
545
+ env->ReleaseStringUTFChars(tool_choice, tool_choice_chars);
546
+ return reinterpret_cast<jobject>(result);
547
+ }
548
+
549
+ JNIEXPORT jobject JNICALL
550
+ Java_com_rnllama_LlamaContext_getFormattedChat(
551
+ JNIEnv *env,
552
+ jobject thiz,
553
+ jlong context_ptr,
554
+ jstring messages,
555
+ jstring chat_template
556
+ ) {
557
+ UNUSED(thiz);
558
+ auto llama = context_map[(long) context_ptr];
559
+
560
+ const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
561
+ const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
562
+
563
+ std::string formatted_chat = llama->getFormattedChat(messages_chars, tmpl_chars);
564
+
565
+ env->ReleaseStringUTFChars(messages, messages_chars);
566
+ env->ReleaseStringUTFChars(chat_template, tmpl_chars);
567
+
568
+ return env->NewStringUTF(formatted_chat.c_str());
569
+ }
570
+
571
+ JNIEXPORT jobject JNICALL
572
+ Java_com_rnllama_LlamaContext_loadSession(
573
+ JNIEnv *env,
574
+ jobject thiz,
575
+ jlong context_ptr,
576
+ jstring path
577
+ ) {
578
+ UNUSED(thiz);
579
+ auto llama = context_map[(long) context_ptr];
580
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
581
+
582
+ auto result = createWriteableMap(env);
583
+ size_t n_token_count_out = 0;
584
+ llama->embd.resize(llama->params.n_ctx);
585
+ if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
586
+ env->ReleaseStringUTFChars(path, path_chars);
587
+
588
+ putString(env, result, "error", "Failed to load session");
589
+ return reinterpret_cast<jobject>(result);
590
+ }
591
+ llama->embd.resize(n_token_count_out);
592
+ env->ReleaseStringUTFChars(path, path_chars);
593
+
594
+ const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
595
+ putInt(env, result, "tokens_loaded", n_token_count_out);
596
+ putString(env, result, "prompt", text.c_str());
597
+ return reinterpret_cast<jobject>(result);
598
+ }
599
+
600
+ JNIEXPORT jint JNICALL
601
+ Java_com_rnllama_LlamaContext_saveSession(
602
+ JNIEnv *env,
603
+ jobject thiz,
604
+ jlong context_ptr,
605
+ jstring path,
606
+ jint size
607
+ ) {
608
+ UNUSED(thiz);
609
+ auto llama = context_map[(long) context_ptr];
610
+
611
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
612
+
613
+ std::vector<llama_token> session_tokens = llama->embd;
614
+ int default_size = session_tokens.size();
615
+ int save_size = size > 0 && size <= default_size ? size : default_size;
616
+ if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
617
+ env->ReleaseStringUTFChars(path, path_chars);
618
+ return -1;
619
+ }
620
+
621
+ env->ReleaseStringUTFChars(path, path_chars);
622
+ return session_tokens.size();
623
+ }
624
+
625
+ static inline jobject tokenProbsToMap(
626
+ JNIEnv *env,
627
+ rnllama::llama_rn_context *llama,
628
+ std::vector<rnllama::completion_token_output> probs
629
+ ) {
630
+ auto result = createWritableArray(env);
631
+ for (const auto &prob : probs) {
632
+ auto probsForToken = createWritableArray(env);
633
+ for (const auto &p : prob.probs) {
634
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
635
+ auto probResult = createWriteableMap(env);
636
+ putString(env, probResult, "tok_str", tokStr.c_str());
637
+ putDouble(env, probResult, "prob", p.prob);
638
+ pushMap(env, probsForToken, probResult);
639
+ }
640
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
641
+ auto tokenResult = createWriteableMap(env);
642
+ putString(env, tokenResult, "content", tokStr.c_str());
643
+ putArray(env, tokenResult, "probs", probsForToken);
644
+ pushMap(env, result, tokenResult);
645
+ }
646
+ return result;
647
+ }
648
+
649
+ JNIEXPORT jobject JNICALL
650
+ Java_com_rnllama_LlamaContext_doCompletion(
651
+ JNIEnv *env,
652
+ jobject thiz,
653
+ jlong context_ptr,
654
+ jstring prompt,
655
+ jint chat_format,
656
+ jstring grammar,
657
+ jstring json_schema,
658
+ jboolean grammar_lazy,
659
+ jobject grammar_triggers,
660
+ jobject preserved_tokens,
661
+ jfloat temperature,
662
+ jint n_threads,
663
+ jint n_predict,
664
+ jint n_probs,
665
+ jint penalty_last_n,
666
+ jfloat penalty_repeat,
667
+ jfloat penalty_freq,
668
+ jfloat penalty_present,
669
+ jfloat mirostat,
670
+ jfloat mirostat_tau,
671
+ jfloat mirostat_eta,
672
+ jint top_k,
673
+ jfloat top_p,
674
+ jfloat min_p,
675
+ jfloat xtc_threshold,
676
+ jfloat xtc_probability,
677
+ jfloat typical_p,
678
+ jint seed,
679
+ jobjectArray stop,
680
+ jboolean ignore_eos,
681
+ jobjectArray logit_bias,
682
+ jfloat dry_multiplier,
683
+ jfloat dry_base,
684
+ jint dry_allowed_length,
685
+ jint dry_penalty_last_n,
686
+ jfloat top_n_sigma,
687
+ jobjectArray dry_sequence_breakers,
688
+ jobject partial_completion_callback
689
+ ) {
690
+ UNUSED(thiz);
691
+ auto llama = context_map[(long) context_ptr];
692
+
693
+ llama->rewind();
694
+
695
+ //llama_reset_timings(llama->ctx);
696
+
697
+ auto prompt_chars = env->GetStringUTFChars(prompt, nullptr);
698
+ llama->params.prompt = prompt_chars;
699
+ llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
700
+
701
+ int max_threads = std::thread::hardware_concurrency();
702
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
703
+ int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
704
+ llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
705
+
706
+ llama->params.n_predict = n_predict;
707
+ llama->params.sampling.ignore_eos = ignore_eos;
708
+
709
+ auto & sparams = llama->params.sampling;
710
+ sparams.temp = temperature;
711
+ sparams.penalty_last_n = penalty_last_n;
712
+ sparams.penalty_repeat = penalty_repeat;
713
+ sparams.penalty_freq = penalty_freq;
714
+ sparams.penalty_present = penalty_present;
715
+ sparams.mirostat = mirostat;
716
+ sparams.mirostat_tau = mirostat_tau;
717
+ sparams.mirostat_eta = mirostat_eta;
718
+ sparams.top_k = top_k;
719
+ sparams.top_p = top_p;
720
+ sparams.min_p = min_p;
721
+ sparams.typ_p = typical_p;
722
+ sparams.n_probs = n_probs;
723
+ sparams.xtc_threshold = xtc_threshold;
724
+ sparams.xtc_probability = xtc_probability;
725
+ sparams.dry_multiplier = dry_multiplier;
726
+ sparams.dry_base = dry_base;
727
+ sparams.dry_allowed_length = dry_allowed_length;
728
+ sparams.dry_penalty_last_n = dry_penalty_last_n;
729
+ sparams.top_n_sigma = top_n_sigma;
730
+
731
+ // grammar
732
+ auto grammar_chars = env->GetStringUTFChars(grammar, nullptr);
733
+ if (grammar_chars && grammar_chars[0] != '\0') {
734
+ sparams.grammar = grammar_chars;
735
+ }
736
+ sparams.grammar_lazy = grammar_lazy;
737
+
738
+ if (preserved_tokens != nullptr) {
739
+ int preserved_tokens_size = readablearray::size(env, preserved_tokens);
740
+ for (int i = 0; i < preserved_tokens_size; i++) {
741
+ jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
742
+ auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
743
+ if (ids.size() == 1) {
744
+ sparams.preserved_tokens.insert(ids[0]);
745
+ } else {
746
+ LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
747
+ }
748
+ }
749
+ }
750
+
751
+ if (grammar_triggers != nullptr) {
752
+ int grammar_triggers_size = readablearray::size(env, grammar_triggers);
753
+ for (int i = 0; i < grammar_triggers_size; i++) {
754
+ auto trigger_map = readablearray::getMap(env, grammar_triggers, i);
755
+ const auto type = static_cast<common_grammar_trigger_type>(readablemap::getInt(env, trigger_map, "type", 0));
756
+ jstring trigger_word = readablemap::getString(env, trigger_map, "value", nullptr);
757
+ auto word = env->GetStringUTFChars(trigger_word, nullptr);
758
+
759
+ if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
760
+ auto ids = common_tokenize(llama->ctx, word, /* add_special= */ false, /* parse_special= */ true);
761
+ if (ids.size() == 1) {
762
+ auto token = ids[0];
763
+ if (std::find(sparams.preserved_tokens.begin(), sparams.preserved_tokens.end(), (llama_token) token) == sparams.preserved_tokens.end()) {
764
+ throw std::runtime_error("Grammar trigger word should be marked as preserved token");
765
+ }
766
+ common_grammar_trigger trigger;
767
+ trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
768
+ trigger.value = word;
769
+ trigger.token = token;
770
+ sparams.grammar_triggers.push_back(std::move(trigger));
771
+ } else {
772
+ sparams.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
773
+ }
774
+ } else {
775
+ common_grammar_trigger trigger;
776
+ trigger.type = type;
777
+ trigger.value = word;
778
+ if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
779
+ const auto token = (llama_token) readablemap::getInt(env, trigger_map, "token", 0);
780
+ trigger.token = token;
781
+ }
782
+ sparams.grammar_triggers.push_back(std::move(trigger));
783
+ }
784
+ }
785
+ }
786
+
787
+ auto json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
788
+ if ((!grammar_chars || grammar_chars[0] == '\0') && json_schema_chars && json_schema_chars[0] != '\0') {
789
+ auto schema = json::parse(json_schema_chars);
790
+ sparams.grammar = json_schema_to_grammar(schema);
791
+ }
792
+ env->ReleaseStringUTFChars(json_schema, json_schema_chars);
793
+
794
+
795
+ const llama_model * model = llama_get_model(llama->ctx);
796
+ const llama_vocab * vocab = llama_model_get_vocab(model);
797
+
798
+ sparams.logit_bias.clear();
799
+ if (ignore_eos) {
800
+ sparams.logit_bias[llama_vocab_eos(vocab)].bias = -INFINITY;
801
+ }
802
+
803
+ // dry break seq
804
+
805
+ jint size = env->GetArrayLength(dry_sequence_breakers);
806
+ std::vector<std::string> dry_sequence_breakers_vector;
807
+
808
+ for (jint i = 0; i < size; i++) {
809
+ jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
810
+ const char *nativeString = env->GetStringUTFChars(javaString, 0);
811
+ dry_sequence_breakers_vector.push_back(std::string(nativeString));
812
+ env->ReleaseStringUTFChars(javaString, nativeString);
813
+ env->DeleteLocalRef(javaString);
814
+ }
815
+
816
+ sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
817
+
818
+ // logit bias
819
+ const int n_vocab = llama_vocab_n_tokens(vocab);
820
+ jsize logit_bias_len = env->GetArrayLength(logit_bias);
821
+
822
+ for (jsize i = 0; i < logit_bias_len; i++) {
823
+ jdoubleArray el = (jdoubleArray) env->GetObjectArrayElement(logit_bias, i);
824
+ if (el && env->GetArrayLength(el) == 2) {
825
+ jdouble* doubleArray = env->GetDoubleArrayElements(el, 0);
826
+
827
+ llama_token tok = static_cast<llama_token>(doubleArray[0]);
828
+ if (tok >= 0 && tok < n_vocab) {
829
+ if (doubleArray[1] != 0) { // If the second element is not false (0)
830
+ sparams.logit_bias[tok].bias = doubleArray[1];
831
+ } else {
832
+ sparams.logit_bias[tok].bias = -INFINITY;
833
+ }
834
+ }
835
+
836
+ env->ReleaseDoubleArrayElements(el, doubleArray, 0);
837
+ }
838
+ env->DeleteLocalRef(el);
839
+ }
840
+
841
+ llama->params.antiprompt.clear();
842
+ int stop_len = env->GetArrayLength(stop);
843
+ for (int i = 0; i < stop_len; i++) {
844
+ jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
845
+ const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
846
+ llama->params.antiprompt.push_back(stop_chars);
847
+ env->ReleaseStringUTFChars(stop_str, stop_chars);
848
+ }
849
+
850
+ if (!llama->initSampling()) {
851
+ auto result = createWriteableMap(env);
852
+ putString(env, result, "error", "Failed to initialize sampling");
853
+ return reinterpret_cast<jobject>(result);
854
+ }
855
+ llama->beginCompletion();
856
+ llama->loadPrompt();
857
+
858
+ size_t sent_count = 0;
859
+ size_t sent_token_probs_index = 0;
860
+
861
+ while (llama->has_next_token && !llama->is_interrupted) {
862
+ const rnllama::completion_token_output token_with_probs = llama->doCompletion();
863
+ if (token_with_probs.tok == -1 || llama->incomplete) {
864
+ continue;
865
+ }
866
+ const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
867
+
868
+ size_t pos = std::min(sent_count, llama->generated_text.size());
869
+
870
+ const std::string str_test = llama->generated_text.substr(pos);
871
+ bool is_stop_full = false;
872
+ size_t stop_pos =
873
+ llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
874
+ if (stop_pos != std::string::npos) {
875
+ is_stop_full = true;
876
+ llama->generated_text.erase(
877
+ llama->generated_text.begin() + pos + stop_pos,
878
+ llama->generated_text.end());
879
+ pos = std::min(sent_count, llama->generated_text.size());
880
+ } else {
881
+ is_stop_full = false;
882
+ stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
883
+ rnllama::STOP_PARTIAL);
884
+ }
885
+
886
+ if (
887
+ stop_pos == std::string::npos ||
888
+ // Send rest of the text if we are at the end of the generation
889
+ (!llama->has_next_token && !is_stop_full && stop_pos > 0)
890
+ ) {
891
+ const std::string to_send = llama->generated_text.substr(pos, std::string::npos);
892
+
893
+ sent_count += to_send.size();
894
+
895
+ std::vector<rnllama::completion_token_output> probs_output = {};
896
+
897
+ auto tokenResult = createWriteableMap(env);
898
+ putString(env, tokenResult, "token", to_send.c_str());
899
+
900
+ if (llama->params.sampling.n_probs > 0) {
901
+ const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
902
+ size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
903
+ size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
904
+ if (probs_pos < probs_stop_pos) {
905
+ probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
906
+ }
907
+ sent_token_probs_index = probs_stop_pos;
908
+
909
+ putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
910
+ }
911
+
912
+ jclass cb_class = env->GetObjectClass(partial_completion_callback);
913
+ jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
914
+ env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
915
+ }
916
+ }
917
+
918
+ env->ReleaseStringUTFChars(grammar, grammar_chars);
919
+ env->ReleaseStringUTFChars(prompt, prompt_chars);
920
+ llama_perf_context_print(llama->ctx);
921
+ llama->is_predicting = false;
922
+
923
+ auto toolCalls = createWritableArray(env);
924
+ std::string reasoningContent = "";
925
+ std::string content;
926
+ auto toolCallsSize = 0;
927
+ if (!llama->is_interrupted) {
928
+ try {
929
+ common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
930
+ if (!message.reasoning_content.empty()) {
931
+ reasoningContent = message.reasoning_content;
932
+ }
933
+ content = message.content;
934
+ for (const auto &tc : message.tool_calls) {
935
+ auto toolCall = createWriteableMap(env);
936
+ putString(env, toolCall, "type", "function");
937
+ auto functionMap = createWriteableMap(env);
938
+ putString(env, functionMap, "name", tc.name.c_str());
939
+ putString(env, functionMap, "arguments", tc.arguments.c_str());
940
+ putMap(env, toolCall, "function", functionMap);
941
+ if (!tc.id.empty()) {
942
+ putString(env, toolCall, "id", tc.id.c_str());
943
+ }
944
+ pushMap(env, toolCalls, toolCall);
945
+ toolCallsSize++;
946
+ }
947
+ } catch (const std::exception &e) {
948
+ // LOGI("Error parsing tool calls: %s", e.what());
949
+ }
950
+ }
951
+
952
+ auto result = createWriteableMap(env);
953
+ putString(env, result, "text", llama->generated_text.c_str());
954
+ if (!content.empty()) {
955
+ putString(env, result, "content", content.c_str());
956
+ }
957
+ if (!reasoningContent.empty()) {
958
+ putString(env, result, "reasoning_content", reasoningContent.c_str());
959
+ }
960
+ if (toolCallsSize > 0) {
961
+ putArray(env, result, "tool_calls", toolCalls);
962
+ }
963
+ putArray(env, result, "completion_probabilities", tokenProbsToMap(env, llama, llama->generated_token_probs));
964
+ putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
965
+ putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
966
+ putInt(env, result, "truncated", llama->truncated);
967
+ putInt(env, result, "stopped_eos", llama->stopped_eos);
968
+ putInt(env, result, "stopped_word", llama->stopped_word);
969
+ putInt(env, result, "stopped_limit", llama->stopped_limit);
970
+ putString(env, result, "stopping_word", llama->stopping_word.c_str());
971
+ putInt(env, result, "tokens_cached", llama->n_past);
972
+
973
+ const auto timings_token = llama_perf_context(llama -> ctx);
974
+
975
+ auto timingsResult = createWriteableMap(env);
976
+ putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
977
+ putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms);
978
+ putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval);
979
+ putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval);
980
+ putInt(env, timingsResult, "predicted_n", timings_token.n_eval);
981
+ putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms);
982
+ putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval);
983
+ putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval);
984
+
985
+ putMap(env, result, "timings", timingsResult);
986
+
987
+ return reinterpret_cast<jobject>(result);
988
+ }
989
+
990
+ JNIEXPORT void JNICALL
991
+ Java_com_rnllama_LlamaContext_stopCompletion(
992
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
993
+ UNUSED(env);
994
+ UNUSED(thiz);
995
+ auto llama = context_map[(long) context_ptr];
996
+ llama->is_interrupted = true;
997
+ }
998
+
999
+ JNIEXPORT jboolean JNICALL
1000
+ Java_com_rnllama_LlamaContext_isPredicting(
1001
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
1002
+ UNUSED(env);
1003
+ UNUSED(thiz);
1004
+ auto llama = context_map[(long) context_ptr];
1005
+ return llama->is_predicting;
1006
+ }
1007
+
1008
+ JNIEXPORT jobject JNICALL
1009
+ Java_com_rnllama_LlamaContext_tokenize(
1010
+ JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
1011
+ UNUSED(thiz);
1012
+ auto llama = context_map[(long) context_ptr];
1013
+
1014
+ const char *text_chars = env->GetStringUTFChars(text, nullptr);
1015
+
1016
+ const std::vector<llama_token> toks = common_tokenize(
1017
+ llama->ctx,
1018
+ text_chars,
1019
+ false
1020
+ );
1021
+
1022
+ jobject result = createWritableArray(env);
1023
+ for (const auto &tok : toks) {
1024
+ pushInt(env, result, tok);
1025
+ }
1026
+
1027
+ env->ReleaseStringUTFChars(text, text_chars);
1028
+ return result;
1029
+ }
1030
+
1031
+ JNIEXPORT jstring JNICALL
1032
+ Java_com_rnllama_LlamaContext_detokenize(
1033
+ JNIEnv *env, jobject thiz, jlong context_ptr, jintArray tokens) {
1034
+ UNUSED(thiz);
1035
+ auto llama = context_map[(long) context_ptr];
1036
+
1037
+ jsize tokens_len = env->GetArrayLength(tokens);
1038
+ jint *tokens_ptr = env->GetIntArrayElements(tokens, 0);
1039
+ std::vector<llama_token> toks;
1040
+ for (int i = 0; i < tokens_len; i++) {
1041
+ toks.push_back(tokens_ptr[i]);
1042
+ }
1043
+
1044
+ auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
1045
+
1046
+ env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
1047
+
1048
+ return env->NewStringUTF(text.c_str());
1049
+ }
1050
+
1051
+ JNIEXPORT jboolean JNICALL
1052
+ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
1053
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
1054
+ UNUSED(env);
1055
+ UNUSED(thiz);
1056
+ auto llama = context_map[(long) context_ptr];
1057
+ return llama->params.embedding;
1058
+ }
1059
+
1060
+ JNIEXPORT jobject JNICALL
1061
+ Java_com_rnllama_LlamaContext_embedding(
1062
+ JNIEnv *env, jobject thiz,
1063
+ jlong context_ptr,
1064
+ jstring text,
1065
+ jint embd_normalize
1066
+ ) {
1067
+ UNUSED(thiz);
1068
+ auto llama = context_map[(long) context_ptr];
1069
+
1070
+ common_params embdParams;
1071
+ embdParams.embedding = true;
1072
+ embdParams.embd_normalize = llama->params.embd_normalize;
1073
+ if (embd_normalize != -1) {
1074
+ embdParams.embd_normalize = embd_normalize;
1075
+ }
1076
+
1077
+ const char *text_chars = env->GetStringUTFChars(text, nullptr);
1078
+
1079
+ llama->rewind();
1080
+
1081
+ llama_perf_context_reset(llama->ctx);
1082
+
1083
+ llama->params.prompt = text_chars;
1084
+
1085
+ llama->params.n_predict = 0;
1086
+
1087
+ auto result = createWriteableMap(env);
1088
+ if (!llama->initSampling()) {
1089
+ putString(env, result, "error", "Failed to initialize sampling");
1090
+ return reinterpret_cast<jobject>(result);
1091
+ }
1092
+
1093
+ llama->beginCompletion();
1094
+ llama->loadPrompt();
1095
+ llama->doCompletion();
1096
+
1097
+ std::vector<float> embedding = llama->getEmbedding(embdParams);
1098
+
1099
+ auto embeddings = createWritableArray(env);
1100
+ for (const auto &val : embedding) {
1101
+ pushDouble(env, embeddings, (double) val);
1102
+ }
1103
+ putArray(env, result, "embedding", embeddings);
1104
+
1105
+ auto promptTokens = createWritableArray(env);
1106
+ for (const auto &tok : llama->embd) {
1107
+ pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
1108
+ }
1109
+ putArray(env, result, "prompt_tokens", promptTokens);
1110
+
1111
+ env->ReleaseStringUTFChars(text, text_chars);
1112
+ return result;
1113
+ }
1114
+
1115
+ JNIEXPORT jstring JNICALL
1116
+ Java_com_rnllama_LlamaContext_bench(
1117
+ JNIEnv *env,
1118
+ jobject thiz,
1119
+ jlong context_ptr,
1120
+ jint pp,
1121
+ jint tg,
1122
+ jint pl,
1123
+ jint nr
1124
+ ) {
1125
+ UNUSED(thiz);
1126
+ auto llama = context_map[(long) context_ptr];
1127
+ std::string result = llama->bench(pp, tg, pl, nr);
1128
+ return env->NewStringUTF(result.c_str());
1129
+ }
1130
+
1131
+ JNIEXPORT jint JNICALL
1132
+ Java_com_rnllama_LlamaContext_applyLoraAdapters(
1133
+ JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
1134
+ UNUSED(thiz);
1135
+ auto llama = context_map[(long) context_ptr];
1136
+
1137
+ // lora_adapters: ReadableArray<ReadableMap>
1138
+ std::vector<common_adapter_lora_info> lora_adapters;
1139
+ int lora_adapters_size = readablearray::size(env, loraAdapters);
1140
+ for (int i = 0; i < lora_adapters_size; i++) {
1141
+ jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
1142
+ jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
1143
+ if (path != nullptr) {
1144
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
1145
+ env->ReleaseStringUTFChars(path, path_chars);
1146
+ float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
1147
+ common_adapter_lora_info la;
1148
+ la.path = path_chars;
1149
+ la.scale = scaled;
1150
+ lora_adapters.push_back(la);
1151
+ }
1152
+ }
1153
+ return llama->applyLoraAdapters(lora_adapters);
1154
+ }
1155
+
1156
+ JNIEXPORT void JNICALL
1157
+ Java_com_rnllama_LlamaContext_removeLoraAdapters(
1158
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
1159
+ UNUSED(env);
1160
+ UNUSED(thiz);
1161
+ auto llama = context_map[(long) context_ptr];
1162
+ llama->removeLoraAdapters();
1163
+ }
1164
+
1165
+ JNIEXPORT jobject JNICALL
1166
+ Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
1167
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
1168
+ UNUSED(thiz);
1169
+ auto llama = context_map[(long) context_ptr];
1170
+ auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
1171
+ auto result = createWritableArray(env);
1172
+ for (common_adapter_lora_info &la : loaded_lora_adapters) {
1173
+ auto map = createWriteableMap(env);
1174
+ putString(env, map, "path", la.path.c_str());
1175
+ putDouble(env, map, "scaled", la.scale);
1176
+ pushMap(env, result, map);
1177
+ }
1178
+ return result;
1179
+ }
1180
+
1181
+ JNIEXPORT void JNICALL
1182
+ Java_com_rnllama_LlamaContext_freeContext(
1183
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
1184
+ UNUSED(env);
1185
+ UNUSED(thiz);
1186
+ auto llama = context_map[(long) context_ptr];
1187
+ context_map.erase((long) llama->ctx);
1188
+ delete llama;
1189
+ }
1190
+
1191
+ struct log_callback_context {
1192
+ JavaVM *jvm;
1193
+ jobject callback;
1194
+ };
1195
+
1196
+ static void rnllama_log_callback_to_j(lm_ggml_log_level level, const char * text, void * data) {
1197
+ auto level_c = "";
1198
+ if (level == LM_GGML_LOG_LEVEL_ERROR) {
1199
+ __android_log_print(ANDROID_LOG_ERROR, TAG, text, nullptr);
1200
+ level_c = "error";
1201
+ } else if (level == LM_GGML_LOG_LEVEL_INFO) {
1202
+ __android_log_print(ANDROID_LOG_INFO, TAG, text, nullptr);
1203
+ level_c = "info";
1204
+ } else if (level == LM_GGML_LOG_LEVEL_WARN) {
1205
+ __android_log_print(ANDROID_LOG_WARN, TAG, text, nullptr);
1206
+ level_c = "warn";
1207
+ } else {
1208
+ __android_log_print(ANDROID_LOG_DEFAULT, TAG, text, nullptr);
1209
+ }
1210
+
1211
+ log_callback_context *cb_ctx = (log_callback_context *) data;
1212
+
1213
+ JNIEnv *env;
1214
+ bool need_detach = false;
1215
+ int getEnvResult = cb_ctx->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
1216
+
1217
+ if (getEnvResult == JNI_EDETACHED) {
1218
+ if (cb_ctx->jvm->AttachCurrentThread(&env, nullptr) == JNI_OK) {
1219
+ need_detach = true;
1220
+ } else {
1221
+ return;
1222
+ }
1223
+ } else if (getEnvResult != JNI_OK) {
1224
+ return;
1225
+ }
1226
+
1227
+ jobject callback = cb_ctx->callback;
1228
+ jclass cb_class = env->GetObjectClass(callback);
1229
+ jmethodID emitNativeLog = env->GetMethodID(cb_class, "emitNativeLog", "(Ljava/lang/String;Ljava/lang/String;)V");
1230
+
1231
+ jstring level_str = env->NewStringUTF(level_c);
1232
+ jstring text_str = env->NewStringUTF(text);
1233
+ env->CallVoidMethod(callback, emitNativeLog, level_str, text_str);
1234
+ env->DeleteLocalRef(level_str);
1235
+ env->DeleteLocalRef(text_str);
1236
+
1237
+ if (need_detach) {
1238
+ cb_ctx->jvm->DetachCurrentThread();
1239
+ }
1240
+ }
1241
+
1242
+ JNIEXPORT void JNICALL
1243
+ Java_com_rnllama_LlamaContext_setupLog(JNIEnv *env, jobject thiz, jobject logCallback) {
1244
+ UNUSED(thiz);
1245
+
1246
+ log_callback_context *cb_ctx = new log_callback_context;
1247
+
1248
+ JavaVM *jvm;
1249
+ env->GetJavaVM(&jvm);
1250
+ cb_ctx->jvm = jvm;
1251
+ cb_ctx->callback = env->NewGlobalRef(logCallback);
1252
+
1253
+ llama_log_set(rnllama_log_callback_to_j, cb_ctx);
1254
+ }
1255
+
1256
+ JNIEXPORT void JNICALL
1257
+ Java_com_rnllama_LlamaContext_unsetLog(JNIEnv *env, jobject thiz) {
1258
+ UNUSED(env);
1259
+ UNUSED(thiz);
1260
+ llama_log_set(rnllama_log_callback_default, NULL);
1261
+ }
1262
+
1263
+ } // extern "C"