cui-llama.rn 1.3.6 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/README.md +22 -1
  2. package/android/src/main/CMakeLists.txt +25 -26
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
  5. package/android/src/main/jni-utils.h +94 -0
  6. package/android/src/main/jni.cpp +132 -62
  7. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
  8. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
  9. package/cpp/common.cpp +1982 -1982
  10. package/cpp/common.h +665 -664
  11. package/cpp/ggml-cpu.c +14122 -14122
  12. package/cpp/ggml-cpu.cpp +627 -627
  13. package/cpp/ggml-metal-impl.h +288 -0
  14. package/cpp/ggml-opt.cpp +854 -0
  15. package/cpp/ggml-opt.h +216 -0
  16. package/cpp/llama-mmap.cpp +589 -589
  17. package/cpp/llama.cpp +12547 -12544
  18. package/cpp/rn-llama.hpp +117 -116
  19. package/cpp/sgemm.h +14 -14
  20. package/ios/RNLlama.mm +47 -0
  21. package/ios/RNLlamaContext.h +3 -1
  22. package/ios/RNLlamaContext.mm +71 -14
  23. package/jest/mock.js +15 -3
  24. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  25. package/lib/commonjs/index.js +33 -37
  26. package/lib/commonjs/index.js.map +1 -1
  27. package/lib/module/NativeRNLlama.js.map +1 -1
  28. package/lib/module/index.js +31 -35
  29. package/lib/module/index.js.map +1 -1
  30. package/lib/typescript/NativeRNLlama.d.ts +26 -6
  31. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  32. package/lib/typescript/index.d.ts +21 -36
  33. package/lib/typescript/index.d.ts.map +1 -1
  34. package/llama-rn.podspec +4 -18
  35. package/package.json +2 -3
  36. package/src/NativeRNLlama.ts +32 -13
  37. package/src/index.ts +52 -47
  38. package/cpp/llama.cpp.rej +0 -23
