cui-llama.rn 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/LICENSE +20 -0
  2. package/README.md +330 -0
  3. package/android/build.gradle +107 -0
  4. package/android/gradle.properties +5 -0
  5. package/android/src/main/AndroidManifest.xml +4 -0
  6. package/android/src/main/CMakeLists.txt +69 -0
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
  8. package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
  10. package/android/src/main/jni.cpp +635 -0
  11. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
  12. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
  13. package/cpp/README.md +4 -0
  14. package/cpp/common.cpp +3237 -0
  15. package/cpp/common.h +467 -0
  16. package/cpp/ggml-aarch64.c +2193 -0
  17. package/cpp/ggml-aarch64.h +39 -0
  18. package/cpp/ggml-alloc.c +1041 -0
  19. package/cpp/ggml-alloc.h +76 -0
  20. package/cpp/ggml-backend-impl.h +153 -0
  21. package/cpp/ggml-backend.c +2225 -0
  22. package/cpp/ggml-backend.h +236 -0
  23. package/cpp/ggml-common.h +1829 -0
  24. package/cpp/ggml-impl.h +655 -0
  25. package/cpp/ggml-metal.h +65 -0
  26. package/cpp/ggml-metal.m +3273 -0
  27. package/cpp/ggml-quants.c +15022 -0
  28. package/cpp/ggml-quants.h +132 -0
  29. package/cpp/ggml.c +22034 -0
  30. package/cpp/ggml.h +2444 -0
  31. package/cpp/grammar-parser.cpp +536 -0
  32. package/cpp/grammar-parser.h +29 -0
  33. package/cpp/json-schema-to-grammar.cpp +1045 -0
  34. package/cpp/json-schema-to-grammar.h +8 -0
  35. package/cpp/json.hpp +24766 -0
  36. package/cpp/llama.cpp +21789 -0
  37. package/cpp/llama.h +1201 -0
  38. package/cpp/log.h +737 -0
  39. package/cpp/rn-llama.hpp +630 -0
  40. package/cpp/sampling.cpp +460 -0
  41. package/cpp/sampling.h +160 -0
  42. package/cpp/sgemm.cpp +1027 -0
  43. package/cpp/sgemm.h +14 -0
  44. package/cpp/unicode-data.cpp +7032 -0
  45. package/cpp/unicode-data.h +20 -0
  46. package/cpp/unicode.cpp +812 -0
  47. package/cpp/unicode.h +64 -0
  48. package/ios/RNLlama.h +11 -0
  49. package/ios/RNLlama.mm +302 -0
  50. package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
  51. package/ios/RNLlamaContext.h +39 -0
  52. package/ios/RNLlamaContext.mm +426 -0
  53. package/jest/mock.js +169 -0
  54. package/lib/commonjs/NativeRNLlama.js +10 -0
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -0
  56. package/lib/commonjs/grammar.js +574 -0
  57. package/lib/commonjs/grammar.js.map +1 -0
  58. package/lib/commonjs/index.js +151 -0
  59. package/lib/commonjs/index.js.map +1 -0
  60. package/lib/module/NativeRNLlama.js +3 -0
  61. package/lib/module/NativeRNLlama.js.map +1 -0
  62. package/lib/module/grammar.js +566 -0
  63. package/lib/module/grammar.js.map +1 -0
  64. package/lib/module/index.js +129 -0
  65. package/lib/module/index.js.map +1 -0
  66. package/lib/typescript/NativeRNLlama.d.ts +107 -0
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
  68. package/lib/typescript/grammar.d.ts +38 -0
  69. package/lib/typescript/grammar.d.ts.map +1 -0
  70. package/lib/typescript/index.d.ts +46 -0
  71. package/lib/typescript/index.d.ts.map +1 -0
  72. package/llama-rn.podspec +56 -0
  73. package/package.json +230 -0
  74. package/src/NativeRNLlama.ts +132 -0
  75. package/src/grammar.ts +849 -0
  76. package/src/index.ts +182 -0
