cui-llama.rn 1.3.6 → 1.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +22 -1
- package/android/src/main/CMakeLists.txt +25 -26
- package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
- package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
- package/android/src/main/jni-utils.h +94 -0
- package/android/src/main/jni.cpp +132 -62
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
- package/cpp/common.cpp +1982 -1982
- package/cpp/common.h +665 -664
- package/cpp/ggml-cpu.c +14122 -14122
- package/cpp/ggml-cpu.cpp +627 -627
- package/cpp/ggml-metal-impl.h +288 -0
- package/cpp/ggml-opt.cpp +854 -0
- package/cpp/ggml-opt.h +216 -0
- package/cpp/llama-mmap.cpp +589 -589
- package/cpp/llama.cpp +12547 -12544
- package/cpp/rn-llama.hpp +117 -116
- package/cpp/sgemm.h +14 -14
- package/ios/RNLlama.mm +47 -0
- package/ios/RNLlamaContext.h +3 -1
- package/ios/RNLlamaContext.mm +71 -14
- package/jest/mock.js +15 -3
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +33 -37
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +31 -35
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +26 -6
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +21 -36
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +4 -18
- package/package.json +2 -3
- package/src/NativeRNLlama.ts +32 -13
- package/src/index.ts +52 -47
- package/cpp/llama.cpp.rej +0 -23
package/android/src/main/jni.cpp
CHANGED
@@ -14,13 +14,14 @@
|
|
14
14
|
#include "llama-context.h"
|
15
15
|
#include "gguf.h"
|
16
16
|
#include "rn-llama.hpp"
|
17
|
+
#include "jni-utils.h"
|
17
18
|
|
18
19
|
#define UNUSED(x) (void)(x)
|
19
20
|
#define TAG "RNLLAMA_ANDROID_JNI"
|
20
21
|
|
21
22
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
22
23
|
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
23
|
-
|
24
|
+
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
24
25
|
static inline int min(int a, int b) {
|
25
26
|
return (a < b) ? a : b;
|
26
27
|
}
|
@@ -129,7 +130,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
|
|
129
130
|
// Method to push WritableMap into WritableArray
|
130
131
|
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
|
131
132
|
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
132
|
-
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/
|
133
|
+
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
|
133
134
|
|
134
135
|
env->CallVoidMethod(arr, pushMapMethod, value);
|
135
136
|
}
|
@@ -199,7 +200,7 @@ Java_com_rnllama_LlamaContext_modelInfo(
|
|
199
200
|
continue;
|
200
201
|
}
|
201
202
|
|
202
|
-
const std::string value =
|
203
|
+
const std::string value = lm_gguf_kv_to_str(ctx, i);
|
203
204
|
putString(env, info, key, value.c_str());
|
204
205
|
}
|
205
206
|
}
|
@@ -234,16 +235,18 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
234
235
|
jint embd_normalize,
|
235
236
|
jint n_ctx,
|
236
237
|
jint n_batch,
|
238
|
+
jint n_ubatch,
|
237
239
|
jint n_threads,
|
238
240
|
jint n_gpu_layers, // TODO: Support this
|
239
241
|
jboolean flash_attn,
|
240
|
-
|
241
|
-
|
242
|
+
jstring cache_type_k,
|
243
|
+
jstring cache_type_v,
|
242
244
|
jboolean use_mlock,
|
243
245
|
jboolean use_mmap,
|
244
246
|
jboolean vocab_only,
|
245
247
|
jstring lora_str,
|
246
248
|
jfloat lora_scaled,
|
249
|
+
jobject lora_list,
|
247
250
|
jfloat rope_freq_base,
|
248
251
|
jfloat rope_freq_scale,
|
249
252
|
jint pooling_type,
|
@@ -263,6 +266,7 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
263
266
|
|
264
267
|
defaultParams.n_ctx = n_ctx;
|
265
268
|
defaultParams.n_batch = n_batch;
|
269
|
+
defaultParams.n_ubatch = n_ubatch;
|
266
270
|
|
267
271
|
if (pooling_type != -1) {
|
268
272
|
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
|
@@ -285,19 +289,14 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
285
289
|
// defaultParams.n_gpu_layers = n_gpu_layers;
|
286
290
|
defaultParams.flash_attn = flash_attn;
|
287
291
|
|
288
|
-
|
289
|
-
|
290
|
-
defaultParams.cache_type_k = (
|
291
|
-
defaultParams.cache_type_v = (
|
292
|
+
const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
|
293
|
+
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
|
294
|
+
defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
|
295
|
+
defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
|
292
296
|
|
293
297
|
defaultParams.use_mlock = use_mlock;
|
294
298
|
defaultParams.use_mmap = use_mmap;
|
295
299
|
|
296
|
-
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
297
|
-
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
298
|
-
defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
|
299
|
-
}
|
300
|
-
|
301
300
|
defaultParams.rope_freq_base = rope_freq_base;
|
302
301
|
defaultParams.rope_freq_scale = rope_freq_scale;
|
303
302
|
|
@@ -331,23 +330,55 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
331
330
|
bool is_model_loaded = llama->loadModel(defaultParams);
|
332
331
|
|
333
332
|
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
334
|
-
env->ReleaseStringUTFChars(
|
335
|
-
|
336
|
-
// env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
333
|
+
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
|
334
|
+
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
337
335
|
|
338
336
|
LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
|
339
337
|
if (is_model_loaded) {
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
338
|
+
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
|
339
|
+
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
|
340
|
+
llama_free(llama->ctx);
|
341
|
+
return -1;
|
342
|
+
}
|
343
|
+
context_map[(long) llama->ctx] = llama;
|
346
344
|
} else {
|
347
|
-
|
345
|
+
llama_free(llama->ctx);
|
346
|
+
}
|
347
|
+
|
348
|
+
std::vector<common_lora_adapter_info> lora;
|
349
|
+
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
350
|
+
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
351
|
+
common_lora_adapter_info la;
|
352
|
+
la.path = lora_chars;
|
353
|
+
la.scale = lora_scaled;
|
354
|
+
lora.push_back(la);
|
348
355
|
}
|
349
356
|
|
350
|
-
|
357
|
+
if (lora_list != nullptr) {
|
358
|
+
// lora_adapters: ReadableArray<ReadableMap>
|
359
|
+
int lora_list_size = readablearray::size(env, lora_list);
|
360
|
+
for (int i = 0; i < lora_list_size; i++) {
|
361
|
+
jobject lora_adapter = readablearray::getMap(env, lora_list, i);
|
362
|
+
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
363
|
+
if (path != nullptr) {
|
364
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
365
|
+
common_lora_adapter_info la;
|
366
|
+
la.path = path_chars;
|
367
|
+
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
368
|
+
lora.push_back(la);
|
369
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
370
|
+
}
|
371
|
+
}
|
372
|
+
}
|
373
|
+
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
374
|
+
int result = llama->applyLoraAdapters(lora);
|
375
|
+
if (result != 0) {
|
376
|
+
LOGI("[RNLlama] Failed to apply lora adapters");
|
377
|
+
llama_free(llama->ctx);
|
378
|
+
return -1;
|
379
|
+
}
|
380
|
+
|
381
|
+
return reinterpret_cast<jlong>(llama->ctx);
|
351
382
|
}
|
352
383
|
|
353
384
|
|
@@ -373,13 +404,13 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
|
|
373
404
|
UNUSED(thiz);
|
374
405
|
auto llama = context_map[(long) context_ptr];
|
375
406
|
|
376
|
-
int count = llama_model_meta_count(llama->model
|
407
|
+
int count = llama_model_meta_count(llama->model);
|
377
408
|
auto meta = createWriteableMap(env);
|
378
409
|
for (int i = 0; i < count; i++) {
|
379
410
|
char key[256];
|
380
|
-
llama_model_meta_key_by_index(llama->model
|
411
|
+
llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
|
381
412
|
char val[2048];
|
382
|
-
llama_model_meta_val_str_by_index(llama->model
|
413
|
+
llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
|
383
414
|
|
384
415
|
putString(env, meta, key, val);
|
385
416
|
}
|
@@ -387,10 +418,10 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
|
|
387
418
|
auto result = createWriteableMap(env);
|
388
419
|
|
389
420
|
char desc[1024];
|
390
|
-
llama_model_desc(llama->model
|
421
|
+
llama_model_desc(llama->model, desc, sizeof(desc));
|
391
422
|
putString(env, result, "desc", desc);
|
392
|
-
putDouble(env, result, "size", llama_model_size(llama->model
|
393
|
-
putDouble(env, result, "nParams", llama_model_n_params(llama->model
|
423
|
+
putDouble(env, result, "size", llama_model_size(llama->model));
|
424
|
+
putDouble(env, result, "nParams", llama_model_n_params(llama->model));
|
394
425
|
putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
|
395
426
|
putMap(env, result, "metadata", meta);
|
396
427
|
|
@@ -432,7 +463,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
|
|
432
463
|
}
|
433
464
|
|
434
465
|
const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
|
435
|
-
std::string formatted_chat = common_chat_apply_template(llama->model
|
466
|
+
std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true);
|
436
467
|
|
437
468
|
return env->NewStringUTF(formatted_chat.c_str());
|
438
469
|
}
|
@@ -451,7 +482,7 @@ Java_com_rnllama_LlamaContext_loadSession(
|
|
451
482
|
auto result = createWriteableMap(env);
|
452
483
|
size_t n_token_count_out = 0;
|
453
484
|
llama->embd.resize(llama->params.n_ctx);
|
454
|
-
if (!llama_state_load_file(llama->ctx
|
485
|
+
if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
|
455
486
|
env->ReleaseStringUTFChars(path, path_chars);
|
456
487
|
|
457
488
|
putString(env, result, "error", "Failed to load session");
|
@@ -460,7 +491,7 @@ Java_com_rnllama_LlamaContext_loadSession(
|
|
460
491
|
llama->embd.resize(n_token_count_out);
|
461
492
|
env->ReleaseStringUTFChars(path, path_chars);
|
462
493
|
|
463
|
-
const std::string text = rnllama::tokens_to_str(llama->ctx
|
494
|
+
const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
|
464
495
|
putInt(env, result, "tokens_loaded", n_token_count_out);
|
465
496
|
putString(env, result, "prompt", text.c_str());
|
466
497
|
return reinterpret_cast<jobject>(result);
|
@@ -482,7 +513,7 @@ Java_com_rnllama_LlamaContext_saveSession(
|
|
482
513
|
std::vector<llama_token> session_tokens = llama->embd;
|
483
514
|
int default_size = session_tokens.size();
|
484
515
|
int save_size = size > 0 && size <= default_size ? size : default_size;
|
485
|
-
if (!llama_state_save_file(llama->ctx
|
516
|
+
if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
|
486
517
|
env->ReleaseStringUTFChars(path, path_chars);
|
487
518
|
return -1;
|
488
519
|
}
|
@@ -500,13 +531,13 @@ static inline jobject tokenProbsToMap(
|
|
500
531
|
for (const auto &prob : probs) {
|
501
532
|
auto probsForToken = createWritableArray(env);
|
502
533
|
for (const auto &p : prob.probs) {
|
503
|
-
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx
|
534
|
+
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
|
504
535
|
auto probResult = createWriteableMap(env);
|
505
536
|
putString(env, probResult, "tok_str", tokStr.c_str());
|
506
537
|
putDouble(env, probResult, "prob", p.prob);
|
507
538
|
pushMap(env, probsForToken, probResult);
|
508
539
|
}
|
509
|
-
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx
|
540
|
+
std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
|
510
541
|
auto tokenResult = createWriteableMap(env);
|
511
542
|
putString(env, tokenResult, "content", tokStr.c_str());
|
512
543
|
putArray(env, tokenResult, "probs", probsForToken);
|
@@ -533,7 +564,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
533
564
|
jfloat mirostat,
|
534
565
|
jfloat mirostat_tau,
|
535
566
|
jfloat mirostat_eta,
|
536
|
-
jboolean penalize_nl,
|
537
567
|
jint top_k,
|
538
568
|
jfloat top_p,
|
539
569
|
jfloat min_p,
|
@@ -546,7 +576,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
546
576
|
jobjectArray logit_bias,
|
547
577
|
jfloat dry_multiplier,
|
548
578
|
jfloat dry_base,
|
549
|
-
jint dry_allowed_length,
|
579
|
+
jint dry_allowed_length,
|
550
580
|
jint dry_penalty_last_n,
|
551
581
|
jobjectArray dry_sequence_breakers,
|
552
582
|
jobject partial_completion_callback
|
@@ -556,7 +586,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
556
586
|
|
557
587
|
llama->rewind();
|
558
588
|
|
559
|
-
//llama_reset_timings(llama->ctx
|
589
|
+
//llama_reset_timings(llama->ctx);
|
560
590
|
|
561
591
|
llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
|
562
592
|
llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
|
@@ -578,7 +608,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
578
608
|
sparams.mirostat = mirostat;
|
579
609
|
sparams.mirostat_tau = mirostat_tau;
|
580
610
|
sparams.mirostat_eta = mirostat_eta;
|
581
|
-
// sparams.penalize_nl = penalize_nl;
|
582
611
|
sparams.top_k = top_k;
|
583
612
|
sparams.top_p = top_p;
|
584
613
|
sparams.min_p = min_p;
|
@@ -594,7 +623,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
594
623
|
|
595
624
|
sparams.logit_bias.clear();
|
596
625
|
if (ignore_eos) {
|
597
|
-
sparams.logit_bias[llama_token_eos(llama->model
|
626
|
+
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
|
598
627
|
}
|
599
628
|
|
600
629
|
// dry break seq
|
@@ -613,7 +642,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
613
642
|
sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
|
614
643
|
|
615
644
|
// logit bias
|
616
|
-
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx
|
645
|
+
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
|
617
646
|
jsize logit_bias_len = env->GetArrayLength(logit_bias);
|
618
647
|
|
619
648
|
for (jsize i = 0; i < logit_bias_len; i++) {
|
@@ -660,7 +689,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
660
689
|
if (token_with_probs.tok == -1 || llama->incomplete) {
|
661
690
|
continue;
|
662
691
|
}
|
663
|
-
const std::string token_text = common_token_to_piece(llama->ctx
|
692
|
+
const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
|
664
693
|
|
665
694
|
size_t pos = std::min(sent_count, llama->generated_text.size());
|
666
695
|
|
@@ -695,7 +724,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
695
724
|
putString(env, tokenResult, "token", to_send.c_str());
|
696
725
|
|
697
726
|
if (llama->params.sampling.n_probs > 0) {
|
698
|
-
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx
|
727
|
+
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
|
699
728
|
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
|
700
729
|
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
|
701
730
|
if (probs_pos < probs_stop_pos) {
|
@@ -712,7 +741,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
712
741
|
}
|
713
742
|
}
|
714
743
|
|
715
|
-
llama_perf_context_print(llama->ctx
|
744
|
+
llama_perf_context_print(llama->ctx);
|
716
745
|
llama->is_predicting = false;
|
717
746
|
|
718
747
|
auto result = createWriteableMap(env);
|
@@ -727,7 +756,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
727
756
|
putString(env, result, "stopping_word", llama->stopping_word.c_str());
|
728
757
|
putInt(env, result, "tokens_cached", llama->n_past);
|
729
758
|
|
730
|
-
const auto timings_token = llama_perf_context(llama -> ctx
|
759
|
+
const auto timings_token = llama_perf_context(llama -> ctx);
|
731
760
|
|
732
761
|
auto timingsResult = createWriteableMap(env);
|
733
762
|
putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
|
@@ -771,7 +800,7 @@ Java_com_rnllama_LlamaContext_tokenize(
|
|
771
800
|
const char *text_chars = env->GetStringUTFChars(text, nullptr);
|
772
801
|
|
773
802
|
const std::vector<llama_token> toks = common_tokenize(
|
774
|
-
llama->ctx
|
803
|
+
llama->ctx,
|
775
804
|
text_chars,
|
776
805
|
false
|
777
806
|
);
|
@@ -798,7 +827,7 @@ Java_com_rnllama_LlamaContext_detokenize(
|
|
798
827
|
toks.push_back(tokens_ptr[i]);
|
799
828
|
}
|
800
829
|
|
801
|
-
auto text = rnllama::tokens_to_str(llama->ctx
|
830
|
+
auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
|
802
831
|
|
803
832
|
env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
|
804
833
|
|
@@ -835,7 +864,7 @@ Java_com_rnllama_LlamaContext_embedding(
|
|
835
864
|
|
836
865
|
llama->rewind();
|
837
866
|
|
838
|
-
llama_perf_context_reset(llama->ctx
|
867
|
+
llama_perf_context_reset(llama->ctx);
|
839
868
|
|
840
869
|
llama->params.prompt = text_chars;
|
841
870
|
|
@@ -861,7 +890,7 @@ Java_com_rnllama_LlamaContext_embedding(
|
|
861
890
|
|
862
891
|
auto promptTokens = createWritableArray(env);
|
863
892
|
for (const auto &tok : llama->embd) {
|
864
|
-
pushString(env, promptTokens, common_token_to_piece(llama->ctx
|
893
|
+
pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
|
865
894
|
}
|
866
895
|
putArray(env, result, "prompt_tokens", promptTokens);
|
867
896
|
|
@@ -885,23 +914,64 @@ Java_com_rnllama_LlamaContext_bench(
|
|
885
914
|
return env->NewStringUTF(result.c_str());
|
886
915
|
}
|
887
916
|
|
917
|
+
JNIEXPORT jint JNICALL
|
918
|
+
Java_com_rnllama_LlamaContext_applyLoraAdapters(
|
919
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
|
920
|
+
UNUSED(thiz);
|
921
|
+
auto llama = context_map[(long) context_ptr];
|
922
|
+
|
923
|
+
// lora_adapters: ReadableArray<ReadableMap>
|
924
|
+
std::vector<common_lora_adapter_info> lora_adapters;
|
925
|
+
int lora_adapters_size = readablearray::size(env, loraAdapters);
|
926
|
+
for (int i = 0; i < lora_adapters_size; i++) {
|
927
|
+
jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
|
928
|
+
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
929
|
+
if (path != nullptr) {
|
930
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
931
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
932
|
+
float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
933
|
+
common_lora_adapter_info la;
|
934
|
+
la.path = path_chars;
|
935
|
+
la.scale = scaled;
|
936
|
+
lora_adapters.push_back(la);
|
937
|
+
}
|
938
|
+
}
|
939
|
+
return llama->applyLoraAdapters(lora_adapters);
|
940
|
+
}
|
941
|
+
|
942
|
+
JNIEXPORT void JNICALL
|
943
|
+
Java_com_rnllama_LlamaContext_removeLoraAdapters(
|
944
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
945
|
+
UNUSED(env);
|
946
|
+
UNUSED(thiz);
|
947
|
+
auto llama = context_map[(long) context_ptr];
|
948
|
+
llama->removeLoraAdapters();
|
949
|
+
}
|
950
|
+
|
951
|
+
JNIEXPORT jobject JNICALL
|
952
|
+
Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
|
953
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
954
|
+
UNUSED(thiz);
|
955
|
+
auto llama = context_map[(long) context_ptr];
|
956
|
+
auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
|
957
|
+
auto result = createWritableArray(env);
|
958
|
+
for (common_lora_adapter_info &la : loaded_lora_adapters) {
|
959
|
+
auto map = createWriteableMap(env);
|
960
|
+
putString(env, map, "path", la.path.c_str());
|
961
|
+
putDouble(env, map, "scaled", la.scale);
|
962
|
+
pushMap(env, result, map);
|
963
|
+
}
|
964
|
+
return result;
|
965
|
+
}
|
966
|
+
|
888
967
|
JNIEXPORT void JNICALL
|
889
968
|
Java_com_rnllama_LlamaContext_freeContext(
|
890
969
|
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
891
970
|
UNUSED(env);
|
892
971
|
UNUSED(thiz);
|
893
972
|
auto llama = context_map[(long) context_ptr];
|
894
|
-
|
895
|
-
|
896
|
-
}
|
897
|
-
if (llama->ctx.get()) {
|
898
|
-
llama_free(llama->ctx.get());
|
899
|
-
}
|
900
|
-
/*if (llama->ctx.get()-> != nullptr)
|
901
|
-
{
|
902
|
-
common_sampler_free(llama->ctx.get() -> _sampling);
|
903
|
-
}*/
|
904
|
-
context_map.erase((long) llama->ctx.get());
|
973
|
+
context_map.erase((long) llama->ctx);
|
974
|
+
delete llama;
|
905
975
|
}
|
906
976
|
|
907
977
|
JNIEXPORT void JNICALL
|
@@ -103,6 +103,21 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
|
|
103
103
|
rnllama.bench(id, pp, tg, pl, nr, promise);
|
104
104
|
}
|
105
105
|
|
106
|
+
@ReactMethod
|
107
|
+
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
|
108
|
+
rnllama.applyLoraAdapters(id, loraAdapters, promise);
|
109
|
+
}
|
110
|
+
|
111
|
+
@ReactMethod
|
112
|
+
public void removeLoraAdapters(double id, final Promise promise) {
|
113
|
+
rnllama.removeLoraAdapters(id, promise);
|
114
|
+
}
|
115
|
+
|
116
|
+
@ReactMethod
|
117
|
+
public void getLoadedLoraAdapters(double id, final Promise promise) {
|
118
|
+
rnllama.getLoadedLoraAdapters(id, promise);
|
119
|
+
}
|
120
|
+
|
106
121
|
@ReactMethod
|
107
122
|
public void releaseContext(double id, Promise promise) {
|
108
123
|
rnllama.releaseContext(id, promise);
|
@@ -104,6 +104,21 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
|
|
104
104
|
rnllama.bench(id, pp, tg, pl, nr, promise);
|
105
105
|
}
|
106
106
|
|
107
|
+
@ReactMethod
|
108
|
+
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
|
109
|
+
rnllama.applyLoraAdapters(id, loraAdapters, promise);
|
110
|
+
}
|
111
|
+
|
112
|
+
@ReactMethod
|
113
|
+
public void removeLoraAdapters(double id, final Promise promise) {
|
114
|
+
rnllama.removeLoraAdapters(id, promise);
|
115
|
+
}
|
116
|
+
|
117
|
+
@ReactMethod
|
118
|
+
public void getLoadedLoraAdapters(double id, final Promise promise) {
|
119
|
+
rnllama.getLoadedLoraAdapters(id, promise);
|
120
|
+
}
|
121
|
+
|
107
122
|
@ReactMethod
|
108
123
|
public void releaseContext(double id, Promise promise) {
|
109
124
|
rnllama.releaseContext(id, promise);
|