cui-llama.rn 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/LICENSE +20 -0
  2. package/README.md +330 -0
  3. package/android/build.gradle +107 -0
  4. package/android/gradle.properties +5 -0
  5. package/android/src/main/AndroidManifest.xml +4 -0
  6. package/android/src/main/CMakeLists.txt +69 -0
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
  8. package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
  10. package/android/src/main/jni.cpp +635 -0
  11. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
  12. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
  13. package/cpp/README.md +4 -0
  14. package/cpp/common.cpp +3237 -0
  15. package/cpp/common.h +467 -0
  16. package/cpp/ggml-aarch64.c +2193 -0
  17. package/cpp/ggml-aarch64.h +39 -0
  18. package/cpp/ggml-alloc.c +1041 -0
  19. package/cpp/ggml-alloc.h +76 -0
  20. package/cpp/ggml-backend-impl.h +153 -0
  21. package/cpp/ggml-backend.c +2225 -0
  22. package/cpp/ggml-backend.h +236 -0
  23. package/cpp/ggml-common.h +1829 -0
  24. package/cpp/ggml-impl.h +655 -0
  25. package/cpp/ggml-metal.h +65 -0
  26. package/cpp/ggml-metal.m +3273 -0
  27. package/cpp/ggml-quants.c +15022 -0
  28. package/cpp/ggml-quants.h +132 -0
  29. package/cpp/ggml.c +22034 -0
  30. package/cpp/ggml.h +2444 -0
  31. package/cpp/grammar-parser.cpp +536 -0
  32. package/cpp/grammar-parser.h +29 -0
  33. package/cpp/json-schema-to-grammar.cpp +1045 -0
  34. package/cpp/json-schema-to-grammar.h +8 -0
  35. package/cpp/json.hpp +24766 -0
  36. package/cpp/llama.cpp +21789 -0
  37. package/cpp/llama.h +1201 -0
  38. package/cpp/log.h +737 -0
  39. package/cpp/rn-llama.hpp +630 -0
  40. package/cpp/sampling.cpp +460 -0
  41. package/cpp/sampling.h +160 -0
  42. package/cpp/sgemm.cpp +1027 -0
  43. package/cpp/sgemm.h +14 -0
  44. package/cpp/unicode-data.cpp +7032 -0
  45. package/cpp/unicode-data.h +20 -0
  46. package/cpp/unicode.cpp +812 -0
  47. package/cpp/unicode.h +64 -0
  48. package/ios/RNLlama.h +11 -0
  49. package/ios/RNLlama.mm +302 -0
  50. package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
  51. package/ios/RNLlamaContext.h +39 -0
  52. package/ios/RNLlamaContext.mm +426 -0
  53. package/jest/mock.js +169 -0
  54. package/lib/commonjs/NativeRNLlama.js +10 -0
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -0
  56. package/lib/commonjs/grammar.js +574 -0
  57. package/lib/commonjs/grammar.js.map +1 -0
  58. package/lib/commonjs/index.js +151 -0
  59. package/lib/commonjs/index.js.map +1 -0
  60. package/lib/module/NativeRNLlama.js +3 -0
  61. package/lib/module/NativeRNLlama.js.map +1 -0
  62. package/lib/module/grammar.js +566 -0
  63. package/lib/module/grammar.js.map +1 -0
  64. package/lib/module/index.js +129 -0
  65. package/lib/module/index.js.map +1 -0
  66. package/lib/typescript/NativeRNLlama.d.ts +107 -0
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
  68. package/lib/typescript/grammar.d.ts +38 -0
  69. package/lib/typescript/grammar.d.ts.map +1 -0
  70. package/lib/typescript/index.d.ts +46 -0
  71. package/lib/typescript/index.d.ts.map +1 -0
  72. package/llama-rn.podspec +56 -0
  73. package/package.json +230 -0
  74. package/src/NativeRNLlama.ts +132 -0
  75. package/src/grammar.ts +849 -0
  76. package/src/index.ts +182 -0