@@ -0,0 +1,353 @@
1
+ package com.rnllama;
2
+
3
+ import com.facebook.react.bridge.Arguments;
4
+ import com.facebook.react.bridge.WritableArray;
5
+ import com.facebook.react.bridge.WritableMap;
6
+ import com.facebook.react.bridge.ReadableMap;
7
+ import com.facebook.react.bridge.ReadableArray;
8
+ import com.facebook.react.bridge.ReactApplicationContext;
9
+ import com.facebook.react.modules.core.DeviceEventManagerModule;
10
+
11
+ import android.util.Log;
12
+ import android.os.Build;
13
+ import android.content.res.AssetManager;
14
+
15
+ import java.lang.StringBuilder;
16
+ import java.io.BufferedReader;
17
+ import java.io.FileReader;
18
+ import java.io.File;
19
+ import java.io.IOException;
20
+
21
+ public class LlamaContext {
22
+ public static final String NAME = "RNLlamaContext";
23
+
24
+ private int id;
25
+ private ReactApplicationContext reactContext;
26
+ private long context;
27
+ private WritableMap modelDetails;
28
+ private int jobId = -1;
29
+ private DeviceEventManagerModule.RCTDeviceEventEmitter eventEmitter;
30
+
31
+ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap params) {
32
+ if (LlamaContext.isArm64V8a() == false && LlamaContext.isX86_64() == false) {
33
+ throw new IllegalStateException("Only 64-bit architectures are supported");
34
+ }
35
+ if (!params.hasKey("model")) {
36
+ throw new IllegalArgumentException("Missing required parameter: model");
37
+ }
38
+ this.id = id;
39
+ this.context = initContext(
40
+ // String model,
41
+ params.getString("model"),
42
+ // boolean embedding,
43
+ params.hasKey("embedding") ? params.getBoolean("embedding") : false,
44
+ // int n_ctx,
45
+ params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
46
+ // int n_batch,
47
+ params.hasKey("n_batch") ? params.getInt("n_batch") : 512,
48
+ // int n_threads,
49
+ params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
50
+ // int n_gpu_layers, // TODO: Support this
51
+ params.hasKey("n_gpu_layers") ? params.getInt("n_gpu_layers") : 0,
52
+ // boolean use_mlock,
53
+ params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
54
+ // boolean use_mmap,
55
+ params.hasKey("use_mmap") ? params.getBoolean("use_mmap") : true,
56
+ // String lora,
57
+ params.hasKey("lora") ? params.getString("lora") : "",
58
+ // float lora_scaled,
59
+ params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
60
+ // String lora_base,
61
+ params.hasKey("lora_base") ? params.getString("lora_base") : "",
62
+ // float rope_freq_base,
63
+ params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
64
+ // float rope_freq_scale
65
+ params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f
66
+ );
67
+ this.modelDetails = loadModelDetails(this.context);
68
+ this.reactContext = reactContext;
69
+ eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
70
+ }
71
+
72
+ public long getContext() {
73
+ return context;
74
+ }
75
+
76
+ public WritableMap getModelDetails() {
77
+ return modelDetails;
78
+ }
79
+
80
+ private void emitPartialCompletion(WritableMap tokenResult) {
81
+ WritableMap event = Arguments.createMap();
82
+ event.putInt("contextId", LlamaContext.this.id);
83
+ event.putMap("tokenResult", tokenResult);
84
+ eventEmitter.emit("@RNLlama_onToken", event);
85
+ }
86
+
87
+ private static class PartialCompletionCallback {
88
+ LlamaContext context;
89
+ boolean emitNeeded;
90
+
91
+ public PartialCompletionCallback(LlamaContext context, boolean emitNeeded) {
92
+ this.context = context;
93
+ this.emitNeeded = emitNeeded;
94
+ }
95
+
96
+ void onPartialCompletion(WritableMap tokenResult) {
97
+ if (!emitNeeded) return;
98
+ context.emitPartialCompletion(tokenResult);
99
+ }
100
+ }
101
+
102
+ public WritableMap loadSession(String path) {
103
+ if (path == null || path.isEmpty()) {
104
+ throw new IllegalArgumentException("File path is empty");
105
+ }
106
+ File file = new File(path);
107
+ if (!file.exists()) {
108
+ throw new IllegalArgumentException("File does not exist: " + path);
109
+ }
110
+ WritableMap result = loadSession(this.context, path);
111
+ if (result.hasKey("error")) {
112
+ throw new IllegalStateException(result.getString("error"));
113
+ }
114
+ return result;
115
+ }
116
+
117
+ public int saveSession(String path, int size) {
118
+ if (path == null || path.isEmpty()) {
119
+ throw new IllegalArgumentException("File path is empty");
120
+ }
121
+ return saveSession(this.context, path, size);
122
+ }
123
+
124
+ public WritableMap completion(ReadableMap params) {
125
+ if (!params.hasKey("prompt")) {
126
+ throw new IllegalArgumentException("Missing required parameter: prompt");
127
+ }
128
+
129
+ double[][] logit_bias = new double[0][0];
130
+ if (params.hasKey("logit_bias")) {
131
+ ReadableArray logit_bias_array = params.getArray("logit_bias");
132
+ logit_bias = new double[logit_bias_array.size()][];
133
+ for (int i = 0; i < logit_bias_array.size(); i++) {
134
+ ReadableArray logit_bias_row = logit_bias_array.getArray(i);
135
+ logit_bias[i] = new double[logit_bias_row.size()];
136
+ for (int j = 0; j < logit_bias_row.size(); j++) {
137
+ logit_bias[i][j] = logit_bias_row.getDouble(j);
138
+ }
139
+ }
140
+ }
141
+
142
+ return doCompletion(
143
+ this.context,
144
+ // String prompt,
145
+ params.getString("prompt"),
146
+ // String grammar,
147
+ params.hasKey("grammar") ? params.getString("grammar") : "",
148
+ // float temperature,
149
+ params.hasKey("temperature") ? (float) params.getDouble("temperature") : 0.7f,
150
+ // int n_threads,
151
+ params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
152
+ // int n_predict,
153
+ params.hasKey("n_predict") ? params.getInt("n_predict") : -1,
154
+ // int n_probs,
155
+ params.hasKey("n_probs") ? params.getInt("n_probs") : 0,
156
+ // int penalty_last_n,
157
+ params.hasKey("penalty_last_n") ? params.getInt("penalty_last_n") : 64,
158
+ // float penalty_repeat,
159
+ params.hasKey("penalty_repeat") ? (float) params.getDouble("penalty_repeat") : 1.00f,
160
+ // float penalty_freq,
161
+ params.hasKey("penalty_freq") ? (float) params.getDouble("penalty_freq") : 0.00f,
162
+ // float penalty_present,
163
+ params.hasKey("penalty_present") ? (float) params.getDouble("penalty_present") : 0.00f,
164
+ // float mirostat,
165
+ params.hasKey("mirostat") ? (float) params.getDouble("mirostat") : 0.00f,
166
+ // float mirostat_tau,
167
+ params.hasKey("mirostat_tau") ? (float) params.getDouble("mirostat_tau") : 5.00f,
168
+ // float mirostat_eta,
169
+ params.hasKey("mirostat_eta") ? (float) params.getDouble("mirostat_eta") : 0.10f,
170
+ // boolean penalize_nl,
171
+ params.hasKey("penalize_nl") ? params.getBoolean("penalize_nl") : false,
172
+ // int top_k,
173
+ params.hasKey("top_k") ? params.getInt("top_k") : 40,
174
+ // float top_p,
175
+ params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
176
+ // float min_p,
177
+ params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
178
+ // float tfs_z,
179
+ params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
180
+ // float typical_p,
181
+ params.hasKey("typical_p") ? (float) params.getDouble("typical_p") : 1.00f,
182
+ // int seed,
183
+ params.hasKey("seed") ? params.getInt("seed") : -1,
184
+ // String[] stop,
185
+ params.hasKey("stop") ? params.getArray("stop").toArrayList().toArray(new String[0]) : new String[0],
186
+ // boolean ignore_eos,
187
+ params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false,
188
+ // double[][] logit_bias,
189
+ logit_bias,
190
+ // PartialCompletionCallback partial_completion_callback
191
+ new PartialCompletionCallback(
192
+ this,
193
+ params.hasKey("emit_partial_completion") ? params.getBoolean("emit_partial_completion") : false
194
+ )
195
+ );
196
+ }
197
+
198
+ public void stopCompletion() {
199
+ stopCompletion(this.context);
200
+ }
201
+
202
+ public boolean isPredicting() {
203
+ return isPredicting(this.context);
204
+ }
205
+
206
+ public WritableMap tokenize(String text) {
207
+ WritableMap result = Arguments.createMap();
208
+ result.putArray("tokens", tokenize(this.context, text));
209
+ return result;
210
+ }
211
+
212
+ public String detokenize(ReadableArray tokens) {
213
+ int[] toks = new int[tokens.size()];
214
+ for (int i = 0; i < tokens.size(); i++) {
215
+ toks[i] = (int) tokens.getDouble(i);
216
+ }
217
+ return detokenize(this.context, toks);
218
+ }
219
+
220
+ public WritableMap embedding(String text) {
221
+ if (isEmbeddingEnabled(this.context) == false) {
222
+ throw new IllegalStateException("Embedding is not enabled");
223
+ }
224
+ WritableMap result = Arguments.createMap();
225
+ result.putArray("embedding", embedding(this.context, text));
226
+ return result;
227
+ }
228
+
229
+ public String bench(int pp, int tg, int pl, int nr) {
230
+ return bench(this.context, pp, tg, pl, nr);
231
+ }
232
+
233
+ public void release() {
234
+ freeContext(context);
235
+ }
236
+
237
+ static {
238
+ Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
239
+ if (LlamaContext.isArm64V8a()) {
240
+ boolean loadV8fp16 = false;
241
+ if (LlamaContext.isArm64V8a()) {
242
+ // ARMv8.2a needs runtime detection support
243
+ String cpuInfo = LlamaContext.cpuInfo();
244
+ if (cpuInfo != null) {
245
+ Log.d(NAME, "CPU info: " + cpuInfo);
246
+ if (cpuInfo.contains("fphp")) {
247
+ Log.d(NAME, "CPU supports fp16 arithmetic");
248
+ loadV8fp16 = true;
249
+ }
250
+ }
251
+ }
252
+
253
+ if (loadV8fp16) {
254
+ Log.d(NAME, "Loading librnllama_v8fp16_va.so");
255
+ System.loadLibrary("rnllama_v8fp16_va");
256
+ } else {
257
+ Log.d(NAME, "Loading librnllama.so");
258
+ System.loadLibrary("rnllama");
259
+ }
260
+ } else if (LlamaContext.isX86_64()) {
261
+ Log.d(NAME, "Loading librnllama.so");
262
+ System.loadLibrary("rnllama");
263
+ }
264
+ }
265
+
266
+ private static boolean isArm64V8a() {
267
+ return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
268
+ }
269
+
270
+ private static boolean isX86_64() {
271
+ return Build.SUPPORTED_ABIS[0].equals("x86_64");
272
+ }
273
+
274
+ private static String cpuInfo() {
275
+ File file = new File("/proc/cpuinfo");
276
+ StringBuilder stringBuilder = new StringBuilder();
277
+ try {
278
+ BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
279
+ String line;
280
+ while ((line = bufferedReader.readLine()) != null) {
281
+ stringBuilder.append(line);
282
+ }
283
+ bufferedReader.close();
284
+ return stringBuilder.toString();
285
+ } catch (IOException e) {
286
+ Log.w(NAME, "Couldn't read /proc/cpuinfo", e);
287
+ return null;
288
+ }
289
+ }
290
+
291
+ protected static native long initContext(
292
+ String model,
293
+ boolean embedding,
294
+ int n_ctx,
295
+ int n_batch,
296
+ int n_threads,
297
+ int n_gpu_layers, // TODO: Support this
298
+ boolean use_mlock,
299
+ boolean use_mmap,
300
+ String lora,
301
+ float lora_scaled,
302
+ String lora_base,
303
+ float rope_freq_base,
304
+ float rope_freq_scale
305
+ );
306
+ protected static native WritableMap loadModelDetails(
307
+ long contextPtr
308
+ );
309
+ protected static native WritableMap loadSession(
310
+ long contextPtr,
311
+ String path
312
+ );
313
+ protected static native int saveSession(
314
+ long contextPtr,
315
+ String path,
316
+ int size
317
+ );
318
+ protected static native WritableMap doCompletion(
319
+ long context_ptr,
320
+ String prompt,
321
+ String grammar,
322
+ float temperature,
323
+ int n_threads,
324
+ int n_predict,
325
+ int n_probs,
326
+ int penalty_last_n,
327
+ float penalty_repeat,
328
+ float penalty_freq,
329
+ float penalty_present,
330
+ float mirostat,
331
+ float mirostat_tau,
332
+ float mirostat_eta,
333
+ boolean penalize_nl,
334
+ int top_k,
335
+ float top_p,
336
+ float min_p,
337
+ float tfs_z,
338
+ float typical_p,
339
+ int seed,
340
+ String[] stop,
341
+ boolean ignore_eos,
342
+ double[][] logit_bias,
343
+ PartialCompletionCallback partial_completion_callback
344
+ );
345
+ protected static native void stopCompletion(long contextPtr);
346
+ protected static native boolean isPredicting(long contextPtr);
347
+ protected static native WritableArray tokenize(long contextPtr, String text);
348
+ protected static native String detokenize(long contextPtr, int[] tokens);
349
+ protected static native boolean isEmbeddingEnabled(long contextPtr);
350
+ protected static native WritableArray embedding(long contextPtr, String text);
351
+ protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
352
+ protected static native void freeContext(long contextPtr);
353
+ }