cui-llama.rn 1.2.6 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -2
- package/android/src/main/CMakeLists.txt +20 -5
- package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
- package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
- package/android/src/main/jni.cpp +222 -34
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/cpp/common.cpp +1682 -2114
- package/cpp/common.h +600 -613
- package/cpp/ggml-aarch64.c +129 -3478
- package/cpp/ggml-aarch64.h +19 -39
- package/cpp/ggml-alloc.c +1040 -1040
- package/cpp/ggml-alloc.h +76 -76
- package/cpp/ggml-backend-impl.h +216 -216
- package/cpp/ggml-backend-reg.cpp +195 -0
- package/cpp/ggml-backend.cpp +1997 -2661
- package/cpp/ggml-backend.h +328 -314
- package/cpp/ggml-common.h +1853 -1853
- package/cpp/ggml-cpp.h +38 -38
- package/cpp/ggml-cpu-aarch64.c +3560 -0
- package/cpp/ggml-cpu-aarch64.h +30 -0
- package/cpp/ggml-cpu-impl.h +371 -614
- package/cpp/ggml-cpu-quants.c +10822 -0
- package/cpp/ggml-cpu-quants.h +63 -0
- package/cpp/ggml-cpu.c +13975 -13720
- package/cpp/ggml-cpu.cpp +663 -0
- package/cpp/ggml-cpu.h +177 -150
- package/cpp/ggml-impl.h +550 -296
- package/cpp/ggml-metal.h +66 -66
- package/cpp/ggml-metal.m +4294 -3933
- package/cpp/ggml-quants.c +5247 -15739
- package/cpp/ggml-quants.h +100 -147
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +12 -0
- package/cpp/ggml.c +8180 -8390
- package/cpp/ggml.h +2411 -2441
- package/cpp/llama-grammar.cpp +1138 -1138
- package/cpp/llama-grammar.h +144 -144
- package/cpp/llama-impl.h +181 -181
- package/cpp/llama-sampling.cpp +2348 -2345
- package/cpp/llama-sampling.h +48 -48
- package/cpp/llama-vocab.cpp +1984 -1984
- package/cpp/llama-vocab.h +170 -170
- package/cpp/llama.cpp +22132 -22046
- package/cpp/llama.h +1253 -1255
- package/cpp/log.cpp +401 -401
- package/cpp/log.h +121 -121
- package/cpp/rn-llama.hpp +83 -19
- package/cpp/sampling.cpp +466 -466
- package/cpp/sgemm.cpp +1884 -1276
- package/ios/RNLlama.mm +43 -20
- package/ios/RNLlamaContext.h +9 -3
- package/ios/RNLlamaContext.mm +133 -33
- package/jest/mock.js +0 -1
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +52 -15
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +51 -15
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +29 -5
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +12 -5
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +41 -6
- package/src/index.ts +82 -27
- package/cpp/json-schema-to-grammar.cpp +0 -1045
- package/cpp/json-schema-to-grammar.h +0 -8
- package/cpp/json.hpp +0 -24766
package/android/src/main/jni.cpp
CHANGED
@@ -4,13 +4,15 @@
|
|
4
4
|
#include <android/log.h>
|
5
5
|
#include <cstdlib>
|
6
6
|
#include <ctime>
|
7
|
+
#include <ctime>
|
7
8
|
#include <sys/sysinfo.h>
|
8
9
|
#include <string>
|
9
10
|
#include <thread>
|
10
11
|
#include <unordered_map>
|
11
12
|
#include "llama.h"
|
12
|
-
#include "
|
13
|
+
#include "llama-impl.h"
|
13
14
|
#include "ggml.h"
|
15
|
+
#include "rn-llama.hpp"
|
14
16
|
|
15
17
|
#define UNUSED(x) (void)(x)
|
16
18
|
#define TAG "RNLLAMA_ANDROID_JNI"
|
@@ -22,6 +24,13 @@ static inline int min(int a, int b) {
|
|
22
24
|
return (a < b) ? a : b;
|
23
25
|
}
|
24
26
|
|
27
|
+
static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
|
28
|
+
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
29
|
+
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
30
|
+
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
|
31
|
+
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
|
32
|
+
}
|
33
|
+
|
25
34
|
extern "C" {
|
26
35
|
|
27
36
|
// Method to create WritableMap
|
@@ -107,6 +116,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
|
|
107
116
|
env->CallVoidMethod(arr, pushDoubleMethod, value);
|
108
117
|
}
|
109
118
|
|
119
|
+
// Method to push string into WritableArray
|
120
|
+
static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
|
121
|
+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
122
|
+
jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
|
123
|
+
|
124
|
+
jstring jValue = env->NewStringUTF(value);
|
125
|
+
env->CallVoidMethod(arr, pushStringMethod, jValue);
|
126
|
+
}
|
127
|
+
|
110
128
|
// Method to push WritableMap into WritableArray
|
111
129
|
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
|
112
130
|
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
@@ -125,6 +143,77 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
|
|
125
143
|
env->CallVoidMethod(map, putArrayMethod, jKey, value);
|
126
144
|
}
|
127
145
|
|
146
|
+
JNIEXPORT jobject JNICALL
|
147
|
+
Java_com_rnllama_LlamaContext_modelInfo(
|
148
|
+
JNIEnv *env,
|
149
|
+
jobject thiz,
|
150
|
+
jstring model_path_str,
|
151
|
+
jobjectArray skip
|
152
|
+
) {
|
153
|
+
UNUSED(thiz);
|
154
|
+
|
155
|
+
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
156
|
+
|
157
|
+
std::vector<std::string> skip_vec;
|
158
|
+
int skip_len = env->GetArrayLength(skip);
|
159
|
+
for (int i = 0; i < skip_len; i++) {
|
160
|
+
jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
|
161
|
+
const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
|
162
|
+
skip_vec.push_back(skip_chars);
|
163
|
+
env->ReleaseStringUTFChars(skip_str, skip_chars);
|
164
|
+
}
|
165
|
+
|
166
|
+
struct lm_gguf_init_params params = {
|
167
|
+
/*.no_alloc = */ false,
|
168
|
+
/*.ctx = */ NULL,
|
169
|
+
};
|
170
|
+
struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
|
171
|
+
|
172
|
+
if (!ctx) {
|
173
|
+
LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
|
174
|
+
return nullptr;
|
175
|
+
}
|
176
|
+
|
177
|
+
auto info = createWriteableMap(env);
|
178
|
+
putInt(env, info, "version", lm_gguf_get_version(ctx));
|
179
|
+
putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
|
180
|
+
putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
|
181
|
+
{
|
182
|
+
const int n_kv = lm_gguf_get_n_kv(ctx);
|
183
|
+
|
184
|
+
for (int i = 0; i < n_kv; ++i) {
|
185
|
+
const char * key = lm_gguf_get_key(ctx, i);
|
186
|
+
|
187
|
+
bool skipped = false;
|
188
|
+
if (skip_len > 0) {
|
189
|
+
for (int j = 0; j < skip_len; j++) {
|
190
|
+
if (skip_vec[j] == key) {
|
191
|
+
skipped = true;
|
192
|
+
break;
|
193
|
+
}
|
194
|
+
}
|
195
|
+
}
|
196
|
+
|
197
|
+
if (skipped) {
|
198
|
+
continue;
|
199
|
+
}
|
200
|
+
|
201
|
+
const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
|
202
|
+
putString(env, info, key, value.c_str());
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
207
|
+
lm_gguf_free(ctx);
|
208
|
+
|
209
|
+
return reinterpret_cast<jobject>(info);
|
210
|
+
}
|
211
|
+
|
212
|
+
struct callback_context {
|
213
|
+
JNIEnv *env;
|
214
|
+
rnllama::llama_rn_context *llama;
|
215
|
+
jobject callback;
|
216
|
+
};
|
128
217
|
|
129
218
|
std::unordered_map<long, rnllama::llama_rn_context *> context_map;
|
130
219
|
|
@@ -141,10 +230,14 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
141
230
|
jobject thiz,
|
142
231
|
jstring model_path_str,
|
143
232
|
jboolean embedding,
|
233
|
+
jint embd_normalize,
|
144
234
|
jint n_ctx,
|
145
235
|
jint n_batch,
|
146
236
|
jint n_threads,
|
147
237
|
jint n_gpu_layers, // TODO: Support this
|
238
|
+
jboolean flash_attn,
|
239
|
+
jstring cache_type_k,
|
240
|
+
jstring cache_type_v,
|
148
241
|
jboolean use_mlock,
|
149
242
|
jboolean use_mmap,
|
150
243
|
jboolean vocab_only,
|
@@ -152,7 +245,8 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
152
245
|
jfloat lora_scaled,
|
153
246
|
jfloat rope_freq_base,
|
154
247
|
jfloat rope_freq_scale,
|
155
|
-
|
248
|
+
jint pooling_type,
|
249
|
+
jobject load_progress_callback
|
156
250
|
) {
|
157
251
|
UNUSED(thiz);
|
158
252
|
|
@@ -166,64 +260,109 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
166
260
|
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
167
261
|
defaultParams.model = model_path_chars;
|
168
262
|
|
169
|
-
defaultParams.embedding = embedding;
|
170
|
-
|
171
263
|
defaultParams.n_ctx = n_ctx;
|
172
264
|
defaultParams.n_batch = n_batch;
|
173
265
|
|
266
|
+
if (pooling_type != -1) {
|
267
|
+
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
|
268
|
+
}
|
269
|
+
|
270
|
+
defaultParams.embedding = embedding;
|
271
|
+
if (embd_normalize != -1) {
|
272
|
+
defaultParams.embd_normalize = embd_normalize;
|
273
|
+
}
|
274
|
+
if (embedding) {
|
275
|
+
// For non-causal models, batch size must be equal to ubatch size
|
276
|
+
defaultParams.n_ubatch = defaultParams.n_batch;
|
277
|
+
}
|
278
|
+
|
174
279
|
int max_threads = std::thread::hardware_concurrency();
|
175
280
|
// Use 2 threads by default on 4-core devices, 4 threads on more cores
|
176
281
|
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
|
177
282
|
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
178
283
|
|
179
284
|
defaultParams.n_gpu_layers = n_gpu_layers;
|
180
|
-
|
285
|
+
defaultParams.flash_attn = flash_attn;
|
286
|
+
|
287
|
+
const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
|
288
|
+
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
|
289
|
+
defaultParams.cache_type_k = cache_type_k_chars;
|
290
|
+
defaultParams.cache_type_v = cache_type_v_chars;
|
291
|
+
|
181
292
|
defaultParams.use_mlock = use_mlock;
|
182
293
|
defaultParams.use_mmap = use_mmap;
|
183
294
|
|
184
295
|
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
185
296
|
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
186
297
|
defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
|
187
|
-
defaultParams.use_mmap = false;
|
188
298
|
}
|
189
299
|
|
190
300
|
defaultParams.rope_freq_base = rope_freq_base;
|
191
301
|
defaultParams.rope_freq_scale = rope_freq_scale;
|
192
302
|
|
193
|
-
// progress callback when loading
|
194
|
-
jclass llamaContextClass = env->GetObjectClass(javaLlamaContext);
|
195
|
-
jmethodID sendProgressMethod = env->GetMethodID(llamaContextClass, "emitModelProgressUpdate", "(I)V");
|
196
|
-
|
197
|
-
CallbackContext callbackctx = {env, javaLlamaContext, sendProgressMethod, 0};
|
198
|
-
|
199
|
-
defaultParams.progress_callback_user_data = &callbackctx;
|
200
|
-
defaultParams.progress_callback = [](float progress, void * ctx) {
|
201
|
-
unsigned percentage = (unsigned) (100 * progress);
|
202
|
-
CallbackContext * cbctx = static_cast<CallbackContext*>(ctx);
|
203
|
-
// reduce call frequency by only calling method when value changes
|
204
|
-
if (percentage <= cbctx->current) return true;
|
205
|
-
cbctx->current = percentage;
|
206
|
-
cbctx->env->CallVoidMethod(cbctx->thiz, cbctx->sendProgressMethod, percentage);
|
207
|
-
return true;
|
208
|
-
};
|
209
|
-
|
210
|
-
|
211
303
|
auto llama = new rnllama::llama_rn_context();
|
304
|
+
llama->is_load_interrupted = false;
|
305
|
+
llama->loading_progress = 0;
|
306
|
+
|
307
|
+
if (load_progress_callback != nullptr) {
|
308
|
+
defaultParams.progress_callback = [](float progress, void * user_data) {
|
309
|
+
callback_context *cb_ctx = (callback_context *)user_data;
|
310
|
+
JNIEnv *env = cb_ctx->env;
|
311
|
+
auto llama = cb_ctx->llama;
|
312
|
+
jobject callback = cb_ctx->callback;
|
313
|
+
int percentage = (int) (100 * progress);
|
314
|
+
if (percentage > llama->loading_progress) {
|
315
|
+
llama->loading_progress = percentage;
|
316
|
+
jclass callback_class = env->GetObjectClass(callback);
|
317
|
+
jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
|
318
|
+
env->CallVoidMethod(callback, onLoadProgress, percentage);
|
319
|
+
}
|
320
|
+
return !llama->is_load_interrupted;
|
321
|
+
};
|
322
|
+
|
323
|
+
callback_context *cb_ctx = new callback_context;
|
324
|
+
cb_ctx->env = env;
|
325
|
+
cb_ctx->llama = llama;
|
326
|
+
cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
|
327
|
+
defaultParams.progress_callback_user_data = cb_ctx;
|
328
|
+
}
|
329
|
+
|
212
330
|
bool is_model_loaded = llama->loadModel(defaultParams);
|
213
331
|
|
332
|
+
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
333
|
+
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
334
|
+
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
|
335
|
+
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
336
|
+
|
214
337
|
LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
|
215
338
|
if (is_model_loaded) {
|
339
|
+
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
|
340
|
+
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
|
341
|
+
llama_free(llama->ctx);
|
342
|
+
return -1;
|
343
|
+
}
|
216
344
|
context_map[(long) llama->ctx] = llama;
|
217
345
|
} else {
|
218
346
|
llama_free(llama->ctx);
|
219
347
|
}
|
220
348
|
|
221
|
-
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
222
|
-
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
223
|
-
|
224
349
|
return reinterpret_cast<jlong>(llama->ctx);
|
225
350
|
}
|
226
351
|
|
352
|
+
|
353
|
+
JNIEXPORT void JNICALL
|
354
|
+
Java_com_rnllama_LlamaContext_interruptLoad(
|
355
|
+
JNIEnv *env,
|
356
|
+
jobject thiz,
|
357
|
+
jlong context_ptr
|
358
|
+
) {
|
359
|
+
UNUSED(thiz);
|
360
|
+
auto llama = context_map[(long) context_ptr];
|
361
|
+
if (llama) {
|
362
|
+
llama->is_load_interrupted = true;
|
363
|
+
}
|
364
|
+
}
|
365
|
+
|
227
366
|
JNIEXPORT jobject JNICALL
|
228
367
|
Java_com_rnllama_LlamaContext_loadModelDetails(
|
229
368
|
JNIEnv *env,
|
@@ -397,13 +536,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
397
536
|
jint top_k,
|
398
537
|
jfloat top_p,
|
399
538
|
jfloat min_p,
|
400
|
-
jfloat
|
401
|
-
jfloat
|
539
|
+
jfloat xtc_threshold,
|
540
|
+
jfloat xtc_probability,
|
402
541
|
jfloat typical_p,
|
403
542
|
jint seed,
|
404
543
|
jobjectArray stop,
|
405
544
|
jboolean ignore_eos,
|
406
545
|
jobjectArray logit_bias,
|
546
|
+
jfloat dry_multiplier,
|
547
|
+
jfloat dry_base,
|
548
|
+
jint dry_allowed_length,
|
549
|
+
jint dry_penalty_last_n,
|
550
|
+
jobjectArray dry_sequence_breakers,
|
407
551
|
jobject partial_completion_callback
|
408
552
|
) {
|
409
553
|
UNUSED(thiz);
|
@@ -440,14 +584,34 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
440
584
|
sparams.typ_p = typical_p;
|
441
585
|
sparams.n_probs = n_probs;
|
442
586
|
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
|
443
|
-
sparams.xtc_threshold =
|
444
|
-
sparams.xtc_probability =
|
587
|
+
sparams.xtc_threshold = xtc_threshold;
|
588
|
+
sparams.xtc_probability = xtc_probability;
|
589
|
+
sparams.dry_multiplier = dry_multiplier;
|
590
|
+
sparams.dry_base = dry_base;
|
591
|
+
sparams.dry_allowed_length = dry_allowed_length;
|
592
|
+
sparams.dry_penalty_last_n = dry_penalty_last_n;
|
445
593
|
|
446
594
|
sparams.logit_bias.clear();
|
447
595
|
if (ignore_eos) {
|
448
596
|
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
|
449
597
|
}
|
450
598
|
|
599
|
+
// dry break seq
|
600
|
+
|
601
|
+
jint size = env->GetArrayLength(dry_sequence_breakers);
|
602
|
+
std::vector<std::string> dry_sequence_breakers_vector;
|
603
|
+
|
604
|
+
for (jint i = 0; i < size; i++) {
|
605
|
+
jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
|
606
|
+
const char *nativeString = env->GetStringUTFChars(javaString, 0);
|
607
|
+
dry_sequence_breakers_vector.push_back(std::string(nativeString));
|
608
|
+
env->ReleaseStringUTFChars(javaString, nativeString);
|
609
|
+
env->DeleteLocalRef(javaString);
|
610
|
+
}
|
611
|
+
|
612
|
+
sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
|
613
|
+
|
614
|
+
// logit bias
|
451
615
|
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
|
452
616
|
jsize logit_bias_len = env->GetArrayLength(logit_bias);
|
453
617
|
|
@@ -651,16 +815,27 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
|
|
651
815
|
|
652
816
|
JNIEXPORT jobject JNICALL
|
653
817
|
Java_com_rnllama_LlamaContext_embedding(
|
654
|
-
JNIEnv *env, jobject thiz,
|
818
|
+
JNIEnv *env, jobject thiz,
|
819
|
+
jlong context_ptr,
|
820
|
+
jstring text,
|
821
|
+
jint embd_normalize
|
822
|
+
) {
|
655
823
|
UNUSED(thiz);
|
656
824
|
auto llama = context_map[(long) context_ptr];
|
657
825
|
|
826
|
+
common_params embdParams;
|
827
|
+
embdParams.embedding = true;
|
828
|
+
embdParams.embd_normalize = llama->params.embd_normalize;
|
829
|
+
if (embd_normalize != -1) {
|
830
|
+
embdParams.embd_normalize = embd_normalize;
|
831
|
+
}
|
832
|
+
|
658
833
|
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
659
834
|
|
660
835
|
llama->rewind();
|
661
836
|
|
662
837
|
llama_perf_context_reset(llama->ctx);
|
663
|
-
|
838
|
+
|
664
839
|
llama->params.prompt = text_chars;
|
665
840
|
|
666
841
|
llama->params.n_predict = 0;
|
@@ -675,7 +850,7 @@ Java_com_rnllama_LlamaContext_embedding(
|
|
675
850
|
llama->loadPrompt();
|
676
851
|
llama->doCompletion();
|
677
852
|
|
678
|
-
std::vector<float> embedding = llama->getEmbedding();
|
853
|
+
std::vector<float> embedding = llama->getEmbedding(embdParams);
|
679
854
|
|
680
855
|
auto embeddings = createWritableArray(env);
|
681
856
|
for (const auto &val : embedding) {
|
@@ -683,6 +858,12 @@ Java_com_rnllama_LlamaContext_embedding(
|
|
683
858
|
}
|
684
859
|
putArray(env, result, "embedding", embeddings);
|
685
860
|
|
861
|
+
auto promptTokens = createWritableArray(env);
|
862
|
+
for (const auto &tok : llama->embd) {
|
863
|
+
pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
|
864
|
+
}
|
865
|
+
putArray(env, result, "prompt_tokens", promptTokens);
|
866
|
+
|
686
867
|
env->ReleaseStringUTFChars(text, text_chars);
|
687
868
|
return result;
|
688
869
|
}
|
@@ -722,4 +903,11 @@ Java_com_rnllama_LlamaContext_freeContext(
|
|
722
903
|
context_map.erase((long) llama->ctx);
|
723
904
|
}
|
724
905
|
|
906
|
+
JNIEXPORT void JNICALL
|
907
|
+
Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
|
908
|
+
UNUSED(env);
|
909
|
+
UNUSED(thiz);
|
910
|
+
llama_log_set(log_callback, NULL);
|
911
|
+
}
|
912
|
+
|
725
913
|
} // extern "C"
|
@@ -39,8 +39,13 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
|
|
39
39
|
}
|
40
40
|
|
41
41
|
@ReactMethod
|
42
|
-
public void
|
43
|
-
rnllama.
|
42
|
+
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
|
43
|
+
rnllama.modelInfo(model, skip, promise);
|
44
|
+
}
|
45
|
+
|
46
|
+
@ReactMethod
|
47
|
+
public void initContext(double id, final ReadableMap params, final Promise promise) {
|
48
|
+
rnllama.initContext(id, params, promise);
|
44
49
|
}
|
45
50
|
|
46
51
|
@ReactMethod
|
@@ -89,8 +94,8 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
|
|
89
94
|
}
|
90
95
|
|
91
96
|
@ReactMethod
|
92
|
-
public void embedding(double id, final String text, final Promise promise) {
|
93
|
-
rnllama.embedding(id, text, promise);
|
97
|
+
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
|
98
|
+
rnllama.embedding(id, text, params, promise);
|
94
99
|
}
|
95
100
|
|
96
101
|
@ReactMethod
|
@@ -40,8 +40,13 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
|
|
40
40
|
}
|
41
41
|
|
42
42
|
@ReactMethod
|
43
|
-
public void
|
44
|
-
rnllama.
|
43
|
+
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
|
44
|
+
rnllama.modelInfo(model, skip, promise);
|
45
|
+
}
|
46
|
+
|
47
|
+
@ReactMethod
|
48
|
+
public void initContext(double id, final ReadableMap params, final Promise promise) {
|
49
|
+
rnllama.initContext(id, params, promise);
|
45
50
|
}
|
46
51
|
|
47
52
|
@ReactMethod
|
@@ -90,8 +95,8 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
|
|
90
95
|
}
|
91
96
|
|
92
97
|
@ReactMethod
|
93
|
-
public void embedding(double id, final String text, final Promise promise) {
|
94
|
-
rnllama.embedding(id, text, promise);
|
98
|
+
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
|
99
|
+
rnllama.embedding(id, text, params, promise);
|
95
100
|
}
|
96
101
|
|
97
102
|
@ReactMethod
|