package/README.md CHANGED
@@ -53,12 +53,23 @@ For get a GGUF model or quantize manually, see [`Prepare and Quantize`](https://
53
53
 
54
54
  ## Usage
55
55
 
56
+ Load model info only:
57
+
58
+ ```js
59
+ import { loadLlamaModelInfo } from 'llama.rn'
60
+
61
+ const modelPath = 'file://<path to gguf model>'
62
+ console.log('Model Info:', await loadLlamaModelInfo(modelPath))
63
+ ```
64
+
65
+ Initialize a Llama context & do completion:
66
+
56
67
  ```js
57
68
  import { initLlama } from 'llama.rn'
58
69
 
59
70
  // Initial a Llama context with the model (may take a while)
60
71
  const context = await initLlama({
61
- model: 'file://<path to gguf model>',
72
+ model: modelPath,
62
73
  use_mlock: true,
63
74
  n_ctx: 2048,
64
75
  n_gpu_layers: 1, // > 0: enable Metal on iOS
@@ -318,6 +329,16 @@ Android:
318
329
 
319
330
  See the [contributing guide](CONTRIBUTING.md) to learn how to contribute to the repository and the development workflow.
320
331
 
332
+ ## Apps using `llama.rn`
333
+
334
+ - [BRICKS](https://bricks.tools): Our product for building interactive signage in simple way. We provide LLM functions as Generator LLM/Assistant.
335
+ - [ChatterUI](https://github.com/Vali-98/ChatterUI): Simple frontend for LLMs built in react-native.
336
+ - [PocketPal AI](https://github.com/a-ghorbani/pocketpal-ai): An app that brings language models directly to your phone.
337
+
338
+ ## Node.js binding
339
+
340
+ - [llama.node](https://github.com/mybigday/llama.node): An another Node.js binding of `llama.cpp` but made API same as `llama.rn`.
341
+
321
342
  ## License
322
343
 
323
344
  MIT
@@ -9,45 +9,44 @@ include_directories(${RNLLAMA_LIB_DIR})
9
9
 
10
10
  set(
11
11
  SOURCE_FILES
12
-
13
- ${RNLLAMA_LIB_DIR}/common.cpp
14
- ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
15
- ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
16
- ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
17
- ${RNLLAMA_LIB_DIR}/llama-chat.cpp
18
- ${RNLLAMA_LIB_DIR}/llama-mmap.cpp
19
- ${RNLLAMA_LIB_DIR}/llama-context.cpp
20
- ${RNLLAMA_LIB_DIR}/llama-kv-cache.cpp
21
- ${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
22
- ${RNLLAMA_LIB_DIR}/llama-model.cpp
23
- ${RNLLAMA_LIB_DIR}/llama-batch.cpp
24
- ${RNLLAMA_LIB_DIR}/llama-arch.cpp
25
- ${RNLLAMA_LIB_DIR}/llama-cparams.cpp
26
- ${RNLLAMA_LIB_DIR}/llama-hparams.cpp
27
- ${RNLLAMA_LIB_DIR}/llama-adapter.cpp
28
- ${RNLLAMA_LIB_DIR}/llama-impl.cpp
29
- ${RNLLAMA_LIB_DIR}/log.cpp
30
- ${RNLLAMA_LIB_DIR}/json.hpp
31
- ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
32
-
12
+ ${RNLLAMA_LIB_DIR}/ggml.c
33
13
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
34
14
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
35
15
  ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
36
- ${RNLLAMA_LIB_DIR}/ggml.c
37
- ${RNLLAMA_LIB_DIR}/gguf.cpp
38
16
  ${RNLLAMA_LIB_DIR}/ggml-cpu.c
39
17
  ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
40
18
  ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
41
- ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
42
19
  ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
20
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
21
+ ${RNLLAMA_LIB_DIR}/ggml-opt.cpp
43
22
  ${RNLLAMA_LIB_DIR}/ggml-threading.cpp
44
23
  ${RNLLAMA_LIB_DIR}/ggml-quants.c
24
+ ${RNLLAMA_LIB_DIR}/gguf.cpp
25
+ ${RNLLAMA_LIB_DIR}/log.cpp
26
+ ${RNLLAMA_LIB_DIR}/llama-impl.cpp
27
+ ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
28
+ ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
29
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
30
+ ${RNLLAMA_LIB_DIR}/llama-adapter.cpp
31
+ ${RNLLAMA_LIB_DIR}/llama-chat.cpp
32
+ ${RNLLAMA_LIB_DIR}/llama-context.cpp
33
+ ${RNLLAMA_LIB_DIR}/llama-kv-cache.cpp
34
+ ${RNLLAMA_LIB_DIR}/llama-arch.cpp
35
+ ${RNLLAMA_LIB_DIR}/llama-batch.cpp
36
+ ${RNLLAMA_LIB_DIR}/llama-cparams.cpp
37
+ ${RNLLAMA_LIB_DIR}/llama-hparams.cpp
38
+ ${RNLLAMA_LIB_DIR}/llama.cpp
39
+ ${RNLLAMA_LIB_DIR}/llama-model.cpp
40
+ ${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
41
+ ${RNLLAMA_LIB_DIR}/llama-mmap.cpp
42
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
45
43
  ${RNLLAMA_LIB_DIR}/sampling.cpp
46
44
  ${RNLLAMA_LIB_DIR}/unicode-data.cpp
47
45
  ${RNLLAMA_LIB_DIR}/unicode.cpp
48
- ${RNLLAMA_LIB_DIR}/llama.cpp
49
46
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
47
+ ${RNLLAMA_LIB_DIR}/common.cpp
50
48
  ${RNLLAMA_LIB_DIR}/rn-llama.hpp
49
+ ${CMAKE_SOURCE_DIR}/jni-utils.h
51
50
  ${CMAKE_SOURCE_DIR}/jni.cpp
52
51
  )
53
52
 
@@ -62,7 +61,7 @@ function(build_library target_name cpu_flags)
62
61
 
63
62
  target_link_libraries(${target_name} ${LOG_LIB} android)
64
63
 
65
- target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags} -DLM_GGML_USE_CPU -DLM_GGML_USE_CPU_AARCH64)
64
+ target_compile_options(${target_name} PRIVATE -DLM_GGML_USE_CPU -DLM_GGML_USE_CPU_AARCH64 -pthread ${cpu_flags})
66
65
 
67
66
  if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
68
67
  target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING)
@@ -108,6 +108,8 @@ public class LlamaContext {
108
108
  params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
109
109
  // int n_batch,
110
110
  params.hasKey("n_batch") ? params.getInt("n_batch") : 512,
111
+ // int n_ubatch,
112
+ params.hasKey("n_ubatch") ? params.getInt("n_ubatch") : 512,
111
113
  // int n_threads,
112
114
  params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
113
115
  // int n_gpu_layers, // TODO: Support this
@@ -115,9 +117,9 @@ public class LlamaContext {
115
117
  // boolean flash_attn,
116
118
  params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
117
119
  // String cache_type_k,
118
- params.hasKey("cache_type_k") ? params.getInt("cache_type_k") : 1,
120
+ params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
119
121
  // String cache_type_v,
120
- params.hasKey("cache_type_v") ? params.getInt("cache_type_v") : 1,
122
+ params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
121
123
  // boolean use_mlock,
122
124
  params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
123
125
  // boolean use_mmap,
@@ -128,6 +130,8 @@ public class LlamaContext {
128
130
  params.hasKey("lora") ? params.getString("lora") : "",
129
131
  // float lora_scaled,
130
132
  params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
133
+ // ReadableArray lora_adapters,
134
+ params.hasKey("lora_list") ? params.getArray("lora_list") : null,
131
135
  // float rope_freq_base,
132
136
  params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
133
137
  // float rope_freq_scale
@@ -168,7 +172,7 @@ public class LlamaContext {
168
172
  WritableMap event = Arguments.createMap();
169
173
  event.putInt("contextId", LlamaContext.this.id);
170
174
  event.putInt("progress", progress);
171
- eventEmitter.emit("@RNLlama_onInitContextProgress", event);
175
+ eventEmitter.emit("@RNLlama_onContextProgress", event);
172
176
  }
173
177
 
174
178
  private static class LoadProgressCallback {
@@ -273,8 +277,6 @@ public class LlamaContext {
273
277
  params.hasKey("mirostat_tau") ? (float) params.getDouble("mirostat_tau") : 5.00f,
274
278
  // float mirostat_eta,
275
279
  params.hasKey("mirostat_eta") ? (float) params.getDouble("mirostat_eta") : 0.10f,
276
- // boolean penalize_nl,
277
- params.hasKey("penalize_nl") ? params.getBoolean("penalize_nl") : false,
278
280
  // int top_k,
279
281
  params.hasKey("top_k") ? params.getInt("top_k") : 40,
280
282
  // float top_p,
@@ -359,6 +361,22 @@ public class LlamaContext {
359
361
  return bench(this.context, pp, tg, pl, nr);
360
362
  }
361
363
 
364
+ public int applyLoraAdapters(ReadableArray loraAdapters) {
365
+ int result = applyLoraAdapters(this.context, loraAdapters);
366
+ if (result != 0) {
367
+ throw new IllegalStateException("Failed to apply lora adapters");
368
+ }
369
+ return result;
370
+ }
371
+
372
+ public void removeLoraAdapters() {
373
+ removeLoraAdapters(this.context);
374
+ }
375
+
376
+ public WritableArray getLoadedLoraAdapters() {
377
+ return getLoadedLoraAdapters(this.context);
378
+ }
379
+
362
380
  public void release() {
363
381
  freeContext(context);
364
382
  }
@@ -460,16 +478,18 @@ public class LlamaContext {
460
478
  int embd_normalize,
461
479
  int n_ctx,
462
480
  int n_batch,
481
+ int n_ubatch,
463
482
  int n_threads,
464
483
  int n_gpu_layers, // TODO: Support this
465
484
  boolean flash_attn,
466
- int cache_type_k,
467
- int cache_type_v,
485
+ String cache_type_k,
486
+ String cache_type_v,
468
487
  boolean use_mlock,
469
488
  boolean use_mmap,
470
489
  boolean vocab_only,
471
490
  String lora,
472
491
  float lora_scaled,
492
+ ReadableArray lora_list,
473
493
  float rope_freq_base,
474
494
  float rope_freq_scale,
475
495
  int pooling_type,
@@ -508,7 +528,6 @@ public class LlamaContext {
508
528
  float mirostat,
509
529
  float mirostat_tau,
510
530
  float mirostat_eta,
511
- boolean penalize_nl,
512
531
  int top_k,
513
532
  float top_p,
514
533
  float min_p,
@@ -521,7 +540,7 @@ public class LlamaContext {
521
540
  double[][] logit_bias,
522
541
  float dry_multiplier,
523
542
  float dry_base,
524
- int dry_allowed_length,
543
+ int dry_allowed_length,
525
544
  int dry_penalty_last_n,
526
545
  String[] dry_sequence_breakers,
527
546
  PartialCompletionCallback partial_completion_callback
@@ -537,6 +556,9 @@ public class LlamaContext {
537
556
  int embd_normalize
538
557
  );
539
558
  protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
559
+ protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters);
560
+ protected static native void removeLoraAdapters(long contextPtr);
561
+ protected static native WritableArray getLoadedLoraAdapters(long contextPtr);
540
562
  protected static native void freeContext(long contextPtr);
541
563
  protected static native void logToAndroid();
542
564
  }
@@ -462,6 +462,104 @@ public class RNLlama implements LifecycleEventListener {
462
462
  tasks.put(task, "bench-" + contextId);
463
463
  }
464
464
 
465
+ public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
466
+ final int contextId = (int) id;
467
+ AsyncTask task = new AsyncTask<Void, Void, Void>() {
468
+ private Exception exception;
469
+
470
+ @Override
471
+ protected Void doInBackground(Void... voids) {
472
+ try {
473
+ LlamaContext context = contexts.get(contextId);
474
+ if (context == null) {
475
+ throw new Exception("Context not found");
476
+ }
477
+ if (context.isPredicting()) {
478
+ throw new Exception("Context is busy");
479
+ }
480
+ context.applyLoraAdapters(loraAdapters);
481
+ } catch (Exception e) {
482
+ exception = e;
483
+ }
484
+ return null;
485
+ }
486
+
487
+ @Override
488
+ protected void onPostExecute(Void result) {
489
+ if (exception != null) {
490
+ promise.reject(exception);
491
+ return;
492
+ }
493
+ }
494
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
495
+ tasks.put(task, "applyLoraAdapters-" + contextId);
496
+ }
497
+
498
+ public void removeLoraAdapters(double id, final Promise promise) {
499
+ final int contextId = (int) id;
500
+ AsyncTask task = new AsyncTask<Void, Void, Void>() {
501
+ private Exception exception;
502
+
503
+ @Override
504
+ protected Void doInBackground(Void... voids) {
505
+ try {
506
+ LlamaContext context = contexts.get(contextId);
507
+ if (context == null) {
508
+ throw new Exception("Context not found");
509
+ }
510
+ if (context.isPredicting()) {
511
+ throw new Exception("Context is busy");
512
+ }
513
+ context.removeLoraAdapters();
514
+ } catch (Exception e) {
515
+ exception = e;
516
+ }
517
+ return null;
518
+ }
519
+
520
+ @Override
521
+ protected void onPostExecute(Void result) {
522
+ if (exception != null) {
523
+ promise.reject(exception);
524
+ return;
525
+ }
526
+ promise.resolve(null);
527
+ }
528
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
529
+ tasks.put(task, "removeLoraAdapters-" + contextId);
530
+ }
531
+
532
+ public void getLoadedLoraAdapters(double id, final Promise promise) {
533
+ final int contextId = (int) id;
534
+ AsyncTask task = new AsyncTask<Void, Void, ReadableArray>() {
535
+ private Exception exception;
536
+
537
+ @Override
538
+ protected ReadableArray doInBackground(Void... voids) {
539
+ try {
540
+ LlamaContext context = contexts.get(contextId);
541
+ if (context == null) {
542
+ throw new Exception("Context not found");
543
+ }
544
+ return context.getLoadedLoraAdapters();
545
+ } catch (Exception e) {
546
+ exception = e;
547
+ }
548
+ return null;
549
+ }
550
+
551
+ @Override
552
+ protected void onPostExecute(ReadableArray result) {
553
+ if (exception != null) {
554
+ promise.reject(exception);
555
+ return;
556
+ }
557
+ promise.resolve(result);
558
+ }
559
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
560
+ tasks.put(task, "getLoadedLoraAdapters-" + contextId);
561
+ }
562
+
465
563
  public void releaseContext(double id, Promise promise) {
466
564
  final int contextId = (int) id;
467
565
  AsyncTask task = new AsyncTask<Void, Void, Void>() {
@@ -0,0 +1,94 @@
1
+ #include <jni.h>
2
+
3
+ // ReadableMap utils
4
+
5
+ namespace readablearray {
6
+
7
+ int size(JNIEnv *env, jobject readableArray) {
8
+ jclass arrayClass = env->GetObjectClass(readableArray);
9
+ jmethodID sizeMethod = env->GetMethodID(arrayClass, "size", "()I");
10
+ return env->CallIntMethod(readableArray, sizeMethod);
11
+ }
12
+
13
+ jobject getMap(JNIEnv *env, jobject readableArray, int index) {
14
+ jclass arrayClass = env->GetObjectClass(readableArray);
15
+ jmethodID getMapMethod = env->GetMethodID(arrayClass, "getMap", "(I)Lcom/facebook/react/bridge/ReadableMap;");
16
+ return env->CallObjectMethod(readableArray, getMapMethod, index);
17
+ }
18
+
19
+ // Other methods not used yet
20
+
21
+ }
22
+
23
+ namespace readablemap {
24
+
25
+ bool hasKey(JNIEnv *env, jobject readableMap, const char *key) {
26
+ jclass mapClass = env->GetObjectClass(readableMap);
27
+ jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
28
+ jstring jKey = env->NewStringUTF(key);
29
+ jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
30
+ env->DeleteLocalRef(jKey);
31
+ return result;
32
+ }
33
+
34
+ int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
35
+ if (!hasKey(env, readableMap, key)) {
36
+ return defaultValue;
37
+ }
38
+ jclass mapClass = env->GetObjectClass(readableMap);
39
+ jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
40
+ jstring jKey = env->NewStringUTF(key);
41
+ jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
42
+ env->DeleteLocalRef(jKey);
43
+ return result;
44
+ }
45
+
46
+ bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
47
+ if (!hasKey(env, readableMap, key)) {
48
+ return defaultValue;
49
+ }
50
+ jclass mapClass = env->GetObjectClass(readableMap);
51
+ jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
52
+ jstring jKey = env->NewStringUTF(key);
53
+ jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
54
+ env->DeleteLocalRef(jKey);
55
+ return result;
56
+ }
57
+
58
+ long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
59
+ if (!hasKey(env, readableMap, key)) {
60
+ return defaultValue;
61
+ }
62
+ jclass mapClass = env->GetObjectClass(readableMap);
63
+ jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
64
+ jstring jKey = env->NewStringUTF(key);
65
+ jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
66
+ env->DeleteLocalRef(jKey);
67
+ return result;
68
+ }
69
+
70
+ float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
71
+ if (!hasKey(env, readableMap, key)) {
72
+ return defaultValue;
73
+ }
74
+ jclass mapClass = env->GetObjectClass(readableMap);
75
+ jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
76
+ jstring jKey = env->NewStringUTF(key);
77
+ jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
78
+ env->DeleteLocalRef(jKey);
79
+ return result;
80
+ }
81
+
82
+ jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
83
+ if (!hasKey(env, readableMap, key)) {
84
+ return defaultValue;
85
+ }
86
+ jclass mapClass = env->GetObjectClass(readableMap);
87
+ jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
88
+ jstring jKey = env->NewStringUTF(key);
89
+ jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
90
+ env->DeleteLocalRef(jKey);
91
+ return result;
92
+ }
93
+
94
+ }