cui-llama.rn 1.2.4 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/README.md +3 -4
  2. package/android/src/main/CMakeLists.txt +21 -5
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -30
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +222 -36
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/common.cpp +1682 -2122
  9. package/cpp/common.h +600 -594
  10. package/cpp/ggml-aarch64.c +129 -3209
  11. package/cpp/ggml-aarch64.h +19 -39
  12. package/cpp/ggml-alloc.c +1040 -1040
  13. package/cpp/ggml-alloc.h +76 -76
  14. package/cpp/ggml-backend-impl.h +216 -227
  15. package/cpp/ggml-backend-reg.cpp +195 -0
  16. package/cpp/ggml-backend.cpp +1997 -2625
  17. package/cpp/ggml-backend.h +328 -326
  18. package/cpp/ggml-common.h +1853 -1853
  19. package/cpp/ggml-cpp.h +38 -0
  20. package/cpp/ggml-cpu-aarch64.c +3560 -0
  21. package/cpp/ggml-cpu-aarch64.h +30 -0
  22. package/cpp/ggml-cpu-impl.h +371 -614
  23. package/cpp/ggml-cpu-quants.c +10822 -0
  24. package/cpp/ggml-cpu-quants.h +63 -0
  25. package/cpp/ggml-cpu.c +13975 -0
  26. package/cpp/ggml-cpu.cpp +663 -0
  27. package/cpp/ggml-cpu.h +177 -0
  28. package/cpp/ggml-impl.h +550 -209
  29. package/cpp/ggml-metal.h +66 -66
  30. package/cpp/ggml-metal.m +4294 -3819
  31. package/cpp/ggml-quants.c +5247 -15752
  32. package/cpp/ggml-quants.h +100 -147
  33. package/cpp/ggml-threading.cpp +12 -0
  34. package/cpp/ggml-threading.h +12 -0
  35. package/cpp/ggml.c +8180 -23464
  36. package/cpp/ggml.h +2411 -2562
  37. package/cpp/llama-grammar.cpp +1138 -1138
  38. package/cpp/llama-grammar.h +144 -144
  39. package/cpp/llama-impl.h +181 -181
  40. package/cpp/llama-sampling.cpp +2348 -2194
  41. package/cpp/llama-sampling.h +48 -30
  42. package/cpp/llama-vocab.cpp +1984 -1968
  43. package/cpp/llama-vocab.h +170 -165
  44. package/cpp/llama.cpp +22132 -21969
  45. package/cpp/llama.h +1253 -1253
  46. package/cpp/log.cpp +401 -401
  47. package/cpp/log.h +121 -121
  48. package/cpp/rn-llama.hpp +83 -19
  49. package/cpp/sampling.cpp +466 -458
  50. package/cpp/sgemm.cpp +1884 -1219
  51. package/ios/RNLlama.mm +43 -20
  52. package/ios/RNLlamaContext.h +9 -3
  53. package/ios/RNLlamaContext.mm +133 -33
  54. package/jest/mock.js +0 -1
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  56. package/lib/commonjs/index.js +52 -15
  57. package/lib/commonjs/index.js.map +1 -1
  58. package/lib/module/NativeRNLlama.js.map +1 -1
  59. package/lib/module/index.js +51 -15
  60. package/lib/module/index.js.map +1 -1
  61. package/lib/typescript/NativeRNLlama.d.ts +29 -6
  62. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  63. package/lib/typescript/index.d.ts +12 -5
  64. package/lib/typescript/index.d.ts.map +1 -1
  65. package/package.json +1 -1
  66. package/src/NativeRNLlama.ts +41 -7
  67. package/src/index.ts +82 -27
  68. package/cpp/json-schema-to-grammar.cpp +0 -1045
  69. package/cpp/json-schema-to-grammar.h +0 -8
  70. package/cpp/json.hpp +0 -24766
@@ -4,13 +4,15 @@
4
4
  #include <android/log.h>
5
5
  #include <cstdlib>
6
6
  #include <ctime>
7
+ #include <ctime>
7
8
  #include <sys/sysinfo.h>
8
9
  #include <string>
9
10
  #include <thread>
10
11
  #include <unordered_map>
11
12
  #include "llama.h"
