cui-llama.rn 1.3.5 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/README.md +22 -1
  2. package/android/src/main/CMakeLists.txt +25 -20
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
  5. package/android/src/main/jni-utils.h +94 -0
  6. package/android/src/main/jni.cpp +108 -37
  7. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
  8. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
  9. package/cpp/common.cpp +1982 -1965
  10. package/cpp/common.h +665 -657
  11. package/cpp/ggml-backend-reg.cpp +5 -0
  12. package/cpp/ggml-backend.cpp +5 -2
  13. package/cpp/ggml-cpp.h +1 -0
  14. package/cpp/ggml-cpu-aarch64.cpp +6 -1
  15. package/cpp/ggml-cpu-quants.c +5 -1
  16. package/cpp/ggml-cpu.c +14122 -14122
  17. package/cpp/ggml-cpu.cpp +627 -627
  18. package/cpp/ggml-impl.h +11 -16
  19. package/cpp/ggml-metal-impl.h +288 -0
  20. package/cpp/ggml-metal.m +2 -2
  21. package/cpp/ggml-opt.cpp +854 -0
  22. package/cpp/ggml-opt.h +216 -0
  23. package/cpp/ggml.c +0 -1276
  24. package/cpp/ggml.h +0 -140
  25. package/cpp/gguf.cpp +1325 -0
  26. package/cpp/gguf.h +202 -0
  27. package/cpp/llama-adapter.cpp +346 -0
  28. package/cpp/llama-adapter.h +73 -0
  29. package/cpp/llama-arch.cpp +1434 -0
  30. package/cpp/llama-arch.h +395 -0
  31. package/cpp/llama-batch.cpp +368 -0
  32. package/cpp/llama-batch.h +88 -0
  33. package/cpp/llama-chat.cpp +567 -0
  34. package/cpp/llama-chat.h +51 -0
  35. package/cpp/llama-context.cpp +1771 -0
  36. package/cpp/llama-context.h +128 -0
  37. package/cpp/llama-cparams.cpp +1 -0
  38. package/cpp/llama-cparams.h +37 -0
  39. package/cpp/llama-cpp.h +30 -0
  40. package/cpp/llama-grammar.cpp +1 -0
  41. package/cpp/llama-grammar.h +3 -1
  42. package/cpp/llama-hparams.cpp +71 -0
  43. package/cpp/llama-hparams.h +140 -0
  44. package/cpp/llama-impl.cpp +167 -0
  45. package/cpp/llama-impl.h +16 -136
  46. package/cpp/llama-kv-cache.cpp +718 -0
  47. package/cpp/llama-kv-cache.h +218 -0
  48. package/cpp/llama-mmap.cpp +589 -0
  49. package/cpp/llama-mmap.h +67 -0
  50. package/cpp/llama-model-loader.cpp +1011 -0
  51. package/cpp/llama-model-loader.h +158 -0
  52. package/cpp/llama-model.cpp +2202 -0
  53. package/cpp/llama-model.h +391 -0
  54. package/cpp/llama-sampling.cpp +117 -4
  55. package/cpp/llama-vocab.cpp +21 -28
  56. package/cpp/llama-vocab.h +13 -1
  57. package/cpp/llama.cpp +12547 -23528
  58. package/cpp/llama.h +31 -6
  59. package/cpp/rn-llama.hpp +90 -87
  60. package/cpp/sgemm.cpp +776 -70
  61. package/cpp/sgemm.h +14 -14
  62. package/cpp/unicode.cpp +6 -0
  63. package/ios/RNLlama.mm +47 -0
  64. package/ios/RNLlamaContext.h +3 -1
  65. package/ios/RNLlamaContext.mm +71 -14
  66. package/jest/mock.js +15 -3
  67. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  68. package/lib/commonjs/index.js +33 -37
  69. package/lib/commonjs/index.js.map +1 -1
  70. package/lib/module/NativeRNLlama.js.map +1 -1
  71. package/lib/module/index.js +31 -35
  72. package/lib/module/index.js.map +1 -1
  73. package/lib/typescript/NativeRNLlama.d.ts +26 -6
  74. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  75. package/lib/typescript/index.d.ts +21 -36
  76. package/lib/typescript/index.d.ts.map +1 -1
  77. package/llama-rn.podspec +4 -18
  78. package/package.json +2 -3
  79. package/src/NativeRNLlama.ts +32 -13
  80. package/src/index.ts +52 -47
