cui-llama.rn 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +20 -0
- package/README.md +330 -0
- package/android/build.gradle +107 -0
- package/android/gradle.properties +5 -0
- package/android/src/main/AndroidManifest.xml +4 -0
- package/android/src/main/CMakeLists.txt +69 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
- package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
- package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
- package/android/src/main/jni.cpp +635 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
- package/cpp/README.md +4 -0
- package/cpp/common.cpp +3237 -0
- package/cpp/common.h +467 -0
- package/cpp/ggml-aarch64.c +2193 -0
- package/cpp/ggml-aarch64.h +39 -0
- package/cpp/ggml-alloc.c +1041 -0
- package/cpp/ggml-alloc.h +76 -0
- package/cpp/ggml-backend-impl.h +153 -0
- package/cpp/ggml-backend.c +2225 -0
- package/cpp/ggml-backend.h +236 -0
- package/cpp/ggml-common.h +1829 -0
- package/cpp/ggml-impl.h +655 -0
- package/cpp/ggml-metal.h +65 -0
- package/cpp/ggml-metal.m +3273 -0
- package/cpp/ggml-quants.c +15022 -0
- package/cpp/ggml-quants.h +132 -0
- package/cpp/ggml.c +22034 -0
- package/cpp/ggml.h +2444 -0
- package/cpp/grammar-parser.cpp +536 -0
- package/cpp/grammar-parser.h +29 -0
- package/cpp/json-schema-to-grammar.cpp +1045 -0
- package/cpp/json-schema-to-grammar.h +8 -0
- package/cpp/json.hpp +24766 -0
- package/cpp/llama.cpp +21789 -0
- package/cpp/llama.h +1201 -0
- package/cpp/log.h +737 -0
- package/cpp/rn-llama.hpp +630 -0
- package/cpp/sampling.cpp +460 -0
- package/cpp/sampling.h +160 -0
- package/cpp/sgemm.cpp +1027 -0
- package/cpp/sgemm.h +14 -0
- package/cpp/unicode-data.cpp +7032 -0
- package/cpp/unicode-data.h +20 -0
- package/cpp/unicode.cpp +812 -0
- package/cpp/unicode.h +64 -0
- package/ios/RNLlama.h +11 -0
- package/ios/RNLlama.mm +302 -0
- package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
- package/ios/RNLlamaContext.h +39 -0
- package/ios/RNLlamaContext.mm +426 -0
- package/jest/mock.js +169 -0
- package/lib/commonjs/NativeRNLlama.js +10 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -0
- package/lib/commonjs/grammar.js +574 -0
- package/lib/commonjs/grammar.js.map +1 -0
- package/lib/commonjs/index.js +151 -0
- package/lib/commonjs/index.js.map +1 -0
- package/lib/module/NativeRNLlama.js +3 -0
- package/lib/module/NativeRNLlama.js.map +1 -0
- package/lib/module/grammar.js +566 -0
- package/lib/module/grammar.js.map +1 -0
- package/lib/module/index.js +129 -0
- package/lib/module/index.js.map +1 -0
- package/lib/typescript/NativeRNLlama.d.ts +107 -0
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
- package/lib/typescript/grammar.d.ts +38 -0
- package/lib/typescript/grammar.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +46 -0
- package/lib/typescript/index.d.ts.map +1 -0
- package/llama-rn.podspec +56 -0
- package/package.json +230 -0
- package/src/NativeRNLlama.ts +132 -0
- package/src/grammar.ts +849 -0
- package/src/index.ts +182 -0
@@ -0,0 +1,635 @@
|
|
1
|
+
#include <jni.h>
|
2
|
+
// #include <android/asset_manager.h>
|
3
|
+
// #include <android/asset_manager_jni.h>
|
4
|
+
#include <android/log.h>
|
5
|
+
#include <cstdlib>
|
6
|
+
#include <sys/sysinfo.h>
|
7
|
+
#include <string>
|
8
|
+
#include <thread>
|
9
|
+
#include <unordered_map>
|
10
|
+
#include "llama.h"
|
11
|
+
#include "rn-llama.hpp"
|
12
|
+
#include "ggml.h"
|
13
|
+
|
14
|
+
#define UNUSED(x) (void)(x)
|
15
|
+
#define TAG "RNLLAMA_ANDROID_JNI"
|
16
|
+
|
17
|
+
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
18
|
+
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
19
|
+
|
20
|
+
static inline int min(int a, int b) {
|
21
|
+
return (a < b) ? a : b;
|
22
|
+
}
|
23
|
+
|
24
|
+
extern "C" {
|
25
|
+
|
26
|
+
// Method to create WritableMap
|
27
|
+
static inline jobject createWriteableMap(JNIEnv *env) {
|
28
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
|
29
|
+
jmethodID init = env->GetStaticMethodID(mapClass, "createMap", "()Lcom/facebook/react/bridge/WritableMap;");
|
30
|
+
jobject map = env->CallStaticObjectMethod(mapClass, init);
|
31
|
+
return map;
|
32
|
+
}
|
33
|
+
|
34
|
+
// Method to put string into WritableMap
|
35
|
+
static inline void putString(JNIEnv *env, jobject map, const char *key, const char *value) {
|
36
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
37
|
+
jmethodID putStringMethod = env->GetMethodID(mapClass, "putString", "(Ljava/lang/String;Ljava/lang/String;)V");
|
38
|
+
|
39
|
+
jstring jKey = env->NewStringUTF(key);
|
40
|
+
jstring jValue = env->NewStringUTF(value);
|
41
|
+
|
42
|
+
env->CallVoidMethod(map, putStringMethod, jKey, jValue);
|
43
|
+
}
|
44
|
+
|
45
|
+
// Method to put int into WritableMap
|
46
|
+
static inline void putInt(JNIEnv *env, jobject map, const char *key, int value) {
|
47
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
48
|
+
jmethodID putIntMethod = env->GetMethodID(mapClass, "putInt", "(Ljava/lang/String;I)V");
|
49
|
+
|
50
|
+
jstring jKey = env->NewStringUTF(key);
|
51
|
+
|
52
|
+
env->CallVoidMethod(map, putIntMethod, jKey, value);
|
53
|
+
}
|
54
|
+
|
55
|
+
// Method to put double into WritableMap
|
56
|
+
static inline void putDouble(JNIEnv *env, jobject map, const char *key, double value) {
|
57
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
58
|
+
jmethodID putDoubleMethod = env->GetMethodID(mapClass, "putDouble", "(Ljava/lang/String;D)V");
|
59
|
+
|
60
|
+
jstring jKey = env->NewStringUTF(key);
|
61
|
+
|
62
|
+
env->CallVoidMethod(map, putDoubleMethod, jKey, value);
|
63
|
+
}
|
64
|
+
|
65
|
+
// Method to put WriteableMap into WritableMap
|
66
|
+
static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
|
67
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
68
|
+
jmethodID putMapMethod = env->GetMethodID(mapClass, "putMap", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableMap;)V");
|
69
|
+
|
70
|
+
jstring jKey = env->NewStringUTF(key);
|
71
|
+
|
72
|
+
env->CallVoidMethod(map, putMapMethod, jKey, value);
|
73
|
+
}
|
74
|
+
|
75
|
+
// Method to create WritableArray
|
76
|
+
static inline jobject createWritableArray(JNIEnv *env) {
|
77
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/Arguments");
|
78
|
+
jmethodID init = env->GetStaticMethodID(mapClass, "createArray", "()Lcom/facebook/react/bridge/WritableArray;");
|
79
|
+
jobject map = env->CallStaticObjectMethod(mapClass, init);
|
80
|
+
return map;
|
81
|
+
}
|
82
|
+
|
83
|
+
// Method to push int into WritableArray
|
84
|
+
static inline void pushInt(JNIEnv *env, jobject arr, int value) {
|
85
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
86
|
+
jmethodID pushIntMethod = env->GetMethodID(mapClass, "pushInt", "(I)V");
|
87
|
+
|
88
|
+
env->CallVoidMethod(arr, pushIntMethod, value);
|
89
|
+
}
|
90
|
+
|
91
|
+
// Method to push double into WritableArray
|
92
|
+
static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
|
93
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
94
|
+
jmethodID pushDoubleMethod = env->GetMethodID(mapClass, "pushDouble", "(D)V");
|
95
|
+
|
96
|
+
env->CallVoidMethod(arr, pushDoubleMethod, value);
|
97
|
+
}
|
98
|
+
|
99
|
+
// Method to push WritableMap into WritableArray
|
100
|
+
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
|
101
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
102
|
+
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V");
|
103
|
+
|
104
|
+
env->CallVoidMethod(arr, pushMapMethod, value);
|
105
|
+
}
|
106
|
+
|
107
|
+
// Method to put WritableArray into WritableMap
|
108
|
+
static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject value) {
|
109
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
|
110
|
+
jmethodID putArrayMethod = env->GetMethodID(mapClass, "putArray", "(Ljava/lang/String;Lcom/facebook/react/bridge/ReadableArray;)V");
|
111
|
+
|
112
|
+
jstring jKey = env->NewStringUTF(key);
|
113
|
+
|
114
|
+
env->CallVoidMethod(map, putArrayMethod, jKey, value);
|
115
|
+
}
|
116
|
+
|
117
|
+
|
118
|
+
std::unordered_map<long, rnllama::llama_rn_context *> context_map;
|
119
|
+
|
120
|
+
JNIEXPORT jlong JNICALL
|
121
|
+
Java_com_rnllama_LlamaContext_initContext(
|
122
|
+
JNIEnv *env,
|
123
|
+
jobject thiz,
|
124
|
+
jstring model_path_str,
|
125
|
+
jboolean embedding,
|
126
|
+
jint n_ctx,
|
127
|
+
jint n_batch,
|
128
|
+
jint n_threads,
|
129
|
+
jint n_gpu_layers, // TODO: Support this
|
130
|
+
jboolean use_mlock,
|
131
|
+
jboolean use_mmap,
|
132
|
+
jstring lora_str,
|
133
|
+
jfloat lora_scaled,
|
134
|
+
jstring lora_base_str,
|
135
|
+
jfloat rope_freq_base,
|
136
|
+
jfloat rope_freq_scale
|
137
|
+
) {
|
138
|
+
UNUSED(thiz);
|
139
|
+
|
140
|
+
gpt_params defaultParams;
|
141
|
+
|
142
|
+
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
143
|
+
defaultParams.model = model_path_chars;
|
144
|
+
|
145
|
+
defaultParams.embedding = embedding;
|
146
|
+
|
147
|
+
defaultParams.n_ctx = n_ctx;
|
148
|
+
defaultParams.n_batch = n_batch;
|
149
|
+
|
150
|
+
int max_threads = std::thread::hardware_concurrency();
|
151
|
+
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
152
|
+
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
153
|
+
defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
154
|
+
|
155
|
+
defaultParams.n_gpu_layers = n_gpu_layers;
|
156
|
+
|
157
|
+
defaultParams.use_mlock = use_mlock;
|
158
|
+
defaultParams.use_mmap = use_mmap;
|
159
|
+
|
160
|
+
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
161
|
+
const char *lora_base_chars = env->GetStringUTFChars(lora_base_str, nullptr);
|
162
|
+
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
163
|
+
defaultParams.lora_adapter.push_back({lora_chars, lora_scaled});
|
164
|
+
defaultParams.lora_base = lora_base_chars;
|
165
|
+
defaultParams.use_mmap = false;
|
166
|
+
}
|
167
|
+
|
168
|
+
defaultParams.rope_freq_base = rope_freq_base;
|
169
|
+
defaultParams.rope_freq_scale = rope_freq_scale;
|
170
|
+
|
171
|
+
auto llama = new rnllama::llama_rn_context();
|
172
|
+
bool is_model_loaded = llama->loadModel(defaultParams);
|
173
|
+
|
174
|
+
LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
|
175
|
+
if (is_model_loaded) {
|
176
|
+
context_map[(long) llama->ctx] = llama;
|
177
|
+
} else {
|
178
|
+
llama_free(llama->ctx);
|
179
|
+
}
|
180
|
+
|
181
|
+
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
182
|
+
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
183
|
+
env->ReleaseStringUTFChars(lora_base_str, lora_base_chars);
|
184
|
+
|
185
|
+
return reinterpret_cast<jlong>(llama->ctx);
|
186
|
+
}
|
187
|
+
|
188
|
+
JNIEXPORT jobject JNICALL
|
189
|
+
Java_com_rnllama_LlamaContext_loadModelDetails(
|
190
|
+
JNIEnv *env,
|
191
|
+
jobject thiz,
|
192
|
+
jlong context_ptr
|
193
|
+
) {
|
194
|
+
UNUSED(thiz);
|
195
|
+
auto llama = context_map[(long) context_ptr];
|
196
|
+
|
197
|
+
int count = llama_model_meta_count(llama->model);
|
198
|
+
auto meta = createWriteableMap(env);
|
199
|
+
for (int i = 0; i < count; i++) {
|
200
|
+
char key[256];
|
201
|
+
llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
|
202
|
+
char val[256];
|
203
|
+
llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
|
204
|
+
|
205
|
+
putString(env, meta, key, val);
|
206
|
+
}
|
207
|
+
|
208
|
+
auto result = createWriteableMap(env);
|
209
|
+
|
210
|
+
char desc[1024];
|
211
|
+
llama_model_desc(llama->model, desc, sizeof(desc));
|
212
|
+
putString(env, result, "desc", desc);
|
213
|
+
putDouble(env, result, "size", llama_model_size(llama->model));
|
214
|
+
putDouble(env, result, "nParams", llama_model_n_params(llama->model));
|
215
|
+
putMap(env, result, "metadata", meta);
|
216
|
+
|
217
|
+
return reinterpret_cast<jobject>(result);
|
218
|
+
}
|
219
|
+
|
220
|
+
JNIEXPORT jobject JNICALL
|
221
|
+
Java_com_rnllama_LlamaContext_loadSession(
|
222
|
+
JNIEnv *env,
|
223
|
+
jobject thiz,
|
224
|
+
jlong context_ptr,
|
225
|
+
jstring path
|
226
|
+
) {
|
227
|
+
UNUSED(thiz);
|
228
|
+
auto llama = context_map[(long) context_ptr];
|
229
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
230
|
+
|
231
|
+
auto result = createWriteableMap(env);
|
232
|
+
size_t n_token_count_out = 0;
|
233
|
+
llama->embd.resize(llama->params.n_ctx);
|
234
|
+
if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
|
235
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
236
|
+
|
237
|
+
putString(env, result, "error", "Failed to load session");
|
238
|
+
return reinterpret_cast<jobject>(result);
|
239
|
+
}
|
240
|
+
llama->embd.resize(n_token_count_out);
|
241
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
242
|
+
|
243
|
+
const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
|
244
|
+
putInt(env, result, "tokens_loaded", n_token_count_out);
|
245
|
+
putString(env, result, "prompt", text.c_str());
|
246
|
+
return reinterpret_cast<jobject>(result);
|
247
|
+
}
|
248
|
+
|
249
|
+
JNIEXPORT jint JNICALL
|
250
|
+
Java_com_rnllama_LlamaContext_saveSession(
|
251
|
+
JNIEnv *env,
|
252
|
+
jobject thiz,
|
253
|
+
jlong context_ptr,
|
254
|
+
jstring path,
|
255
|
+
jint size
|
256
|
+
) {
|
257
|
+
UNUSED(thiz);
|
258
|
+
auto llama = context_map[(long) context_ptr];
|
259
|
+
|
260
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
261
|
+
|
262
|
+
std::vector<llama_token> session_tokens = llama->embd;
|
263
|
+
int default_size = session_tokens.size();
|
264
|
+
int save_size = size > 0 && size <= default_size ? size : default_size;
|
265
|
+
if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
|
266
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
267
|
+
return -1;
|
268
|
+
}
|
269
|
+
|
270
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
271
|
+
return session_tokens.size();
|
272
|
+
}
|
273
|
+
|
274
|
+
static inline jobject tokenProbsToMap(
|
275
|
+
JNIEnv *env,
|
276
|
+
rnllama::llama_rn_context *llama,
|
277
|
+
std::vector<rnllama::completion_token_output> probs
|
278
|
+
) {
|
279
|
+
auto result = createWritableArray(env);
|
280
|
+
for (const auto &prob : probs) {
|
281
|
+
auto probsForToken = createWritableArray(env);
|
282
|
+
for (const auto &p : prob.probs) {
|
283
|
+
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
|
284
|
+
auto probResult = createWriteableMap(env);
|
285
|
+
putString(env, probResult, "tok_str", tokStr.c_str());
|
286
|
+
putDouble(env, probResult, "prob", p.prob);
|
287
|
+
pushMap(env, probsForToken, probResult);
|
288
|
+
}
|
289
|
+
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
|
290
|
+
auto tokenResult = createWriteableMap(env);
|
291
|
+
putString(env, tokenResult, "content", tokStr.c_str());
|
292
|
+
putArray(env, tokenResult, "probs", probsForToken);
|
293
|
+
pushMap(env, result, tokenResult);
|
294
|
+
}
|
295
|
+
return result;
|
296
|
+
}
|
297
|
+
|
298
|
+
JNIEXPORT jobject JNICALL
|
299
|
+
Java_com_rnllama_LlamaContext_doCompletion(
|
300
|
+
JNIEnv *env,
|
301
|
+
jobject thiz,
|
302
|
+
jlong context_ptr,
|
303
|
+
jstring prompt,
|
304
|
+
jstring grammar,
|
305
|
+
jfloat temperature,
|
306
|
+
jint n_threads,
|
307
|
+
jint n_predict,
|
308
|
+
jint n_probs,
|
309
|
+
jint penalty_last_n,
|
310
|
+
jfloat penalty_repeat,
|
311
|
+
jfloat penalty_freq,
|
312
|
+
jfloat penalty_present,
|
313
|
+
jfloat mirostat,
|
314
|
+
jfloat mirostat_tau,
|
315
|
+
jfloat mirostat_eta,
|
316
|
+
jboolean penalize_nl,
|
317
|
+
jint top_k,
|
318
|
+
jfloat top_p,
|
319
|
+
jfloat min_p,
|
320
|
+
jfloat tfs_z,
|
321
|
+
jfloat typical_p,
|
322
|
+
jint seed,
|
323
|
+
jobjectArray stop,
|
324
|
+
jboolean ignore_eos,
|
325
|
+
jobjectArray logit_bias,
|
326
|
+
jobject partial_completion_callback
|
327
|
+
) {
|
328
|
+
UNUSED(thiz);
|
329
|
+
auto llama = context_map[(long) context_ptr];
|
330
|
+
|
331
|
+
llama->rewind();
|
332
|
+
|
333
|
+
llama_reset_timings(llama->ctx);
|
334
|
+
|
335
|
+
llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
|
336
|
+
llama->params.seed = seed;
|
337
|
+
|
338
|
+
int max_threads = std::thread::hardware_concurrency();
|
339
|
+
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
340
|
+
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
341
|
+
llama->params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
342
|
+
|
343
|
+
llama->params.n_predict = n_predict;
|
344
|
+
llama->params.ignore_eos = ignore_eos;
|
345
|
+
|
346
|
+
auto & sparams = llama->params.sparams;
|
347
|
+
sparams.temp = temperature;
|
348
|
+
sparams.penalty_last_n = penalty_last_n;
|
349
|
+
sparams.penalty_repeat = penalty_repeat;
|
350
|
+
sparams.penalty_freq = penalty_freq;
|
351
|
+
sparams.penalty_present = penalty_present;
|
352
|
+
sparams.mirostat = mirostat;
|
353
|
+
sparams.mirostat_tau = mirostat_tau;
|
354
|
+
sparams.mirostat_eta = mirostat_eta;
|
355
|
+
sparams.penalize_nl = penalize_nl;
|
356
|
+
sparams.top_k = top_k;
|
357
|
+
sparams.top_p = top_p;
|
358
|
+
sparams.min_p = min_p;
|
359
|
+
sparams.tfs_z = tfs_z;
|
360
|
+
sparams.typical_p = typical_p;
|
361
|
+
sparams.n_probs = n_probs;
|
362
|
+
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
|
363
|
+
|
364
|
+
sparams.logit_bias.clear();
|
365
|
+
if (ignore_eos) {
|
366
|
+
sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
|
367
|
+
}
|
368
|
+
|
369
|
+
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
|
370
|
+
jsize logit_bias_len = env->GetArrayLength(logit_bias);
|
371
|
+
|
372
|
+
for (jsize i = 0; i < logit_bias_len; i++) {
|
373
|
+
jdoubleArray el = (jdoubleArray) env->GetObjectArrayElement(logit_bias, i);
|
374
|
+
if (el && env->GetArrayLength(el) == 2) {
|
375
|
+
jdouble* doubleArray = env->GetDoubleArrayElements(el, 0);
|
376
|
+
|
377
|
+
llama_token tok = static_cast<llama_token>(doubleArray[0]);
|
378
|
+
if (tok >= 0 && tok < n_vocab) {
|
379
|
+
if (doubleArray[1] != 0) { // If the second element is not false (0)
|
380
|
+
sparams.logit_bias[tok] = doubleArray[1];
|
381
|
+
} else {
|
382
|
+
sparams.logit_bias[tok] = -INFINITY;
|
383
|
+
}
|
384
|
+
}
|
385
|
+
|
386
|
+
env->ReleaseDoubleArrayElements(el, doubleArray, 0);
|
387
|
+
}
|
388
|
+
env->DeleteLocalRef(el);
|
389
|
+
}
|
390
|
+
|
391
|
+
llama->params.antiprompt.clear();
|
392
|
+
int stop_len = env->GetArrayLength(stop);
|
393
|
+
for (int i = 0; i < stop_len; i++) {
|
394
|
+
jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
|
395
|
+
const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
|
396
|
+
llama->params.antiprompt.push_back(stop_chars);
|
397
|
+
env->ReleaseStringUTFChars(stop_str, stop_chars);
|
398
|
+
}
|
399
|
+
|
400
|
+
if (!llama->initSampling()) {
|
401
|
+
auto result = createWriteableMap(env);
|
402
|
+
putString(env, result, "error", "Failed to initialize sampling");
|
403
|
+
return reinterpret_cast<jobject>(result);
|
404
|
+
}
|
405
|
+
llama->beginCompletion();
|
406
|
+
llama->loadPrompt();
|
407
|
+
|
408
|
+
size_t sent_count = 0;
|
409
|
+
size_t sent_token_probs_index = 0;
|
410
|
+
|
411
|
+
while (llama->has_next_token && !llama->is_interrupted) {
|
412
|
+
const rnllama::completion_token_output token_with_probs = llama->doCompletion();
|
413
|
+
if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) {
|
414
|
+
continue;
|
415
|
+
}
|
416
|
+
const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
|
417
|
+
|
418
|
+
size_t pos = std::min(sent_count, llama->generated_text.size());
|
419
|
+
|
420
|
+
const std::string str_test = llama->generated_text.substr(pos);
|
421
|
+
bool is_stop_full = false;
|
422
|
+
size_t stop_pos =
|
423
|
+
llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
|
424
|
+
if (stop_pos != std::string::npos) {
|
425
|
+
is_stop_full = true;
|
426
|
+
llama->generated_text.erase(
|
427
|
+
llama->generated_text.begin() + pos + stop_pos,
|
428
|
+
llama->generated_text.end());
|
429
|
+
pos = std::min(sent_count, llama->generated_text.size());
|
430
|
+
} else {
|
431
|
+
is_stop_full = false;
|
432
|
+
stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
|
433
|
+
rnllama::STOP_PARTIAL);
|
434
|
+
}
|
435
|
+
|
436
|
+
if (
|
437
|
+
stop_pos == std::string::npos ||
|
438
|
+
// Send rest of the text if we are at the end of the generation
|
439
|
+
(!llama->has_next_token && !is_stop_full && stop_pos > 0)
|
440
|
+
) {
|
441
|
+
const std::string to_send = llama->generated_text.substr(pos, std::string::npos);
|
442
|
+
|
443
|
+
sent_count += to_send.size();
|
444
|
+
|
445
|
+
std::vector<rnllama::completion_token_output> probs_output = {};
|
446
|
+
|
447
|
+
auto tokenResult = createWriteableMap(env);
|
448
|
+
putString(env, tokenResult, "token", to_send.c_str());
|
449
|
+
|
450
|
+
if (llama->params.sparams.n_probs > 0) {
|
451
|
+
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
|
452
|
+
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
|
453
|
+
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
|
454
|
+
if (probs_pos < probs_stop_pos) {
|
455
|
+
probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
|
456
|
+
}
|
457
|
+
sent_token_probs_index = probs_stop_pos;
|
458
|
+
|
459
|
+
putArray(env, tokenResult, "completion_probabilities", tokenProbsToMap(env, llama, probs_output));
|
460
|
+
}
|
461
|
+
|
462
|
+
jclass cb_class = env->GetObjectClass(partial_completion_callback);
|
463
|
+
jmethodID onPartialCompletion = env->GetMethodID(cb_class, "onPartialCompletion", "(Lcom/facebook/react/bridge/WritableMap;)V");
|
464
|
+
env->CallVoidMethod(partial_completion_callback, onPartialCompletion, tokenResult);
|
465
|
+
}
|
466
|
+
}
|
467
|
+
|
468
|
+
llama_print_timings(llama->ctx);
|
469
|
+
llama->is_predicting = false;
|
470
|
+
|
471
|
+
auto result = createWriteableMap(env);
|
472
|
+
putString(env, result, "text", llama->generated_text.c_str());
|
473
|
+
putArray(env, result, "completion_probabilities", tokenProbsToMap(env, llama, llama->generated_token_probs));
|
474
|
+
putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
|
475
|
+
putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
|
476
|
+
putInt(env, result, "truncated", llama->truncated);
|
477
|
+
putInt(env, result, "stopped_eos", llama->stopped_eos);
|
478
|
+
putInt(env, result, "stopped_word", llama->stopped_word);
|
479
|
+
putInt(env, result, "stopped_limit", llama->stopped_limit);
|
480
|
+
putString(env, result, "stopping_word", llama->stopping_word.c_str());
|
481
|
+
putInt(env, result, "tokens_cached", llama->n_past);
|
482
|
+
|
483
|
+
const auto timings = llama_get_timings(llama->ctx);
|
484
|
+
auto timingsResult = createWriteableMap(env);
|
485
|
+
putInt(env, timingsResult, "prompt_n", timings.n_p_eval);
|
486
|
+
putInt(env, timingsResult, "prompt_ms", timings.t_p_eval_ms);
|
487
|
+
putInt(env, timingsResult, "prompt_per_token_ms", timings.t_p_eval_ms / timings.n_p_eval);
|
488
|
+
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
|
489
|
+
putInt(env, timingsResult, "predicted_n", timings.n_eval);
|
490
|
+
putInt(env, timingsResult, "predicted_ms", timings.t_eval_ms);
|
491
|
+
putInt(env, timingsResult, "predicted_per_token_ms", timings.t_eval_ms / timings.n_eval);
|
492
|
+
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings.t_eval_ms * timings.n_eval);
|
493
|
+
|
494
|
+
putMap(env, result, "timings", timingsResult);
|
495
|
+
|
496
|
+
return reinterpret_cast<jobject>(result);
|
497
|
+
}
|
498
|
+
|
499
|
+
JNIEXPORT void JNICALL
|
500
|
+
Java_com_rnllama_LlamaContext_stopCompletion(
|
501
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
502
|
+
UNUSED(env);
|
503
|
+
UNUSED(thiz);
|
504
|
+
auto llama = context_map[(long) context_ptr];
|
505
|
+
llama->is_interrupted = true;
|
506
|
+
}
|
507
|
+
|
508
|
+
JNIEXPORT jboolean JNICALL
|
509
|
+
Java_com_rnllama_LlamaContext_isPredicting(
|
510
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
511
|
+
UNUSED(env);
|
512
|
+
UNUSED(thiz);
|
513
|
+
auto llama = context_map[(long) context_ptr];
|
514
|
+
return llama->is_predicting;
|
515
|
+
}
|
516
|
+
|
517
|
+
JNIEXPORT jobject JNICALL
|
518
|
+
Java_com_rnllama_LlamaContext_tokenize(
|
519
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
|
520
|
+
UNUSED(thiz);
|
521
|
+
auto llama = context_map[(long) context_ptr];
|
522
|
+
|
523
|
+
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
524
|
+
|
525
|
+
const std::vector<llama_token> toks = llama_tokenize(
|
526
|
+
llama->ctx,
|
527
|
+
text_chars,
|
528
|
+
false
|
529
|
+
);
|
530
|
+
|
531
|
+
jobject result = createWritableArray(env);
|
532
|
+
for (const auto &tok : toks) {
|
533
|
+
pushInt(env, result, tok);
|
534
|
+
}
|
535
|
+
|
536
|
+
env->ReleaseStringUTFChars(text, text_chars);
|
537
|
+
return result;
|
538
|
+
}
|
539
|
+
|
540
|
+
JNIEXPORT jstring JNICALL
|
541
|
+
Java_com_rnllama_LlamaContext_detokenize(
|
542
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jintArray tokens) {
|
543
|
+
UNUSED(thiz);
|
544
|
+
auto llama = context_map[(long) context_ptr];
|
545
|
+
|
546
|
+
jsize tokens_len = env->GetArrayLength(tokens);
|
547
|
+
jint *tokens_ptr = env->GetIntArrayElements(tokens, 0);
|
548
|
+
std::vector<llama_token> toks;
|
549
|
+
for (int i = 0; i < tokens_len; i++) {
|
550
|
+
toks.push_back(tokens_ptr[i]);
|
551
|
+
}
|
552
|
+
|
553
|
+
auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
|
554
|
+
|
555
|
+
env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
|
556
|
+
|
557
|
+
return env->NewStringUTF(text.c_str());
|
558
|
+
}
|
559
|
+
|
560
|
+
JNIEXPORT jboolean JNICALL
|
561
|
+
Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
|
562
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
563
|
+
UNUSED(env);
|
564
|
+
UNUSED(thiz);
|
565
|
+
auto llama = context_map[(long) context_ptr];
|
566
|
+
return llama->params.embedding;
|
567
|
+
}
|
568
|
+
|
569
|
+
JNIEXPORT jobject JNICALL
|
570
|
+
Java_com_rnllama_LlamaContext_embedding(
|
571
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
|
572
|
+
UNUSED(thiz);
|
573
|
+
auto llama = context_map[(long) context_ptr];
|
574
|
+
|
575
|
+
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
576
|
+
|
577
|
+
llama->rewind();
|
578
|
+
|
579
|
+
llama_reset_timings(llama->ctx);
|
580
|
+
|
581
|
+
llama->params.prompt = text_chars;
|
582
|
+
|
583
|
+
llama->params.n_predict = 0;
|
584
|
+
llama->beginCompletion();
|
585
|
+
llama->loadPrompt();
|
586
|
+
llama->doCompletion();
|
587
|
+
|
588
|
+
std::vector<float> embedding = llama->getEmbedding();
|
589
|
+
|
590
|
+
jobject result = createWritableArray(env);
|
591
|
+
|
592
|
+
for (const auto &val : embedding) {
|
593
|
+
pushDouble(env, result, (double) val);
|
594
|
+
}
|
595
|
+
|
596
|
+
env->ReleaseStringUTFChars(text, text_chars);
|
597
|
+
return result;
|
598
|
+
}
|
599
|
+
|
600
|
+
JNIEXPORT jstring JNICALL
|
601
|
+
Java_com_rnllama_LlamaContext_bench(
|
602
|
+
JNIEnv *env,
|
603
|
+
jobject thiz,
|
604
|
+
jlong context_ptr,
|
605
|
+
jint pp,
|
606
|
+
jint tg,
|
607
|
+
jint pl,
|
608
|
+
jint nr
|
609
|
+
) {
|
610
|
+
UNUSED(thiz);
|
611
|
+
auto llama = context_map[(long) context_ptr];
|
612
|
+
std::string result = llama->bench(pp, tg, pl, nr);
|
613
|
+
return env->NewStringUTF(result.c_str());
|
614
|
+
}
|
615
|
+
|
616
|
+
JNIEXPORT void JNICALL
|
617
|
+
Java_com_rnllama_LlamaContext_freeContext(
|
618
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
619
|
+
UNUSED(env);
|
620
|
+
UNUSED(thiz);
|
621
|
+
auto llama = context_map[(long) context_ptr];
|
622
|
+
if (llama->model) {
|
623
|
+
llama_free_model(llama->model);
|
624
|
+
}
|
625
|
+
if (llama->ctx) {
|
626
|
+
llama_free(llama->ctx);
|
627
|
+
}
|
628
|
+
if (llama->ctx_sampling != nullptr)
|
629
|
+
{
|
630
|
+
llama_sampling_free(llama->ctx_sampling);
|
631
|
+
}
|
632
|
+
context_map.erase((long) llama->ctx);
|
633
|
+
}
|
634
|
+
|
635
|
+
} // extern "C"
|