cui-llama.rn 1.2.6 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/README.md +3 -2
  2. package/android/src/main/CMakeLists.txt +20 -5
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +222 -34
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/common.cpp +1682 -2114
  9. package/cpp/common.h +600 -613
  10. package/cpp/ggml-aarch64.c +129 -3478
  11. package/cpp/ggml-aarch64.h +19 -39
  12. package/cpp/ggml-alloc.c +1040 -1040
  13. package/cpp/ggml-alloc.h +76 -76
  14. package/cpp/ggml-backend-impl.h +216 -216
  15. package/cpp/ggml-backend-reg.cpp +195 -0
  16. package/cpp/ggml-backend.cpp +1997 -2661
  17. package/cpp/ggml-backend.h +328 -314
  18. package/cpp/ggml-common.h +1853 -1853
  19. package/cpp/ggml-cpp.h +38 -38
  20. package/cpp/ggml-cpu-aarch64.c +3560 -0
  21. package/cpp/ggml-cpu-aarch64.h +30 -0
  22. package/cpp/ggml-cpu-impl.h +371 -614
  23. package/cpp/ggml-cpu-quants.c +10822 -0
  24. package/cpp/ggml-cpu-quants.h +63 -0
  25. package/cpp/ggml-cpu.c +13975 -13720
  26. package/cpp/ggml-cpu.cpp +663 -0
  27. package/cpp/ggml-cpu.h +177 -150
  28. package/cpp/ggml-impl.h +550 -296
  29. package/cpp/ggml-metal.h +66 -66
  30. package/cpp/ggml-metal.m +4294 -3933
  31. package/cpp/ggml-quants.c +5247 -15739
  32. package/cpp/ggml-quants.h +100 -147
  33. package/cpp/ggml-threading.cpp +12 -0
  34. package/cpp/ggml-threading.h +12 -0
  35. package/cpp/ggml.c +8180 -8390
  36. package/cpp/ggml.h +2411 -2441
  37. package/cpp/llama-grammar.cpp +1138 -1138
  38. package/cpp/llama-grammar.h +144 -144
  39. package/cpp/llama-impl.h +181 -181
  40. package/cpp/llama-sampling.cpp +2348 -2345
  41. package/cpp/llama-sampling.h +48 -48
  42. package/cpp/llama-vocab.cpp +1984 -1984
  43. package/cpp/llama-vocab.h +170 -170
  44. package/cpp/llama.cpp +22132 -22046
  45. package/cpp/llama.h +1253 -1255
  46. package/cpp/log.cpp +401 -401
  47. package/cpp/log.h +121 -121
  48. package/cpp/rn-llama.hpp +83 -19
  49. package/cpp/sampling.cpp +466 -466
  50. package/cpp/sgemm.cpp +1884 -1276
  51. package/ios/RNLlama.mm +43 -20
  52. package/ios/RNLlamaContext.h +9 -3
  53. package/ios/RNLlamaContext.mm +133 -33
  54. package/jest/mock.js +0 -1
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  56. package/lib/commonjs/index.js +52 -15
  57. package/lib/commonjs/index.js.map +1 -1
  58. package/lib/module/NativeRNLlama.js.map +1 -1
  59. package/lib/module/index.js +51 -15
  60. package/lib/module/index.js.map +1 -1
  61. package/lib/typescript/NativeRNLlama.d.ts +29 -5
  62. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  63. package/lib/typescript/index.d.ts +12 -5
  64. package/lib/typescript/index.d.ts.map +1 -1
  65. package/package.json +1 -1
  66. package/src/NativeRNLlama.ts +41 -6
  67. package/src/index.ts +82 -27
  68. package/cpp/json-schema-to-grammar.cpp +0 -1045
  69. package/cpp/json-schema-to-grammar.h +0 -8
  70. package/cpp/json.hpp +0 -24766
@@ -4,13 +4,15 @@
4
4
  #include <android/log.h>
5
5
  #include <cstdlib>
6
6
  #include <ctime>
7
+ #include <ctime>
7
8
  #include <sys/sysinfo.h>
8
9
  #include <string>
9
10
  #include <thread>
10
11
  #include <unordered_map>
11
12
  #include "llama.h"
