cui-llama.rn 1.0.3 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/README.md +35 -39
  2. package/android/src/main/CMakeLists.txt +12 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +29 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
  5. package/android/src/main/jni.cpp +62 -8
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  8. package/cpp/common.cpp +3237 -3231
  9. package/cpp/common.h +469 -468
  10. package/cpp/ggml-aarch64.c +2193 -2193
  11. package/cpp/ggml-aarch64.h +39 -39
  12. package/cpp/ggml-alloc.c +1036 -1042
  13. package/cpp/ggml-backend-impl.h +153 -153
  14. package/cpp/ggml-backend.c +2240 -2234
  15. package/cpp/ggml-backend.h +238 -238
  16. package/cpp/ggml-common.h +1833 -1829
  17. package/cpp/ggml-impl.h +755 -655
  18. package/cpp/ggml-metal.h +65 -65
  19. package/cpp/ggml-metal.m +3269 -3269
  20. package/cpp/ggml-quants.c +14872 -14860
  21. package/cpp/ggml-quants.h +132 -132
  22. package/cpp/ggml.c +22055 -22044
  23. package/cpp/ggml.h +2453 -2447
  24. package/cpp/llama-grammar.cpp +539 -0
  25. package/cpp/llama-grammar.h +39 -0
  26. package/cpp/llama-impl.h +26 -0
  27. package/cpp/llama-sampling.cpp +635 -0
  28. package/cpp/llama-sampling.h +56 -0
  29. package/cpp/llama-vocab.cpp +1721 -0
  30. package/cpp/llama-vocab.h +130 -0
  31. package/cpp/llama.cpp +19171 -21892
  32. package/cpp/llama.h +1240 -1217
  33. package/cpp/log.h +737 -737
  34. package/cpp/rn-llama.hpp +207 -29
  35. package/cpp/sampling.cpp +460 -460
  36. package/cpp/sgemm.cpp +1027 -1027
  37. package/cpp/sgemm.h +14 -14
  38. package/cpp/unicode.cpp +6 -0
  39. package/cpp/unicode.h +3 -0
  40. package/ios/RNLlama.mm +15 -6
  41. package/ios/RNLlamaContext.h +2 -8
  42. package/ios/RNLlamaContext.mm +41 -34
  43. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  44. package/lib/commonjs/chat.js +37 -0
  45. package/lib/commonjs/chat.js.map +1 -0
  46. package/lib/commonjs/index.js +14 -1
  47. package/lib/commonjs/index.js.map +1 -1
  48. package/lib/module/NativeRNLlama.js.map +1 -1
  49. package/lib/module/chat.js +31 -0
  50. package/lib/module/chat.js.map +1 -0
  51. package/lib/module/index.js +14 -1
  52. package/lib/module/index.js.map +1 -1
  53. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  54. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  55. package/lib/typescript/chat.d.ts +10 -0
  56. package/lib/typescript/chat.d.ts.map +1 -0
  57. package/lib/typescript/index.d.ts +9 -2
  58. package/lib/typescript/index.d.ts.map +1 -1
  59. package/package.json +1 -1
  60. package/src/NativeRNLlama.ts +10 -1
  61. package/src/chat.ts +44 -0
  62. package/src/index.ts +31 -4
package/README.md CHANGED
@@ -46,38 +46,7 @@ Add proguard rule if it's enabled in project (android/app/proguard-rules.pro):
46
46
 