12
- #include "rn-llama.hpp"
13
+ #include "llama-impl.h"
13
14
  #include "ggml.h"
15
+ #include "rn-llama.hpp"
14
16
 
15
17
  #define UNUSED(x) (void)(x)
16
18
  #define TAG "RNLLAMA_ANDROID_JNI"
@@ -22,6 +24,13 @@ static inline int min(int a, int b) {
22
24
  return (a < b) ? a : b;
23
25
  }
24
26
 
27
+ static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
28
+ if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
29
+ else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
30
+ else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
31
+ else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
32
+ }
33
+
25
34
  extern "C" {
26
35
 
27
36
  // Method to create WritableMap
@@ -107,6 +116,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
107
116
  env->CallVoidMethod(arr, pushDoubleMethod, value);
108
117
  }
109
118
 
119
+ // Method to push string into WritableArray
120
+ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
121
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
122
+ jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
123
+
124
+ jstring jValue = env->NewStringUTF(value);
125
+ env->CallVoidMethod(arr, pushStringMethod, jValue);
126
+ }
127
+
110
128
  // Method to push WritableMap into WritableArray
111
129
  static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
112
130
  jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
@@ -125,6 +143,77 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
125
143
  env->CallVoidMethod(map, putArrayMethod, jKey, value);
126
144
  }
127
145
 
146
+ JNIEXPORT jobject JNICALL
147
+ Java_com_rnllama_LlamaContext_modelInfo(
148
+ JNIEnv *env,
149
+ jobject thiz,
150
+ jstring model_path_str,
151
+ jobjectArray skip
152
+ ) {
153
+ UNUSED(thiz);
154
+
155
+ const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
156
+
157
+ std::vector<std::string> skip_vec;
158
+ int skip_len = env->GetArrayLength(skip);
159
+ for (int i = 0; i < skip_len; i++) {
160
+ jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
161
+ const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
162
+ skip_vec.push_back(skip_chars);
163
+ env->ReleaseStringUTFChars(skip_str, skip_chars);
164
+ }
165
+
166
+ struct lm_gguf_init_params params = {
167
+ /*.no_alloc = */ false,
168
+ /*.ctx = */ NULL,
169
+ };
170
+ struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
171
+
172
+ if (!ctx) {
173
+ LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
174
+ return nullptr;
175
+ }
176
+
177
+ auto info = createWriteableMap(env);
178
+ putInt(env, info, "version", lm_gguf_get_version(ctx));
179
+ putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
180
+ putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
181
+ {
182
+ const int n_kv = lm_gguf_get_n_kv(ctx);
183
+
184
+ for (int i = 0; i < n_kv; ++i) {
185
+ const char * key = lm_gguf_get_key(ctx, i);
186
+
187
+ bool skipped = false;
188
+ if (skip_len > 0) {
189
+ for (int j = 0; j < skip_len; j++) {
190
+ if (skip_vec[j] == key) {
191
+ skipped = true;
192
+ break;
193
+ }
194
+ }
195
+ }
196
+
197
+ if (skipped) {
198
+ continue;
199
+ }
200
+
201
+ const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
202
+ putString(env, info, key, value.c_str());
203
+ }
204
+ }
205
+
206
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
207
+ lm_gguf_free(ctx);
208
+
209
+ return reinterpret_cast<jobject>(info);
210
+ }
211
+
212
+ struct callback_context {
213
+ JNIEnv *env;
214
+ rnllama::llama_rn_context *llama;
215
+ jobject callback;
216
+ };
128
217
 
129
218
  std::unordered_map<long, rnllama::llama_rn_context *> context_map;
130
219
 