@@ -0,0 +1,635 @@
1
+ #include <jni.h>
2
+ // #include <android/asset_manager.h>
3
+ // #include <android/asset_manager_jni.h>
4
+ #include <android/log.h>
5
+ #include <cstdlib>
6
+ #include <sys/sysinfo.h>
7
+ #include <string>
8
+ #include <thread>
9
+ #include <unordered_map>
10
+ #include "llama.h"
11
+ #include "rn-llama.hpp"
12
+ #include "ggml.h"
13
+
14
+ #define UNUSED(x) (void)(x)
15
+ #define TAG "RNLLAMA_ANDROID_JNI"
16
+
17
+ #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
18
+ #define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
19
+
20
+ static inline int min(int a, int b) {
21
+ return (a < b) ? a : b;
22
+ }
23
+
24
+ extern "C" {
25
+
26
+ // Method to create WritableMap
27
+ static inline jobject createWriteableMap(JNIEnv *env) {
28
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
29
+ jmethodID init = env->GetStaticMethodID(mapClass, "createMap", "()Lcom/facebook/react/bridge/WritableMap;");
30
+ jobject map = env->CallStaticObjectMethod(mapClass, init);
31
+ return map;
32
+ }
33
+
34
+ // Method to put string into WritableMap
35
+ static inline void putString(JNIEnv *env, jobject map, const char *key, const char *value) {
36
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
37
+ jmethodID putStringMethod = env->GetMethodID(mapClass, "putString", "(Ljava/lang/String;Ljava/lang/String;)V");
38
+
39
+ jstring jKey = env->NewStringUTF(key);
40
+ jstring jValue = env->NewStringUTF(value);
41
+
42
+ env->CallVoidMethod(map, putStringMethod, jKey, jValue);
43
+ }
44
+
45
+ // Method to put int into WritableMap
46
+ static inline void putInt(JNIEnv *env, jobject map, const char *key, int value) {
47
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
48
+ jmethodID putIntMethod = env->GetMethodID(mapClass, "putInt", "(Ljava/lang/String;I)V");
49
+
50
+ jstring jKey = env->NewStringUTF(key);
51
+
52
+ env->CallVoidMethod(map, putIntMethod, jKey, value);
53
+ }
54
+
55
+ // Method to put double into WritableMap
56
+ static inline void putDouble(JNIEnv *env, jobject map, const char *key, double value) {
57
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
58
+ jmethodID putDoubleMethod = env->GetMethodID(mapClass, "putDouble", "(Ljava/lang/String;D)V");
59
+
60
+ jstring jKey = env->NewStringUTF(key);
61
+
62
+ env->CallVoidMethod(map, putDoubleMethod, jKey, value);
63
+ }
64
+
65
+ // Method to put WriteableMap into WritableMap
66
+ static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
67
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
68
+ jmethodID putMapMethod = env->GetMethodID(mapClass, "putMap", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableMap;)V");
69
+
70
+ jstring jKey = env->NewStringUTF(key);
71
+
72
+ env->CallVoidMethod(map, putMapMethod, jKey, value);
73
+ }
74
+
75
+ // Method to create WritableArray
76
+ static inline jobject createWritableArray(JNIEnv *env) {
77
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
78
+ jmethodID init = env->GetStaticMethodID(mapClass, "createArray", "()Lcom/facebook/react/bridge/WritableArray;");
79
+ jobject map = env->CallStaticObjectMethod(mapClass, init);
80
+ return map;
81
+ }
82
+
83
+ // Method to push int into WritableArray
84
+ static inline void pushInt(JNIEnv *env, jobject arr, int value) {
85
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
86
+ jmethodID pushIntMethod = env->GetMethodID(mapClass, "pushInt", "(I)V");
87
+
88
+ env->CallVoidMethod(arr, pushIntMethod, value);
89
+ }
90
+
91
+ // Method to push double into WritableArray
92
+ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
93
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
94
+ jmethodID pushDoubleMethod = env->GetMethodID(mapClass, "pushDouble", "(D)V");
95
+
96
+ env->CallVoidMethod(arr, pushDoubleMethod, value);
97
+ }
98
+
99
+ // Method to push WritableMap into WritableArray
100
+ static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
101
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
102
+ jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V");
103
+
104
+ env->CallVoidMethod(arr, pushMapMethod, value);
105
+ }
106
+
107
+ // Method to put WritableArray into WritableMap
108
+ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject value) {
109
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
110
+ jmethodID putArrayMethod = env->GetMethodID(mapClass, "putArray", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableArray;)V");
111
+
112
+ jstring jKey = env->NewStringUTF(key);
113
+
114
+ env->CallVoidMethod(map, putArrayMethod, jKey, value);
115
+ }
116
+
117
+
118
+ std::unordered_map<long, rnllama::llama_rn_context *> context_map;
119
+
120
+ JNIEXPORT jlong JNICALL
121
+ Java_com_rnllama_LlamaContext_initContext(
122
+ JNIEnv *env,
123
+ jobject thiz,
124
+ jstring model_path_str,
125
+ jboolean embedding,
126
+ jint n_ctx,
127
+ jint n_batch,
128
+ jint n_threads,
129
+ jint n_gpu_layers, // TODO: Support this
130
+ jboolean use_mlock,
131
+ jboolean use_mmap,
132
+ jstring lora_str,
133
+ jfloat lora_scaled,
134
+ jstring lora_base_str,
135
+ jfloat rope_freq_base,
136
+ jfloat rope_freq_scale
137
+ ) {
138
+ UNUSED(thiz);
139
+
140
+ gpt_params defaultParams;
141
+
142
+ const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
143
+ defaultParams.model = model_path_chars;
144
+
145
+ defaultParams.embedding = embedding;
146
+
147
+ defaultParams.n_ctx = n_ctx;
148
+ defaultParams.n_batch = n_batch;
149
+
150
+ int max_threads = std::thread::hardware_concurrency();
151
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
152
+ int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
153
+ defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
154
+
155
+ defaultParams.n_gpu_layers = n_gpu_layers;
156
+
157
+ defaultParams.use_mlock = use_mlock;
158
+ defaultParams.use_mmap = use_mmap;
159
+
160
+ const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
161
+ const char *lora_base_chars = env->GetStringUTFChars(lora_base_str, nullptr);
162
+ if (lora_chars != nullptr && lora_chars[0] != '\0') {
163
+ defaultParams.lora_adapter.push_back({lora_chars, lora_scaled});
164
+ defaultParams.lora_base = lora_base_chars;
165
+ defaultParams.use_mmap = false;
166
+ }
167
+
168
+ defaultParams.rope_freq_base = rope_freq_base;
169
+ defaultParams.rope_freq_scale = rope_freq_scale;
170
+
171
+ auto llama = new rnllama::llama_rn_context();
172
+ bool is_model_loaded = llama->loadModel(defaultParams);
173
+
174
+ LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
175
+ if (is_model_loaded) {
176
+ context_map[(long) llama->ctx] = llama;
177
+ } else {
178
+ llama_free(llama->ctx);
179
+ }
180
+
181
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
182
+ env->ReleaseStringUTFChars(lora_str, lora_chars);
183
+ env->ReleaseStringUTFChars(lora_base_str, lora_base_chars);
184
+
185
+ return reinterpret_cast<jlong>(llama->ctx);
186
+ }
187
+
188
+ JNIEXPORT jobject JNICALL
189
+ Java_com_rnllama_LlamaContext_loadModelDetails(
190
+ JNIEnv *env,
191
+ jobject thiz,
192
+ jlong context_ptr
193
+ ) {
194
+ UNUSED(thiz);
195
+ auto llama = context_map[(long) context_ptr];
196
+
197
+ int count = llama_model_meta_count(llama->model);
198
+ auto meta = createWriteableMap(env);
199
+ for (int i = 0; i < count; i++) {
200
+ char key[256];
201
+ llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
202
+ char val[256];
203
+ llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
204
+
205
+ putString(env, meta, key, val);
206
+ }
207
+
208
+ auto result = createWriteableMap(env);
209
+
210
+ char desc[1024];
211
+ llama_model_desc(llama->model, desc, sizeof(desc));
212
+ putString(env, result, "desc", desc);
213
+ putDouble(env, result, "size", llama_model_size(llama->model));
214
+ putDouble(env, result, "nParams", llama_model_n_params(llama->model));
215
+ putMap(env, result, "metadata", meta);
216
+
217
+ return reinterpret_cast<jobject>(result);
218
+ }
219
+
220
+ JNIEXPORT jobject JNICALL
221
+ Java_com_rnllama_LlamaContext_loadSession(
222
+ JNIEnv *env,
223
+ jobject thiz,
224
+ jlong context_ptr,
225
+ jstring path
226
+ ) {
227
+ UNUSED(thiz);
228
+ auto llama = context_map[(long) context_ptr];
229
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
230
+
231
+ auto result = createWriteableMap(env);
232
+ size_t n_token_count_out = 0;
233
+ llama->embd.resize(llama->params.n_ctx);
234
+ if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
235
+ env->ReleaseStringUTFChars(path, path_chars);
236
+
237
+ putString(env, result, "error", "Failed to load session");
238
+ return reinterpret_cast<jobject>(result);
239
+ }
240
+ llama->embd.resize(n_token_count_out);
241
+ env->ReleaseStringUTFChars(path, path_chars);
242
+
243
+ const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
244
+ putInt(env, result, "tokens_loaded", n_token_count_out);
245
+ putString(env, result, "prompt", text.c_str());
246
+ return reinterpret_cast<jobject>(result);
247
+ }
248
+
249
+ JNIEXPORT jint JNICALL
250
+ Java_com_rnllama_LlamaContext_saveSession(
251
+ JNIEnv *env,
252
+ jobject thiz,
253
+ jlong context_ptr,
254
+ jstring path,
255
+ jint size
256
+ ) {
257
+ UNUSED(thiz);
258
+ auto llama = context_map[(long) context_ptr];
259
+
260
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
261
+
262
+ std::vector<llama_token> session_tokens = llama->embd;
263
+ int default_size = session_tokens.size();
264
+ int save_size = size > 0 && size <= default_size ? size : default_size;
265
+ if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
266
+ env->ReleaseStringUTFChars(path, path_chars);
267
+ return -1;
268
+ }
269
+
270
+ env->ReleaseStringUTFChars(path, path_chars);
271
+ return session_tokens.size();
272
+ }
273
+
274
+ static inline jobject tokenProbsToMap(
275
+ JNIEnv *env,
276
+ rnllama::llama_rn_context *llama,
277
+ std::vector<rnllama::completion_token_output> probs
278
+ ) {
279
+ auto result = createWritableArray(env);
280
+ for (const auto &prob : probs) {
281
+ auto probsForToken = createWritableArray(env);
282
+ for (const auto &p : prob.probs) {
283
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
284
+ auto probResult = createWriteableMap(env);
285
+ putString(env, probResult, "tok_str", tokStr.c_str());
286
+ putDouble(env, probResult, "prob", p.prob);
287
+ pushMap(env, probsForToken, probResult);
288
+ }
289
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
290
+ auto tokenResult = createWriteableMap(env);
291
+ putString(env, tokenResult, "content", tokStr.c_str());
292
+ putArray(env, tokenResult, "probs", probsForToken);
293
+ pushMap(env, result, tokenResult);
294
+ }
295
+ return result;
296
+ }
297
+
298
+ JNIEXPORT jobject JNICALL
299
+ Java_com_rnllama_LlamaContext_doCompletion(
300
+ JNIEnv *env,
301
+ jobject thiz,
302
+ jlong context_ptr,
303
+ jstring prompt,
304
+ jstring grammar,
305
+ jfloat temperature,
306
+ jint n_threads,
307
+ jint n_predict,
308
+ jint n_probs,
309
+ jint penalty_last_n,
310
+ jfloat penalty_repeat,
311
+ jfloat penalty_freq,
312
+ jfloat penalty_present,
313
+ jfloat mirostat,
314
+ jfloat mirostat_tau,
315
+ jfloat mirostat_eta,
316
+ jboolean penalize_nl,
317
+ jint top_k,
318
+ jfloat top_p,
319
+ jfloat min_p,
320
+ jfloat tfs_z,
321
+ jfloat typical_p,
322
+ jint seed,
323
+ jobjectArray stop,
324
+ jboolean ignore_eos,
325
+ jobjectArray logit_bias,
326
+ jobject partial_completion_callback
327
+ ) {
328
+ UNUSED(thiz);
329
+ auto llama = context_map[(long) context_ptr];
330
+
331
+ llama->rewind();
332
+
333
+ llama_reset_timings(llama->ctx);
334
+
335
+ llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
336
+ llama->params.seed = seed;
337
+
338
+ int max_threads = std::thread::hardware_concurrency();
339
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
340
+ int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
341
+ llama->params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
342
+
343
+ llama->params.n_predict = n_predict;
344
+ llama->params.ignore_eos = ignore_eos;
345
+
346
+ auto & sparams = llama->params.sparams;
347
+ sparams.temp = temperature;
348
+ sparams.penalty_last_n = penalty_last_n;
349
+ sparams.penalty_repeat = penalty_repeat;
350
+ sparams.penalty_freq = penalty_freq;
351
+ sparams.penalty_present = penalty_present;
352
+ sparams.mirostat = mirostat;
353
+ sparams.mirostat_tau = mirostat_tau;
354
+ sparams.mirostat_eta = mirostat_eta;
355
+ sparams.penalize_nl = penalize_nl;
356
+ sparams.top_k = top_k;
357
+ sparams.top_p = top_p;
358
+ sparams.min_p = min_p;
359
+ sparams.tfs_z = tfs_z;
360
+ sparams.typical_p = typical_p;
361
+ sparams.n_probs = n_probs;
362
+ sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
363
+
364
+ sparams.logit_bias.clear();
365
+ if (ignore_eos) {
366
+ sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
367
+ }
368
+
369
+ const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
370
+ jsize logit_bias_len = env->GetArrayLength(logit_bias);
371
+
372
+ for (jsize i = 0; i < logit_bias_len; i++) {
373
+ jdoubleArray el = (jdoubleArray) env->GetObjectArrayElement(logit_bias, i);
374
+ if (el && env->GetArrayLength(el) == 2) {
375
+ jdouble* doubleArray = env->GetDoubleArrayElements(el, 0);
376
+
377
+ llama_token tok = static_cast<llama_token>(doubleArray[0]);
378
+ if (tok >= 0 && tok < n_vocab) {
379
+ if (doubleArray[1] != 0) { // If the second element is not false (0)
380
+ sparams.logit_bias[tok] = doubleArray[1];
381
+ } else {
382
+ sparams.logit_bias[tok] = -INFINITY;
383
+ }
384
+ }
385
+
386
+ env->ReleaseDoubleArrayElements(el, doubleArray, 0);
387
+ }
388
+ env->DeleteLocalRef(el);
389
+ }
390
+
391
+ llama->params.antiprompt.clear();
392
+ int stop_len = env->GetArrayLength(stop);
393
+ for (int i = 0; i < stop_len; i++) {
394
+ jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
395
+ const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
396
+ llama->params.antiprompt.push_back(stop_chars);
397
+ env->ReleaseStringUTFChars(stop_str, stop_chars);
398
+ }
399
+
400
+ if (!llama->initSampling()) {
401
+ auto result = createWriteableMap(env);
402
+ putString(env, result, "error", "Failed to initialize sampling");
403
+ return reinterpret_cast<jobject>(result);
404
+ }
405
+ llama->beginCompletion();
406
+ llama->loadPrompt();
407
+
408
+ size_t sent_count = 0;
409
+ size_t sent_token_probs_index = 0;
410
+
411
+ while (llama->has_next_token && !llama->is_interrupted) {
412
+ const rnllama::completion_token_output token_with_probs = llama->doCompletion();
413
+ if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) {
414
+ continue;
415
+ }
416
+ const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
417
+
418
+ size_t pos = std::min(sent_count, llama->generated_text.size());
419
+
420
+ const std::string str_test = llama->generated_text.substr(pos);
421
+ bool is_stop_full = false;
422
+ size_t stop_pos =
423
+ llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
424
+ if (stop_pos != std::string::npos) {
425
+ is_stop_full = true;
426
+ llama->generated_text.erase(
427
+ llama->generated_text.begin() + pos + stop_pos,
428
+ llama->generated_text.end());
429
+ pos = std::min(sent_count, llama->generated_text.size());
430
+ } else {
431
+ is_stop_full = false;
432
+ stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
433
+ rnllama::STOP_PARTIAL);
434
+ }
435
+
436
+ if (
437
+ stop_pos == std::string::npos ||
438
+ // Send rest of the text if we are at the end of the generation
439
+ (!llama->has_next_token && !is_stop_full && stop_pos > 0)
440
+ ) {
441
+ const std::string to_send = llama->generated_text.substr(pos, std::string::npos);
442
+
443
+ sent_count += to_send.size();
444
+
445
+ std::vector<rnllama::completion_token_output> probs_output = {};
446
+
447
+ auto tokenResult = createWriteableMap(env);
448
+ putString(env, tokenResult, "token", to_send.c_str());
449
+
450
+ if (llama->params.sparams.n_probs > 0) {
451
+ const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
452
+ size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
453
+ size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
454
+ if (probs_pos < probs_stop_pos) {
455
+ probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
456
+ }
457
+ sent_token_probs_index = probs_stop_pos;
458
+
459
+ putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
460
+ }
461
+
462
+ jclass cb_class = env->GetObjectClass(partial_completion_callback);
463
+ jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
464
+ env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
465
+ }
466
+ }
467
+
468
+ llama_print_timings(llama->ctx);
469
+ llama->is_predicting = false;
470
+
471
+ auto result = createWriteableMap(env);
472
+ putString(env, result, "text", llama->generated_text.c_str());
473
+ putArray(env, result, "completion_probabilities", tokenProbsToMap(env, llama, llama->generated_token_probs));
474
+ putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
475
+ putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
476
+ putInt(env, result, "truncated", llama->truncated);
477
+ putInt(env, result, "stopped_eos", llama->stopped_eos);
478
+ putInt(env, result, "stopped_word", llama->stopped_word);
479
+ putInt(env, result, "stopped_limit", llama->stopped_limit);
480
+ putString(env, result, "stopping_word", llama->stopping_word.c_str());
481
+ putInt(env, result, "tokens_cached", llama->n_past);
482
+
483
+ const auto timings = llama_get_timings(llama->ctx);
484
+ auto timingsResult = createWriteableMap(env);
485
+ putInt(env, timingsResult, "prompt_n", timings.n_p_eval);
486
+ putInt(env, timingsResult, "prompt_ms", timings.t_p_eval_ms);
487
+ putInt(env, timingsResult, "prompt_per_token_ms", timings.t_p_eval_ms / timings.n_p_eval);
488
+ putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
489
+ putInt(env, timingsResult, "predicted_n", timings.n_eval);
490
+ putInt(env, timingsResult, "predicted_ms", timings.t_eval_ms);
491
+ putInt(env, timingsResult, "predicted_per_token_ms", timings.t_eval_ms / timings.n_eval);
492
+ putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings.t_eval_ms * timings.n_eval);
493
+
494
+ putMap(env, result, "timings", timingsResult);
495
+
496
+ return reinterpret_cast<jobject>(result);
497
+ }
498
+
499
+ JNIEXPORT void JNICALL
500
+ Java_com_rnllama_LlamaContext_stopCompletion(
501
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
502
+ UNUSED(env);
503
+ UNUSED(thiz);
504
+ auto llama = context_map[(long) context_ptr];
505
+ llama->is_interrupted = true;
506
+ }
507
+
508
+ JNIEXPORT jboolean JNICALL
509
+ Java_com_rnllama_LlamaContext_isPredicting(
510
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
511
+ UNUSED(env);
512
+ UNUSED(thiz);
513
+ auto llama = context_map[(long) context_ptr];
514
+ return llama->is_predicting;
515
+ }
516
+
517
+ JNIEXPORT jobject JNICALL
518
+ Java_com_rnllama_LlamaContext_tokenize(
519
+ JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
520
+ UNUSED(thiz);
521
+ auto llama = context_map[(long) context_ptr];
522
+
523
+ const char *text_chars = env->GetStringUTFChars(text, nullptr);
524
+
525
+ const std::vector<llama_token> toks = llama_tokenize(
526
+ llama->ctx,
527
+ text_chars,
528
+ false
529
+ );
530
+
531
+ jobject result = createWritableArray(env);
532
+ for (const auto &tok : toks) {
533
+ pushInt(env, result, tok);
534
+ }
535
+
536
+ env->ReleaseStringUTFChars(text, text_chars);
537
+ return result;
538
+ }
539
+
540
+ JNIEXPORT jstring JNICALL
541
+ Java_com_rnllama_LlamaContext_detokenize(
542
+ JNIEnv *env, jobject thiz, jlong context_ptr, jintArray tokens) {
543
+ UNUSED(thiz);
544
+ auto llama = context_map[(long) context_ptr];
545
+
546
+ jsize tokens_len = env->GetArrayLength(tokens);
547
+ jint *tokens_ptr = env->GetIntArrayElements(tokens, 0);
548
+ std::vector<llama_token> toks;
549
+ for (int i = 0; i < tokens_len; i++) {
550
+ toks.push_back(tokens_ptr[i]);
551
+ }
552
+
553
+ auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
554
+
555
+ env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
556
+
557
+ return env->NewStringUTF(text.c_str());
558
+ }
559
+
560
+ JNIEXPORT jboolean JNICALL
561
+ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
562
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
563
+ UNUSED(env);
564
+ UNUSED(thiz);
565
+ auto llama = context_map[(long) context_ptr];
566
+ return llama->params.embedding;
567
+ }
568
+
569
+ JNIEXPORT jobject JNICALL
570
+ Java_com_rnllama_LlamaContext_embedding(
571
+ JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
572
+ UNUSED(thiz);
573
+ auto llama = context_map[(long) context_ptr];
574
+
575
+ const char *text_chars = env->GetStringUTFChars(text, nullptr);
576
+
577
+ llama->rewind();
578
+
579
+ llama_reset_timings(llama->ctx);
580
+
581
+ llama->params.prompt = text_chars;
582
+
583
+ llama->params.n_predict = 0;
584
+ llama->beginCompletion();
585
+ llama->loadPrompt();
586
+ llama->doCompletion();
587
+
588
+ std::vector<float> embedding = llama->getEmbedding();
589
+
590
+ jobject result = createWritableArray(env);
591
+
592
+ for (const auto &val : embedding) {
593
+ pushDouble(env, result, (double) val);
594
+ }
595
+
596
+ env->ReleaseStringUTFChars(text, text_chars);
597
+ return result;
598
+ }
599
+
600
+ JNIEXPORT jstring JNICALL
601
+ Java_com_rnllama_LlamaContext_bench(
602
+ JNIEnv *env,
603
+ jobject thiz,
604
+ jlong context_ptr,
605
+ jint pp,
606
+ jint tg,
607
+ jint pl,
608
+ jint nr
609
+ ) {
610
+ UNUSED(thiz);
611
+ auto llama = context_map[(long) context_ptr];
612
+ std::string result = llama->bench(pp, tg, pl, nr);
613
+ return env->NewStringUTF(result.c_str());
614
+ }
615
+
616
+ JNIEXPORT void JNICALL
617
+ Java_com_rnllama_LlamaContext_freeContext(
618
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
619
+ UNUSED(env);
620
+ UNUSED(thiz);
621
+ auto llama = context_map[(long) context_ptr];
622
+ if (llama->model) {
623
+ llama_free_model(llama->model);
624
+ }
625
+ if (llama->ctx) {
626
+ llama_free(llama->ctx);
627
+ }
628
+ if (llama->ctx_sampling != nullptr)
629
+ {
630
+ llama_sampling_free(llama->ctx_sampling);
631
+ }
632
+ context_map.erase((long) llama->ctx);
633
+ }
634
+
635
+ } // extern "C"