cui-llama.rn 1.3.5 → 1.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +22 -1
- package/android/src/main/CMakeLists.txt +25 -20
- package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
- package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
- package/android/src/main/jni-utils.h +94 -0
- package/android/src/main/jni.cpp +108 -37
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
- package/cpp/common.cpp +1982 -1965
- package/cpp/common.h +665 -657
- package/cpp/ggml-backend-reg.cpp +5 -0
- package/cpp/ggml-backend.cpp +5 -2
- package/cpp/ggml-cpp.h +1 -0
- package/cpp/ggml-cpu-aarch64.cpp +6 -1
- package/cpp/ggml-cpu-quants.c +5 -1
- package/cpp/ggml-cpu.c +14122 -14122
- package/cpp/ggml-cpu.cpp +627 -627
- package/cpp/ggml-impl.h +11 -16
- package/cpp/ggml-metal-impl.h +288 -0
- package/cpp/ggml-metal.m +2 -2
- package/cpp/ggml-opt.cpp +854 -0
- package/cpp/ggml-opt.h +216 -0
- package/cpp/ggml.c +0 -1276
- package/cpp/ggml.h +0 -140
- package/cpp/gguf.cpp +1325 -0
- package/cpp/gguf.h +202 -0
- package/cpp/llama-adapter.cpp +346 -0
- package/cpp/llama-adapter.h +73 -0
- package/cpp/llama-arch.cpp +1434 -0
- package/cpp/llama-arch.h +395 -0
- package/cpp/llama-batch.cpp +368 -0
- package/cpp/llama-batch.h +88 -0
- package/cpp/llama-chat.cpp +567 -0
- package/cpp/llama-chat.h +51 -0
- package/cpp/llama-context.cpp +1771 -0
- package/cpp/llama-context.h +128 -0
- package/cpp/llama-cparams.cpp +1 -0
- package/cpp/llama-cparams.h +37 -0
- package/cpp/llama-cpp.h +30 -0
- package/cpp/llama-grammar.cpp +1 -0
- package/cpp/llama-grammar.h +3 -1
- package/cpp/llama-hparams.cpp +71 -0
- package/cpp/llama-hparams.h +140 -0
- package/cpp/llama-impl.cpp +167 -0
- package/cpp/llama-impl.h +16 -136
- package/cpp/llama-kv-cache.cpp +718 -0
- package/cpp/llama-kv-cache.h +218 -0
- package/cpp/llama-mmap.cpp +589 -0
- package/cpp/llama-mmap.h +67 -0
- package/cpp/llama-model-loader.cpp +1011 -0
- package/cpp/llama-model-loader.h +158 -0
- package/cpp/llama-model.cpp +2202 -0
- package/cpp/llama-model.h +391 -0
- package/cpp/llama-sampling.cpp +117 -4
- package/cpp/llama-vocab.cpp +21 -28
- package/cpp/llama-vocab.h +13 -1
- package/cpp/llama.cpp +12547 -23528
- package/cpp/llama.h +31 -6
- package/cpp/rn-llama.hpp +90 -87
- package/cpp/sgemm.cpp +776 -70
- package/cpp/sgemm.h +14 -14
- package/cpp/unicode.cpp +6 -0
- package/ios/RNLlama.mm +47 -0
- package/ios/RNLlamaContext.h +3 -1
- package/ios/RNLlamaContext.mm +71 -14
- package/jest/mock.js +15 -3
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +33 -37
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +31 -35
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +26 -6
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +21 -36
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +4 -18
- package/package.json +2 -3
- package/src/NativeRNLlama.ts +32 -13
- package/src/index.ts +52 -47
package/README.md
CHANGED
@@ -53,12 +53,23 @@ For get a GGUF model or quantize manually, see [`Prepare and Quantize`](https://
|
|
53
53
|
|
54
54
|
## Usage
|
55
55
|
|
56
|
+
Load model info only:
|
57
|
+
|
58
|
+
```js
|
59
|
+
import { loadLlamaModelInfo } from 'llama.rn'
|
60
|
+
|
61
|
+
const modelPath = 'file://<path to gguf model>'
|
62
|
+
console.log('Model Info:', await loadLlamaModelInfo(modelPath))
|
63
|
+
```
|
64
|
+
|
65
|
+
Initialize a Llama context & do completion:
|
66
|
+
|
56
67
|
```js
|
57
68
|
import { initLlama } from 'llama.rn'
|
58
69
|
|
59
70
|
// Initial a Llama context with the model (may take a while)
|
60
71
|
const context = await initLlama({
|
61
|
-
model:
|
72
|
+
model: modelPath,
|
62
73
|
use_mlock: true,
|
63
74
|
n_ctx: 2048,
|
64
75
|
n_gpu_layers: 1, // > 0: enable Metal on iOS
|
@@ -318,6 +329,16 @@ Android:
|
|
318
329
|
|
319
330
|
See the [contributing guide](CONTRIBUTING.md) to learn how to contribute to the repository and the development workflow.
|
320
331
|
|
332
|
+
## Apps using `llama.rn`
|
333
|
+
|
334
|
+
- [BRICKS](https://bricks.tools): Our product for building interactive signage in simple way. We provide LLM functions as Generator LLM/Assistant.
|
335
|
+
- [ChatterUI](https://github.com/Vali-98/ChatterUI): Simple frontend for LLMs built in react-native.
|
336
|
+
- [PocketPal AI](https://github.com/a-ghorbani/pocketpal-ai): An app that brings language models directly to your phone.
|
337
|
+
|
338
|
+
## Node.js binding
|
339
|
+
|
340
|
+
- [llama.node](https://github.com/mybigday/llama.node): An another Node.js binding of `llama.cpp` but made API same as `llama.rn`.
|
341
|
+
|
321
342
|
## License
|
322
343
|
|
323
344
|
MIT
|
@@ -9,39 +9,44 @@ include_directories(${RNLLAMA_LIB_DIR})
|
|
9
9
|
|
10
10
|
set(
|
11
11
|
SOURCE_FILES
|
12
|
-
${RNLLAMA_LIB_DIR}/
|
13
|
-
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
|
14
|
-
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
15
|
-
${RNLLAMA_LIB_DIR}/log.cpp
|
16
|
-
|
17
|
-
#${RNLLAMA_LIB_DIR}/amx/amx.cpp
|
18
|
-
#${RNLLAMA_LIB_DIR}/amx/mmq.cpp
|
19
|
-
|
20
|
-
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
|
21
|
-
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
|
22
|
-
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
23
|
-
${RNLLAMA_LIB_DIR}/log.cpp
|
24
|
-
${RNLLAMA_LIB_DIR}/json.hpp
|
25
|
-
${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
|
26
|
-
|
12
|
+
${RNLLAMA_LIB_DIR}/ggml.c
|
27
13
|
${RNLLAMA_LIB_DIR}/ggml-alloc.c
|
28
14
|
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
|
29
15
|
${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
|
30
|
-
${RNLLAMA_LIB_DIR}/ggml.c
|
31
16
|
${RNLLAMA_LIB_DIR}/ggml-cpu.c
|
32
17
|
${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
|
33
18
|
${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
|
34
|
-
${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
|
35
19
|
${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
|
20
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
|
21
|
+
${RNLLAMA_LIB_DIR}/ggml-opt.cpp
|
36
22
|
${RNLLAMA_LIB_DIR}/ggml-threading.cpp
|
37
23
|
${RNLLAMA_LIB_DIR}/ggml-quants.c
|
38
|
-
${RNLLAMA_LIB_DIR}/
|
24
|
+
${RNLLAMA_LIB_DIR}/gguf.cpp
|
25
|
+
${RNLLAMA_LIB_DIR}/log.cpp
|
26
|
+
${RNLLAMA_LIB_DIR}/llama-impl.cpp
|
27
|
+
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
|
28
|
+
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
|
29
|
+
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
30
|
+
${RNLLAMA_LIB_DIR}/llama-adapter.cpp
|
31
|
+
${RNLLAMA_LIB_DIR}/llama-chat.cpp
|
32
|
+
${RNLLAMA_LIB_DIR}/llama-context.cpp
|
33
|
+
${RNLLAMA_LIB_DIR}/llama-kv-cache.cpp
|
34
|
+
${RNLLAMA_LIB_DIR}/llama-arch.cpp
|
35
|
+
${RNLLAMA_LIB_DIR}/llama-batch.cpp
|
36
|
+
${RNLLAMA_LIB_DIR}/llama-cparams.cpp
|
37
|
+
${RNLLAMA_LIB_DIR}/llama-hparams.cpp
|
38
|
+
${RNLLAMA_LIB_DIR}/llama.cpp
|
39
|
+
${RNLLAMA_LIB_DIR}/llama-model.cpp
|
40
|
+
${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
|
41
|
+
${RNLLAMA_LIB_DIR}/llama-mmap.cpp
|
42
|
+
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
39
43
|
${RNLLAMA_LIB_DIR}/sampling.cpp
|
40
44
|
${RNLLAMA_LIB_DIR}/unicode-data.cpp
|
41
45
|
${RNLLAMA_LIB_DIR}/unicode.cpp
|
42
|
-
${RNLLAMA_LIB_DIR}/llama.cpp
|
43
46
|
${RNLLAMA_LIB_DIR}/sgemm.cpp
|
47
|
+
${RNLLAMA_LIB_DIR}/common.cpp
|
44
48
|
${RNLLAMA_LIB_DIR}/rn-llama.hpp
|
49
|
+
${CMAKE_SOURCE_DIR}/jni-utils.h
|
45
50
|
${CMAKE_SOURCE_DIR}/jni.cpp
|
46
51
|
)
|
47
52
|
|
@@ -56,7 +61,7 @@ function(build_library target_name cpu_flags)
|
|
56
61
|
|
57
62
|
target_link_libraries(${target_name} ${LOG_LIB} android)
|
58
63
|
|
59
|
-
target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags}
|
64
|
+
target_compile_options(${target_name} PRIVATE -DLM_GGML_USE_CPU -DLM_GGML_USE_CPU_AARCH64 -pthread ${cpu_flags})
|
60
65
|
|
61
66
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
62
67
|
target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING)
|
@@ -108,6 +108,8 @@ public class LlamaContext {
|
|
108
108
|
params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
|
109
109
|
// int n_batch,
|
110
110
|
params.hasKey("n_batch") ? params.getInt("n_batch") : 512,
|
111
|
+
// int n_ubatch,
|
112
|
+
params.hasKey("n_ubatch") ? params.getInt("n_ubatch") : 512,
|
111
113
|
// int n_threads,
|
112
114
|
params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
|
113
115
|
// int n_gpu_layers, // TODO: Support this
|
@@ -115,9 +117,9 @@ public class LlamaContext {
|
|
115
117
|
// boolean flash_attn,
|
116
118
|
params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
|
117
119
|
// String cache_type_k,
|
118
|
-
params.hasKey("cache_type_k") ? params.
|
120
|
+
params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
|
119
121
|
// String cache_type_v,
|
120
|
-
params.hasKey("cache_type_v") ? params.
|
122
|
+
params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
|
121
123
|
// boolean use_mlock,
|
122
124
|
params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
|
123
125
|
// boolean use_mmap,
|
@@ -128,6 +130,8 @@ public class LlamaContext {
|
|
128
130
|
params.hasKey("lora") ? params.getString("lora") : "",
|
129
131
|
// float lora_scaled,
|
130
132
|
params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
|
133
|
+
// ReadableArray lora_adapters,
|
134
|
+
params.hasKey("lora_list") ? params.getArray("lora_list") : null,
|
131
135
|
// float rope_freq_base,
|
132
136
|
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
|
133
137
|
// float rope_freq_scale
|
@@ -168,7 +172,7 @@ public class LlamaContext {
|
|
168
172
|
WritableMap event = Arguments.createMap();
|
169
173
|
event.putInt("contextId", LlamaContext.this.id);
|
170
174
|
event.putInt("progress", progress);
|
171
|
-
eventEmitter.emit("@
|
175
|
+
eventEmitter.emit("@RNLlama_onContextProgress", event);
|
172
176
|
}
|
173
177
|
|
174
178
|
private static class LoadProgressCallback {
|
@@ -273,8 +277,6 @@ public class LlamaContext {
|
|
273
277
|
params.hasKey("mirostat_tau") ? (float) params.getDouble("mirostat_tau") : 5.00f,
|
274
278
|
// float mirostat_eta,
|
275
279
|
params.hasKey("mirostat_eta") ? (float) params.getDouble("mirostat_eta") : 0.10f,
|
276
|
-
// boolean penalize_nl,
|
277
|
-
params.hasKey("penalize_nl") ? params.getBoolean("penalize_nl") : false,
|
278
280
|
// int top_k,
|
279
281
|
params.hasKey("top_k") ? params.getInt("top_k") : 40,
|
280
282
|
// float top_p,
|
@@ -359,6 +361,22 @@ public class LlamaContext {
|
|
359
361
|
return bench(this.context, pp, tg, pl, nr);
|
360
362
|
}
|
361
363
|
|
364
|
+
public int applyLoraAdapters(ReadableArray loraAdapters) {
|
365
|
+
int result = applyLoraAdapters(this.context, loraAdapters);
|
366
|
+
if (result != 0) {
|
367
|
+
throw new IllegalStateException("Failed to apply lora adapters");
|
368
|
+
}
|
369
|
+
return result;
|
370
|
+
}
|
371
|
+
|
372
|
+
public void removeLoraAdapters() {
|
373
|
+
removeLoraAdapters(this.context);
|
374
|
+
}
|
375
|
+
|
376
|
+
public WritableArray getLoadedLoraAdapters() {
|
377
|
+
return getLoadedLoraAdapters(this.context);
|
378
|
+
}
|
379
|
+
|
362
380
|
public void release() {
|
363
381
|
freeContext(context);
|
364
382
|
}
|
@@ -460,16 +478,18 @@ public class LlamaContext {
|
|
460
478
|
int embd_normalize,
|
461
479
|
int n_ctx,
|
462
480
|
int n_batch,
|
481
|
+
int n_ubatch,
|
463
482
|
int n_threads,
|
464
483
|
int n_gpu_layers, // TODO: Support this
|
465
484
|
boolean flash_attn,
|
466
|
-
|
467
|
-
|
485
|
+
String cache_type_k,
|
486
|
+
String cache_type_v,
|
468
487
|
boolean use_mlock,
|
469
488
|
boolean use_mmap,
|
470
489
|
boolean vocab_only,
|
471
490
|
String lora,
|
472
491
|
float lora_scaled,
|
492
|
+
ReadableArray lora_list,
|
473
493
|
float rope_freq_base,
|
474
494
|
float rope_freq_scale,
|
475
495
|
int pooling_type,
|
@@ -508,7 +528,6 @@ public class LlamaContext {
|
|
508
528
|
float mirostat,
|
509
529
|
float mirostat_tau,
|
510
530
|
float mirostat_eta,
|
511
|
-
boolean penalize_nl,
|
512
531
|
int top_k,
|
513
532
|
float top_p,
|
514
533
|
float min_p,
|
@@ -521,7 +540,7 @@ public class LlamaContext {
|
|
521
540
|
double[][] logit_bias,
|
522
541
|
float dry_multiplier,
|
523
542
|
float dry_base,
|
524
|
-
int dry_allowed_length,
|
543
|
+
int dry_allowed_length,
|
525
544
|
int dry_penalty_last_n,
|
526
545
|
String[] dry_sequence_breakers,
|
527
546
|
PartialCompletionCallback partial_completion_callback
|
@@ -537,6 +556,9 @@ public class LlamaContext {
|
|
537
556
|
int embd_normalize
|
538
557
|
);
|
539
558
|
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
|
559
|
+
protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters);
|
560
|
+
protected static native void removeLoraAdapters(long contextPtr);
|
561
|
+
protected static native WritableArray getLoadedLoraAdapters(long contextPtr);
|
540
562
|
protected static native void freeContext(long contextPtr);
|
541
563
|
protected static native void logToAndroid();
|
542
564
|
}
|
@@ -462,6 +462,104 @@ public class RNLlama implements LifecycleEventListener {
|
|
462
462
|
tasks.put(task, "bench-" + contextId);
|
463
463
|
}
|
464
464
|
|
465
|
+
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
|
466
|
+
final int contextId = (int) id;
|
467
|
+
AsyncTask task = new AsyncTask<Void, Void, Void>() {
|
468
|
+
private Exception exception;
|
469
|
+
|
470
|
+
@Override
|
471
|
+
protected Void doInBackground(Void... voids) {
|
472
|
+
try {
|
473
|
+
LlamaContext context = contexts.get(contextId);
|
474
|
+
if (context == null) {
|
475
|
+
throw new Exception("Context not found");
|
476
|
+
}
|
477
|
+
if (context.isPredicting()) {
|
478
|
+
throw new Exception("Context is busy");
|
479
|
+
}
|
480
|
+
context.applyLoraAdapters(loraAdapters);
|
481
|
+
} catch (Exception e) {
|
482
|
+
exception = e;
|
483
|
+
}
|
484
|
+
return null;
|
485
|
+
}
|
486
|
+
|
487
|
+
@Override
|
488
|
+
protected void onPostExecute(Void result) {
|
489
|
+
if (exception != null) {
|
490
|
+
promise.reject(exception);
|
491
|
+
return;
|
492
|
+
}
|
493
|
+
}
|
494
|
+
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
|
495
|
+
tasks.put(task, "applyLoraAdapters-" + contextId);
|
496
|
+
}
|
497
|
+
|
498
|
+
public void removeLoraAdapters(double id, final Promise promise) {
|
499
|
+
final int contextId = (int) id;
|
500
|
+
AsyncTask task = new AsyncTask<Void, Void, Void>() {
|
501
|
+
private Exception exception;
|
502
|
+
|
503
|
+
@Override
|
504
|
+
protected Void doInBackground(Void... voids) {
|
505
|
+
try {
|
506
|
+
LlamaContext context = contexts.get(contextId);
|
507
|
+
if (context == null) {
|
508
|
+
throw new Exception("Context not found");
|
509
|
+
}
|
510
|
+
if (context.isPredicting()) {
|
511
|
+
throw new Exception("Context is busy");
|
512
|
+
}
|
513
|
+
context.removeLoraAdapters();
|
514
|
+
} catch (Exception e) {
|
515
|
+
exception = e;
|
516
|
+
}
|
517
|
+
return null;
|
518
|
+
}
|
519
|
+
|
520
|
+
@Override
|
521
|
+
protected void onPostExecute(Void result) {
|
522
|
+
if (exception != null) {
|
523
|
+
promise.reject(exception);
|
524
|
+
return;
|
525
|
+
}
|
526
|
+
promise.resolve(null);
|
527
|
+
}
|
528
|
+
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
|
529
|
+
tasks.put(task, "removeLoraAdapters-" + contextId);
|
530
|
+
}
|
531
|
+
|
532
|
+
public void getLoadedLoraAdapters(double id, final Promise promise) {
|
533
|
+
final int contextId = (int) id;
|
534
|
+
AsyncTask task = new AsyncTask<Void, Void, ReadableArray>() {
|
535
|
+
private Exception exception;
|
536
|
+
|
537
|
+
@Override
|
538
|
+
protected ReadableArray doInBackground(Void... voids) {
|
539
|
+
try {
|
540
|
+
LlamaContext context = contexts.get(contextId);
|
541
|
+
if (context == null) {
|
542
|
+
throw new Exception("Context not found");
|
543
|
+
}
|
544
|
+
return context.getLoadedLoraAdapters();
|
545
|
+
} catch (Exception e) {
|
546
|
+
exception = e;
|
547
|
+
}
|
548
|
+
return null;
|
549
|
+
}
|
550
|
+
|
551
|
+
@Override
|
552
|
+
protected void onPostExecute(ReadableArray result) {
|
553
|
+
if (exception != null) {
|
554
|
+
promise.reject(exception);
|
555
|
+
return;
|
556
|
+
}
|
557
|
+
promise.resolve(result);
|
558
|
+
}
|
559
|
+
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
|
560
|
+
tasks.put(task, "getLoadedLoraAdapters-" + contextId);
|
561
|
+
}
|
562
|
+
|
465
563
|
public void releaseContext(double id, Promise promise) {
|
466
564
|
final int contextId = (int) id;
|
467
565
|
AsyncTask task = new AsyncTask<Void, Void, Void>() {
|
@@ -0,0 +1,94 @@
|
|
1
|
+
#include <jni.h>
|
2
|
+
|
3
|
+
// ReadableMap utils
|
4
|
+
|
5
|
+
namespace readablearray {
|
6
|
+
|
7
|
+
int size(JNIEnv *env, jobject readableArray) {
|
8
|
+
jclass arrayClass = env->GetObjectClass(readableArray);
|
9
|
+
jmethodID sizeMethod = env->GetMethodID(arrayClass, "size", "()I");
|
10
|
+
return env->CallIntMethod(readableArray, sizeMethod);
|
11
|
+
}
|
12
|
+
|
13
|
+
jobject getMap(JNIEnv *env, jobject readableArray, int index) {
|
14
|
+
jclass arrayClass = env->GetObjectClass(readableArray);
|
15
|
+
jmethodID getMapMethod = env->GetMethodID(arrayClass, "getMap", "(I)Lcom/facebook/react/bridge/ReadableMap;");
|
16
|
+
return env->CallObjectMethod(readableArray, getMapMethod, index);
|
17
|
+
}
|
18
|
+
|
19
|
+
// Other methods not used yet
|
20
|
+
|
21
|
+
}
|
22
|
+
|
23
|
+
namespace readablemap {
|
24
|
+
|
25
|
+
bool hasKey(JNIEnv *env, jobject readableMap, const char *key) {
|
26
|
+
jclass mapClass = env->GetObjectClass(readableMap);
|
27
|
+
jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
|
28
|
+
jstring jKey = env->NewStringUTF(key);
|
29
|
+
jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
|
30
|
+
env->DeleteLocalRef(jKey);
|
31
|
+
return result;
|
32
|
+
}
|
33
|
+
|
34
|
+
int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
|
35
|
+
if (!hasKey(env, readableMap, key)) {
|
36
|
+
return defaultValue;
|
37
|
+
}
|
38
|
+
jclass mapClass = env->GetObjectClass(readableMap);
|
39
|
+
jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
|
40
|
+
jstring jKey = env->NewStringUTF(key);
|
41
|
+
jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
|
42
|
+
env->DeleteLocalRef(jKey);
|
43
|
+
return result;
|
44
|
+
}
|
45
|
+
|
46
|
+
bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
|
47
|
+
if (!hasKey(env, readableMap, key)) {
|
48
|
+
return defaultValue;
|
49
|
+
}
|
50
|
+
jclass mapClass = env->GetObjectClass(readableMap);
|
51
|
+
jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
|
52
|
+
jstring jKey = env->NewStringUTF(key);
|
53
|
+
jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
|
54
|
+
env->DeleteLocalRef(jKey);
|
55
|
+
return result;
|
56
|
+
}
|
57
|
+
|
58
|
+
long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
|
59
|
+
if (!hasKey(env, readableMap, key)) {
|
60
|
+
return defaultValue;
|
61
|
+
}
|
62
|
+
jclass mapClass = env->GetObjectClass(readableMap);
|
63
|
+
jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
|
64
|
+
jstring jKey = env->NewStringUTF(key);
|
65
|
+
jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
|
66
|
+
env->DeleteLocalRef(jKey);
|
67
|
+
return result;
|
68
|
+
}
|
69
|
+
|
70
|
+
float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
|
71
|
+
if (!hasKey(env, readableMap, key)) {
|
72
|
+
return defaultValue;
|
73
|
+
}
|
74
|
+
jclass mapClass = env->GetObjectClass(readableMap);
|
75
|
+
jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
|
76
|
+
jstring jKey = env->NewStringUTF(key);
|
77
|
+
jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
|
78
|
+
env->DeleteLocalRef(jKey);
|
79
|
+
return result;
|
80
|
+
}
|
81
|
+
|
82
|
+
jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
|
83
|
+
if (!hasKey(env, readableMap, key)) {
|
84
|
+
return defaultValue;
|
85
|
+
}
|
86
|
+
jclass mapClass = env->GetObjectClass(readableMap);
|
87
|
+
jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
|
88
|
+
jstring jKey = env->NewStringUTF(key);
|
89
|
+
jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
|
90
|
+
env->DeleteLocalRef(jKey);
|
91
|
+
return result;
|
92
|
+
}
|
93
|
+
|
94
|
+
}
|
package/android/src/main/jni.cpp
CHANGED
@@ -11,15 +11,17 @@
|
|
11
11
|
#include <unordered_map>
|
12
12
|
#include "llama.h"
|
13
13
|
#include "llama-impl.h"
|
14
|
-
#include "
|
14
|
+
#include "llama-context.h"
|
15
|
+
#include "gguf.h"
|
15
16
|
#include "rn-llama.hpp"
|
17
|
+
#include "jni-utils.h"
|
16
18
|
|
17
19
|
#define UNUSED(x) (void)(x)
|
18
20
|
#define TAG "RNLLAMA_ANDROID_JNI"
|
19
21
|
|
20
22
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
21
23
|
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
22
|
-
|
24
|
+
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
23
25
|
static inline int min(int a, int b) {
|
24
26
|
return (a < b) ? a : b;
|
25
27
|
}
|
@@ -128,7 +130,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
|
|
128
130
|
// Method to push WritableMap into WritableArray
|
129
131
|
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
|
130
132
|
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
|
131
|
-
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/
|
133
|
+
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");
|
132
134
|
|
133
135
|
env->CallVoidMethod(arr, pushMapMethod, value);
|
134
136
|
}
|
@@ -198,7 +200,7 @@ Java_com_rnllama_LlamaContext_modelInfo(
|
|
198
200
|
continue;
|
199
201
|
}
|
200
202
|
|
201
|
-
const std::string value =
|
203
|
+
const std::string value = lm_gguf_kv_to_str(ctx, i);
|
202
204
|
putString(env, info, key, value.c_str());
|
203
205
|
}
|
204
206
|
}
|
@@ -233,16 +235,18 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
233
235
|
jint embd_normalize,
|
234
236
|
jint n_ctx,
|
235
237
|
jint n_batch,
|
238
|
+
jint n_ubatch,
|
236
239
|
jint n_threads,
|
237
240
|
jint n_gpu_layers, // TODO: Support this
|
238
241
|
jboolean flash_attn,
|
239
|
-
|
240
|
-
|
242
|
+
jstring cache_type_k,
|
243
|
+
jstring cache_type_v,
|
241
244
|
jboolean use_mlock,
|
242
245
|
jboolean use_mmap,
|
243
246
|
jboolean vocab_only,
|
244
247
|
jstring lora_str,
|
245
248
|
jfloat lora_scaled,
|
249
|
+
jobject lora_list,
|
246
250
|
jfloat rope_freq_base,
|
247
251
|
jfloat rope_freq_scale,
|
248
252
|
jint pooling_type,
|
@@ -262,6 +266,7 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
262
266
|
|
263
267
|
defaultParams.n_ctx = n_ctx;
|
264
268
|
defaultParams.n_batch = n_batch;
|
269
|
+
defaultParams.n_ubatch = n_ubatch;
|
265
270
|
|
266
271
|
if (pooling_type != -1) {
|
267
272
|
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
|
@@ -284,19 +289,14 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
284
289
|
// defaultParams.n_gpu_layers = n_gpu_layers;
|
285
290
|
defaultParams.flash_attn = flash_attn;
|
286
291
|
|
287
|
-
|
288
|
-
|
289
|
-
defaultParams.cache_type_k = (
|
290
|
-
defaultParams.cache_type_v = (
|
292
|
+
const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
|
293
|
+
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
|
294
|
+
defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
|
295
|
+
defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);
|
291
296
|
|
292
297
|
defaultParams.use_mlock = use_mlock;
|
293
298
|
defaultParams.use_mmap = use_mmap;
|
294
299
|
|
295
|
-
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
296
|
-
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
297
|
-
defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
|
298
|
-
}
|
299
|
-
|
300
300
|
defaultParams.rope_freq_base = rope_freq_base;
|
301
301
|
defaultParams.rope_freq_scale = rope_freq_scale;
|
302
302
|
|
@@ -330,20 +330,52 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
330
330
|
bool is_model_loaded = llama->loadModel(defaultParams);
|
331
331
|
|
332
332
|
env->ReleaseStringUTFChars(model_path_str, model_path_chars);
|
333
|
-
env->ReleaseStringUTFChars(
|
334
|
-
|
335
|
-
// env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
333
|
+
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
|
334
|
+
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
|
336
335
|
|
337
336
|
LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
|
338
337
|
if (is_model_loaded) {
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
338
|
+
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
|
339
|
+
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
|
340
|
+
llama_free(llama->ctx);
|
341
|
+
return -1;
|
342
|
+
}
|
343
|
+
context_map[(long) llama->ctx] = llama;
|
345
344
|
} else {
|
345
|
+
llama_free(llama->ctx);
|
346
|
+
}
|
347
|
+
|
348
|
+
std::vector<common_lora_adapter_info> lora;
|
349
|
+
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
|
350
|
+
if (lora_chars != nullptr && lora_chars[0] != '\0') {
|
351
|
+
common_lora_adapter_info la;
|
352
|
+
la.path = lora_chars;
|
353
|
+
la.scale = lora_scaled;
|
354
|
+
lora.push_back(la);
|
355
|
+
}
|
356
|
+
|
357
|
+
if (lora_list != nullptr) {
|
358
|
+
// lora_adapters: ReadableArray<ReadableMap>
|
359
|
+
int lora_list_size = readablearray::size(env, lora_list);
|
360
|
+
for (int i = 0; i < lora_list_size; i++) {
|
361
|
+
jobject lora_adapter = readablearray::getMap(env, lora_list, i);
|
362
|
+
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
363
|
+
if (path != nullptr) {
|
364
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
365
|
+
common_lora_adapter_info la;
|
366
|
+
la.path = path_chars;
|
367
|
+
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
368
|
+
lora.push_back(la);
|
369
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
370
|
+
}
|
371
|
+
}
|
372
|
+
}
|
373
|
+
env->ReleaseStringUTFChars(lora_str, lora_chars);
|
374
|
+
int result = llama->applyLoraAdapters(lora);
|
375
|
+
if (result != 0) {
|
376
|
+
LOGI("[RNLlama] Failed to apply lora adapters");
|
346
377
|
llama_free(llama->ctx);
|
378
|
+
return -1;
|
347
379
|
}
|
348
380
|
|
349
381
|
return reinterpret_cast<jlong>(llama->ctx);
|
@@ -532,7 +564,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
532
564
|
jfloat mirostat,
|
533
565
|
jfloat mirostat_tau,
|
534
566
|
jfloat mirostat_eta,
|
535
|
-
jboolean penalize_nl,
|
536
567
|
jint top_k,
|
537
568
|
jfloat top_p,
|
538
569
|
jfloat min_p,
|
@@ -545,7 +576,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
545
576
|
jobjectArray logit_bias,
|
546
577
|
jfloat dry_multiplier,
|
547
578
|
jfloat dry_base,
|
548
|
-
jint dry_allowed_length,
|
579
|
+
jint dry_allowed_length,
|
549
580
|
jint dry_penalty_last_n,
|
550
581
|
jobjectArray dry_sequence_breakers,
|
551
582
|
jobject partial_completion_callback
|
@@ -577,7 +608,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
577
608
|
sparams.mirostat = mirostat;
|
578
609
|
sparams.mirostat_tau = mirostat_tau;
|
579
610
|
sparams.mirostat_eta = mirostat_eta;
|
580
|
-
// sparams.penalize_nl = penalize_nl;
|
581
611
|
sparams.top_k = top_k;
|
582
612
|
sparams.top_p = top_p;
|
583
613
|
sparams.min_p = min_p;
|
@@ -884,23 +914,64 @@ Java_com_rnllama_LlamaContext_bench(
|
|
884
914
|
return env->NewStringUTF(result.c_str());
|
885
915
|
}
|
886
916
|
|
917
|
+
JNIEXPORT jint JNICALL
|
918
|
+
Java_com_rnllama_LlamaContext_applyLoraAdapters(
|
919
|
+
JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
|
920
|
+
UNUSED(thiz);
|
921
|
+
auto llama = context_map[(long) context_ptr];
|
922
|
+
|
923
|
+
// lora_adapters: ReadableArray<ReadableMap>
|
924
|
+
std::vector<common_lora_adapter_info> lora_adapters;
|
925
|
+
int lora_adapters_size = readablearray::size(env, loraAdapters);
|
926
|
+
for (int i = 0; i < lora_adapters_size; i++) {
|
927
|
+
jobject lora_adapter = readablearray::getMap(env, loraAdapters, i);
|
928
|
+
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
|
929
|
+
if (path != nullptr) {
|
930
|
+
const char *path_chars = env->GetStringUTFChars(path, nullptr);
|
931
|
+
env->ReleaseStringUTFChars(path, path_chars);
|
932
|
+
float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
|
933
|
+
common_lora_adapter_info la;
|
934
|
+
la.path = path_chars;
|
935
|
+
la.scale = scaled;
|
936
|
+
lora_adapters.push_back(la);
|
937
|
+
}
|
938
|
+
}
|
939
|
+
return llama->applyLoraAdapters(lora_adapters);
|
940
|
+
}
|
941
|
+
|
942
|
+
JNIEXPORT void JNICALL
|
943
|
+
Java_com_rnllama_LlamaContext_removeLoraAdapters(
|
944
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
945
|
+
UNUSED(env);
|
946
|
+
UNUSED(thiz);
|
947
|
+
auto llama = context_map[(long) context_ptr];
|
948
|
+
llama->removeLoraAdapters();
|
949
|
+
}
|
950
|
+
|
951
|
+
JNIEXPORT jobject JNICALL
|
952
|
+
Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
|
953
|
+
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
954
|
+
UNUSED(thiz);
|
955
|
+
auto llama = context_map[(long) context_ptr];
|
956
|
+
auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
|
957
|
+
auto result = createWritableArray(env);
|
958
|
+
for (common_lora_adapter_info &la : loaded_lora_adapters) {
|
959
|
+
auto map = createWriteableMap(env);
|
960
|
+
putString(env, map, "path", la.path.c_str());
|
961
|
+
putDouble(env, map, "scaled", la.scale);
|
962
|
+
pushMap(env, result, map);
|
963
|
+
}
|
964
|
+
return result;
|
965
|
+
}
|
966
|
+
|
887
967
|
JNIEXPORT void JNICALL
|
888
968
|
Java_com_rnllama_LlamaContext_freeContext(
|
889
969
|
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
890
970
|
UNUSED(env);
|
891
971
|
UNUSED(thiz);
|
892
972
|
auto llama = context_map[(long) context_ptr];
|
893
|
-
if (llama->model) {
|
894
|
-
llama_free_model(llama->model);
|
895
|
-
}
|
896
|
-
if (llama->ctx) {
|
897
|
-
llama_free(llama->ctx);
|
898
|
-
}
|
899
|
-
if (llama->ctx_sampling != nullptr)
|
900
|
-
{
|
901
|
-
common_sampler_free(llama->ctx_sampling);
|
902
|
-
}
|
903
973
|
context_map.erase((long) llama->ctx);
|
974
|
+
delete llama;
|
904
975
|
}
|
905
976
|
|
906
977
|
JNIEXPORT void JNICALL
|