12
- #include "rn-llama.hpp"
13
+ #include "llama-impl.h"
13
14
  #include "ggml.h"
15
+ #include "rn-llama.hpp"
14
16
 
15
17
  #define UNUSED(x) (void)(x)
16
18
  #define TAG "RNLLAMA_ANDROID_JNI"
@@ -22,6 +24,13 @@ static inline int min(int a, int b) {
22
24
  return (a < b) ? a : b;
23
25
  }
24
26
 
27
+ static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
28
+ if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
29
+ else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
30
+ else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
31
+ else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
32
+ }
33
+
25
34
  extern "C" {
26
35
 
27
36
  // Method to create WritableMap
@@ -107,6 +116,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
107
116
  env->CallVoidMethod(arr, pushDoubleMethod, value);
108
117
  }
109
118
 
119
+ // Method to push string into WritableArray
120
+ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
121
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
122
+ jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
123
+
124
+ jstring jValue = env->NewStringUTF(value);
125
+ env->CallVoidMethod(arr, pushStringMethod, jValue);
126
+ }
127
+
110
128
  // Method to push WritableMap into WritableArray
111
129
  static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
112
130
  jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
@@ -125,6 +143,77 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
125
143
  env->CallVoidMethod(map, putArrayMethod, jKey, value);
126
144
  }
127
145
 
146
+ JNIEXPORT jobject JNICALL
147
+ Java_com_rnllama_LlamaContext_modelInfo(
148
+ JNIEnv *env,
149
+ jobject thiz,
150
+ jstring model_path_str,
151
+ jobjectArray skip
152
+ ) {
153
+ UNUSED(thiz);
154
+
155
+ const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
156
+
157
+ std::vector<std::string> skip_vec;
158
+ int skip_len = env->GetArrayLength(skip);
159
+ for (int i = 0; i < skip_len; i++) {
160
+ jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
161
+ const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
162
+ skip_vec.push_back(skip_chars);
163
+ env->ReleaseStringUTFChars(skip_str, skip_chars);
164
+ }
165
+
166
+ struct lm_gguf_init_params params = {
167
+ /*.no_alloc = */ false,
168
+ /*.ctx = */ NULL,
169
+ };
170
+ struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
171
+
172
+ if (!ctx) {
173
+ LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
174
+ return nullptr;
175
+ }
176
+
177
+ auto info = createWriteableMap(env);
178
+ putInt(env, info, "version", lm_gguf_get_version(ctx));
179
+ putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
180
+ putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
181
+ {
182
+ const int n_kv = lm_gguf_get_n_kv(ctx);
183
+
184
+ for (int i = 0; i < n_kv; ++i) {
185
+ const char * key = lm_gguf_get_key(ctx, i);
186
+
187
+ bool skipped = false;
188
+ if (skip_len > 0) {
189
+ for (int j = 0; j < skip_len; j++) {
190
+ if (skip_vec[j] == key) {
191
+ skipped = true;
192
+ break;
193
+ }
194
+ }
195
+ }
196
+
197
+ if (skipped) {
198
+ continue;
199
+ }
200
+
201
+ const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
202
+ putString(env, info, key, value.c_str());
203
+ }
204
+ }
205
+
206
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
207
+ lm_gguf_free(ctx);
208
+
209
+ return reinterpret_cast<jobject>(info);
210
+ }
211
+
212
+ struct callback_context {
213
+ JNIEnv *env;
214
+ rnllama::llama_rn_context *llama;
215
+ jobject callback;
216
+ };
128
217
 
129
218
  std::unordered_map<long, rnllama::llama_rn_context *> context_map;
130
219
 
