cui-llama.rn 1.3.6 → 1.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/README.md +22 -1
  2. package/android/src/main/CMakeLists.txt +25 -26
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
  5. package/android/src/main/jni-utils.h +94 -0
  6. package/android/src/main/jni.cpp +133 -63
  7. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
  8. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
  9. package/cpp/common.cpp +2085 -1982
  10. package/cpp/common.h +696 -664
  11. package/cpp/ggml-alloc.c +1042 -1037
  12. package/cpp/ggml-backend-impl.h +255 -256
  13. package/cpp/ggml-backend-reg.cpp +582 -582
  14. package/cpp/ggml-backend.cpp +2002 -2002
  15. package/cpp/ggml-backend.h +354 -352
  16. package/cpp/ggml-common.h +1853 -1853
  17. package/cpp/ggml-cpp.h +39 -39
  18. package/cpp/ggml-cpu-aarch64.cpp +4247 -4247
  19. package/cpp/ggml-cpu-aarch64.h +8 -8
  20. package/cpp/ggml-cpu-impl.h +386 -386
  21. package/cpp/ggml-cpu-quants.c +10920 -10839
  22. package/cpp/ggml-cpu-traits.cpp +36 -36
  23. package/cpp/ggml-cpu-traits.h +38 -38
  24. package/cpp/ggml-cpu.c +14391 -14122
  25. package/cpp/ggml-cpu.cpp +635 -627
  26. package/cpp/ggml-cpu.h +135 -135
  27. package/cpp/ggml-impl.h +567 -567
  28. package/cpp/ggml-metal-impl.h +288 -0
  29. package/cpp/ggml-metal.m +4884 -4884
  30. package/cpp/ggml-opt.cpp +854 -0
  31. package/cpp/ggml-opt.h +216 -0
  32. package/cpp/ggml-quants.c +5238 -5238
  33. package/cpp/ggml-threading.h +14 -14
  34. package/cpp/ggml.c +6514 -6448
  35. package/cpp/ggml.h +2194 -2163
  36. package/cpp/gguf.cpp +1329 -1325
  37. package/cpp/gguf.h +202 -202
  38. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  39. package/cpp/json-schema-to-grammar.h +8 -8
  40. package/cpp/json.hpp +24766 -24766
  41. package/cpp/llama-adapter.cpp +347 -346
  42. package/cpp/llama-adapter.h +74 -73
  43. package/cpp/llama-arch.cpp +1487 -1434
  44. package/cpp/llama-arch.h +400 -395
  45. package/cpp/llama-batch.cpp +368 -368
  46. package/cpp/llama-batch.h +88 -88
  47. package/cpp/llama-chat.cpp +578 -567
  48. package/cpp/llama-chat.h +52 -51
  49. package/cpp/llama-context.cpp +1775 -1771
  50. package/cpp/llama-context.h +128 -128
  51. package/cpp/llama-cparams.cpp +1 -1
  52. package/cpp/llama-cparams.h +37 -37
  53. package/cpp/llama-cpp.h +30 -30
  54. package/cpp/llama-grammar.cpp +1139 -1139
  55. package/cpp/llama-grammar.h +143 -143
  56. package/cpp/llama-hparams.cpp +71 -71
  57. package/cpp/llama-hparams.h +139 -140
  58. package/cpp/llama-impl.cpp +167 -167
  59. package/cpp/llama-impl.h +61 -61
  60. package/cpp/llama-kv-cache.cpp +718 -718
  61. package/cpp/llama-kv-cache.h +218 -218
  62. package/cpp/llama-mmap.cpp +590 -589
  63. package/cpp/llama-mmap.h +67 -67
  64. package/cpp/llama-model-loader.cpp +1124 -1011
  65. package/cpp/llama-model-loader.h +167 -158
  66. package/cpp/llama-model.cpp +3997 -2202
  67. package/cpp/llama-model.h +370 -391
  68. package/cpp/llama-sampling.cpp +2408 -2406
  69. package/cpp/llama-sampling.h +32 -48
  70. package/cpp/llama-vocab.cpp +3247 -1982
  71. package/cpp/llama-vocab.h +125 -182
  72. package/cpp/llama.cpp +10077 -12544
  73. package/cpp/llama.h +1323 -1285
  74. package/cpp/log.cpp +401 -401
  75. package/cpp/log.h +121 -121
  76. package/cpp/rn-llama.hpp +123 -116
  77. package/cpp/sampling.cpp +505 -500
  78. package/cpp/sgemm.cpp +2597 -2597
  79. package/cpp/sgemm.h +14 -14
  80. package/cpp/speculative.cpp +277 -274
  81. package/cpp/speculative.h +28 -28
  82. package/cpp/unicode.cpp +2 -3
  83. package/ios/RNLlama.mm +47 -0
  84. package/ios/RNLlamaContext.h +3 -1
  85. package/ios/RNLlamaContext.mm +71 -14
  86. package/jest/mock.js +15 -3
  87. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  88. package/lib/commonjs/index.js +33 -37
  89. package/lib/commonjs/index.js.map +1 -1
  90. package/lib/module/NativeRNLlama.js.map +1 -1
  91. package/lib/module/index.js +31 -35
  92. package/lib/module/index.js.map +1 -1
  93. package/lib/typescript/NativeRNLlama.d.ts +26 -6
  94. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  95. package/lib/typescript/index.d.ts +21 -36
  96. package/lib/typescript/index.d.ts.map +1 -1
  97. package/llama-rn.podspec +4 -18
  98. package/package.json +2 -3
  99. package/src/NativeRNLlama.ts +32 -13
  100. package/src/index.ts +52 -47
  101. package/cpp/llama.cpp.rej +0 -23
