cui-llama.rn 1.2.4 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -4
- package/android/src/main/CMakeLists.txt +21 -5
- package/android/src/main/java/com/rnllama/LlamaContext.java +115 -30
- package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
- package/android/src/main/jni.cpp +222 -36
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/cpp/common.cpp +1682 -2122
- package/cpp/common.h +600 -594
- package/cpp/ggml-aarch64.c +129 -3209
- package/cpp/ggml-aarch64.h +19 -39
- package/cpp/ggml-alloc.c +1040 -1040
- package/cpp/ggml-alloc.h +76 -76
- package/cpp/ggml-backend-impl.h +216 -227
- package/cpp/ggml-backend-reg.cpp +195 -0
- package/cpp/ggml-backend.cpp +1997 -2625
- package/cpp/ggml-backend.h +328 -326
- package/cpp/ggml-common.h +1853 -1853
- package/cpp/ggml-cpp.h +38 -0
- package/cpp/ggml-cpu-aarch64.c +3560 -0
- package/cpp/ggml-cpu-aarch64.h +30 -0
- package/cpp/ggml-cpu-impl.h +371 -614
- package/cpp/ggml-cpu-quants.c +10822 -0
- package/cpp/ggml-cpu-quants.h +63 -0
- package/cpp/ggml-cpu.c +13975 -0
- package/cpp/ggml-cpu.cpp +663 -0
- package/cpp/ggml-cpu.h +177 -0
- package/cpp/ggml-impl.h +550 -209
- package/cpp/ggml-metal.h +66 -66
- package/cpp/ggml-metal.m +4294 -3819
- package/cpp/ggml-quants.c +5247 -15752
- package/cpp/ggml-quants.h +100 -147
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +12 -0
- package/cpp/ggml.c +8180 -23464
- package/cpp/ggml.h +2411 -2562
- package/cpp/llama-grammar.cpp +1138 -1138
- package/cpp/llama-grammar.h +144 -144
- package/cpp/llama-impl.h +181 -181
- package/cpp/llama-sampling.cpp +2348 -2194
- package/cpp/llama-sampling.h +48 -30
- package/cpp/llama-vocab.cpp +1984 -1968
- package/cpp/llama-vocab.h +170 -165
- package/cpp/llama.cpp +22132 -21969
- package/cpp/llama.h +1253 -1253
- package/cpp/log.cpp +401 -401
- package/cpp/log.h +121 -121
- package/cpp/rn-llama.hpp +83 -19
- package/cpp/sampling.cpp +466 -458
- package/cpp/sgemm.cpp +1884 -1219
- package/ios/RNLlama.mm +43 -20
- package/ios/RNLlamaContext.h +9 -3
- package/ios/RNLlamaContext.mm +133 -33
- package/jest/mock.js +0 -1
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +52 -15
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +51 -15
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +29 -6
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +12 -5
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +41 -7
- package/src/index.ts +82 -27
- package/cpp/json-schema-to-grammar.cpp +0 -1045
- package/cpp/json-schema-to-grammar.h +0 -8
- package/cpp/json.hpp +0 -24766
package/README.md
CHANGED
@@ -6,15 +6,14 @@ This fork exists to update llama.cpp on a more frequent basis, plus adding usefu
|
|
6
6
|
|
7
7
|
The following features have been added for Android:
|
8
8
|
|
9
|
-
- Updated sync for llama.cpp
|
10
9
|
- Added stopping prompt processing between batches, vital for mobile devices with very slow prompt processing
|
11
10
|
- `vocab_only` mode: utilize the llama.cpp tokenizer
|
12
11
|
- tokenizeSync: non-blocking, synchronous tokenizer function
|
13
12
|
- Context Shift taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)
|
14
|
-
- XTC sampling
|
15
|
-
- Progress callback
|
16
13
|
- Retrieving CPU Features to check for i8mm and dotprod flags
|
17
14
|
|
15
|
+
There is no IOS implementation for these features.
|
16
|
+
|
18
17
|
Original repo README.md below.
|
19
18
|
|
20
19
|
# llama.rn
|
@@ -307,7 +306,7 @@ iOS:
|
|
307
306
|
|
308
307
|
- The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
|
309
308
|
- Metal:
|
310
|
-
- We have tested to know some devices is not able to use Metal (
|
309
|
+
- We have tested to know some devices is not able to use Metal (GPU) due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
|
311
310
|
- It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.
|
312
311
|
|
313
312
|
Android:
|
@@ -14,21 +14,28 @@ set(
|
|
14
14
|
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
15
15
|
${RNLLAMA_LIB_DIR}/log.cpp
|
16
16
|
|
17
|
+
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
|
18
|
+
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
|
19
|
+
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
|
20
|
+
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
21
|
+
${RNLLAMA_LIB_DIR}/log.cpp
|
22
|
+
|
17
23
|
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
|
18
24
|
${RNLLAMA_LIB_DIR}/ggml-alloc.c
|
19
25
|
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
|
26
|
+
${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
|
20
27
|
${RNLLAMA_LIB_DIR}/ggml.c
|
28
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu.c
|
29
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
|
30
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c
|
31
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
|
32
|
+
${RNLLAMA_LIB_DIR}/ggml-threading.cpp
|
21
33
|
${RNLLAMA_LIB_DIR}/ggml-quants.c
|
22
34
|
${RNLLAMA_LIB_DIR}/common.cpp
|
23
|
-
${RNLLAMA_LIB_DIR}/json.hpp
|
24
|
-
${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
|
25
35
|
${RNLLAMA_LIB_DIR}/sampling.cpp
|
26
36
|
${RNLLAMA_LIB_DIR}/unicode-data.cpp
|
27
37
|
${RNLLAMA_LIB_DIR}/unicode.cpp
|
28
38
|
${RNLLAMA_LIB_DIR}/llama.cpp
|
29
|
-
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
30
|
-
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
|
31
|
-
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
|
32
39
|
${RNLLAMA_LIB_DIR}/sgemm.cpp
|
33
40
|
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
|
34
41
|
${RNLLAMA_LIB_DIR}/rn-llama.hpp
|
@@ -71,11 +78,20 @@ build_library("rnllama" "")
|
|
71
78
|
|
72
79
|
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
|
73
80
|
# ARM64 targets
|
81
|
+
build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
|
82
|
+
build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
|
74
83
|
build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
|
75
84
|
build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
|
76
85
|
build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
|
77
86
|
build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
|
78
87
|
build_library("rnllama_v8" "-march=armv8-a")
|
88
|
+
|
89
|
+
# https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
|
90
|
+
# llama.cpp will deal with the cpu features
|
91
|
+
# build_library("rnllama_v8_7" "-march=armv8.7-a")
|
92
|
+
# TODO: Add support runtime check for cpu features
|
93
|
+
# At the moment runtime check is failing.
|
94
|
+
|
79
95
|
elseif (${ANDROID_ABI} STREQUAL "x86_64")
|
80
96
|
# x86_64 target
|
81
97
|
build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt")
|
@@ -93,7 +93,7 @@ public class LlamaContext {
|
|
93
93
|
Log.e(NAME, "Failed to convert to FD!");
|
94
94
|
}
|
95
95
|
}
|
96
|
-
|
96
|
+
logToAndroid();
|
97
97
|
// Check if file has GGUF magic numbers
|
98
98
|
this.id = id;
|
99
99
|
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
|
@@ -102,6 +102,8 @@ public class LlamaContext {
|
|
102
102
|
modelName,
|
103
103
|
// boolean embedding,
|
104
104
|
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
|
105
|
+
// int embd_normalize,
|
106
|
+
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1,
|
105
107
|
// int n_ctx,
|
106
108
|
params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
|
107
109
|
// int n_batch,
|
@@ -110,6 +112,12 @@ public class LlamaContext {
|
|
110
112
|
params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
|
111
113
|
// int n_gpu_layers, // TODO: Support this
|
112
114
|
params.hasKey("n_gpu_layers") ? params.getInt("n_gpu_layers") : 0,
|
115
|
+
// boolean flash_attn,
|
116
|
+
params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
|
117
|
+
// String cache_type_k,
|
118
|
+
params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
|
119
|
+
// String cache_type_v,
|
120
|
+
params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
|
113
121
|
// boolean use_mlock,
|
114
122
|
params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
|
115
123
|
// boolean use_mmap,
|
@@ -124,12 +132,22 @@ public class LlamaContext {
|
|
124
132
|
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
|
125
133
|
// float rope_freq_scale
|
126
134
|
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
|
127
|
-
|
135
|
+
// int pooling_type,
|
136
|
+
params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1,
|
137
|
+
// LoadProgressCallback load_progress_callback
|
138
|
+
params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
|
128
139
|
);
|
140
|
+
if (this.context == -1) {
|
141
|
+
throw new IllegalStateException("Failed to initialize context");
|
142
|
+
}
|
129
143
|
this.modelDetails = loadModelDetails(this.context);
|
130
144
|
this.reactContext = reactContext;
|
131
145
|
}
|
132
146
|
|
147
|
+
public void interruptLoad() {
|
148
|
+
interruptLoad(this.context);
|
149
|
+
}
|
150
|
+
|
133
151
|
public long getContext() {
|
134
152
|
return context;
|
135
153
|
}
|
@@ -146,6 +164,25 @@ public class LlamaContext {
|
|
146
164
|
return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
|
147
165
|
}
|
148
166
|
|
167
|
+
private void emitLoadProgress(int progress) {
|
168
|
+
WritableMap event = Arguments.createMap();
|
169
|
+
event.putInt("contextId", LlamaContext.this.id);
|
170
|
+
event.putInt("progress", progress);
|
171
|
+
eventEmitter.emit("@RNLlama_onInitContextProgress", event);
|
172
|
+
}
|
173
|
+
|
174
|
+
private static class LoadProgressCallback {
|
175
|
+
LlamaContext context;
|
176
|
+
|
177
|
+
public LoadProgressCallback(LlamaContext context) {
|
178
|
+
this.context = context;
|
179
|
+
}
|
180
|
+
|
181
|
+
void onLoadProgress(int progress) {
|
182
|
+
context.emitLoadProgress(progress);
|
183
|
+
}
|
184
|
+
}
|
185
|
+
|
149
186
|
private void emitPartialCompletion(WritableMap tokenResult) {
|
150
187
|
WritableMap event = Arguments.createMap();
|
151
188
|
event.putInt("contextId", LlamaContext.this.id);
|
@@ -244,12 +281,10 @@ public class LlamaContext {
|
|
244
281
|
params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
|
245
282
|
// float min_p,
|
246
283
|
params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
|
247
|
-
// float
|
248
|
-
params.hasKey("
|
249
|
-
// float
|
250
|
-
params.hasKey("
|
251
|
-
// float tfs_z,
|
252
|
-
params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
|
284
|
+
// float xtc_threshold,
|
285
|
+
params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f,
|
286
|
+
// float xtc_probability,
|
287
|
+
params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f,
|
253
288
|
// float typical_p,
|
254
289
|
params.hasKey("typical_p") ? (float) params.getDouble("typical_p") : 1.00f,
|
255
290
|
// int seed,
|
@@ -260,6 +295,16 @@ public class LlamaContext {
|
|
260
295
|
params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false,
|
261
296
|
// double[][] logit_bias,
|
262
297
|
logit_bias,
|
298
|
+
// float dry_multiplier,
|
299
|
+
params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f,
|
300
|
+
// float dry_base,
|
301
|
+
params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f,
|
302
|
+
// int dry_allowed_length,
|
303
|
+
params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
|
304
|
+
// int dry_penalty_last_n,
|
305
|
+
params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
|
306
|
+
// String[] dry_sequence_breakers, when undef, we use the default definition from common.h
|
307
|
+
params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
|
263
308
|
// PartialCompletionCallback partial_completion_callback
|
264
309
|
new PartialCompletionCallback(
|
265
310
|
this,
|
@@ -294,11 +339,16 @@ public class LlamaContext {
|
|
294
339
|
return detokenize(this.context, toks);
|
295
340
|
}
|
296
341
|
|
297
|
-
public WritableMap getEmbedding(String text) {
|
342
|
+
public WritableMap getEmbedding(String text, ReadableMap params) {
|
298
343
|
if (isEmbeddingEnabled(this.context) == false) {
|
299
344
|
throw new IllegalStateException("Embedding is not enabled");
|
300
345
|
}
|
301
|
-
WritableMap result = embedding(
|
346
|
+
WritableMap result = embedding(
|
347
|
+
this.context,
|
348
|
+
text,
|
349
|
+
// int embd_normalize,
|
350
|
+
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1
|
351
|
+
);
|
302
352
|
if (result.hasKey("error")) {
|
303
353
|
throw new IllegalStateException(result.getString("error"));
|
304
354
|
}
|
@@ -315,17 +365,31 @@ public class LlamaContext {
|
|
315
365
|
|
316
366
|
static {
|
317
367
|
Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
|
318
|
-
if (LlamaContext.isArm64V8a()) {
|
319
|
-
String cpuFeatures = LlamaContext.getCpuFeatures();
|
320
|
-
Log.d(NAME, "CPU features: " + cpuFeatures);
|
321
|
-
|
322
|
-
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
|
323
|
-
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
|
324
|
-
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
|
325
|
-
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
|
326
|
-
boolean hasInt8Matmul = cpuFeatures.contains("i8mm");
|
327
368
|
|
328
|
-
|
369
|
+
String cpuFeatures = LlamaContext.getCpuFeatures();
|
370
|
+
Log.d(NAME, "CPU features: " + cpuFeatures);
|
371
|
+
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
|
372
|
+
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
|
373
|
+
boolean hasSve = cpuFeatures.contains("sve");
|
374
|
+
boolean hasI8mm = cpuFeatures.contains("i8mm");
|
375
|
+
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
|
376
|
+
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
|
377
|
+
Log.d(NAME, "- hasFp16: " + hasFp16);
|
378
|
+
Log.d(NAME, "- hasDotProd: " + hasDotProd);
|
379
|
+
Log.d(NAME, "- hasSve: " + hasSve);
|
380
|
+
Log.d(NAME, "- hasI8mm: " + hasI8mm);
|
381
|
+
Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82);
|
382
|
+
Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84);
|
383
|
+
|
384
|
+
// TODO: Add runtime check for cpu features
|
385
|
+
if (LlamaContext.isArm64V8a()) {
|
386
|
+
if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
|
387
|
+
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
|
388
|
+
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
|
389
|
+
} else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
|
390
|
+
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
|
391
|
+
System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
|
392
|
+
} else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
|
329
393
|
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
|
330
394
|
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
|
331
395
|
} else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
|
@@ -341,14 +405,16 @@ public class LlamaContext {
|
|
341
405
|
Log.d(NAME, "Loading librnllama_v8.so");
|
342
406
|
System.loadLibrary("rnllama_v8");
|
343
407
|
}
|
408
|
+
// Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
|
409
|
+
// System.loadLibrary("rnllama_v8_7");
|
344
410
|
} else if (LlamaContext.isX86_64()) {
|
345
|
-
|
346
|
-
|
411
|
+
Log.d(NAME, "Loading librnllama_x86_64.so");
|
412
|
+
System.loadLibrary("rnllama_x86_64");
|
347
413
|
} else {
|
348
|
-
|
349
|
-
|
414
|
+
Log.d(NAME, "Loading default librnllama.so");
|
415
|
+
System.loadLibrary("rnllama");
|
350
416
|
}
|
351
|
-
|
417
|
+
}
|
352
418
|
|
353
419
|
public static boolean isArm64V8a() {
|
354
420
|
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
|
@@ -384,13 +450,21 @@ public class LlamaContext {
|
|
384
450
|
eventEmitter.emit("@RNLlama_onModelProgress", event);
|
385
451
|
}
|
386
452
|
|
453
|
+
protected static native WritableMap modelInfo(
|
454
|
+
String model,
|
455
|
+
String[] skip
|
456
|
+
);
|
387
457
|
protected static native long initContext(
|
388
458
|
String model,
|
389
459
|
boolean embedding,
|
460
|
+
int embd_normalize,
|
390
461
|
int n_ctx,
|
391
462
|
int n_batch,
|
392
463
|
int n_threads,
|
393
464
|
int n_gpu_layers, // TODO: Support this
|
465
|
+
boolean flash_attn,
|
466
|
+
String cache_type_k,
|
467
|
+
String cache_type_v,
|
394
468
|
boolean use_mlock,
|
395
469
|
boolean use_mmap,
|
396
470
|
boolean vocab_only,
|
@@ -398,8 +472,10 @@ public class LlamaContext {
|
|
398
472
|
float lora_scaled,
|
399
473
|
float rope_freq_base,
|
400
474
|
float rope_freq_scale,
|
401
|
-
|
475
|
+
int pooling_type,
|
476
|
+
LoadProgressCallback load_progress_callback
|
402
477
|
);
|
478
|
+
protected static native void interruptLoad(long contextPtr);
|
403
479
|
protected static native WritableMap loadModelDetails(
|
404
480
|
long contextPtr
|
405
481
|
);
|
@@ -436,14 +512,18 @@ public class LlamaContext {
|
|
436
512
|
int top_k,
|
437
513
|
float top_p,
|
438
514
|
float min_p,
|
439
|
-
float
|
440
|
-
float
|
441
|
-
float tfs_z,
|
515
|
+
float xtc_threshold,
|
516
|
+
float xtc_probability,
|
442
517
|
float typical_p,
|
443
518
|
int seed,
|
444
519
|
String[] stop,
|
445
520
|
boolean ignore_eos,
|
446
521
|
double[][] logit_bias,
|
522
|
+
float dry_multiplier,
|
523
|
+
float dry_base,
|
524
|
+
int dry_allowed_length,
|
525
|
+
int dry_penalty_last_n,
|
526
|
+
String[] dry_sequence_breakers,
|
447
527
|
PartialCompletionCallback partial_completion_callback
|
448
528
|
);
|
449
529
|
protected static native void stopCompletion(long contextPtr);
|
@@ -451,7 +531,12 @@ public class LlamaContext {
|
|
451
531
|
protected static native WritableArray tokenize(long contextPtr, String text);
|
452
532
|
protected static native String detokenize(long contextPtr, int[] tokens);
|
453
533
|
protected static native boolean isEmbeddingEnabled(long contextPtr);
|
454
|
-
protected static native WritableMap embedding(
|
534
|
+
protected static native WritableMap embedding(
|
535
|
+
long contextPtr,
|
536
|
+
String text,
|
537
|
+
int embd_normalize
|
538
|
+
);
|
455
539
|
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
|
456
540
|
protected static native void freeContext(long contextPtr);
|
541
|
+
protected static native void logToAndroid();
|
457
542
|
}
|
@@ -42,21 +42,53 @@ public class RNLlama implements LifecycleEventListener {
|
|
42
42
|
promise.resolve(null);
|
43
43
|
}
|
44
44
|
|
45
|
-
public void
|
45
|
+
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
|
46
|
+
new AsyncTask<Void, Void, WritableMap>() {
|
47
|
+
private Exception exception;
|
48
|
+
|
49
|
+
@Override
|
50
|
+
protected WritableMap doInBackground(Void... voids) {
|
51
|
+
try {
|
52
|
+
String[] skipArray = new String[skip.size()];
|
53
|
+
for (int i = 0; i < skip.size(); i++) {
|
54
|
+
skipArray[i] = skip.getString(i);
|
55
|
+
}
|
56
|
+
return LlamaContext.modelInfo(model, skipArray);
|
57
|
+
} catch (Exception e) {
|
58
|
+
exception = e;
|
59
|
+
}
|
60
|
+
return null;
|
61
|
+
}
|
62
|
+
|
63
|
+
@Override
|
64
|
+
protected void onPostExecute(WritableMap result) {
|
65
|
+
if (exception != null) {
|
66
|
+
promise.reject(exception);
|
67
|
+
return;
|
68
|
+
}
|
69
|
+
promise.resolve(result);
|
70
|
+
}
|
71
|
+
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
|
72
|
+
}
|
73
|
+
|
74
|
+
public void initContext(double id, final ReadableMap params, final Promise promise) {
|
75
|
+
final int contextId = (int) id;
|
46
76
|
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
|
47
77
|
private Exception exception;
|
48
78
|
|
49
79
|
@Override
|
50
80
|
protected WritableMap doInBackground(Void... voids) {
|
51
81
|
try {
|
52
|
-
|
53
|
-
|
82
|
+
LlamaContext context = contexts.get(contextId);
|
83
|
+
if (context != null) {
|
84
|
+
throw new Exception("Context already exists");
|
85
|
+
}
|
86
|
+
LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
|
54
87
|
if (llamaContext.getContext() == 0) {
|
55
88
|
throw new Exception("Failed to initialize context");
|
56
89
|
}
|
57
|
-
contexts.put(
|
90
|
+
contexts.put(contextId, llamaContext);
|
58
91
|
WritableMap result = Arguments.createMap();
|
59
|
-
result.putInt("contextId", id);
|
60
92
|
result.putBoolean("gpu", false);
|
61
93
|
result.putString("reasonNoGPU", "Currently not supported");
|
62
94
|
result.putMap("model", llamaContext.getModelDetails());
|
@@ -366,7 +398,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
366
398
|
tasks.put(task, "detokenize-" + contextId);
|
367
399
|
}
|
368
400
|
|
369
|
-
public void embedding(double id, final String text, final Promise promise) {
|
401
|
+
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
|
370
402
|
final int contextId = (int) id;
|
371
403
|
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
|
372
404
|
private Exception exception;
|
@@ -378,7 +410,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
378
410
|
if (context == null) {
|
379
411
|
throw new Exception("Context not found");
|
380
412
|
}
|
381
|
-
return context.getEmbedding(text);
|
413
|
+
return context.getEmbedding(text, params);
|
382
414
|
} catch (Exception e) {
|
383
415
|
exception = e;
|
384
416
|
}
|
@@ -442,6 +474,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
442
474
|
if (context == null) {
|
443
475
|
throw new Exception("Context " + id + " not found");
|
444
476
|
}
|
477
|
+
context.interruptLoad();
|
445
478
|
context.stopCompletion();
|
446
479
|
AsyncTask completionTask = null;
|
447
480
|
for (AsyncTask task : tasks.keySet()) {
|