package/README.md CHANGED
@@ -53,12 +53,23 @@ For get a GGUF model or quantize manually, see [`Prepare and Quantize`](https://
53
53
 
54
54
  ## Usage
55
55
 
56
+ Load model info only:
57
+
58
+ ```js
59
+ import { loadLlamaModelInfo } from 'llama.rn'
60
+
61
+ const modelPath = 'file://<path to gguf model>'
62
+ console.log('Model Info:', await loadLlamaModelInfo(modelPath))
63
+ ```
64
+
65
+ Initialize a Llama context & do completion:
66
+
56
67
  ```js
57
68
  import { initLlama } from 'llama.rn'
58
69
 
59
70
  // Initial a Llama context with the model (may take a while)
60
71
  const context = await initLlama({
61
- model: 'file://<path to gguf model>',
72
+ model: modelPath,
62
73
  use_mlock: true,
63
74
  n_ctx: 2048,
64
75
  n_gpu_layers: 1, // > 0: enable Metal on iOS
@@ -318,6 +329,16 @@ Android:
318
329
 
319
330
  See the [contributing guide](CONTRIBUTING.md) to learn how to contribute to the repository and the development workflow.
320
331
 
332
+ ## Apps using `llama.rn`
333
+
334
+ - [BRICKS](https://bricks.tools): Our product for building interactive signage in simple way. We provide LLM functions as Generator LLM/Assistant.
335
+ - [ChatterUI](https://github.com/Vali-98/ChatterUI): Simple frontend for LLMs built in react-native.
336
+ - [PocketPal AI](https://github.com/a-ghorbani/pocketpal-ai): An app that brings language models directly to your phone.
337
+
338
+ ## Node.js binding
339
+
340
+ - [llama.node](https://github.com/mybigday/llama.node): An another Node.js binding of `llama.cpp` but made API same as `llama.rn`.
341
+
321
342
  ## License
322
343
 
323
344
  MIT
@@ -9,39 +9,44 @@ include_directories(${RNLLAMA_LIB_DIR})
9
9
 
10
10
  set(
11
11
  SOURCE_FILES
12
- ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
13
- ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
14
- ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
15
- ${RNLLAMA_LIB_DIR}/log.cpp
16
-
17
- #${RNLLAMA_LIB_DIR}/amx/amx.cpp
18
- #${RNLLAMA_LIB_DIR}/amx/mmq.cpp
19
-
20
- ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
21
- ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
22
- ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
23
- ${RNLLAMA_LIB_DIR}/log.cpp
24
- ${RNLLAMA_LIB_DIR}/json.hpp
25
- ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
26
-
12
+ ${RNLLAMA_LIB_DIR}/ggml.c
27
13
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
28
14
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
29
15
  ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
30
- ${RNLLAMA_LIB_DIR}/ggml.c
31
16
  ${RNLLAMA_LIB_DIR}/ggml-cpu.c
32
17
  ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
33
18
  ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
34
- ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
35
19
  ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
20
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
21
+ ${RNLLAMA_LIB_DIR}/ggml-opt.cpp
36
22
  ${RNLLAMA_LIB_DIR}/ggml-threading.cpp
37
23
  ${RNLLAMA_LIB_DIR}/ggml-quants.c
38
- ${RNLLAMA_LIB_DIR}/common.cpp
24
+ ${RNLLAMA_LIB_DIR}/gguf.cpp
25
+ ${RNLLAMA_LIB_DIR}/log.cpp
26
+ ${RNLLAMA_LIB_DIR}/llama-impl.cpp
27
+ ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
28
+ ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
29
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
30
+ ${RNLLAMA_LIB_DIR}/llama-adapter.cpp
31
+ ${RNLLAMA_LIB_DIR}/llama-chat.cpp
32
+ ${RNLLAMA_LIB_DIR}/llama-context.cpp
33
+ ${RNLLAMA_LIB_DIR}/llama-kv-cache.cpp
34
+ ${RNLLAMA_LIB_DIR}/llama-arch.cpp
35
+ ${RNLLAMA_LIB_DIR}/llama-batch.cpp
36
+ ${RNLLAMA_LIB_DIR}/llama-cparams.cpp
37
+ ${RNLLAMA_LIB_DIR}/llama-hparams.cpp
38
+ ${RNLLAMA_LIB_DIR}/llama.cpp
39
+ ${RNLLAMA_LIB_DIR}/llama-model.cpp
40
+ ${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
41
+ ${RNLLAMA_LIB_DIR}/llama-mmap.cpp
42
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
39
43
  ${RNLLAMA_LIB_DIR}/sampling.cpp
40
44
  ${RNLLAMA_LIB_DIR}/unicode-data.cpp
41
45
  ${RNLLAMA_LIB_DIR}/unicode.cpp
42
- ${RNLLAMA_LIB_DIR}/llama.cpp
43
46
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
47
+ ${RNLLAMA_LIB_DIR}/common.cpp
44
48
  ${RNLLAMA_LIB_DIR}/rn-llama.hpp
49
+ ${CMAKE_SOURCE_DIR}/jni-utils.h
45
50
  ${CMAKE_SOURCE_DIR}/jni.cpp
46
51
  )
47
52
 
@@ -56,7 +61,7 @@ function(build_library target_name cpu_flags)
56
61
 
57
62
  target_link_libraries(${target_name} ${LOG_LIB} android)
58
63
 
59
- target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags} -DLM_GGML_USE_CPU -DLM_GGML_USE_CPU_AARCH64)
64
+ target_compile_options(${target_name} PRIVATE -DLM_GGML_USE_CPU -DLM_GGML_USE_CPU_AARCH64 -pthread ${cpu_flags})
60
65
 
61
66
  if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
62
67
  target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING)
@@ -108,6 +108,8 @@ public class LlamaContext {
108
108
  params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
109
109
  // int n_batch,
110
110
  params.hasKey("n_batch") ? params.getInt("n_batch") : 512,
111
+ // int n_ubatch,
112
+ params.hasKey("n_ubatch") ? params.getInt("n_ubatch") : 512,
111
113
  // int n_threads,
112
114
  params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
113
115
  // int n_gpu_layers, // TODO: Support this
@@ -115,9 +117,9 @@ public class LlamaContext {
115
117
  // boolean flash_attn,
116
118
  params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
117
119
  // String cache_type_k,
118
- params.hasKey("cache_type_k") ? params.getInt("cache_type_k") : 1,
120
+ params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
119
121
  // String cache_type_v,
120
- params.hasKey("cache_type_v") ? params.getInt("cache_type_v") : 1,
122
+ params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
121
123
  // boolean use_mlock,
122
124
  params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
123
125
  // boolean use_mmap,
@@ -128,6 +130,8 @@ public class LlamaContext {
128
130
  params.hasKey("lora") ? params.getString("lora") : "",
129
131
  // float lora_scaled,
130
132
  params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
133
+ // ReadableArray lora_adapters,
134
+ params.hasKey("lora_list") ? params.getArray("lora_list") : null,
131
135
  // float rope_freq_base,
132
136
  params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
133
137
  // float rope_freq_scale
@@ -168,7 +172,7 @@ public class LlamaContext {
168
172
  WritableMap event = Arguments.createMap();
169
173
  event.putInt("contextId", LlamaContext.this.id);
170
174
  event.putInt("progress", progress);
171
- eventEmitter.emit("@RNLlama_onInitContextProgress", event);
175
+ eventEmitter.emit("@RNLlama_onContextProgress", event);
172
176
  }
173
177
 
174
178
  private static class LoadProgressCallback {
@@ -273,8 +277,6 @@ public class LlamaContext {
273
277
  params.hasKey("mirostat_tau") ? (float) params.getDouble("mirostat_tau") : 5.00f,
274
278
  // float mirostat_eta,
275
279
  params.hasKey("mirostat_eta") ? (float) params.getDouble("mirostat_eta") : 0.10f,
276
- // boolean penalize_nl,
277
- params.hasKey("penalize_nl") ? params.getBoolean("penalize_nl") : false,
278
280
  // int top_k,
279
281
  params.hasKey("top_k") ? params.getInt("top_k") : 40,
280
282
  // float top_p,
@@ -359,6 +361,22 @@ public class LlamaContext {
359
361
  return bench(this.context, pp, tg, pl, nr);
360
362
  }
361
363
 
364
+ public int applyLoraAdapters(ReadableArray loraAdapters) {
365
+ int result = applyLoraAdapters(this.context, loraAdapters);
366
+ if (result != 0) {
367
+ throw new IllegalStateException("Failed to apply lora adapters");
368
+ }
369
+ return result;
370
+ }
371
+
372
+ public void removeLoraAdapters() {
373
+ removeLoraAdapters(this.context);
374
+ }
375
+
376
+ public WritableArray getLoadedLoraAdapters() {
377
+ return getLoadedLoraAdapters(this.context);
378
+ }
379
+
362
380
  public void release() {
363
381
  freeContext(context);
364
382
  }
@@ -460,16 +478,18 @@ public class LlamaContext {
460
478
  int embd_normalize,
461
479
  int n_ctx,
462
480
  int n_batch,
481
+ int n_ubatch,
463
482
  int n_threads,
464
483
  int n_gpu_layers, // TODO: Support this
465
484
  boolean flash_attn,
466
- int cache_type_k,
467
- int cache_type_v,
485
+ String cache_type_k,
486
+ String cache_type_v,
468
487
  boolean use_mlock,
469
488
  boolean use_mmap,
470
489
  boolean vocab_only,
471
490
  String lora,
472
491
  float lora_scaled,
492
+ ReadableArray lora_list,
473
493
  float rope_freq_base,
474
494
  float rope_freq_scale,
475
495
  int pooling_type,
@@ -508,7 +528,6 @@ public class LlamaContext {
508
528
  float mirostat,
509
529
  float mirostat_tau,
510
530
  float mirostat_eta,
511
- boolean penalize_nl,
512
531
  int top_k,
513
532
  float top_p,
514
533
  float min_p,
@@ -521,7 +540,7 @@ public class LlamaContext {
521
540
  double[][] logit_bias,
522
541
  float dry_multiplier,
523
542
  float dry_base,
524
- int dry_allowed_length,
543
+ int dry_allowed_length,
525
544
  int dry_penalty_last_n,
526
545
  String[] dry_sequence_breakers,
527
546
  PartialCompletionCallback partial_completion_callback
@@ -537,6 +556,9 @@ public class LlamaContext {
537
556
  int embd_normalize
538
557
  );
539
558
  protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
559
+ protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters);
560
+ protected static native void removeLoraAdapters(long contextPtr);
561
+ protected static native WritableArray getLoadedLoraAdapters(long contextPtr);
540
562
  protected static native void freeContext(long contextPtr);
541
563
  protected static native void logToAndroid();
542
564
  }
@@ -462,6 +462,104 @@ public class RNLlama implements LifecycleEventListener {
462
462
  tasks.put(task, "bench-" + contextId);
463
463
  }
464
464
 
465
+ public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
466
+ final int contextId = (int) id;
467
+ AsyncTask task = new AsyncTask<Void, Void, Void>() {
468
+ private Exception exception;
469
+
470
+ @Override
471
+ protected Void doInBackground(Void... voids) {
472
+ try {
473
+ LlamaContext context = contexts.get(contextId);
474
+ if (context == null) {
475
+ throw new Exception("Context not found");
476
+ }
477
+ if (context.isPredicting()) {
478
+ throw new Exception("Context is busy");
479
+ }
480
+ context.applyLoraAdapters(loraAdapters);
481
+ } catch (Exception e) {
482
+ exception = e;
483
+ }
484
+ return null;
485
+ }
486
+
487
+ @Override
488
+ protected void onPostExecute(Void result) {
489
+ if (exception != null) {
490
+ promise.reject(exception);
491
+ return;
492
+ }
493
+ }
494
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
495
+ tasks.put(task, "applyLoraAdapters-" + contextId);
496
+ }
497
+
498
+ public void removeLoraAdapters(double id, final Promise promise) {
499
+ final int contextId = (int) id;
500
+ AsyncTask task = new AsyncTask<Void, Void, Void>() {
501
+ private Exception exception;
502
+
503
+ @Override
504
+ protected Void doInBackground(Void... voids) {
505
+ try {
506
+ LlamaContext context = contexts.get(contextId);
507
+ if (context == null) {
508
+ throw new Exception("Context not found");
509
+ }
510
+ if (context.isPredicting()) {
511
+ throw new Exception("Context is busy");
512
+ }
513
+ context.removeLoraAdapters();
514
+ } catch (Exception e) {
515
+ exception = e;
516
+ }
517
+ return null;
518
+ }
519
+
520
+ @Override
521
+ protected void onPostExecute(Void result) {
522
+ if (exception != null) {
523
+ promise.reject(exception);
524
+ return;
525
+ }
526
+ promise.resolve(null);
527
+ }
528
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
529
+ tasks.put(task, "removeLoraAdapters-" + contextId);
530
+ }
531
+
532
+ public void getLoadedLoraAdapters(double id, final Promise promise) {
533
+ final int contextId = (int) id;
534
+ AsyncTask task = new AsyncTask<Void, Void, ReadableArray>() {
535
+ private Exception exception;
536
+
537
+ @Override
538
+ protected ReadableArray doInBackground(Void... voids) {
539
+ try {
540
+ LlamaContext context = contexts.get(contextId);
541
+ if (context == null) {
542
+ throw new Exception("Context not found");
543
+ }
544
+ return context.getLoadedLoraAdapters();
545
+ } catch (Exception e) {
546
+ exception = e;
547
+ }
548
+ return null;
549
+ }
550
+
551
+ @Override
552
+ protected void onPostExecute(ReadableArray result) {
553
+ if (exception != null) {
554
+ promise.reject(exception);
555
+ return;
556
+ }
557
+ promise.resolve(result);
558
+ }
559
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
560
+ tasks.put(task, "getLoadedLoraAdapters-" + contextId);
561
+ }
562
+
465
563
  public void releaseContext(double id, Promise promise) {
466
564
  final int contextId = (int) id;
467
565
  AsyncTask task = new AsyncTask<Void, Void, Void>() {
@@ -0,0 +1,94 @@
1
+ #include <jni.h>
2
+
3
+ // ReadableMap utils
4
+
5
+ namespace readablearray {
6
+
7
+ int size(JNIEnv *env, jobject readableArray) {
8
+ jclass arrayClass = env->GetObjectClass(readableArray);
9
+ jmethodID sizeMethod = env->GetMethodID(arrayClass, "size", "()I");
10
+ return env->CallIntMethod(readableArray, sizeMethod);
11
+ }
12
+
13
+ jobject getMap(JNIEnv *env, jobject readableArray, int index) {
14
+ jclass arrayClass = env->GetObjectClass(readableArray);
15
+ jmethodID getMapMethod = env->GetMethodID(arrayClass, "getMap", "(I)Lcom/facebook/react/bridge/ReadableMap;");
16
+ return env->CallObjectMethod(readableArray, getMapMethod, index);
17
+ }
18
+
19
+ // Other methods not used yet
20
+
21
+ }
22
+
23
+ namespace readablemap {
24
+
25
+ bool hasKey(JNIEnv *env, jobject readableMap, const char *key) {
26
+ jclass mapClass = env->GetObjectClass(readableMap);
27
+ jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
28
+ jstring jKey = env->NewStringUTF(key);
29
+ jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
30
+ env->DeleteLocalRef(jKey);
31
+ return result;
32
+ }
33
+
34
+ int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
35
+ if (!hasKey(env, readableMap, key)) {
36
+ return defaultValue;
37
+ }
38
+ jclass mapClass = env->GetObjectClass(readableMap);
39
+ jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
40
+ jstring jKey = env->NewStringUTF(key);
41
+ jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
42
+ env->DeleteLocalRef(jKey);
43
+ return result;
44
+ }
45
+
46
+ bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
47
+ if (!hasKey(env, readableMap, key)) {
48
+ return defaultValue;
49
+ }
50
+ jclass mapClass = env->GetObjectClass(readableMap);
51
+ jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
52
+ jstring jKey = env->NewStringUTF(key);
53
+ jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
54
+ env->DeleteLocalRef(jKey);
55
+ return result;
56
+ }
57
+
58
+ long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
59
+ if (!hasKey(env, readableMap, key)) {
60
+ return defaultValue;
61
+ }
62
+ jclass mapClass = env->GetObjectClass(readableMap);
63
+ jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
64
+ jstring jKey = env->NewStringUTF(key);
65
+ jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
66
+ env->DeleteLocalRef(jKey);
67
+ return result;
68
+ }
69
+
70
+ float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
71
+ if (!hasKey(env, readableMap, key)) {
72
+ return defaultValue;
73
+ }
74
+ jclass mapClass = env->GetObjectClass(readableMap);
75
+ jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
76
+ jstring jKey = env->NewStringUTF(key);
77
+ jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
78
+ env->DeleteLocalRef(jKey);
79
+ return result;
80
+ }
81
+
82
+ jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
83
+ if (!hasKey(env, readableMap, key)) {
84
+ return defaultValue;
85
+ }
86
+ jclass mapClass = env->GetObjectClass(readableMap);
87
+ jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
88
+ jstring jKey = env->NewStringUTF(key);
89
+ jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
90
+ env->DeleteLocalRef(jKey);
91
+ return result;
92
+ }
93
+
94
+ }
@@ -11,15 +11,17 @@
11
11
  #include <unordered_map>
12
12
  #include "llama.h"
13
13
  #include "llama-impl.h"
14
- #include "ggml.h"
14
+ #include "llama-context.h"
15
+ #include "gguf.h"
15
16
  #include "rn-llama.hpp"
17
+ #include "jni-utils.h"
16
18
 
17
19
  #define UNUSED(x) (void)(x)
18
20
  #define TAG "RNLLAMA_ANDROID_JNI"
19
21
 
20
22
  #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
21
23
  #define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
22
-
24
+ #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
23
25
  static inline int min(int a, int b) {
24
26
  return (a < b) ? a : b;
25
27
  }
@@ -128,7 +130,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
128
130
  // Method to push WritableMap into WritableArray
129
131
  static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
130
132
  jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
131
- jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V");
133
+ jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
132
134
 
133
135
  env->CallVoidMethod(arr, pushMapMethod, value);
134
136
  }
@@ -198,7 +200,7 @@ Java_com_rnllama_LlamaContext_modelInfo(
198
200
  continue;
199
201
  }
200
202
 
201
- const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
203
+ const std::string value = lm_gguf_kv_to_str(ctx, i);
202
204
  putString(env, info, key, value.c_str());
203
205
  }
204
206
  }
@@ -233,16 +235,18 @@ Java_com_rnllama_LlamaContext_initContext(
233
235
  jint embd_normalize,
234
236
  jint n_ctx,
235
237
  jint n_batch,
238
+ jint n_ubatch,
236
239
  jint n_threads,
237
240
  jint n_gpu_layers, // TODO: Support this
238
241
  jboolean flash_attn,
239
- jint cache_type_k,
240
- jint cache_type_v,
242
+ jstring cache_type_k,
243
+ jstring cache_type_v,
241
244
  jboolean use_mlock,
242
245
  jboolean use_mmap,
243
246
  jboolean vocab_only,
244
247
  jstring lora_str,
245
248
  jfloat lora_scaled,
249
+ jobject lora_list,
246
250
  jfloat rope_freq_base,
247
251
  jfloat rope_freq_scale,
248
252
  jint pooling_type,
@@ -262,6 +266,7 @@ Java_com_rnllama_LlamaContext_initContext(
262
266
 
263
267
  defaultParams.n_ctx = n_ctx;
264
268
  defaultParams.n_batch = n_batch;
269
+ defaultParams.n_ubatch = n_ubatch;
265
270
 
266
271
  if (pooling_type != -1) {
267
272
  defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
@@ -284,19 +289,14 @@ Java_com_rnllama_LlamaContext_initContext(
284
289
  // defaultParams.n_gpu_layers = n_gpu_layers;
285
290
  defaultParams.flash_attn = flash_attn;
286
291
 
287
- // const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
- // const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
- defaultParams.cache_type_k = (lm_ggml_type) cache_type_k;
290
- defaultParams.cache_type_v = (lm_ggml_type) cache_type_v;
292
+ const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
293
+ const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
294
+ defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
295
+ defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
291
296
 
292
297
  defaultParams.use_mlock = use_mlock;
293
298
  defaultParams.use_mmap = use_mmap;
294
299
 
295
- const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
296
- if (lora_chars != nullptr && lora_chars[0] != '\0') {
297
- defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
298
- }
299
-
300
300
  defaultParams.rope_freq_base = rope_freq_base;
301
301
  defaultParams.rope_freq_scale = rope_freq_scale;
302
302
 
@@ -330,20 +330,52 @@ Java_com_rnllama_LlamaContext_initContext(
330
330
  bool is_model_loaded = llama->loadModel(defaultParams);
331
331
 
332
332
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
333
- env->ReleaseStringUTFChars(lora_str, lora_chars);
334
- // env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
- // env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
333
+ env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
334
+ env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
336
335
 
337
336
  LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
338
337
  if (is_model_loaded) {
339
- if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
340
- LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
341
- llama_free(llama->ctx);
342
- return -1;
343
- }
344
- context_map[(long) llama->ctx] = llama;
338
+ if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
339
+ LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
340
+ llama_free(llama->ctx);
341
+ return -1;
342
+ }
343
+ context_map[(long) llama->ctx] = llama;
345
344
  } else {
345
+ llama_free(llama->ctx);
346
+ }
347
+
348
+ std::vector<common_lora_adapter_info> lora;
349
+ const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
350
+ if (lora_chars != nullptr && lora_chars[0] != '\0') {
351
+ common_lora_adapter_info la;
352
+ la.path = lora_chars;
353
+ la.scale = lora_scaled;
354
+ lora.push_back(la);
355
+ }
356
+
357
+ if (lora_list != nullptr) {
358
+ // lora_adapters: ReadableArray<ReadableMap>
359
+ int lora_list_size = readablearray::size(env, lora_list);
360
+ for (int i = 0; i < lora_list_size; i++) {
361
+ jobject lora_adapter = readablearray::getMap(env, lora_list, i);
362
+ jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
363
+ if (path != nullptr) {
364
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
365
+ common_lora_adapter_info la;
366
+ la.path = path_chars;
367
+ la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
368
+ lora.push_back(la);
369
+ env->ReleaseStringUTFChars(path, path_chars);
370
+ }
371
+ }
372
+ }
373
+ env->ReleaseStringUTFChars(lora_str, lora_chars);
374
+ int result = llama->applyLoraAdapters(lora);
375
+ if (result != 0) {
376
+ LOGI("[RNLlama] Failed to apply lora adapters");
346
377
  llama_free(llama->ctx);
378
+ return -1;
347
379
  }
348
380
 
349
381
  return reinterpret_cast<jlong>(llama->ctx);
@@ -532,7 +564,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
532
564
  jfloat mirostat,
533
565
  jfloat mirostat_tau,
534
566
  jfloat mirostat_eta,
535
- jboolean penalize_nl,
536
567
  jint top_k,
537
568
  jfloat top_p,
538
569
  jfloat min_p,
@@ -545,7 +576,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
545
576
  jobjectArray logit_bias,
546
577
  jfloat dry_multiplier,
547
578
  jfloat dry_base,
548
- jint dry_allowed_length,
579
+ jint dry_allowed_length,
549
580
  jint dry_penalty_last_n,
550
581
  jobjectArray dry_sequence_breakers,
551
582
  jobject partial_completion_callback
@@ -577,7 +608,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
577
608
  sparams.mirostat = mirostat;
578
609
  sparams.mirostat_tau = mirostat_tau;
579
610
  sparams.mirostat_eta = mirostat_eta;
580
- // sparams.penalize_nl = penalize_nl;
581
611
  sparams.top_k = top_k;
582
612
  sparams.top_p = top_p;
583
613
  sparams.min_p = min_p;
@@ -884,23 +914,64 @@ Java_com_rnllama_LlamaContext_bench(
884
914
  return env->NewStringUTF(result.c_str());
885
915
  }
886
916
 
917
+ JNIEXPORT jint JNICALL
918
+ Java_com_rnllama_LlamaContext_applyLoraAdapters(
919
+ JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
920
+ UNUSED(thiz);
921
+ auto llama = context_map[(long) context_ptr];
922
+
923
+ // lora_adapters: ReadableArray<ReadableMap>
924
+ std::vector<common_lora_adapter_info> lora_adapters;
925
+ int lora_adapters_size = readablearray::size(env, loraAdapters);
926
+ for (int i = 0; i < lora_adapters_size; i++) {
927
+ jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
928
+ jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
929
+ if (path != nullptr) {
930
+ const char *path_chars = env->GetStringUTFChars(path, nullptr);
931
+ env->ReleaseStringUTFChars(path, path_chars);
932
+ float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
933
+ common_lora_adapter_info la;
934
+ la.path = path_chars;
935
+ la.scale = scaled;
936
+ lora_adapters.push_back(la);
937
+ }
938
+ }
939
+ return llama->applyLoraAdapters(lora_adapters);
940
+ }
941
+
942
+ JNIEXPORT void JNICALL
943
+ Java_com_rnllama_LlamaContext_removeLoraAdapters(
944
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
945
+ UNUSED(env);
946
+ UNUSED(thiz);
947
+ auto llama = context_map[(long) context_ptr];
948
+ llama->removeLoraAdapters();
949
+ }
950
+
951
+ JNIEXPORT jobject JNICALL
952
+ Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
953
+ JNIEnv *env, jobject thiz, jlong context_ptr) {
954
+ UNUSED(thiz);
955
+ auto llama = context_map[(long) context_ptr];
956
+ auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
957
+ auto result = createWritableArray(env);
958
+ for (common_lora_adapter_info &la : loaded_lora_adapters) {
959
+ auto map = createWriteableMap(env);
960
+ putString(env, map, "path", la.path.c_str());
961
+ putDouble(env, map, "scaled", la.scale);
962
+ pushMap(env, result, map);
963
+ }
964
+ return result;
965
+ }
966
+
887
967
  JNIEXPORT void JNICALL
888
968
  Java_com_rnllama_LlamaContext_freeContext(
889
969
  JNIEnv *env, jobject thiz, jlong context_ptr) {
890
970
  UNUSED(env);
891
971
  UNUSED(thiz);
892
972
  auto llama = context_map[(long) context_ptr];
893
- if (llama->model) {
894
- llama_free_model(llama->model);
895
- }
896
- if (llama->ctx) {
897
- llama_free(llama->ctx);
898
- }
899
- if (llama->ctx_sampling != nullptr)
900
- {
901
- common_sampler_free(llama->ctx_sampling);
902
- }
903
973
  context_map.erase((long) llama->ctx);
974
+ delete llama;
904
975
  }
905
976
 
906
977
  JNIEXPORT void JNICALL