47
47
  You can search HuggingFace for available models (Keyword: [`GGUF`](https://huggingface.co/search/full-text?q=GGUF&type=model)).
48
48
 
49
- For create a GGUF model manually, for example in Llama 2:
50
-
51
- Download the Llama 2 model
52
-
53
- 1. Request access from [here](https://ai.meta.com/llama)
54
- 2. Download the model from HuggingFace [here](https://huggingface.co/meta-llama/Llama-2-7b-chat) (`Llama-2-7b-chat`)
55
-
56
- Convert the model to ggml format
57
-
58
- ```bash
59
- # Start with submodule in this repo (or you can clone the repo https://github.com/ggerganov/llama.cpp.git)
60
- yarn && yarn bootstrap
61
- cd llama.cpp
62
-
63
- # install Python dependencies
64
- python3 -m pip install -r requirements.txt
65
-
66
- # Move the Llama model weights to the models folder
67
- mv <path to Llama-2-7b-chat> ./models/7B
68
-
69
- # convert the 7B model to ggml FP16 format
70
- python3 convert.py models/7B/ --outtype f16
71
-
72
- # Build the quantize tool
73
- make quantize
74
-
75
- # quantize the model to 2-bits (using q2_k method)
76
- ./quantize ./models/7B/ggml-model-f16.gguf ./models/7B/ggml-model-q2_k.gguf q2_k
77
-
78
- # quantize the model to 4-bits (using q4_0 method)
79
- ./quantize ./models/7B/ggml-model-f16.gguf ./models/7B/ggml-model-q4_0.gguf q4_0
80
- ```
49
+ For get a GGUF model or quantize manually, see [`Prepare and Quantize`](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#prepare-and-quantize) section in llama.cpp.
81
50
 
82
51
  ## Usage
83
52
 
@@ -93,27 +62,54 @@ const context = await initLlama({
93
62
  // embedding: true, // use embedding
94
63
  })
95
64
 
96
- // Do completion
97
- const { text, timings } = await context.completion(
65
+ const stopWords = ['</s>', '<|end|>', '<|eot_id|>', '<|end_of_text|>', '<|im_end|>', '<|EOT|>', '<|END_OF_TURN_TOKEN|>', '<|end_of_turn|>', '<|endoftext|>']
66
+
67
+ // Do chat completion
68
+ const msgResult = await context.completion(
69
+ {
70
+ messages: [
71
+ {
72
+ role: 'system',
73
+ content: 'This is a conversation between user and assistant, a friendly chatbot.',
74
+ },
75
+ {
76
+ role: 'user',
77
+ content: 'Hello!',
78
+ },
79
+ ],
80
+ n_predict: 100,
81
+ stop: stopWords,
82
+ // ...other params
83
+ },
84
+ (data) => {
85
+ // This is a partial completion callback
86
+ const { token } = data
87
+ },
88
+ )
89
+ console.log('Result:', msgResult.text)
90
+ console.log('Timings:', msgResult.timings)
91
+
92
+ // Or do text completion
93
+ const textResult = await context.completion(
98
94
  {
99
95
  prompt:
100
96
  'This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.\n\nUser: Hello!\nLlama:',
101
97
  n_predict: 100,
102
- stop: ['</s>', 'Llama:', 'User:'],
103
- // n_threads: 4,
98
+ stop: [...stopWords, 'Llama:', 'User:'],
99
+ // ...other params
104
100
  },
105
101
  (data) => {
106
102
  // This is a partial completion callback
107
103
  const { token } = data
108
104
  },
109
105
  )
110
- console.log('Result:', text)
111
- console.log('Timings:', timings)
106
+ console.log('Result:', textResult.text)
107
+ console.log('Timings:', textResult.timings)
112
108
  ```
113
109
 
114
110
  The binding’s deisgn inspired by [server.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) example in llama.cpp, so you can map its API to LlamaContext:
115
111
 
116
- - `/completion`: `context.completion(params, partialCompletionCallback)`
112
+ - `/completion` and `/chat/completions`: `context.completion(params, partialCompletionCallback)`
117
113
  - `/tokenize`: `context.tokenize(content)`
118
114
  - `/detokenize`: `context.detokenize(tokens)`
119
115
  - `/embedding`: `context.embedding(content)`
@@ -9,6 +9,10 @@ include_directories(${RNLLAMA_LIB_DIR})
9
9
 
10
10
  set(
11
11
  SOURCE_FILES
12
+ ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
13
+ ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
14
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
15
+
12
16
  ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
13
17
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
14
18
  ${RNLLAMA_LIB_DIR}/ggml-backend.c
@@ -22,6 +26,9 @@ set(
22
26
  ${RNLLAMA_LIB_DIR}/unicode-data.cpp
23
27
  ${RNLLAMA_LIB_DIR}/unicode.cpp
24
28
  ${RNLLAMA_LIB_DIR}/llama.cpp
29
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
30
+ ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
31
+ ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
25
32
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
26
33
  ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
27
34
  ${RNLLAMA_LIB_DIR}/rn-llama.hpp
@@ -45,7 +52,9 @@ function(build_library target_name cpu_flags)
45
52
  target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING)
46
53
  endif ()
47
54
 
48
- #if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
55
+ # NOTE: If you want to debug the native code, you can uncomment if and endif
56
+ # Note that it will be extremely slow
57
+ # if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
49
58
  target_compile_options(${target_name} PRIVATE -O3 -DNDEBUG)
50
59
  target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
51
60
  target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)
@@ -53,7 +62,7 @@ function(build_library target_name cpu_flags)
53
62
  target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
54
63
  target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
55
64
  target_link_options(${target_name} PRIVATE -flto)
56
- #endif ()
65
+ # endif ()
57
66
  endfunction()
58
67
 
59
68
  # Default target (no specific CPU features)
@@ -61,6 +70,7 @@ build_library("rnllama" "")
61
70
 
62
71
  if (${ANDROID_ABI} STREQUAL "arm64-v8a")
63
72
  # ARM64 targets
73
+ build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
64
74
  build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
65
75
  build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
66
76
  build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
@@ -94,8 +94,6 @@ public class LlamaContext {
94
94
  params.hasKey("lora") ? params.getString("lora") : "",
95
95
  // float lora_scaled,
96
96
  params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
97
- // String lora_base,
98
- params.hasKey("lora_base") ? params.getString("lora_base") : "",
99
97
  // float rope_freq_base,
100
98
  params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
101
99
  // float rope_freq_scale
@@ -114,6 +112,14 @@ public class LlamaContext {
114
112
  return modelDetails;
115
113
  }
116
114
 
115
+ public String getFormattedChat(ReadableArray messages, String chatTemplate) {
116
+ ReadableMap[] msgs = new ReadableMap[messages.size()];
117
+ for (int i = 0; i < messages.size(); i++) {
118
+ msgs[i] = messages.getMap(i);
119
+ }
120
+ return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
121
+ }
122
+
117
123
  private void emitPartialCompletion(WritableMap tokenResult) {
118
124
  WritableMap event = Arguments.createMap();
119
125
  event.putInt("contextId", LlamaContext.this.id);
@@ -176,7 +182,7 @@ public class LlamaContext {
176
182
  }
177
183
  }
178
184
 
179
- return doCompletion(
185
+ WritableMap result = doCompletion(
180
186
  this.context,
181
187
  // String prompt,
182
188
  params.getString("prompt"),
@@ -230,6 +236,10 @@ public class LlamaContext {
230
236
  params.hasKey("emit_partial_completion") ? params.getBoolean("emit_partial_completion") : false
231
237
  )
232
238
  );
239
+ if (result.hasKey("error")) {
240
+ throw new IllegalStateException(result.getString("error"));
241
+ }
242
+ return result;
233
243
  }
234
244
 
235
245
  public void stopCompletion() {
@@ -254,12 +264,14 @@ public class LlamaContext {
254
264
  return detokenize(this.context, toks);
255
265
  }
256
266
 
257
- public WritableMap embedding(String text) {
267
+ public WritableMap getEmbedding(String text) {
258
268
  if (isEmbeddingEnabled(this.context) == false) {
259
269
  throw new IllegalStateException("Embedding is not enabled");
260
270
  }
261
- WritableMap result = Arguments.createMap();
262
- result.putArray("embedding", embedding(this.context, text));
271
+ WritableMap result = embedding(this.context, text);
272
+ if (result.hasKey("error")) {
273
+ throw new IllegalStateException(result.getString("error"));
274
+ }
263
275
  return result;
264
276
  }
265
277
 
@@ -281,8 +293,12 @@ public class LlamaContext {
281
293
  boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
282
294
  boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
283
295
  boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
296
+ boolean hasInt8Matmul = cpuFeatures.contains("i8mm");
284
297
 
285
- if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
298
+ if (isAtLeastArmV84 && hasFp16 && hasDotProd && hasInt8Matmul) {
299
+ Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
300
+ System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
301
+ } else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
286
302
  Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so");
287
303
  System.loadLibrary("rnllama_v8_4_fp16_dotprod");
288
304
  } else if (isAtLeastArmV82 && hasFp16 && hasDotProd) {
@@ -344,13 +360,17 @@ public class LlamaContext {
344
360
  boolean vocab_only,
345
361
  String lora,
346
362
  float lora_scaled,
347
- String lora_base,
348
363
  float rope_freq_base,
349
364
  float rope_freq_scale
350
365
  );
351
366
  protected static native WritableMap loadModelDetails(
352
367
  long contextPtr
353
368
  );
369
+ protected static native String getFormattedChat(
370
+ long contextPtr,
371
+ ReadableMap[] messages,
372
+ String chatTemplate
373
+ );
354
374
  protected static native WritableMap loadSession(
355
375
  long contextPtr,
356
376
  String path
@@ -392,7 +412,7 @@ public class LlamaContext {
392
412
  protected static native WritableArray tokenize(long contextPtr, String text);
393
413
  protected static native String detokenize(long contextPtr, int[] tokens);
394
414
  protected static native boolean isEmbeddingEnabled(long contextPtr);
395
- protected static native WritableArray embedding(long contextPtr, String text);
415
+ protected static native WritableMap embedding(long contextPtr, String text);
396
416
  protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
397
417
  protected static native void freeContext(long contextPtr);
398
418
  }
@@ -80,6 +80,38 @@ public class RNLlama implements LifecycleEventListener {
80
80
  tasks.put(task, "initContext");
81
81
  }
82
82
 
83
+ public void getFormattedChat(double id, final ReadableArray messages, final String chatTemplate, Promise promise) {
84
+ final int contextId = (int) id;
85
+ AsyncTask task = new AsyncTask<Void, Void, String>() {
86
+ private Exception exception;
87
+
88
+ @Override
89
+ protected String doInBackground(Void... voids) {
90
+ try {
91
+ LlamaContext context = contexts.get(contextId);
92
+ if (context == null) {
93
+ throw new Exception("Context not found");
94
+ }
95
+ return context.getFormattedChat(messages, chatTemplate);
96
+ } catch (Exception e) {
97
+ exception = e;
98
+ return null;
99
+ }
100
+ }
101
+
102
+ @Override
103
+ protected void onPostExecute(String result) {
104
+ if (exception != null) {
105
+ promise.reject(exception);
106
+ return;
107
+ }
108
+ promise.resolve(result);
109
+ tasks.remove(this);
110
+ }
111
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
112
+ tasks.put(task, "getFormattedChat-" + contextId);
113
+ }
114
+
83
115
  public void loadSession(double id, final String path, Promise promise) {
84
116
  final int contextId = (int) id;
85
117
  AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
@@ -307,7 +339,7 @@ public class RNLlama implements LifecycleEventListener {
307
339
  if (context == null) {
308
340
  throw new Exception("Context not found");
309
341
  }
310
- return context.embedding(text);
342
+ return context.getEmbedding(text);
311
343
  } catch (Exception e) {
312
344
  exception = e;
313
345
  }
@@ -62,6 +62,16 @@ static inline void putDouble(JNIEnv *env, jobject map, const char *key, double v
62
62
  env->CallVoidMethod(map, putDoubleMethod, jKey, value);
63
63
  }
64
64
 
65
+ // Method to put boolean into WritableMap
66
+ static inline void putBoolean(JNIEnv *env, jobject map, const char *key, bool value) {
67
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
68
+ jmethodID putBooleanMethod = env->GetMethodID(mapClass, "putBoolean", "(Ljava/lang/String;Z)V");
69
+
70
+ jstring jKey = env->NewStringUTF(key);
71
+
72
+ env->CallVoidMethod(map, putBooleanMethod, jKey, value);
73
+ }
74
+
65
75
  // Method to put WriteableMap into WritableMap
66
76
  static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
67
77
  jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
@@ -132,7 +142,6 @@ Java_com_rnllama_LlamaContext_initContext(
132
142
  jboolean vocab_only,
133
143
  jstring lora_str,
134
144
  jfloat lora_scaled,
135
- jstring lora_base_str,
136
145
  jfloat rope_freq_base,
137
146
  jfloat rope_freq_scale
138
147
  ) {
@@ -164,10 +173,8 @@ Java_com_rnllama_LlamaContext_initContext(
164
173
  defaultParams.use_mmap = use_mmap;
165
174
 
166
175
  const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
167
- const char *lora_base_chars = env->GetStringUTFChars(lora_base_str, nullptr);
168
176
  if (lora_chars != nullptr && lora_chars[0] != '\0') {
169
177
  defaultParams.lora_adapter.push_back({lora_chars, lora_scaled});
170
- defaultParams.lora_base = lora_base_chars;
171
178
  defaultParams.use_mmap = false;
172
179
  }
173
180
 
@@ -186,7 +193,6 @@ Java_com_rnllama_LlamaContext_initContext(
186
193
 
187
194
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
188
195
  env->ReleaseStringUTFChars(lora_str, lora_chars);
189
- env->ReleaseStringUTFChars(lora_base_str, lora_base_chars);
190
196
 
191
197
  return reinterpret_cast<jlong>(llama->ctx);
192
198
  }
@@ -218,11 +224,52 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
218
224
  putString(env, result, "desc", desc);
219
225
  putDouble(env, result, "size", llama_model_size(llama->model));
220
226
  putDouble(env, result, "nParams", llama_model_n_params(llama->model));
227
+ putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
221
228
  putMap(env, result, "metadata", meta);
222
229
 
223
230
  return reinterpret_cast<jobject>(result);
224
231
  }
225
232
 
233
+ JNIEXPORT jobject JNICALL
234
+ Java_com_rnllama_LlamaContext_getFormattedChat(
235
+ JNIEnv *env,
236
+ jobject thiz,
237
+ jlong context_ptr,
238
+ jobjectArray messages,
239
+ jstring chat_template
240
+ ) {
241
+ UNUSED(thiz);
242
+ auto llama = context_map[(long) context_ptr];
243
+
244
+ std::vector<llama_chat_msg> chat;
245
+
246
+ int messages_len = env->GetArrayLength(messages);
247
+ for (int i = 0; i < messages_len; i++) {
248
+ jobject msg = env->GetObjectArrayElement(messages, i);
249
+ jclass msgClass = env->GetObjectClass(msg);
250
+
251
+ jmethodID getRoleMethod = env->GetMethodID(msgClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
252
+ jstring roleKey = env->NewStringUTF("role");
253
+ jstring contentKey = env->NewStringUTF("content");
254
+
255
+ jstring role_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, roleKey);
256
+ jstring content_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, contentKey);
257
+
258
+ const char *role = env->GetStringUTFChars(role_str, nullptr);
259
+ const char *content = env->GetStringUTFChars(content_str, nullptr);
260
+
261
+ chat.push_back({ role, content });
262
+
263
+ env->ReleaseStringUTFChars(role_str, role);
264
+ env->ReleaseStringUTFChars(content_str, content);
265
+ }
266
+
267
+ const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
268
+ std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true);
269
+
270
+ return env->NewStringUTF(formatted_chat.c_str());
271
+ }
272
+
226
273
  JNIEXPORT jobject JNICALL
227
274
  Java_com_rnllama_LlamaContext_loadSession(
228
275
  JNIEnv *env,
@@ -416,7 +463,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
416
463
 
417
464
  while (llama->has_next_token && !llama->is_interrupted) {
418
465
  const rnllama::completion_token_output token_with_probs = llama->doCompletion();
419
- if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) {
466
+ if (token_with_probs.tok == -1 || llama->incomplete) {
420
467
  continue;
421
468
  }
422
469
  const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
@@ -587,17 +634,24 @@ Java_com_rnllama_LlamaContext_embedding(
587
634
  llama->params.prompt = text_chars;
588
635
 
589
636
  llama->params.n_predict = 0;
637
+
638
+ auto result = createWriteableMap(env);
639
+ if (!llama->initSampling()) {
640
+ putString(env, result, "error", "Failed to initialize sampling");
641
+ return reinterpret_cast<jobject>(result);
642
+ }
643
+
590
644
  llama->beginCompletion();
591
645
  llama->loadPrompt();
592
646
  llama->doCompletion();
593
647
 
594
648
  std::vector<float> embedding = llama->getEmbedding();
595
649
 
596
- jobject result = createWritableArray(env);
597
-
650
+ auto embeddings = createWritableArray(env);
598
651
  for (const auto &val : embedding) {
599
- pushDouble(env, result, (double) val);
652
+ pushDouble(env, embeddings, (double) val);
600
653
  }
654
+ putArray(env, result, "embedding", embeddings);
601
655
 
602
656
  env->ReleaseStringUTFChars(text, text_chars);
603
657
  return result;
@@ -43,6 +43,11 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
43
43
  rnllama.initContext(params, promise);
44
44
  }
45
45
 
46
+ @ReactMethod
47
+ public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
48
+ rnllama.getFormattedChat(id, messages, chatTemplate, promise);
49
+ }
50
+
46
51
  @ReactMethod
47
52
  public void loadSession(double id, String path, Promise promise) {
48
53
  rnllama.loadSession(id, path, promise);
@@ -44,6 +44,11 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
44
44
  rnllama.initContext(params, promise);
45
45
  }
46
46
 
47
+ @ReactMethod
48
+ public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
49
+ rnllama.getFormattedChat(id, messages, chatTemplate, promise);
50
+ }
51
+
47
52
  @ReactMethod
48
53
  public void loadSession(double id, String path, Promise promise) {
49
54
  rnllama.loadSession(id, path, promise);