@@ -141,10 +230,14 @@ Java_com_rnllama_LlamaContext_initContext(
141
230
  jobject thiz,
142
231
  jstring model_path_str,
143
232
  jboolean embedding,
233
+ jint embd_normalize,
144
234
  jint n_ctx,
145
235
  jint n_batch,
146
236
  jint n_threads,
147
237
  jint n_gpu_layers, // TODO: Support this
238
+ jboolean flash_attn,
239
+ jstring cache_type_k,
240
+ jstring cache_type_v,
148
241
  jboolean use_mlock,
149
242
  jboolean use_mmap,
150
243
  jboolean vocab_only,
@@ -152,7 +245,8 @@ Java_com_rnllama_LlamaContext_initContext(
152
245
  jfloat lora_scaled,
153
246
  jfloat rope_freq_base,
154
247
  jfloat rope_freq_scale,
155
- jobject javaLlamaContext
248
+ jint pooling_type,
249
+ jobject load_progress_callback
156
250
  ) {
157
251
  UNUSED(thiz);
158
252
 
@@ -166,64 +260,109 @@ Java_com_rnllama_LlamaContext_initContext(
166
260
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
167
261
  defaultParams.model = model_path_chars;
168
262
 
169
- defaultParams.embedding = embedding;
170
-
171
263
  defaultParams.n_ctx = n_ctx;
172
264
  defaultParams.n_batch = n_batch;
173
265
 
266
+ if (pooling_type != -1) {
267
+ defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
268
+ }
269
+
270
+ defaultParams.embedding = embedding;
271
+ if (embd_normalize != -1) {
272
+ defaultParams.embd_normalize = embd_normalize;
273
+ }
274
+ if (embedding) {
275
+ // For non-causal models, batch size must be equal to ubatch size
276
+ defaultParams.n_ubatch = defaultParams.n_batch;
277
+ }
278
+
174
279
  int max_threads = std::thread::hardware_concurrency();
175
280
  // Use 2 threads by default on 4-core devices, 4 threads on more cores
176
281
  int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
177
282
  defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
178
283
 
179
284
  defaultParams.n_gpu_layers = n_gpu_layers;
180
-
285
+ defaultParams.flash_attn = flash_attn;
286
+
287
+ const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
+ const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
+ defaultParams.cache_type_k = cache_type_k_chars;
290
+ defaultParams.cache_type_v = cache_type_v_chars;
291
+
181
292
  defaultParams.use_mlock = use_mlock;
182
293
  defaultParams.use_mmap = use_mmap;
183
294
 
184
295
  const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
185
296
  if (lora_chars != nullptr && lora_chars[0] != '\0') {
186
297
  defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
187
- defaultParams.use_mmap = false;
188
298
  }
189
299
 
190
300
  defaultParams.rope_freq_base = rope_freq_base;
191
301
  defaultParams.rope_freq_scale = rope_freq_scale;
192
302
 
193
- // progress callback when loading
194
- jclass llamaContextClass = env->GetObjectClass(javaLlamaContext);
195
- jmethodID sendProgressMethod = env->GetMethodID(llamaContextClass, "emitModelProgressUpdate", "(I)V");
196
-
197
- CallbackContext callbackctx = {env, javaLlamaContext, sendProgressMethod, 0};
198
-
199
- defaultParams.progress_callback_user_data = &callbackctx;
200
- defaultParams.progress_callback = [](float progress, void * ctx) {
201
- unsigned percentage = (unsigned) (100 * progress);
202
- CallbackContext * cbctx = static_cast<CallbackContext*>(ctx);
203
- // reduce call frequency by only calling method when value changes
204
- if (percentage <= cbctx->current) return true;
205
- cbctx->current = percentage;
206
- cbctx->env->CallVoidMethod(cbctx->thiz, cbctx->sendProgressMethod, percentage);
207
- return true;
208
- };
209
-
210
-
211
303
  auto llama = new rnllama::llama_rn_context();
304
+ llama->is_load_interrupted = false;
305
+ llama->loading_progress = 0;
306
+
307
+ if (load_progress_callback != nullptr) {
308
+ defaultParams.progress_callback = [](float progress, void * user_data) {
309
+ callback_context *cb_ctx = (callback_context *)user_data;
310
+ JNIEnv *env = cb_ctx->env;
311
+ auto llama = cb_ctx->llama;
312
+ jobject callback = cb_ctx->callback;
313
+ int percentage = (int) (100 * progress);
314
+ if (percentage > llama->loading_progress) {
315
+ llama->loading_progress = percentage;
316
+ jclass callback_class = env->GetObjectClass(callback);
317
+ jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
318
+ env->CallVoidMethod(callback, onLoadProgress, percentage);
319
+ }
320
+ return !llama->is_load_interrupted;
321
+ };
322
+
323
+ callback_context *cb_ctx = new callback_context;
324
+ cb_ctx->env = env;
325
+ cb_ctx->llama = llama;
326
+ cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
327
+ defaultParams.progress_callback_user_data = cb_ctx;
328
+ }
329
+
212
330
  bool is_model_loaded = llama->loadModel(defaultParams);
213
331
 
332
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
333
+ env->ReleaseStringUTFChars(lora_str, lora_chars);
334
+ env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
+ env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
336
+
214
337
  LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
215
338
  if (is_model_loaded) {
339
+ if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
340
+ LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
341
+ llama_free(llama->ctx);
342
+ return -1;
343
+ }
216
344
  context_map[(long) llama->ctx] = llama;
217
345
  } else {
218
346
  llama_free(llama->ctx);
219
347
  }
220
348
 
221
- env->ReleaseStringUTFChars(model_path_str, model_path_chars);
222
- env->ReleaseStringUTFChars(lora_str, lora_chars);
223
-
224
349
  return reinterpret_cast<jlong>(llama->ctx);
225
350
  }
226
351
 
352
+
353
+ JNIEXPORT void JNICALL
354
+ Java_com_rnllama_LlamaContext_interruptLoad(
355
+ JNIEnv *env,
356
+ jobject thiz,
357
+ jlong context_ptr
358
+ ) {
359
+ UNUSED(thiz);
360
+ auto llama = context_map[(long) context_ptr];
361
+ if (llama) {
362
+ llama->is_load_interrupted = true;
363
+ }
364
+ }
365
+
227
366
  JNIEXPORT jobject JNICALL
228
367
  Java_com_rnllama_LlamaContext_loadModelDetails(
229
368
  JNIEnv *env,
@@ -397,13 +536,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
397
536
  jint top_k,
398
537
  jfloat top_p,
399
538
  jfloat min_p,
400
- jfloat xtc_t,
401
- jfloat xtc_p,
539
+ jfloat xtc_threshold,
540
+ jfloat xtc_probability,
402
541
  jfloat typical_p,
403
542
  jint seed,
404
543
  jobjectArray stop,
405
544
  jboolean ignore_eos,
406
545
  jobjectArray logit_bias,
546
+ jfloat dry_multiplier,
547
+ jfloat dry_base,
548
+ jint dry_allowed_length,
549
+ jint dry_penalty_last_n,
550
+ jobjectArray dry_sequence_breakers,
407
551
  jobject partial_completion_callback
408
552
  ) {
409
553
  UNUSED(thiz);
@@ -440,14 +584,34 @@ Java_com_rnllama_LlamaContext_doCompletion(
440
584
  sparams.typ_p = typical_p;
441
585
  sparams.n_probs = n_probs;
442
586
  sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
443
- sparams.xtc_threshold = xtc_t;
444
- sparams.xtc_probability = xtc_p;
587
+ sparams.xtc_threshold = xtc_threshold;
588
+ sparams.xtc_probability = xtc_probability;
589
+ sparams.dry_multiplier = dry_multiplier;
590
+ sparams.dry_base = dry_base;
591
+ sparams.dry_allowed_length = dry_allowed_length;
592
+ sparams.dry_penalty_last_n = dry_penalty_last_n;
445
593
 
446
594
  sparams.logit_bias.clear();
447
595
  if (ignore_eos) {
448
596
  sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
449
597
  }
450
598
 
599
+ // dry break seq
600
+
601
+ jint size = env->GetArrayLength(dry_sequence_breakers);
602
+ std::vector<std::string> dry_sequence_breakers_vector;
603
+
604
+ for (jint i = 0; i < size; i++) {
605
+ jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
606
+ const char *nativeString = env->GetStringUTFChars(javaString, 0);
607
+ dry_sequence_breakers_vector.push_back(std::string(nativeString));
608
+ env->ReleaseStringUTFChars(javaString, nativeString);
609
+ env->DeleteLocalRef(javaString);
610
+ }
611
+
612
+ sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
613
+
614
+ // logit bias
451
615
  const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
452
616
  jsize logit_bias_len = env->GetArrayLength(logit_bias);
453
617
 
@@ -651,16 +815,27 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
651
815
 
652
816
  JNIEXPORT jobject JNICALL
653
817
  Java_com_rnllama_LlamaContext_embedding(
654
- JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
818
+ JNIEnv *env, jobject thiz,
819
+ jlong context_ptr,
820
+ jstring text,
821
+ jint embd_normalize
822
+ ) {
655
823
  UNUSED(thiz);
656
824
  auto llama = context_map[(long) context_ptr];
657
825
 
826
+ common_params embdParams;
827
+ embdParams.embedding = true;
828
+ embdParams.embd_normalize = llama->params.embd_normalize;
829
+ if (embd_normalize != -1) {
830
+ embdParams.embd_normalize = embd_normalize;
831
+ }
832
+
658
833
  const char *text_chars = env->GetStringUTFChars(text, nullptr);
659
834
 
660
835
  llama->rewind();
661
836
 
662
837
  llama_perf_context_reset(llama->ctx);
663
-
838
+
664
839
  llama->params.prompt = text_chars;
665
840
 
666
841
  llama->params.n_predict = 0;
@@ -675,7 +850,7 @@ Java_com_rnllama_LlamaContext_embedding(
675
850
  llama->loadPrompt();
676
851
  llama->doCompletion();
677
852
 
678
- std::vector<float> embedding = llama->getEmbedding();
853
+ std::vector<float> embedding = llama->getEmbedding(embdParams);
679
854
 
680
855
  auto embeddings = createWritableArray(env);
681
856
  for (const auto &val : embedding) {
@@ -683,6 +858,12 @@ Java_com_rnllama_LlamaContext_embedding(
683
858
  }
684
859
  putArray(env, result, "embedding", embeddings);
685
860
 
861
+ auto promptTokens = createWritableArray(env);
862
+ for (const auto &tok : llama->embd) {
863
+ pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
864
+ }
865
+ putArray(env, result, "prompt_tokens", promptTokens);
866
+
686
867
  env->ReleaseStringUTFChars(text, text_chars);
687
868
  return result;
688
869
  }
@@ -722,4 +903,11 @@ Java_com_rnllama_LlamaContext_freeContext(
722
903
  context_map.erase((long) llama->ctx);
723
904
  }
724
905
 
906
+ JNIEXPORT void JNICALL
907
+ Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
908
+ UNUSED(env);
909
+ UNUSED(thiz);
910
+ llama_log_set(log_callback, NULL);
911
+ }
912
+
725
913
  } // extern "C"
@@ -39,8 +39,13 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
39
39
  }
40
40
 
41
41
  @ReactMethod
42
- public void initContext(final ReadableMap params, final Promise promise) {
43
- rnllama.initContext(params, promise);
42
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
43
+ rnllama.modelInfo(model, skip, promise);
44
+ }
45
+
46
+ @ReactMethod
47
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
48
+ rnllama.initContext(id, params, promise);
44
49
  }
45
50
 
46
51
  @ReactMethod
@@ -89,8 +94,8 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
89
94
  }
90
95
 
91
96
  @ReactMethod
92
- public void embedding(double id, final String text, final Promise promise) {
93
- rnllama.embedding(id, text, promise);
97
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
98
+ rnllama.embedding(id, text, params, promise);
94
99
  }
95
100
 
96
101
  @ReactMethod
@@ -40,8 +40,13 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
40
40
  }
41
41
 
42
42
  @ReactMethod
43
- public void initContext(final ReadableMap params, final Promise promise) {
44
- rnllama.initContext(params, promise);
43
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
44
+ rnllama.modelInfo(model, skip, promise);
45
+ }
46
+
47
+ @ReactMethod
48
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
49
+ rnllama.initContext(id, params, promise);
45
50
  }
46
51
 
47
52
  @ReactMethod
@@ -90,8 +95,8 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
90
95
  }
91
96
 
92
97
  @ReactMethod
93
- public void embedding(double id, final String text, final Promise promise) {
94
- rnllama.embedding(id, text, promise);
98
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
99
+ rnllama.embedding(id, text, params, promise);
95
100
  }
96
101
 
97
102
  @ReactMethod