@@ -141,10 +230,14 @@ Java_com_rnllama_LlamaContext_initContext(
141
230
  jobject thiz,
142
231
  jstring model_path_str,
143
232
  jboolean embedding,
233
+ jint embd_normalize,
144
234
  jint n_ctx,
145
235
  jint n_batch,
146
236
  jint n_threads,
147
237
  jint n_gpu_layers, // TODO: Support this
238
+ jboolean flash_attn,
239
+ jstring cache_type_k,
240
+ jstring cache_type_v,
148
241
  jboolean use_mlock,
149
242
  jboolean use_mmap,
150
243
  jboolean vocab_only,
@@ -152,7 +245,8 @@ Java_com_rnllama_LlamaContext_initContext(
152
245
  jfloat lora_scaled,
153
246
  jfloat rope_freq_base,
154
247
  jfloat rope_freq_scale,
155
- jobject javaLlamaContext
248
+ jint pooling_type,
249
+ jobject load_progress_callback
156
250
  ) {
157
251
  UNUSED(thiz);
158
252
 
@@ -166,64 +260,109 @@ Java_com_rnllama_LlamaContext_initContext(
166
260
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
167
261
  defaultParams.model = model_path_chars;
168
262
 
169
- defaultParams.embedding = embedding;
170
-
171
263
  defaultParams.n_ctx = n_ctx;
172
264
  defaultParams.n_batch = n_batch;
173
265
 
266
+ if (pooling_type != -1) {
267
+ defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
268
+ }
269
+
270
+ defaultParams.embedding = embedding;
271
+ if (embd_normalize != -1) {
272
+ defaultParams.embd_normalize = embd_normalize;
273
+ }
274
+ if (embedding) {
275
+ // For non-causal models, batch size must be equal to ubatch size
276
+ defaultParams.n_ubatch = defaultParams.n_batch;
277
+ }
278
+
174
279
  int max_threads = std::thread::hardware_concurrency();
175
280
  // Use 2 threads by default on 4-core devices, 4 threads on more cores
176
281
  int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
177
282
  defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
178
283
 
179
284
  defaultParams.n_gpu_layers = n_gpu_layers;
180
-
285
+ defaultParams.flash_attn = flash_attn;
286
+
287
+ const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
+ const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
+ defaultParams.cache_type_k = cache_type_k_chars;
290
+ defaultParams.cache_type_v = cache_type_v_chars;
291
+
181
292
  defaultParams.use_mlock = use_mlock;
182
293
  defaultParams.use_mmap = use_mmap;
183
294
 
184
295
  const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
185
296
  if (lora_chars != nullptr && lora_chars[0] != '\0') {
186
297
  defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
187
- defaultParams.use_mmap = false;
188
298
  }
189
299
 
190
300
  defaultParams.rope_freq_base = rope_freq_base;
191
301
  defaultParams.rope_freq_scale = rope_freq_scale;
192
302
 
193
- // progress callback when loading
194
- jclass llamaContextClass = env->GetObjectClass(javaLlamaContext);
195
- jmethodID sendProgressMethod = env->GetMethodID(llamaContextClass, "emitModelProgressUpdate", "(I)V");
196
-
197
- CallbackContext callbackctx = {env, javaLlamaContext, sendProgressMethod, 0};
198
-
199
- defaultParams.progress_callback_user_data = &callbackctx;
200
- defaultParams.progress_callback = [](float progress, void * ctx) {
201
- unsigned percentage = (unsigned) (100 * progress);
202
- CallbackContext * cbctx = static_cast<CallbackContext*>(ctx);
203
- // reduce call frequency by only calling method when value changes
204
- if (percentage <= cbctx->current) return true;
205
- cbctx->current = percentage;
206
- cbctx->env->CallVoidMethod(cbctx->thiz, cbctx->sendProgressMethod, percentage);
207
- return true;
208
- };
209
-
210
-
211
303
  auto llama = new rnllama::llama_rn_context();
304
+ llama->is_load_interrupted = false;
305
+ llama->loading_progress = 0;
306
+
307
+ if (load_progress_callback != nullptr) {
308
+ defaultParams.progress_callback = [](float progress, void * user_data) {
309
+ callback_context *cb_ctx = (callback_context *)user_data;
310
+ JNIEnv *env = cb_ctx->env;
311
+ auto llama = cb_ctx->llama;
312
+ jobject callback = cb_ctx->callback;
313
+ int percentage = (int) (100 * progress);
314
+ if (percentage > llama->loading_progress) {
315
+ llama->loading_progress = percentage;
316
+ jclass callback_class = env->GetObjectClass(callback);
317
+ jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
318
+ env->CallVoidMethod(callback, onLoadProgress, percentage);
319
+ }
320
+ return !llama->is_load_interrupted;
321
+ };
322
+
323
+ callback_context *cb_ctx = new callback_context;
324
+ cb_ctx->env = env;
325
+ cb_ctx->llama = llama;
326
+ cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
327
+ defaultParams.progress_callback_user_data = cb_ctx;
328
+ }
329
+
212
330
  bool is_model_loaded = llama->loadModel(defaultParams);
213
331
 
332
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
333
+ env->ReleaseStringUTFChars(lora_str, lora_chars);
334
+ env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
+ env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
336
+
214
337
  LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
215
338
  if (is_model_loaded) {
339
+ if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
340
+ LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
341
+ llama_free(llama->ctx);
342
+ return -1;
343
+ }
216
344
  context_map[(long) llama->ctx] = llama;
217
345
  } else {
218
346
  llama_free(llama->ctx);
219
347
  }
220
348
 
221
- env->ReleaseStringUTFChars(model_path_str, model_path_chars);
222
- env->ReleaseStringUTFChars(lora_str, lora_chars);
223
-
224
349
  return reinterpret_cast<jlong>(llama->ctx);
225
350
  }
226
351
 
352
+
353
+ JNIEXPORT void JNICALL
354
+ Java_com_rnllama_LlamaContext_interruptLoad(
355
+ JNIEnv *env,
356
+ jobject thiz,
357
+ jlong context_ptr
358
+ ) {
359
+ UNUSED(thiz);
360
+ auto llama = context_map[(long) context_ptr];
361
+ if (llama) {
362
+ llama->is_load_interrupted = true;
363
+ }
364
+ }
365
+
227
366
  JNIEXPORT jobject JNICALL
228
367
  Java_com_rnllama_LlamaContext_loadModelDetails(
229
368
  JNIEnv *env,
@@ -397,14 +536,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
397
536
  jint top_k,
398
537
  jfloat top_p,
399
538
  jfloat min_p,
400
- jfloat xtc_t,
401
- jfloat xtc_p,
402
- jfloat tfs_z,
539
+ jfloat xtc_threshold,
540
+ jfloat xtc_probability,
403
541
  jfloat typical_p,
404
542
  jint seed,
405
543
  jobjectArray stop,
406
544
  jboolean ignore_eos,
407
545
  jobjectArray logit_bias,
546
+ jfloat dry_multiplier,
547
+ jfloat dry_base,
548
+ jint dry_allowed_length,
549
+ jint dry_penalty_last_n,
550
+ jobjectArray dry_sequence_breakers,
408
551
  jobject partial_completion_callback
409
552
  ) {
410
553
  UNUSED(thiz);
@@ -438,18 +581,37 @@ Java_com_rnllama_LlamaContext_doCompletion(
438
581
  sparams.top_k = top_k;
439
582
  sparams.top_p = top_p;
440
583
  sparams.min_p = min_p;
441
- sparams.tfs_z = tfs_z;
442
584
  sparams.typ_p = typical_p;
443
585
  sparams.n_probs = n_probs;
444
586
  sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
445
- sparams.xtc_t = xtc_t;
446
- sparams.xtc_p = xtc_p;
587
+ sparams.xtc_threshold = xtc_threshold;
588
+ sparams.xtc_probability = xtc_probability;
589
+ sparams.dry_multiplier = dry_multiplier;
590
+ sparams.dry_base = dry_base;
591
+ sparams.dry_allowed_length = dry_allowed_length;
592
+ sparams.dry_penalty_last_n = dry_penalty_last_n;
447
593
 
448
594
  sparams.logit_bias.clear();
449
595
  if (ignore_eos) {
450
596
  sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
451
597
  }
452
598
 
599
+ // dry break seq
600
+
601
+ jint size = env->GetArrayLength(dry_sequence_breakers);
602
+ std::vector<std::string> dry_sequence_breakers_vector;
603
+
604
+ for (jint i = 0; i < size; i++) {
605
+ jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
606
+ const char *nativeString = env->GetStringUTFChars(javaString, 0);
607
+ dry_sequence_breakers_vector.push_back(std::string(nativeString));
608
+ env->ReleaseStringUTFChars(javaString, nativeString);
609
+ env->DeleteLocalRef(javaString);
610
+ }
611
+
612
+ sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
613
+
614
+ // logit bias
453
615
  const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
454
616
  jsize logit_bias_len = env->GetArrayLength(logit_bias);
455
617
 
@@ -653,16 +815,27 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
653
815
 
654
816
  JNIEXPORT jobject JNICALL
655
817
  Java_com_rnllama_LlamaContext_embedding(
656
- JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
818
+ JNIEnv *env, jobject thiz,
819
+ jlong context_ptr,
820
+ jstring text,
821
+ jint embd_normalize
822
+ ) {
657
823
  UNUSED(thiz);
658
824
  auto llama = context_map[(long) context_ptr];
659
825
 
826
+ common_params embdParams;
827
+ embdParams.embedding = true;
828
+ embdParams.embd_normalize = llama->params.embd_normalize;
829
+ if (embd_normalize != -1) {
830
+ embdParams.embd_normalize = embd_normalize;
831
+ }
832
+
660
833
  const char *text_chars = env->GetStringUTFChars(text, nullptr);
661
834
 
662
835
  llama->rewind();
663
836
 
664
837
  llama_perf_context_reset(llama->ctx);
665
-
838
+
666
839
  llama->params.prompt = text_chars;
667
840
 
668
841
  llama->params.n_predict = 0;
@@ -677,7 +850,7 @@ Java_com_rnllama_LlamaContext_embedding(
677
850
  llama->loadPrompt();
678
851
  llama->doCompletion();
679
852
 
680
- std::vector<float> embedding = llama->getEmbedding();
853
+ std::vector<float> embedding = llama->getEmbedding(embdParams);
681
854
 
682
855
  auto embeddings = createWritableArray(env);
683
856
  for (const auto &val : embedding) {
@@ -685,6 +858,12 @@ Java_com_rnllama_LlamaContext_embedding(
685
858
  }
686
859
  putArray(env, result, "embedding", embeddings);
687
860
 
861
+ auto promptTokens = createWritableArray(env);
862
+ for (const auto &tok : llama->embd) {
863
+ pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
864
+ }
865
+ putArray(env, result, "prompt_tokens", promptTokens);
866
+
688
867
  env->ReleaseStringUTFChars(text, text_chars);
689
868
  return result;
690
869
  }
@@ -724,4 +903,11 @@ Java_com_rnllama_LlamaContext_freeContext(
724
903
  context_map.erase((long) llama->ctx);
725
904
  }
726
905
 
906
+ JNIEXPORT void JNICALL
907
+ Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
908
+ UNUSED(env);
909
+ UNUSED(thiz);
910
+ llama_log_set(log_callback, NULL);
911
+ }
912
+
727
913
  } // extern "C"
@@ -39,8 +39,13 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
39
39
  }
40
40
 
41
41
  @ReactMethod
42
- public void initContext(final ReadableMap params, final Promise promise) {
43
- rnllama.initContext(params, promise);
42
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
43
+ rnllama.modelInfo(model, skip, promise);
44
+ }
45
+
46
+ @ReactMethod
47
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
48
+ rnllama.initContext(id, params, promise);
44
49
  }
45
50
 
46
51
  @ReactMethod
@@ -89,8 +94,8 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
89
94
  }
90
95
 
91
96
  @ReactMethod
92
- public void embedding(double id, final String text, final Promise promise) {
93
- rnllama.embedding(id, text, promise);
97
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
98
+ rnllama.embedding(id, text, params, promise);
94
99
  }
95
100
 
96
101
  @ReactMethod
@@ -40,8 +40,13 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
40
40
  }
41
41
 
42
42
  @ReactMethod
43
- public void initContext(final ReadableMap params, final Promise promise) {
44
- rnllama.initContext(params, promise);
43
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
44
+ rnllama.modelInfo(model, skip, promise);
45
+ }
46
+
47
+ @ReactMethod
48
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
49
+ rnllama.initContext(id, params, promise);
45
50
  }
46
51
 
47
52
  @ReactMethod
@@ -90,8 +95,8 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
90
95
  }
91
96
 
92
97
  @ReactMethod
93
- public void embedding(double id, final String text, final Promise promise) {
94
- rnllama.embedding(id, text, promise);
98
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
99
+ rnllama.embedding(id, text, params, promise);
95
100
  }
96
101
 
97
102
  @ReactMethod