cui-llama.rn 1.3.6 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/README.md +22 -1
  2. package/android/src/main/CMakeLists.txt +25 -26
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
  5. package/android/src/main/jni-utils.h +94 -0
  6. package/android/src/main/jni.cpp +132 -62
  7. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
  8. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
  9. package/cpp/common.cpp +1982 -1982
  10. package/cpp/common.h +665 -664
  11. package/cpp/ggml-cpu.c +14122 -14122
  12. package/cpp/ggml-cpu.cpp +627 -627
  13. package/cpp/ggml-metal-impl.h +288 -0
  14. package/cpp/ggml-opt.cpp +854 -0
  15. package/cpp/ggml-opt.h +216 -0
  16. package/cpp/llama-mmap.cpp +589 -589
  17. package/cpp/llama.cpp +12547 -12544
  18. package/cpp/rn-llama.hpp +117 -116
  19. package/cpp/sgemm.h +14 -14
  20. package/ios/RNLlama.mm +47 -0
  21. package/ios/RNLlamaContext.h +3 -1
  22. package/ios/RNLlamaContext.mm +71 -14
  23. package/jest/mock.js +15 -3
  24. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  25. package/lib/commonjs/index.js +33 -37
  26. package/lib/commonjs/index.js.map +1 -1
  27. package/lib/module/NativeRNLlama.js.map +1 -1
  28. package/lib/module/index.js +31 -35
  29. package/lib/module/index.js.map +1 -1
  30. package/lib/typescript/NativeRNLlama.d.ts +26 -6
  31. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  32. package/lib/typescript/index.d.ts +21 -36
  33. package/lib/typescript/index.d.ts.map +1 -1
  34. package/llama-rn.podspec +4 -18
  35. package/package.json +2 -3
  36. package/src/NativeRNLlama.ts +32 -13
  37. package/src/index.ts +52 -47
  38. package/cpp/llama.cpp.rej +0 -23
package/cpp/rn-llama.hpp CHANGED
@@ -5,65 +5,35 @@
5
5
  #include <iostream>
6
6
  #include "common.h"
7
7
  #include "ggml.h"
8
+ #include "gguf.h"
8
9
  #include "llama.h"
9
10
  #include "llama-impl.h"
10
11
  #include "sampling.h"
11
- #include "llama-cpp.h"
12
+ #if defined(__ANDROID__)
13
+ #include <android/log.h>
14
+ #endif
12
15
 
