cui-llama.rn 1.2.6 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -2
- package/android/src/main/CMakeLists.txt +20 -5
- package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
- package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
- package/android/src/main/jni.cpp +222 -34
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
- package/cpp/common.cpp +1682 -2114
- package/cpp/common.h +600 -613
- package/cpp/ggml-aarch64.c +129 -3478
- package/cpp/ggml-aarch64.h +19 -39
- package/cpp/ggml-alloc.c +1040 -1040
- package/cpp/ggml-alloc.h +76 -76
- package/cpp/ggml-backend-impl.h +216 -216
- package/cpp/ggml-backend-reg.cpp +195 -0
- package/cpp/ggml-backend.cpp +1997 -2661
- package/cpp/ggml-backend.h +328 -314
- package/cpp/ggml-common.h +1853 -1853
- package/cpp/ggml-cpp.h +38 -38
- package/cpp/ggml-cpu-aarch64.c +3560 -0
- package/cpp/ggml-cpu-aarch64.h +30 -0
- package/cpp/ggml-cpu-impl.h +371 -614
- package/cpp/ggml-cpu-quants.c +10822 -0
- package/cpp/ggml-cpu-quants.h +63 -0
- package/cpp/ggml-cpu.c +13975 -13720
- package/cpp/ggml-cpu.cpp +663 -0
- package/cpp/ggml-cpu.h +177 -150
- package/cpp/ggml-impl.h +550 -296
- package/cpp/ggml-metal.h +66 -66
- package/cpp/ggml-metal.m +4294 -3933
- package/cpp/ggml-quants.c +5247 -15739
- package/cpp/ggml-quants.h +100 -147
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +12 -0
- package/cpp/ggml.c +8180 -8390
- package/cpp/ggml.h +2411 -2441
- package/cpp/llama-grammar.cpp +1138 -1138
- package/cpp/llama-grammar.h +144 -144
- package/cpp/llama-impl.h +181 -181
- package/cpp/llama-sampling.cpp +2348 -2345
- package/cpp/llama-sampling.h +48 -48
- package/cpp/llama-vocab.cpp +1984 -1984
- package/cpp/llama-vocab.h +170 -170
- package/cpp/llama.cpp +22132 -22046
- package/cpp/llama.h +1253 -1255
- package/cpp/log.cpp +401 -401
- package/cpp/log.h +121 -121
- package/cpp/rn-llama.hpp +83 -19
- package/cpp/sampling.cpp +466 -466
- package/cpp/sgemm.cpp +1884 -1276
- package/ios/RNLlama.mm +43 -20
- package/ios/RNLlamaContext.h +9 -3
- package/ios/RNLlamaContext.mm +133 -33
- package/jest/mock.js +0 -1
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +52 -15
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +51 -15
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +29 -5
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +12 -5
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +41 -6
- package/src/index.ts +82 -27
- package/cpp/json-schema-to-grammar.cpp +0 -1045
- package/cpp/json-schema-to-grammar.h +0 -8
- package/cpp/json.hpp +0 -24766
package/README.md
CHANGED
@@ -6,13 +6,14 @@ This fork exists to update llama.cpp on a more frequent basis, plus adding usefu
|
|
6
6
|
|
7
7
|
The following features have been added for Android:
|
8
8
|
|
9
|
-
- Updated sync for llama.cpp
|
10
9
|
- Added stopping prompt processing between batches, vital for mobile devices with very slow prompt processing
|
11
10
|
- `vocab_only` mode: utilize the llama.cpp tokenizer
|
12
11
|
- tokenizeSync: non-blocking, synchronous tokenizer function
|
13
12
|
- Context Shift taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)
|
14
13
|
- Retrieving CPU Features to check for i8mm and dotprod flags
|
15
14
|
|
15
|
+
There is no IOS implementation for these features.
|
16
|
+
|
16
17
|
Original repo README.md below.
|
17
18
|
|
18
19
|
# llama.rn
|
@@ -305,7 +306,7 @@ iOS:
|
|
305
306
|
|
306
307
|
- The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
|
307
308
|
- Metal:
|
308
|
-
- We have tested to know some devices is not able to use Metal (
|
309
|
+
- We have tested to know some devices is not able to use Metal (GPU) due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
|
309
310
|
- It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.
|
310
311
|
|
311
312
|
Android:
|
@@ -14,22 +14,28 @@ set(
|
|
14
14
|
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
15
15
|
${RNLLAMA_LIB_DIR}/log.cpp
|
16
16
|
|
17
|
+
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
|
18
|
+
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
|
19
|
+
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
|
20
|
+
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
21
|
+
${RNLLAMA_LIB_DIR}/log.cpp
|
22
|
+
|
17
23
|
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
|
18
24
|
${RNLLAMA_LIB_DIR}/ggml-alloc.c
|
19
25
|
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
|
26
|
+
${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
|
20
27
|
${RNLLAMA_LIB_DIR}/ggml.c
|
21
28
|
${RNLLAMA_LIB_DIR}/ggml-cpu.c
|
29
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
|
30
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c
|
31
|
+
${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
|
32
|
+
${RNLLAMA_LIB_DIR}/ggml-threading.cpp
|
22
33
|
${RNLLAMA_LIB_DIR}/ggml-quants.c
|
23
34
|
${RNLLAMA_LIB_DIR}/common.cpp
|
24
|
-
${RNLLAMA_LIB_DIR}/json.hpp
|
25
|
-
${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
|
26
35
|
${RNLLAMA_LIB_DIR}/sampling.cpp
|
27
36
|
${RNLLAMA_LIB_DIR}/unicode-data.cpp
|
28
37
|
${RNLLAMA_LIB_DIR}/unicode.cpp
|
29
38
|
${RNLLAMA_LIB_DIR}/llama.cpp
|
30
|
-
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
31
|
-
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
|
32
|
-
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
|
33
39
|
${RNLLAMA_LIB_DIR}/sgemm.cpp
|
34
40
|
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
|
35
41
|
${RNLLAMA_LIB_DIR}/rn-llama.hpp
|
@@ -72,11 +78,20 @@ build_library("rnllama" "")
|
|
72
78
|
|
73
79
|
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
|
74
80
|
# ARM64 targets
|
81
|
+
build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
|
82
|
+
build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
|
75
83
|
build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
|
76
84
|
build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
|
77
85
|
build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
|
78
86
|
build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
|
79
87
|
build_library("rnllama_v8" "-march=armv8-a")
|
88
|
+
|
89
|
+
# https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
|
90
|
+
# llama.cpp will deal with the cpu features
|
91
|
+
# build_library("rnllama_v8_7" "-march=armv8.7-a")
|
92
|
+
# TODO: Add support runtime check for cpu features
|
93
|
+
# At the moment runtime check is failing.
|
94
|
+
|
80
95
|
elseif (${ANDROID_ABI} STREQUAL "x86_64")
|
81
96
|
# x86_64 target
|
82
97
|
build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt")
|
@@ -93,7 +93,7 @@ public class LlamaContext {
|
|
93
93
|
Log.e(NAME, "Failed to convert to FD!");
|
94
94
|
}
|
95
95
|
}
|
96
|
-
|
96
|
+
logToAndroid();
|
97
97
|
// Check if file has GGUF magic numbers
|
98
98
|
this.id = id;
|
99
99
|
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
|
@@ -102,6 +102,8 @@ public class LlamaContext {
|
|
102
102
|
modelName,
|
103
103
|
// boolean embedding,
|
104
104
|
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
|
105
|
+
// int embd_normalize,
|
106
|
+
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1,
|
105
107
|
// int n_ctx,
|
106
108
|
params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
|
107
109
|
// int n_batch,
|
@@ -110,6 +112,12 @@ public class LlamaContext {
|
|
110
112
|
params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
|
111
113
|
// int n_gpu_layers, // TODO: Support this
|
112
114
|
params.hasKey("n_gpu_layers") ? params.getInt("n_gpu_layers") : 0,
|
115
|
+
// boolean flash_attn,
|
116
|
+
params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
|
117
|
+
// String cache_type_k,
|
118
|
+
params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
|
119
|
+
// String cache_type_v,
|
120
|
+
params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
|
113
121
|
// boolean use_mlock,
|
114
122
|
params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
|
115
123
|
// boolean use_mmap,
|
@@ -124,12 +132,22 @@ public class LlamaContext {
|
|
124
132
|
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
|
125
133
|
// float rope_freq_scale
|
126
134
|
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
|
127
|
-
|
135
|
+
// int pooling_type,
|
136
|
+
params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1,
|
137
|
+
// LoadProgressCallback load_progress_callback
|
138
|
+
params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
|
128
139
|
);
|
140
|
+
if (this.context == -1) {
|
141
|
+
throw new IllegalStateException("Failed to initialize context");
|
142
|
+
}
|
129
143
|
this.modelDetails = loadModelDetails(this.context);
|
130
144
|
this.reactContext = reactContext;
|
131
145
|
}
|
132
146
|
|
147
|
+
public void interruptLoad() {
|
148
|
+
interruptLoad(this.context);
|
149
|
+
}
|
150
|
+
|
133
151
|
public long getContext() {
|
134
152
|
return context;
|
135
153
|
}
|
@@ -146,6 +164,25 @@ public class LlamaContext {
|
|
146
164
|
return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
|
147
165
|
}
|
148
166
|
|
167
|
+
private void emitLoadProgress(int progress) {
|
168
|
+
WritableMap event = Arguments.createMap();
|
169
|
+
event.putInt("contextId", LlamaContext.this.id);
|
170
|
+
event.putInt("progress", progress);
|
171
|
+
eventEmitter.emit("@RNLlama_onInitContextProgress", event);
|
172
|
+
}
|
173
|
+
|
174
|
+
private static class LoadProgressCallback {
|
175
|
+
LlamaContext context;
|
176
|
+
|
177
|
+
public LoadProgressCallback(LlamaContext context) {
|
178
|
+
this.context = context;
|
179
|
+
}
|
180
|
+
|
181
|
+
void onLoadProgress(int progress) {
|
182
|
+
context.emitLoadProgress(progress);
|
183
|
+
}
|
184
|
+
}
|
185
|
+
|
149
186
|
private void emitPartialCompletion(WritableMap tokenResult) {
|
150
187
|
WritableMap event = Arguments.createMap();
|
151
188
|
event.putInt("contextId", LlamaContext.this.id);
|
@@ -244,10 +281,10 @@ public class LlamaContext {
|
|
244
281
|
params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
|
245
282
|
// float min_p,
|
246
283
|
params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
|
247
|
-
// float
|
248
|
-
params.hasKey("
|
249
|
-
// float
|
250
|
-
params.hasKey("
|
284
|
+
// float xtc_threshold,
|
285
|
+
params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f,
|
286
|
+
// float xtc_probability,
|
287
|
+
params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f,
|
251
288
|
// float typical_p,
|
252
289
|
params.hasKey("typical_p") ? (float) params.getDouble("typical_p") : 1.00f,
|
253
290
|
// int seed,
|
@@ -258,6 +295,16 @@ public class LlamaContext {
|
|
258
295
|
params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false,
|
259
296
|
// double[][] logit_bias,
|
260
297
|
logit_bias,
|
298
|
+
// float dry_multiplier,
|
299
|
+
params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f,
|
300
|
+
// float dry_base,
|
301
|
+
params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f,
|
302
|
+
// int dry_allowed_length,
|
303
|
+
params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
|
304
|
+
// int dry_penalty_last_n,
|
305
|
+
params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
|
306
|
+
// String[] dry_sequence_breakers, when undef, we use the default definition from common.h
|
307
|
+
params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
|
261
308
|
// PartialCompletionCallback partial_completion_callback
|
262
309
|
new PartialCompletionCallback(
|
263
310
|
this,
|
@@ -292,11 +339,16 @@ public class LlamaContext {
|
|
292
339
|
return detokenize(this.context, toks);
|
293
340
|
}
|
294
341
|
|
295
|
-
public WritableMap getEmbedding(String text) {
|
342
|
+
public WritableMap getEmbedding(String text, ReadableMap params) {
|
296
343
|
if (isEmbeddingEnabled(this.context) == false) {
|
297
344
|
throw new IllegalStateException("Embedding is not enabled");
|
298
345
|
}
|
299
|
-
WritableMap result = embedding(
|
346
|
+
WritableMap result = embedding(
|
347
|
+
this.context,
|
348
|
+
text,
|
349
|
+
// int embd_normalize,
|
350
|
+
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1
|
351
|
+
);
|
300
352
|
if (result.hasKey("error")) {
|
301
353
|
throw new IllegalStateException(result.getString("error"));
|
302
354
|
}
|
@@ -313,17 +365,31 @@ public class LlamaContext {
|
|
313
365
|
|
314
366
|
static {
|
315
367
|
Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
|
316
|
-
if (LlamaContext.isArm64V8a()) {
|
317
|
-
String cpuFeatures = LlamaContext.getCpuFeatures();
|
318
|
-
Log.d(NAME, "CPU features: " + cpuFeatures);
|
319
|
-
|
320
|
-
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
|
321
|
-
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
|
322
|
-
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
|
323
|
-
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
|
324
|
-
boolean hasInt8Matmul = cpuFeatures.contains("i8mm");
|
325
368
|
|
326
|
-
|
369
|
+
String cpuFeatures = LlamaContext.getCpuFeatures();
|
370
|
+
Log.d(NAME, "CPU features: " + cpuFeatures);
|
371
|
+
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
|
372
|
+
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
|
373
|
+
boolean hasSve = cpuFeatures.contains("sve");
|
374
|
+
boolean hasI8mm = cpuFeatures.contains("i8mm");
|
375
|
+
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
|
376
|
+
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
|
377
|
+
Log.d(NAME, "- hasFp16: " + hasFp16);
|
378
|
+
Log.d(NAME, "- hasDotProd: " + hasDotProd);
|
379
|
+
Log.d(NAME, "- hasSve: " + hasSve);
|
380
|
+
Log.d(NAME, "- hasI8mm: " + hasI8mm);
|
381
|
+
Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82);
|
382
|
+
Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84);
|
383
|
+
|
384
|
+
// TODO: Add runtime check for cpu features
|
385
|
+
if (LlamaContext.isArm64V8a()) {
|
386
|
+
if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
|
387
|
+
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
|
388
|
+
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
|
389
|
+
} else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
|
390
|
+
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
|
391
|
+
System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
|
392
|
+
} else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
|
327
393
|
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
|
328
394
|
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
|
329
395
|
} else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
|
@@ -339,14 +405,16 @@ public class LlamaContext {
|
|
339
405
|
Log.d(NAME, "Loading librnllama_v8.so");
|
340
406
|
System.loadLibrary("rnllama_v8");
|
341
407
|
}
|
408
|
+
// Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
|
409
|
+
// System.loadLibrary("rnllama_v8_7");
|
342
410
|
} else if (LlamaContext.isX86_64()) {
|
343
|
-
|
344
|
-
|
411
|
+
Log.d(NAME, "Loading librnllama_x86_64.so");
|
412
|
+
System.loadLibrary("rnllama_x86_64");
|
345
413
|
} else {
|
346
|
-
|
347
|
-
|
414
|
+
Log.d(NAME, "Loading default librnllama.so");
|
415
|
+
System.loadLibrary("rnllama");
|
348
416
|
}
|
349
|
-
|
417
|
+
}
|
350
418
|
|
351
419
|
public static boolean isArm64V8a() {
|
352
420
|
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
|
@@ -382,13 +450,21 @@ public class LlamaContext {
|
|
382
450
|
eventEmitter.emit("@RNLlama_onModelProgress", event);
|
383
451
|
}
|
384
452
|
|
453
|
+
protected static native WritableMap modelInfo(
|
454
|
+
String model,
|
455
|
+
String[] skip
|
456
|
+
);
|
385
457
|
protected static native long initContext(
|
386
458
|
String model,
|
387
459
|
boolean embedding,
|
460
|
+
int embd_normalize,
|
388
461
|
int n_ctx,
|
389
462
|
int n_batch,
|
390
463
|
int n_threads,
|
391
464
|
int n_gpu_layers, // TODO: Support this
|
465
|
+
boolean flash_attn,
|
466
|
+
String cache_type_k,
|
467
|
+
String cache_type_v,
|
392
468
|
boolean use_mlock,
|
393
469
|
boolean use_mmap,
|
394
470
|
boolean vocab_only,
|
@@ -396,8 +472,10 @@ public class LlamaContext {
|
|
396
472
|
float lora_scaled,
|
397
473
|
float rope_freq_base,
|
398
474
|
float rope_freq_scale,
|
399
|
-
|
475
|
+
int pooling_type,
|
476
|
+
LoadProgressCallback load_progress_callback
|
400
477
|
);
|
478
|
+
protected static native void interruptLoad(long contextPtr);
|
401
479
|
protected static native WritableMap loadModelDetails(
|
402
480
|
long contextPtr
|
403
481
|
);
|
@@ -434,13 +512,18 @@ public class LlamaContext {
|
|
434
512
|
int top_k,
|
435
513
|
float top_p,
|
436
514
|
float min_p,
|
437
|
-
float
|
438
|
-
float
|
515
|
+
float xtc_threshold,
|
516
|
+
float xtc_probability,
|
439
517
|
float typical_p,
|
440
518
|
int seed,
|
441
519
|
String[] stop,
|
442
520
|
boolean ignore_eos,
|
443
521
|
double[][] logit_bias,
|
522
|
+
float dry_multiplier,
|
523
|
+
float dry_base,
|
524
|
+
int dry_allowed_length,
|
525
|
+
int dry_penalty_last_n,
|
526
|
+
String[] dry_sequence_breakers,
|
444
527
|
PartialCompletionCallback partial_completion_callback
|
445
528
|
);
|
446
529
|
protected static native void stopCompletion(long contextPtr);
|
@@ -448,7 +531,12 @@ public class LlamaContext {
|
|
448
531
|
protected static native WritableArray tokenize(long contextPtr, String text);
|
449
532
|
protected static native String detokenize(long contextPtr, int[] tokens);
|
450
533
|
protected static native boolean isEmbeddingEnabled(long contextPtr);
|
451
|
-
protected static native WritableMap embedding(
|
534
|
+
protected static native WritableMap embedding(
|
535
|
+
long contextPtr,
|
536
|
+
String text,
|
537
|
+
int embd_normalize
|
538
|
+
);
|
452
539
|
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
|
453
540
|
protected static native void freeContext(long contextPtr);
|
541
|
+
protected static native void logToAndroid();
|
454
542
|
}
|
@@ -42,21 +42,53 @@ public class RNLlama implements LifecycleEventListener {
|
|
42
42
|
promise.resolve(null);
|
43
43
|
}
|
44
44
|
|
45
|
-
public void
|
45
|
+
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
|
46
|
+
new AsyncTask<Void, Void, WritableMap>() {
|
47
|
+
private Exception exception;
|
48
|
+
|
49
|
+
@Override
|
50
|
+
protected WritableMap doInBackground(Void... voids) {
|
51
|
+
try {
|
52
|
+
String[] skipArray = new String[skip.size()];
|
53
|
+
for (int i = 0; i < skip.size(); i++) {
|
54
|
+
skipArray[i] = skip.getString(i);
|
55
|
+
}
|
56
|
+
return LlamaContext.modelInfo(model, skipArray);
|
57
|
+
} catch (Exception e) {
|
58
|
+
exception = e;
|
59
|
+
}
|
60
|
+
return null;
|
61
|
+
}
|
62
|
+
|
63
|
+
@Override
|
64
|
+
protected void onPostExecute(WritableMap result) {
|
65
|
+
if (exception != null) {
|
66
|
+
promise.reject(exception);
|
67
|
+
return;
|
68
|
+
}
|
69
|
+
promise.resolve(result);
|
70
|
+
}
|
71
|
+
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
|
72
|
+
}
|
73
|
+
|
74
|
+
public void initContext(double id, final ReadableMap params, final Promise promise) {
|
75
|
+
final int contextId = (int) id;
|
46
76
|
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
|
47
77
|
private Exception exception;
|
48
78
|
|
49
79
|
@Override
|
50
80
|
protected WritableMap doInBackground(Void... voids) {
|
51
81
|
try {
|
52
|
-
|
53
|
-
|
82
|
+
LlamaContext context = contexts.get(contextId);
|
83
|
+
if (context != null) {
|
84
|
+
throw new Exception("Context already exists");
|
85
|
+
}
|
86
|
+
LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
|
54
87
|
if (llamaContext.getContext() == 0) {
|
55
88
|
throw new Exception("Failed to initialize context");
|
56
89
|
}
|
57
|
-
contexts.put(
|
90
|
+
contexts.put(contextId, llamaContext);
|
58
91
|
WritableMap result = Arguments.createMap();
|
59
|
-
result.putInt("contextId", id);
|
60
92
|
result.putBoolean("gpu", false);
|
61
93
|
result.putString("reasonNoGPU", "Currently not supported");
|
62
94
|
result.putMap("model", llamaContext.getModelDetails());
|
@@ -366,7 +398,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
366
398
|
tasks.put(task, "detokenize-" + contextId);
|
367
399
|
}
|
368
400
|
|
369
|
-
public void embedding(double id, final String text, final Promise promise) {
|
401
|
+
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
|
370
402
|
final int contextId = (int) id;
|
371
403
|
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
|
372
404
|
private Exception exception;
|
@@ -378,7 +410,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
378
410
|
if (context == null) {
|
379
411
|
throw new Exception("Context not found");
|
380
412
|
}
|
381
|
-
return context.getEmbedding(text);
|
413
|
+
return context.getEmbedding(text, params);
|
382
414
|
} catch (Exception e) {
|
383
415
|
exception = e;
|
384
416
|
}
|
@@ -442,6 +474,7 @@ public class RNLlama implements LifecycleEventListener {
|
|
442
474
|
if (context == null) {
|
443
475
|
throw new Exception("Context " + id + " not found");
|
444
476
|
}
|
477
|
+
context.interruptLoad();
|
445
478
|
context.stopCompletion();
|
446
479
|
AsyncTask completionTask = null;
|
447
480
|
for (AsyncTask task : tasks.keySet()) {
|