@@ -14,13 +14,14 @@
14
14
  #include "llama-context.h"
15
15
  #include "gguf.h"
16
16
  #include "rn-llama.hpp"
17
+ #include "jni-utils.h"
17
18
 
18
19
  #define UNUSED(x) (void)(x)
19
20
  #define TAG "RNLLAMA_ANDROID_JNI"
20
21
 
21
22
  #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
22
23
  #define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
23
-
24
+ #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
24
25
  static inline int min(int a, int b) {
25
26
  return (a < b) ? a : b;
26
27
  }
@@ -129,7 +130,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
129
130
  // Method to push WritableMap into WritableArray
130
131
  static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
131
132
  jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
132
- jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V");
133
+ jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
133
134
 
134
135
  env->CallVoidMethod(arr, pushMapMethod, value);
135
136
  }
@@ -199,7 +200,7 @@ Java_com_rnllama_LlamaContext_modelInfo(
199
200
  continue;
200
201
  }
201
202
 
202
- const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
203
+ const std::string value = lm_gguf_kv_to_str(ctx, i);
203
204
  putString(env, info, key, value.c_str());
204
205
  }
205
206
  }
@@ -234,16 +235,18 @@ Java_com_rnllama_LlamaContext_initContext(
234
235
  jint embd_normalize,
235
236
  jint n_ctx,
236
237
  jint n_batch,
238
+ jint n_ubatch,
237
239
  jint n_threads,
238
240
  jint n_gpu_layers, // TODO: Support this
239
241
  jboolean flash_attn,
240
- jint cache_type_k,
241
- jint cache_type_v,
242
+ jstring cache_type_k,
243
+ jstring cache_type_v,
242
244
  jboolean use_mlock,
243
245
  jboolean use_mmap,
244
246
  jboolean vocab_only,
245
247
  jstring lora_str,
246
248
  jfloat lora_scaled,
249
+ jobject lora_list,
247
250
  jfloat rope_freq_base,
248
251
  jfloat rope_freq_scale,
249
252
  jint pooling_type,
@@ -263,6 +266,7 @@ Java_com_rnllama_LlamaContext_initContext(
263
266
 
264
267
  defaultParams.n_ctx = n_ctx;
265
268
  defaultParams.n_batch = n_batch;
269
+ defaultParams.n_ubatch = n_ubatch;
266
270
 
267
271
  if (pooling_type != -1) {
268
272
  defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
@@ -285,19 +289,14 @@ Java_com_rnllama_LlamaContext_initContext(
285
289
  // defaultParams.n_gpu_layers = n_gpu_layers;
286
290
  defaultParams.flash_attn = flash_attn;
287
291
 
288
- // const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
289
- // const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
290
- defaultParams.cache_type_k = (lm_ggml_type) cache_type_k;
291
- defaultParams.cache_type_v = (lm_ggml_type) cache_type_v;
292
+ const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
293
+ const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
294
+ defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
295
+ defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
292
296
 
293
297
  defaultParams.use_mlock = use_mlock;
294
298
  defaultParams.use_mmap = use_mmap;
295
299
 
296
- const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
297
- if (lora_chars != nullptr && lora_chars[0] != '\0') {
298
- defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
299
- }
300
-
301
300
  defaultParams.rope_freq_base = rope_freq_base;
302
301
  defaultParams.rope_freq_scale = rope_freq_scale;
303
302
 
@@ -331,23 +330,55 @@ Java_com_rnllama_LlamaContext_initContext(
331
330
  bool is_model_loaded = llama->loadModel(defaultParams);
332
331
 
333
332
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
334
- env->ReleaseStringUTFChars(lora_str, lora_chars);
335
- // env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
336
- // env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
333
+ env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
334
+ env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
337
335
 
338
336
  LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
339
337
  if (is_model_loaded) {
340
- if (embedding && llama_model_has_encoder(llama->model.get()) && llama_model_has_decoder(llama->model.get())) {
341
- LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
342
- llama_free(llama->ctx.get());
343
- return -1;
344
- }
345
- context_map[(long) llama->ctx.get()] = llama;
338
+ if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
339
+ LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
340
+ llama_free(llama->ctx);
341
+ return -1;
342
+ }
343
+ context_map[(long) llama->ctx] = llama;
346
344
  } else {
347
- llama_free(llama->ctx.get());
345
+ llama_free(llama->ctx);
346
+ }
347
+
348
+ std::vector<common_adapter_lora_info> lora;
349
+ const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
350
+ if (lora_chars != nullptr && lora_chars[0] != '\0') {
351
+ common_adapter_lora_info la;
352
+ la.path = lora_chars;
353
+ la.scale = lora_scaled;
354
+ lora.push_back(la);
348
355
  }
349
356
 
350
- return reinterpret_cast<jlong>(llama->ctx.get());
357
+ if (lora_list != nullptr) {
358
+ // lora_adapters: ReadableArray<ReadableMap>
359
+ int lora_list_size = readablearray::size(env, lora_list);
360
+ for (int i = 0; i < lora_list_size; i++) {
361
+ jobject lora_adapter = readablearray::getMap(env, lora_list, i);
362
+ jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
363
+ if (path != nullptr) {
364
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
365
+ common_adapter_lora_info la;
366
+ la.path = path_chars;
367
+ la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
368
+ lora.push_back(la);
369
+ env->ReleaseStringUTFChars(path, path_chars);
370
+ }
371
+ }
372
+ }
373
+ env->ReleaseStringUTFChars(lora_str, lora_chars);
374
+ int result = llama->applyLoraAdapters(lora);
375
+ if (result != 0) {
376
+ LOGI("[RNLlama] Failed to apply lora adapters");
377
+ llama_free(llama->ctx);
378
+ return -1;
379
+ }
380
+
381
+ return reinterpret_cast<jlong>(llama->ctx);
351
382
  }
352
383
 
353
384
 
@@ -373,13 +404,13 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
373
404
  UNUSED(thiz);
374
405
  auto llama = context_map[(long) context_ptr];
375
406
 
376
- int count = llama_model_meta_count(llama->model.get());
407
+ int count = llama_model_meta_count(llama->model);
377
408
  auto meta = createWriteableMap(env);
378
409
  for (int i = 0; i < count; i++) {
379
410
  char key[256];
380
- llama_model_meta_key_by_index(llama->model.get(), i, key, sizeof(key));
381
- char val[2048];
382
- llama_model_meta_val_str_by_index(llama->model.get(), i, val, sizeof(val));
411
+ llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
412
+ char val[4096];
413
+ llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
383
414
 
384
415
  putString(env, meta, key, val);
385
416
  }
@@ -387,10 +418,10 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
387
418
  auto result = createWriteableMap(env);
388
419
 
389
420
  char desc[1024];
390
- llama_model_desc(llama->model.get(), desc, sizeof(desc));
421
+ llama_model_desc(llama->model, desc, sizeof(desc));
391
422
  putString(env, result, "desc", desc);
392
- putDouble(env, result, "size", llama_model_size(llama->model.get()));
393
- putDouble(env, result, "nParams", llama_model_n_params(llama->model.get()));
423
+ putDouble(env, result, "size", llama_model_size(llama->model));
424
+ putDouble(env, result, "nParams", llama_model_n_params(llama->model));
394
425
  putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
395
426
  putMap(env, result, "metadata", meta);
396
427
 
@@ -432,7 +463,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
432
463
  }
433
464
 
434
465
  const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
435
- std::string formatted_chat = common_chat_apply_template(llama->model.get(), tmpl_chars, chat, true);
466
+ std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true);
436
467
 
437
468
  return env->NewStringUTF(formatted_chat.c_str());
438
469
  }