13
16
  namespace rnllama {
14
17
 
15
- static std::string lm_gguf_data_to_str(enum lm_gguf_type type, const void * data, int i) {
16
- switch (type) {
17
- case LM_GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
18
- case LM_GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
19
- case LM_GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
20
- case LM_GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
21
- case LM_GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
22
- case LM_GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
23
- case LM_GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
24
- case LM_GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
25
- case LM_GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
26
- case LM_GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
27
- case LM_GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
28
- default: return "unknown type: " + std::to_string(type);
29
- }
30
- }
31
-
32
- static std::string lm_gguf_kv_to_str(const struct lm_gguf_context * ctx_gguf, int i) {
33
- const enum lm_gguf_type type = lm_gguf_get_kv_type(ctx_gguf, i);
18
+ const std::vector<lm_ggml_type> kv_cache_types = {
19
+ LM_GGML_TYPE_F32,
20
+ LM_GGML_TYPE_F16,
21
+ LM_GGML_TYPE_BF16,
22
+ LM_GGML_TYPE_Q8_0,
23
+ LM_GGML_TYPE_Q4_0,
24
+ LM_GGML_TYPE_Q4_1,
25
+ LM_GGML_TYPE_IQ4_NL,
26
+ LM_GGML_TYPE_Q5_0,
27
+ LM_GGML_TYPE_Q5_1,
28
+ };
34
29
 
35
- switch (type) {
36
- case LM_GGUF_TYPE_STRING:
37
- return lm_gguf_get_val_str(ctx_gguf, i);
38
- case LM_GGUF_TYPE_ARRAY:
39
- {
40
- const enum lm_gguf_type arr_type = lm_gguf_get_arr_type(ctx_gguf, i);
41
- int arr_n = lm_gguf_get_arr_n(ctx_gguf, i);
42
- const void * data = lm_gguf_get_arr_data(ctx_gguf, i);
43
- std::stringstream ss;
44
- ss << "[";
45
- for (int j = 0; j < arr_n; j++) {
46
- if (arr_type == LM_GGUF_TYPE_STRING) {
47
- std::string val = lm_gguf_get_arr_str(ctx_gguf, i, j);
48
- // escape quotes
49
- replace_all(val, "\\", "\\\\");
50
- replace_all(val, "\"", "\\\"");
51
- ss << '"' << val << '"';
52
- } else if (arr_type == LM_GGUF_TYPE_ARRAY) {
53
- ss << "???";
54
- } else {
55
- ss << lm_gguf_data_to_str(arr_type, data, j);
56
- }
57
- if (j < arr_n - 1) {
58
- ss << ", ";
59
- }
60
- }
61
- ss << "]";
62
- return ss.str();
63
- }
64
- default:
65
- return lm_gguf_data_to_str(type, lm_gguf_get_val_data(ctx_gguf, i), 0);
30
+ static lm_ggml_type kv_cache_type_from_str(const std::string & s) {
31
+ for (const auto & type : kv_cache_types) {
32
+ if (lm_ggml_type_name(type) == s) {
33
+ return type;
34
+ }
66
35
  }
36
+ throw std::runtime_error("Unsupported cache type: " + s);
67
37
  }
68
38
 
69
39
  static void llama_batch_clear(llama_batch *batch) {
@@ -87,16 +57,32 @@ static void llama_batch_add(llama_batch *batch, llama_token id, llama_pos pos, s
87
57
  static void log(const char *level, const char *function, int line,
88
58
  const char *format, ...)
89
59
  {
90
- printf("[%s] %s:%d ", level, function, line);
91
-
92
60
  va_list args;
93
- va_start(args, format);
94
- vprintf(format, args);
95
- va_end(args);
96
-
97
- printf("\n");
61
+ #if defined(__ANDROID__)
62
+ char prefix[256];
63
+ snprintf(prefix, sizeof(prefix), "%s:%d %s", function, line, format);
64
+
65
+ va_start(args, format);
66
+ android_LogPriority priority;
67
+ if (strcmp(level, "ERROR") == 0) {
68
+ priority = ANDROID_LOG_ERROR;
69
+ } else if (strcmp(level, "WARNING") == 0) {
70
+ priority = ANDROID_LOG_WARN;
71
+ } else if (strcmp(level, "INFO") == 0) {
72
+ priority = ANDROID_LOG_INFO;
73
+ } else {
74
+ priority = ANDROID_LOG_DEBUG;
75
+ }
76
+ __android_log_vprint(priority, "RNLlama", prefix, args);
77
+ va_end(args);
78
+ #else
79
+ printf("[%s] %s:%d ", level, function, line);
80
+ va_start(args, format);
81
+ vprintf(format, args);
82
+ va_end(args);
83
+ printf("\n");
84
+ #endif
98
85
  }
99
-
100
86
  static bool rnllama_verbose = false;
101
87
 
102
88
  #if RNLLAMA_VERBOSE != 1
@@ -215,11 +201,13 @@ struct llama_rn_context
215
201
 
216
202
  common_params params;
217
203
 
218
- llama_model_ptr model = nullptr;
204
+ common_init_result llama_init;
205
+
206
+ llama_model *model = nullptr;
219
207
  float loading_progress = 0;
220
208
  bool is_load_interrupted = false;
221
209
 
222
- llama_context_ptr ctx = nullptr;
210
+ llama_context *ctx = nullptr;
223
211
  common_sampler *ctx_sampling = nullptr;
224
212
 
225
213
  int n_ctx;
@@ -231,18 +219,10 @@ struct llama_rn_context
231
219
  std::string stopping_word;
232
220
  bool incomplete = false;
233
221
 
222
+ std::vector<common_lora_adapter_info> lora;
223
+
234
224
  ~llama_rn_context()
235
225
  {
236
- if (ctx)
237
- {
238
- llama_free(ctx.get());
239
- ctx = nullptr;
240
- }
241
- if (model)
242
- {
243
- llama_model_free(model.get());
244
- model = nullptr;
245
- }
246
226
  if (ctx_sampling != nullptr)
247
227
  {
248
228
  common_sampler_free(ctx_sampling);
@@ -274,37 +254,33 @@ struct llama_rn_context
274
254
  if (ctx_sampling != nullptr) {
275
255
  common_sampler_free(ctx_sampling);
276
256
  }
277
- ctx_sampling = common_sampler_init(model.get(), params.sampling);
257
+ ctx_sampling = common_sampler_init(model, params.sampling);
278
258
  return ctx_sampling != nullptr;
279
259
  }
280
260
 
281
261
  bool loadModel(common_params &params_)
282
262
  {
283
263
  params = params_;
284
- common_init_result result = common_init_from_params(params);
285
- model = std::move(result.model);
286
- ctx = std::move(result.context);
264
+ llama_init = common_init_from_params(params);
265
+ model = llama_init.model.get();
266
+ ctx = llama_init.context.get();
287
267
  if (model == nullptr)
288
268
  {
289
269
  LOG_ERROR("unable to load model: %s", params_.model.c_str());
290
270
  return false;
291
271
  }
292
- LOG_VERBOSE("getting n_ctx");
293
- n_ctx = llama_n_ctx(ctx.get());
272
+ n_ctx = llama_n_ctx(ctx);
273
+
274
+ // We can uncomment for debugging or after this fix: https://github.com/ggerganov/llama.cpp/pull/11101
275
+ // LOG_INFO("%s\n", common_params_get_system_info(params).c_str());
276
+
294
277
  return true;
295
278
  }
296
279
 
297
280
  bool validateModelChatTemplate() const {
298
- std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
299
- std::string template_key = "tokenizer.chat_template";
300
- int32_t res = llama_model_meta_val_str(model.get(), template_key.c_str(), model_template.data(), model_template.size());
301
- if (res >= 0) {
302
- llama_chat_message chat[] = {{"user", "test"}};
303
- std::string tmpl = std::string(model_template.data(), model_template.size());
304
- int32_t chat_res = llama_chat_apply_template(model.get(), tmpl.c_str(), chat, 1, true, nullptr, 0);
305
- return chat_res > 0;
306
- }
307
- return res > 0;
281
+ llama_chat_message chat[] = {{"user", "test"}};
282
+ int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
283
+ return chat_res > 0;
308
284
  }
309
285
 
310
286
  void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
@@ -331,7 +307,7 @@ struct llama_rn_context
331
307
 
332
308
  void loadPrompt()
333
309
  {
334
- std::vector<llama_token> prompt_tokens = ::common_tokenize(model.get(), params.prompt, true, true);
310
+ std::vector<llama_token> prompt_tokens = ::common_tokenize(model, params.prompt, true, true);
335
311
  num_prompt_tokens = prompt_tokens.size();
336
312
 
337
313
  // LOG tokens
@@ -359,7 +335,7 @@ struct llama_rn_context
359
335
 
360
336
  // do Context Shift , may be buggy! TODO: Verify functionality
361
337
  if(!params.embedding){
362
- purge_missing_tokens(ctx.get(), embd, prompt_tokens, params.n_predict, params.n_ctx);
338
+ purge_missing_tokens(ctx, embd, prompt_tokens, params.n_predict, params.n_ctx);
363
339
  }
364
340
 
365
341
  // push the prompt into the sampling context (do not apply grammar)
@@ -380,7 +356,7 @@ struct llama_rn_context
380
356
  }
381
357
 
382
358
  // since #3228 we now have to manually manage the KV cache
383
- llama_kv_cache_seq_rm(ctx.get(), 0, n_past, -1);
359
+ llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
384
360
 
385
361
  LOG_VERBOSE("prompt ingested, n_past: %d, cached: %s, to_eval: %s",
386
362
  n_past,
@@ -395,7 +371,7 @@ struct llama_rn_context
395
371
  {
396
372
  // number of tokens to keep when resetting context
397
373
  n_remain = params.n_predict;
398
- llama_perf_context_reset(ctx.get());
374
+ llama_perf_context_reset(ctx);
399
375
  is_predicting = true;
400
376
  }
401
377
 
@@ -411,8 +387,8 @@ struct llama_rn_context
411
387
  const int n_left = n_past - params.n_keep - 1;
412
388
  const int n_discard = n_left/2;
413
389
 
414
- llama_kv_cache_seq_rm (ctx.get(), 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
415
- llama_kv_cache_seq_add(ctx.get(), 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
390
+ llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
391
+ llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
416
392
 
417
393
  for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
418
394
  {
@@ -438,14 +414,14 @@ struct llama_rn_context
438
414
  {
439
415
  n_eval = params.n_batch;
440
416
  }
441
- if (llama_decode(ctx.get(), llama_batch_get_one(&embd[n_past], n_eval)))
417
+ if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval)))
442
418
  {
443
419
 
444
420
  LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s",
445
421
  n_eval,
446
422
  n_past,
447
423
  params.cpuparams.n_threads,
448
- tokens_to_str(ctx.get(), embd.cbegin() + n_past, embd.cend()).c_str()
424
+ tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()).c_str()
449
425
  );
450
426
  has_next_token = false;
451
427
  return result;
@@ -463,23 +439,23 @@ struct llama_rn_context
463
439
  if (params.n_predict == 0)
464
440
  {
465
441
  has_next_token = false;
466
- result.tok = llama_token_eos(model.get());
442
+ result.tok = llama_token_eos(model);
467
443
  return result;
468
444
  }
469
445
 
470
446
  {
471
447
  // out of user input, sample next token
472
448
  std::vector<llama_token_data> candidates;
473
- candidates.reserve(llama_n_vocab(model.get()));
449
+ candidates.reserve(llama_n_vocab(model));
474
450
 
475
- result.tok = common_sampler_sample(ctx_sampling, ctx.get(), -1);
451
+ result.tok = common_sampler_sample(ctx_sampling, ctx, -1);
476
452
 
477
453
  llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling);
478
454
 
479
455
  const int32_t n_probs = params.sampling.n_probs;
480
456
 
481
457
  // deprecated
482
- /*if (params.sparams.temp <= 0 && n_probs > 0)
458
+ /*if (params.sampling.temp <= 0 && n_probs > 0)
483
459
  {
484
460
  // For llama_sample_token_greedy we need to sort candidates
485
461
  llama_sampler_init_softmax();
@@ -503,7 +479,7 @@ struct llama_rn_context
503
479
  // decrement remaining sampling budget
504
480
  --n_remain;
505
481
 
506
- if (!embd.empty() && embd.back() == llama_token_eos(model.get()))
482
+ if (!embd.empty() && embd.back() == llama_token_eos(model))
507
483
  {
508
484
  // stopping_word = llama_token_to_piece(ctx, embd.back());
509
485
  has_next_token = false;
@@ -552,7 +528,7 @@ struct llama_rn_context
552
528
  {
553
529
  const completion_token_output token_with_probs = nextToken();
554
530
 
555
- const std::string token_text = token_with_probs.tok == -1 ? "" : common_token_to_piece(ctx.get(), token_with_probs.tok);
531
+ const std::string token_text = token_with_probs.tok == -1 ? "" : common_token_to_piece(ctx, token_with_probs.tok);
556
532
  generated_text += token_text;
557
533
 
558
534
  if (params.sampling.n_probs > 0)
@@ -608,7 +584,7 @@ struct llama_rn_context
608
584
 
609
585
  std::vector<float> getEmbedding(common_params &embd_params)
610
586
  {
611
- static const int n_embd = llama_n_embd(llama_get_model(ctx.get()));
587
+ static const int n_embd = llama_n_embd(llama_get_model(ctx));
612
588
  if (!embd_params.embedding)
613
589
  {
614
590
  LOG_WARNING("embedding disabled, embedding: %s", embd_params.embedding);
@@ -616,12 +592,12 @@ struct llama_rn_context
616
592
  }
617
593
  float *data;
618
594
 
619
- const enum llama_pooling_type pooling_type = llama_pooling_type(ctx.get());
595
+ const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
620
596
  printf("pooling_type: %d\n", pooling_type);
621
597
  if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
622
- data = llama_get_embeddings(ctx.get());
598
+ data = llama_get_embeddings(ctx);
623
599
  } else {
624
- data = llama_get_embeddings_seq(ctx.get(), 0);
600
+ data = llama_get_embeddings_seq(ctx, 0);
625
601
  }
626
602
 
627
603
  if (!data) {
@@ -649,7 +625,11 @@ struct llama_rn_context
649
625
  double tg_std = 0;
650
626
 
651
627
  // TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30)
652
- llama_batch batch = llama_batch_init(512, 0, 1);
628
+ llama_batch batch = llama_batch_init(
629
+ std::min(pp, params.n_ubatch), // max n_tokens is limited by n_ubatch
630
+ 0, // No embeddings
631
+ 1 // Single sequence
632
+ );
653
633
 
654
634
  for (int i = 0; i < nr; i++)
655
635
  {
@@ -663,15 +643,15 @@ struct llama_rn_context
663
643
  }
664
644
  batch.logits[batch.n_tokens - 1] = 1; // true
665
645
 
666
- llama_kv_cache_clear(ctx.get());
646
+ llama_kv_cache_clear(ctx);
667
647
 
668
648
  const int64_t t_pp_start = llama_time_us();
669
- if (llama_decode(ctx.get(), batch) != 0)
649
+ if (llama_decode(ctx, batch) != 0)
670
650
  {
671
651
  LOG_ERROR("llama_decode() failed during prompt", "");
672
652
  }
673
653
  const int64_t t_pp_end = llama_time_us();
674
- llama_kv_cache_clear(ctx.get());
654
+ llama_kv_cache_clear(ctx);
675
655
 
676
656
  if (is_interrupted) break;
677
657
 
@@ -686,7 +666,7 @@ struct llama_rn_context
686
666
  llama_batch_add(&batch, 0, i, {j}, true);
687
667
  }
688
668
 
689
- if (llama_decode(ctx.get(), batch) != 0)
669
+ if (llama_decode(ctx, batch) != 0)
690
670
  {
691
671
  LOG_ERROR("llama_decode() failed during text generation", "");
692
672
  }
@@ -695,7 +675,7 @@ struct llama_rn_context
695
675
 
696
676
  const int64_t t_tg_end = llama_time_us();
697
677
 
698
- llama_kv_cache_clear(ctx.get());
678
+ llama_kv_cache_clear(ctx);
699
679
 
700
680
  const double t_pp = (t_pp_end - t_pp_start) / 1000000.0;
701
681
  const double t_tg = (t_tg_end - t_tg_start) / 1000000.0;
@@ -721,14 +701,14 @@ struct llama_rn_context
721
701
  tg_std = 0;
722
702
  }
723
703
 
724
- if (is_interrupted) llama_kv_cache_clear(ctx.get());
704
+ if (is_interrupted) llama_kv_cache_clear(ctx);
725
705
  is_predicting = false;
726
706
 
727
707
  char model_desc[128];
728
- llama_model_desc(model.get(), model_desc, sizeof(model_desc));
708
+ llama_model_desc(model, model_desc, sizeof(model_desc));
729
709
  return std::string("[\"") + model_desc + std::string("\",") +
730
- std::to_string(llama_model_size(model.get())) + std::string(",") +
731
- std::to_string(llama_model_n_params(model.get())) + std::string(",") +
710
+ std::to_string(llama_model_size(model)) + std::string(",") +
711
+ std::to_string(llama_model_n_params(model)) + std::string(",") +
732
712
  std::to_string(pp_avg) + std::string(",") +
733
713
  std::to_string(pp_std) + std::string(",") +
734
714
  std::to_string(tg_avg) + std::string(",") +
@@ -736,7 +716,27 @@ struct llama_rn_context
736
716
  std::string("]");
737
717
  }
738
718
 
739
-
719
+ int applyLoraAdapters(std::vector<common_lora_adapter_info> lora) {
720
+ for (auto &la : lora) {
721
+ la.ptr = llama_lora_adapter_init(model, la.path.c_str());
722
+ if (la.ptr == nullptr) {
723
+ LOG_ERROR("failed to apply lora adapter '%s'\n", la.path.c_str());
724
+ return -1;
725
+ }
726
+ }
727
+ this->lora = lora;
728
+ common_lora_adapters_apply(ctx, lora);
729
+ return 0;
730
+ }
731
+
732
+ void removeLoraAdapters() {
733
+ this->lora.clear();
734
+ common_lora_adapters_apply(ctx, this->lora); // apply empty list
735
+ }
736
+
737
+ std::vector<common_lora_adapter_info> getLoadedLoraAdapters() {
738
+ return this->lora;
739
+ }
740
740
  // Context Shifting from KoboldCpp <https://github.com/LostRuins/koboldcpp>
741
741
  // Implementation obtained with special permission from @concedo
742
742
 
@@ -899,6 +899,7 @@ void purge_missing_tokens(llama_context * ctx, std::vector<int> &current_context
899
899
  }
900
900
 
901
901
  // End Context Shifting
902
+
902
903
  };
903
904
 
904
905
  }
package/cpp/sgemm.h CHANGED
@@ -1,14 +1,14 @@
1
- #pragma once
2
- #include <stdint.h>
3
- #include <stdbool.h>
4
- #ifdef __cplusplus
5
- extern "C" {
6
- #endif
7
-
8
- bool llamafile_sgemm(const struct lm_lm_ggml_compute_params * params, int64_t, int64_t, int64_t,
9
- const void *, int64_t, const void *, int64_t, void *, int64_t,
10
- int, int, int);
11
-
12
- #ifdef __cplusplus
13
- }
14
- #endif
1
+ #pragma once
2
+ #include <stdint.h>
3
+ #include <stdbool.h>
4
+ #ifdef __cplusplus
5
+ extern "C" {
6
+ #endif
7
+
8
+ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t, int64_t, int64_t,
9
+ const void *, int64_t, const void *, int64_t, void *, int64_t,
10
+ int, int, int);
11
+
12
+ #ifdef __cplusplus
13
+ }
14
+ #endif
package/ios/RNLlama.mm CHANGED
@@ -271,6 +271,53 @@ RCT_EXPORT_METHOD(bench:(double)contextId
271
271
  }
272
272
  }
273
273
 
274
+ RCT_EXPORT_METHOD(applyLoraAdapters:(double)contextId
275
+ withLoraAdapters:(NSArray *)loraAdapters
276
+ withResolver:(RCTPromiseResolveBlock)resolve
277
+ withRejecter:(RCTPromiseRejectBlock)reject)
278
+ {
279
+ RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]];
280
+ if (context == nil) {
281
+ reject(@"llama_error", @"Context not found", nil);
282
+ return;
283
+ }
284
+ if ([context isPredicting]) {
285
+ reject(@"llama_error", @"Context is busy", nil);
286
+ return;
287
+ }
288
+ [context applyLoraAdapters:loraAdapters];
289
+ resolve(nil);
290
+ }
291
+
292
+ RCT_EXPORT_METHOD(removeLoraAdapters:(double)contextId
293
+ withResolver:(RCTPromiseResolveBlock)resolve
294
+ withRejecter:(RCTPromiseRejectBlock)reject)
295
+ {
296
+ RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]];
297
+ if (context == nil) {
298
+ reject(@"llama_error", @"Context not found", nil);
299
+ return;
300
+ }
301
+ if ([context isPredicting]) {
302
+ reject(@"llama_error", @"Context is busy", nil);
303
+ return;
304
+ }
305
+ [context removeLoraAdapters];
306
+ resolve(nil);
307
+ }
308
+
309
+ RCT_EXPORT_METHOD(getLoadedLoraAdapters:(double)contextId
310
+ withResolver:(RCTPromiseResolveBlock)resolve
311
+ withRejecter:(RCTPromiseRejectBlock)reject)
312
+ {
313
+ RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]];
314
+ if (context == nil) {
315
+ reject(@"llama_error", @"Context not found", nil);
316
+ return;
317
+ }
318
+ resolve([context getLoadedLoraAdapters]);
319
+ }
320
+
274
321
  RCT_EXPORT_METHOD(releaseContext:(double)contextId
275
322
  withResolver:(RCTPromiseResolveBlock)resolve
276
323
  withRejecter:(RCTPromiseRejectBlock)reject)
@@ -33,7 +33,9 @@
33
33
  - (NSDictionary *)loadSession:(NSString *)path;
34
34
  - (int)saveSession:(NSString *)path size:(int)size;
35
35
  - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr;
36
-
36
+ - (void)applyLoraAdapters:(NSArray *)loraAdapters;
37
+ - (void)removeLoraAdapters;
38
+ - (NSArray *)getLoadedLoraAdapters;
37
39
  - (void)invalidate;
38
40
 
39
41
  @end