@@ -451,7 +482,7 @@ Java_com_rnllama_LlamaContext_loadSession(
451
482
  auto result = createWriteableMap(env);
452
483
  size_t n_token_count_out = 0;
453
484
  llama->embd.resize(llama->params.n_ctx);
454
- if (!llama_state_load_file(llama->ctx.get(), path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
485
+ if (!llama_state_load_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
455
486
  env->ReleaseStringUTFChars(path, path_chars);
456
487
 
457
488
  putString(env, result, "error", "Failed to load session");
@@ -460,7 +491,7 @@ Java_com_rnllama_LlamaContext_loadSession(
460
491
  llama->embd.resize(n_token_count_out);
461
492
  env->ReleaseStringUTFChars(path, path_chars);
462
493
 
463
- const std::string text = rnllama::tokens_to_str(llama->ctx.get(), llama->embd.cbegin(), llama->embd.cend());
494
+ const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
464
495
  putInt(env, result, "tokens_loaded", n_token_count_out);
465
496
  putString(env, result, "prompt", text.c_str());
466
497
  return reinterpret_cast<jobject>(result);
@@ -482,7 +513,7 @@ Java_com_rnllama_LlamaContext_saveSession(
482
513
  std::vector<llama_token> session_tokens = llama->embd;
483
514
  int default_size = session_tokens.size();
484
515
  int save_size = size > 0 && size <= default_size ? size : default_size;
485
- if (!llama_state_save_file(llama->ctx.get(), path_chars, session_tokens.data(), save_size)) {
516
+ if (!llama_state_save_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
486
517
  env->ReleaseStringUTFChars(path, path_chars);
487
518
  return -1;
488
519
  }
@@ -500,13 +531,13 @@ static inline jobject tokenProbsToMap(
500
531
  for (const auto &prob : probs) {
501
532
  auto probsForToken = createWritableArray(env);
502
533
  for (const auto &p : prob.probs) {
503
- std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx.get(), p.tok);
534
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
504
535
  auto probResult = createWriteableMap(env);
505
536
  putString(env, probResult, "tok_str", tokStr.c_str());
506
537
  putDouble(env, probResult, "prob", p.prob);
507
538
  pushMap(env, probsForToken, probResult);
508
539
  }
509
- std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx.get(), prob.tok);
540
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
510
541
  auto tokenResult = createWriteableMap(env);
511
542
  putString(env, tokenResult, "content", tokStr.c_str());
512
543
  putArray(env, tokenResult, "probs", probsForToken);
@@ -533,7 +564,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
533
564
  jfloat mirostat,
534
565
  jfloat mirostat_tau,
535
566
  jfloat mirostat_eta,
536
- jboolean penalize_nl,
537
567
  jint top_k,
538
568
  jfloat top_p,
539
569
  jfloat min_p,
@@ -546,7 +576,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
546
576
  jobjectArray logit_bias,
547
577
  jfloat dry_multiplier,
548
578
  jfloat dry_base,
549
- jint dry_allowed_length,
579
+ jint dry_allowed_length,
550
580
  jint dry_penalty_last_n,
551
581
  jobjectArray dry_sequence_breakers,
552
582
  jobject partial_completion_callback
@@ -556,7 +586,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
556
586
 
557
587
  llama->rewind();
558
588
 
559
- //llama_reset_timings(llama->ctx.get());
589
+ //llama_reset_timings(llama->ctx);
560
590
 
561
591
  llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
562
592
  llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
@@ -578,7 +608,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
578
608
  sparams.mirostat = mirostat;
579
609
  sparams.mirostat_tau = mirostat_tau;
580
610
  sparams.mirostat_eta = mirostat_eta;
581
- // sparams.penalize_nl = penalize_nl;
582
611
  sparams.top_k = top_k;
583
612
  sparams.top_p = top_p;
584
613
  sparams.min_p = min_p;
@@ -594,7 +623,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
594
623
 
595
624
  sparams.logit_bias.clear();
596
625
  if (ignore_eos) {
597
- sparams.logit_bias[llama_token_eos(llama->model.get())].bias = -INFINITY;
626
+ sparams.logit_bias[llama_vocab_eos(llama_model_get_vocab(llama->model))].bias = -INFINITY;
598
627
  }
599
628
 
600
629
  // dry break seq
@@ -613,7 +642,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
613
642
  sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
614
643
 
615
644
  // logit bias
616
- const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx.get()));
645
+ const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama->model));
617
646
  jsize logit_bias_len = env->GetArrayLength(logit_bias);
618
647
 
619
648
  for (jsize i = 0; i < logit_bias_len; i++) {
@@ -660,7 +689,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
660
689
  if (token_with_probs.tok == -1 || llama->incomplete) {
661
690
  continue;
662
691
  }
663
- const std::string token_text = common_token_to_piece(llama->ctx.get(), token_with_probs.tok);
692
+ const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
664
693
 
665
694
  size_t pos = std::min(sent_count, llama->generated_text.size());
666
695
 
@@ -695,7 +724,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
695
724
  putString(env, tokenResult, "token", to_send.c_str());
696
725
 
697
726
  if (llama->params.sampling.n_probs > 0) {
698
- const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx.get(), to_send, false);
727
+ const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
699
728
  size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
700
729
  size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
701
730
  if (probs_pos < probs_stop_pos) {
@@ -712,7 +741,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
712
741
  }
713
742
  }
714
743
 
715
- llama_perf_context_print(llama->ctx.get());
744
+ llama_perf_context_print(llama->ctx);
716
745
  llama->is_predicting = false;
717
746
 
718
747
  auto result = createWriteableMap(env);
@@ -727,7 +756,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
727
756
  putString(env, result, "stopping_word", llama->stopping_word.c_str());
728
757
  putInt(env, result, "tokens_cached", llama->n_past);
729
758
 
730
- const auto timings_token = llama_perf_context(llama -> ctx.get());
759
+ const auto timings_token = llama_perf_context(llama -> ctx);
731
760
 
732
761
  auto timingsResult = createWriteableMap(env);
733
762
  putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
@@ -771,7 +800,7 @@ Java_com_rnllama_LlamaContext_tokenize(
771
800
  const char *text_chars = env->GetStringUTFChars(text, nullptr);
772
801
 
773
802
  const std::vector<llama_token> toks = common_tokenize(
774
- llama->ctx.get(),
803
+ llama->ctx,
775
804
  text_chars,
776
805
  false
777
806
  );
@@ -798,7 +827,7 @@ Java_com_rnllama_LlamaContext_detokenize(
798
827
  toks.push_back(tokens_ptr[i]);
799
828
  }
800
829
 
801
- auto text = rnllama::tokens_to_str(llama->ctx.get(), toks.cbegin(), toks.cend());
830
+ auto text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
802
831
 
803
832
  env->ReleaseIntArrayElements(tokens, tokens_ptr, 0);
804
833
 
@@ -835,7 +864,7 @@ Java_com_rnllama_LlamaContext_embedding(
835
864
 
836
865
  llama->rewind();
837
866
 
838
- llama_perf_context_reset(llama->ctx.get());
867
+ llama_perf_context_reset(llama->ctx);
839
868
 
840
869
  llama->params.prompt = text_chars;
841
870
 
@@ -861,7 +890,7 @@ Java_com_rnllama_LlamaContext_embedding(
861
890
 
862
891
  auto promptTokens = createWritableArray(env);
863
892
  for (const auto &tok : llama->embd) {
864
- pushString(env, promptTokens, common_token_to_piece(llama->ctx.get(), tok).c_str());
893
+ pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
865
894
  }
866
895
  putArray(env, result, "prompt_tokens", promptTokens);
867
896
 
@@ -885,23 +914,64 @@ Java_com_rnllama_LlamaContext_bench(
885
914
  return env->NewStringUTF(result.c_str());
886
915
  }
887
916
 
917
+ JNIEXPORT jint JNICALL
918
+ Java_com_rnllama_LlamaContext_applyLoraAdapters(
919
+ JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
920
+ UNUSED(thiz);
921
+ auto llama = context_map[(long) context_ptr];
922
+
923
+ // lora_adapters: ReadableArray<ReadableMap>
924
+ std::vector<common_adapter_lora_info> lora_adapters;
925
+ int lora_adapters_size = readablearray::size(env, loraAdapters);
926
+ for (int i = 0; i < lora_adapters_size; i++) {
927
+ jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
928
+ jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
929
+ if (path != nullptr) {
930
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
931
+ env->ReleaseStringUTFChars(path, path_chars);
932
+ float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
933
+ common_adapter_lora_info la;
934
+ la.path = path_chars;
935
+ la.scale = scaled;
936
+ lora_adapters.push_back(la);
937
+ }
938
+ }
939
+ return llama->applyLoraAdapters(lora_adapters);
940
+ }
941
+
942
+ JNIEXPORT void JNICALL
943
+ Java_com_rnllama_LlamaContext_removeLoraAdapters(
944
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
945
+ UNUSED(env);
946
+ UNUSED(thiz);
947
+ auto llama = context_map[(long) context_ptr];
948
+ llama->removeLoraAdapters();
949
+ }
950
+
951
+ JNIEXPORT jobject JNICALL
952
+ Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
953
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
954
+ UNUSED(thiz);
955
+ auto llama = context_map[(long) context_ptr];
956
+ auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
957
+ auto result = createWritableArray(env);
958
+ for (common_adapter_lora_info &la : loaded_lora_adapters) {
959
+ auto map = createWriteableMap(env);
960
+ putString(env, map, "path", la.path.c_str());
961
+ putDouble(env, map, "scaled", la.scale);
962
+ pushMap(env, result, map);
963
+ }
964
+ return result;
965
+ }
966
+
888
967
  JNIEXPORT void JNICALL
889
968
  Java_com_rnllama_LlamaContext_freeContext(
890
969
  JNIEnv *env, jobject thiz, jlong context_ptr) {
891
970
  UNUSED(env);
892
971
  UNUSED(thiz);
893
972
  auto llama = context_map[(long) context_ptr];
894
- if (llama->model.get()) {
895
- llama_model_free(llama->model.get());
896
- }
897
- if (llama->ctx.get()) {
898
- llama_free(llama->ctx.get());
899
- }
900
- /*if (llama->ctx.get()-> != nullptr)
901
- {
902
- common_sampler_free(llama->ctx.get() -> _sampling);
903
- }*/
904
- context_map.erase((long) llama->ctx.get());
973
+ context_map.erase((long) llama->ctx);
974
+ delete llama;
905
975
  }
906
976
 
907
977
  JNIEXPORT void JNICALL
@@ -103,6 +103,21 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
103
103
  rnllama.bench(id, pp, tg, pl, nr, promise);
104
104
  }
105
105
 
106
+ @ReactMethod
107
+ public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
108
+ rnllama.applyLoraAdapters(id, loraAdapters, promise);
109
+ }
110
+
111
+ @ReactMethod
112
+ public void removeLoraAdapters(double id, final Promise promise) {
113
+ rnllama.removeLoraAdapters(id, promise);
114
+ }
115
+
116
+ @ReactMethod
117
+ public void getLoadedLoraAdapters(double id, final Promise promise) {
118
+ rnllama.getLoadedLoraAdapters(id, promise);
119
+ }
120
+
106
121
  @ReactMethod
107
122
  public void releaseContext(double id, Promise promise) {
108
123
  rnllama.releaseContext(id, promise);
@@ -104,6 +104,21 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
104
104
  rnllama.bench(id, pp, tg, pl, nr, promise);
105
105
  }
106
106
 
107
+ @ReactMethod
108
+ public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
109
+ rnllama.applyLoraAdapters(id, loraAdapters, promise);
110
+ }
111
+
112
+ @ReactMethod
113
+ public void removeLoraAdapters(double id, final Promise promise) {
114
+ rnllama.removeLoraAdapters(id, promise);
115
+ }
116
+
117
+ @ReactMethod
118
+ public void getLoadedLoraAdapters(double id, final Promise promise) {
119
+ rnllama.getLoadedLoraAdapters(id, promise);
120
+ }
121
+
107
122
  @ReactMethod
108
123
  public void releaseContext(double id, Promise promise) {
109
124
  rnllama.releaseContext(id, promise);