cui-llama.rn 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/LICENSE +20 -0
  2. package/README.md +330 -0
  3. package/android/build.gradle +107 -0
  4. package/android/gradle.properties +5 -0
  5. package/android/src/main/AndroidManifest.xml +4 -0
  6. package/android/src/main/CMakeLists.txt +69 -0
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
  8. package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
  10. package/android/src/main/jni.cpp +635 -0
  11. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
  12. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
  13. package/cpp/README.md +4 -0
  14. package/cpp/common.cpp +3237 -0
  15. package/cpp/common.h +467 -0
  16. package/cpp/ggml-aarch64.c +2193 -0
  17. package/cpp/ggml-aarch64.h +39 -0
  18. package/cpp/ggml-alloc.c +1041 -0
  19. package/cpp/ggml-alloc.h +76 -0
  20. package/cpp/ggml-backend-impl.h +153 -0
  21. package/cpp/ggml-backend.c +2225 -0
  22. package/cpp/ggml-backend.h +236 -0
  23. package/cpp/ggml-common.h +1829 -0
  24. package/cpp/ggml-impl.h +655 -0
  25. package/cpp/ggml-metal.h +65 -0
  26. package/cpp/ggml-metal.m +3273 -0
  27. package/cpp/ggml-quants.c +15022 -0
  28. package/cpp/ggml-quants.h +132 -0
  29. package/cpp/ggml.c +22034 -0
  30. package/cpp/ggml.h +2444 -0
  31. package/cpp/grammar-parser.cpp +536 -0
  32. package/cpp/grammar-parser.h +29 -0
  33. package/cpp/json-schema-to-grammar.cpp +1045 -0
  34. package/cpp/json-schema-to-grammar.h +8 -0
  35. package/cpp/json.hpp +24766 -0
  36. package/cpp/llama.cpp +21789 -0
  37. package/cpp/llama.h +1201 -0
  38. package/cpp/log.h +737 -0
  39. package/cpp/rn-llama.hpp +630 -0
  40. package/cpp/sampling.cpp +460 -0
  41. package/cpp/sampling.h +160 -0
  42. package/cpp/sgemm.cpp +1027 -0
  43. package/cpp/sgemm.h +14 -0
  44. package/cpp/unicode-data.cpp +7032 -0
  45. package/cpp/unicode-data.h +20 -0
  46. package/cpp/unicode.cpp +812 -0
  47. package/cpp/unicode.h +64 -0
  48. package/ios/RNLlama.h +11 -0
  49. package/ios/RNLlama.mm +302 -0
  50. package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
  51. package/ios/RNLlamaContext.h +39 -0
  52. package/ios/RNLlamaContext.mm +426 -0
  53. package/jest/mock.js +169 -0
  54. package/lib/commonjs/NativeRNLlama.js +10 -0
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -0
  56. package/lib/commonjs/grammar.js +574 -0
  57. package/lib/commonjs/grammar.js.map +1 -0
  58. package/lib/commonjs/index.js +151 -0
  59. package/lib/commonjs/index.js.map +1 -0
  60. package/lib/module/NativeRNLlama.js +3 -0
  61. package/lib/module/NativeRNLlama.js.map +1 -0
  62. package/lib/module/grammar.js +566 -0
  63. package/lib/module/grammar.js.map +1 -0
  64. package/lib/module/index.js +129 -0
  65. package/lib/module/index.js.map +1 -0
  66. package/lib/typescript/NativeRNLlama.d.ts +107 -0
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
  68. package/lib/typescript/grammar.d.ts +38 -0
  69. package/lib/typescript/grammar.d.ts.map +1 -0
  70. package/lib/typescript/index.d.ts +46 -0
  71. package/lib/typescript/index.d.ts.map +1 -0
  72. package/llama-rn.podspec +56 -0
  73. package/package.json +230 -0
  74. package/src/NativeRNLlama.ts +132 -0
  75. package/src/grammar.ts +849 -0
  76. package/src/index.ts +182 -0
@@ -0,0 +1,446 @@
1
+ package com.rnllama;
2
+
3
+ import androidx.annotation.NonNull;
4
+ import android.util.Log;
5
+ import android.os.Build;
6
+ import android.os.Handler;
7
+ import android.os.AsyncTask;
8
+
9
+ import com.facebook.react.bridge.Promise;
10
+ import com.facebook.react.bridge.ReactApplicationContext;
11
+ import com.facebook.react.bridge.ReactMethod;
12
+ import com.facebook.react.bridge.LifecycleEventListener;
13
+ import com.facebook.react.bridge.ReadableMap;
14
+ import com.facebook.react.bridge.ReadableArray;
15
+ import com.facebook.react.bridge.WritableMap;
16
+ import com.facebook.react.bridge.Arguments;
17
+
18
+ import java.util.HashMap;
19
+ import java.util.Random;
20
+ import java.io.File;
21
+ import java.io.FileInputStream;
22
+ import java.io.PushbackInputStream;
23
+
24
+ public class RNLlama implements LifecycleEventListener {
25
+ public static final String NAME = "RNLlama";
26
+
27
+ private ReactApplicationContext reactContext;
28
+
29
+ public RNLlama(ReactApplicationContext reactContext) {
30
+ reactContext.addLifecycleEventListener(this);
31
+ this.reactContext = reactContext;
32
+ }
33
+
34
+ private HashMap<AsyncTask, String> tasks = new HashMap<>();
35
+
36
+ private HashMap<Integer, LlamaContext> contexts = new HashMap<>();
37
+
38
+ private int llamaContextLimit = 1;
39
+
40
+ public void setContextLimit(double limit, Promise promise) {
41
+ llamaContextLimit = (int) limit;
42
+ promise.resolve(null);
43
+ }
44
+
45
+ public void initContext(final ReadableMap params, final Promise promise) {
46
+ AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
47
+ private Exception exception;
48
+
49
+ @Override
50
+ protected WritableMap doInBackground(Void... voids) {
51
+ try {
52
+ int id = Math.abs(new Random().nextInt());
53
+ LlamaContext llamaContext = new LlamaContext(id, reactContext, params);
54
+ if (llamaContext.getContext() == 0) {
55
+ throw new Exception("Failed to initialize context");
56
+ }
57
+ contexts.put(id, llamaContext);
58
+ WritableMap result = Arguments.createMap();
59
+ result.putInt("contextId", id);
60
+ result.putBoolean("gpu", false);
61
+ result.putString("reasonNoGPU", "Currently not supported");
62
+ result.putMap("model", llamaContext.getModelDetails());
63
+ return result;
64
+ } catch (Exception e) {
65
+ exception = e;
66
+ return null;
67
+ }
68
+ }
69
+
70
+ @Override
71
+ protected void onPostExecute(WritableMap result) {
72
+ if (exception != null) {
73
+ promise.reject(exception);
74
+ return;
75
+ }
76
+ promise.resolve(result);
77
+ tasks.remove(this);
78
+ }
79
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
80
+ tasks.put(task, "initContext");
81
+ }
82
+
83
+ public void loadSession(double id, final String path, Promise promise) {
84
+ final int contextId = (int) id;
85
+ AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
86
+ private Exception exception;
87
+
88
+ @Override
89
+ protected WritableMap doInBackground(Void... voids) {
90
+ try {
91
+ LlamaContext context = contexts.get(contextId);
92
+ if (context == null) {
93
+ throw new Exception("Context not found");
94
+ }
95
+ WritableMap result = context.loadSession(path);
96
+ return result;
97
+ } catch (Exception e) {
98
+ exception = e;
99
+ }
100
+ return null;
101
+ }
102
+
103
+ @Override
104
+ protected void onPostExecute(WritableMap result) {
105
+ if (exception != null) {
106
+ promise.reject(exception);
107
+ return;
108
+ }
109
+ promise.resolve(result);
110
+ tasks.remove(this);
111
+ }
112
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
113
+ tasks.put(task, "loadSession-" + contextId);
114
+ }
115
+
116
+ public void saveSession(double id, final String path, double size, Promise promise) {
117
+ final int contextId = (int) id;
118
+ AsyncTask task = new AsyncTask<Void, Void, Integer>() {
119
+ private Exception exception;
120
+
121
+ @Override
122
+ protected Integer doInBackground(Void... voids) {
123
+ try {
124
+ LlamaContext context = contexts.get(contextId);
125
+ if (context == null) {
126
+ throw new Exception("Context not found");
127
+ }
128
+ Integer count = context.saveSession(path, (int) size);
129
+ return count;
130
+ } catch (Exception e) {
131
+ exception = e;
132
+ }
133
+ return -1;
134
+ }
135
+
136
+ @Override
137
+ protected void onPostExecute(Integer result) {
138
+ if (exception != null) {
139
+ promise.reject(exception);
140
+ return;
141
+ }
142
+ promise.resolve(result);
143
+ tasks.remove(this);
144
+ }
145
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
146
+ tasks.put(task, "saveSession-" + contextId);
147
+ }
148
+
149
+ public void completion(double id, final ReadableMap params, final Promise promise) {
150
+ final int contextId = (int) id;
151
+ AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
152
+ private Exception exception;
153
+
154
+ @Override
155
+ protected WritableMap doInBackground(Void... voids) {
156
+ try {
157
+ LlamaContext context = contexts.get(contextId);
158
+ if (context == null) {
159
+ throw new Exception("Context not found");
160
+ }
161
+ if (context.isPredicting()) {
162
+ throw new Exception("Context is busy");
163
+ }
164
+ WritableMap result = context.completion(params);
165
+ return result;
166
+ } catch (Exception e) {
167
+ exception = e;
168
+ }
169
+ return null;
170
+ }
171
+
172
+ @Override
173
+ protected void onPostExecute(WritableMap result) {
174
+ if (exception != null) {
175
+ promise.reject(exception);
176
+ return;
177
+ }
178
+ promise.resolve(result);
179
+ tasks.remove(this);
180
+ }
181
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
182
+ tasks.put(task, "completion-" + contextId);
183
+ }
184
+
185
+ public void stopCompletion(double id, final Promise promise) {
186
+ final int contextId = (int) id;
187
+ AsyncTask task = new AsyncTask<Void, Void, Void>() {
188
+ private Exception exception;
189
+
190
+ @Override
191
+ protected Void doInBackground(Void... voids) {
192
+ try {
193
+ LlamaContext context = contexts.get(contextId);
194
+ if (context == null) {
195
+ throw new Exception("Context not found");
196
+ }
197
+ context.stopCompletion();
198
+ AsyncTask completionTask = null;
199
+ for (AsyncTask task : tasks.keySet()) {
200
+ if (tasks.get(task).equals("completion-" + contextId)) {
201
+ task.get();
202
+ break;
203
+ }
204
+ }
205
+ } catch (Exception e) {
206
+ exception = e;
207
+ }
208
+ return null;
209
+ }
210
+
211
+ @Override
212
+ protected void onPostExecute(Void result) {
213
+ if (exception != null) {
214
+ promise.reject(exception);
215
+ return;
216
+ }
217
+ promise.resolve(result);
218
+ tasks.remove(this);
219
+ }
220
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
221
+ tasks.put(task, "stopCompletion-" + contextId);
222
+ }
223
+
224
+ public void tokenize(double id, final String text, final Promise promise) {
225
+ final int contextId = (int) id;
226
+ AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
227
+ private Exception exception;
228
+
229
+ @Override
230
+ protected WritableMap doInBackground(Void... voids) {
231
+ try {
232
+ LlamaContext context = contexts.get(contextId);
233
+ if (context == null) {
234
+ throw new Exception("Context not found");
235
+ }
236
+ return context.tokenize(text);
237
+ } catch (Exception e) {
238
+ exception = e;
239
+ }
240
+ return null;
241
+ }
242
+
243
+ @Override
244
+ protected void onPostExecute(WritableMap result) {
245
+ if (exception != null) {
246
+ promise.reject(exception);
247
+ return;
248
+ }
249
+ promise.resolve(result);
250
+ tasks.remove(this);
251
+ }
252
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
253
+ tasks.put(task, "tokenize-" + contextId);
254
+ }
255
+
256
+ public void detokenize(double id, final ReadableArray tokens, final Promise promise) {
257
+ final int contextId = (int) id;
258
+ AsyncTask task = new AsyncTask<Void, Void, String>() {
259
+ private Exception exception;
260
+
261
+ @Override
262
+ protected String doInBackground(Void... voids) {
263
+ try {
264
+ LlamaContext context = contexts.get(contextId);
265
+ if (context == null) {
266
+ throw new Exception("Context not found");
267
+ }
268
+ return context.detokenize(tokens);
269
+ } catch (Exception e) {
270
+ exception = e;
271
+ }
272
+ return null;
273
+ }
274
+
275
+ @Override
276
+ protected void onPostExecute(String result) {
277
+ if (exception != null) {
278
+ promise.reject(exception);
279
+ return;
280
+ }
281
+ promise.resolve(result);
282
+ tasks.remove(this);
283
+ }
284
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
285
+ tasks.put(task, "detokenize-" + contextId);
286
+ }
287
+
288
+ public void embedding(double id, final String text, final Promise promise) {
289
+ final int contextId = (int) id;
290
+ AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
291
+ private Exception exception;
292
+
293
+ @Override
294
+ protected WritableMap doInBackground(Void... voids) {
295
+ try {
296
+ LlamaContext context = contexts.get(contextId);
297
+ if (context == null) {
298
+ throw new Exception("Context not found");
299
+ }
300
+ return context.embedding(text);
301
+ } catch (Exception e) {
302
+ exception = e;
303
+ }
304
+ return null;
305
+ }
306
+
307
+ @Override
308
+ protected void onPostExecute(WritableMap result) {
309
+ if (exception != null) {
310
+ promise.reject(exception);
311
+ return;
312
+ }
313
+ promise.resolve(result);
314
+ tasks.remove(this);
315
+ }
316
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
317
+ tasks.put(task, "embedding-" + contextId);
318
+ }
319
+
320
+ public void bench(double id, final double pp, final double tg, final double pl, final double nr, final Promise promise) {
321
+ final int contextId = (int) id;
322
+ AsyncTask task = new AsyncTask<Void, Void, String>() {
323
+ private Exception exception;
324
+
325
+ @Override
326
+ protected String doInBackground(Void... voids) {
327
+ try {
328
+ LlamaContext context = contexts.get(contextId);
329
+ if (context == null) {
330
+ throw new Exception("Context not found");
331
+ }
332
+ return context.bench((int) pp, (int) tg, (int) pl, (int) nr);
333
+ } catch (Exception e) {
334
+ exception = e;
335
+ }
336
+ return null;
337
+ }
338
+
339
+ @Override
340
+ protected void onPostExecute(String result) {
341
+ if (exception != null) {
342
+ promise.reject(exception);
343
+ return;
344
+ }
345
+ promise.resolve(result);
346
+ tasks.remove(this);
347
+ }
348
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
349
+ tasks.put(task, "bench-" + contextId);
350
+ }
351
+
352
+ public void releaseContext(double id, Promise promise) {
353
+ final int contextId = (int) id;
354
+ AsyncTask task = new AsyncTask<Void, Void, Void>() {
355
+ private Exception exception;
356
+
357
+ @Override
358
+ protected Void doInBackground(Void... voids) {
359
+ try {
360
+ LlamaContext context = contexts.get(contextId);
361
+ if (context == null) {
362
+ throw new Exception("Context " + id + " not found");
363
+ }
364
+ context.stopCompletion();
365
+ AsyncTask completionTask = null;
366
+ for (AsyncTask task : tasks.keySet()) {
367
+ if (tasks.get(task).equals("completion-" + contextId)) {
368
+ task.get();
369
+ break;
370
+ }
371
+ }
372
+ context.release();
373
+ contexts.remove(contextId);
374
+ } catch (Exception e) {
375
+ exception = e;
376
+ }
377
+ return null;
378
+ }
379
+
380
+ @Override
381
+ protected void onPostExecute(Void result) {
382
+ if (exception != null) {
383
+ promise.reject(exception);
384
+ return;
385
+ }
386
+ promise.resolve(null);
387
+ tasks.remove(this);
388
+ }
389
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
390
+ tasks.put(task, "releaseContext-" + contextId);
391
+ }
392
+
393
+ public void releaseAllContexts(Promise promise) {
394
+ AsyncTask task = new AsyncTask<Void, Void, Void>() {
395
+ private Exception exception;
396
+
397
+ @Override
398
+ protected Void doInBackground(Void... voids) {
399
+ try {
400
+ onHostDestroy();
401
+ } catch (Exception e) {
402
+ exception = e;
403
+ }
404
+ return null;
405
+ }
406
+
407
+ @Override
408
+ protected void onPostExecute(Void result) {
409
+ if (exception != null) {
410
+ promise.reject(exception);
411
+ return;
412
+ }
413
+ promise.resolve(null);
414
+ tasks.remove(this);
415
+ }
416
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
417
+ tasks.put(task, "releaseAllContexts");
418
+ }
419
+
420
+ @Override
421
+ public void onHostResume() {
422
+ }
423
+
424
+ @Override
425
+ public void onHostPause() {
426
+ }
427
+
428
+ @Override
429
+ public void onHostDestroy() {
430
+ for (LlamaContext context : contexts.values()) {
431
+ context.stopCompletion();
432
+ }
433
+ for (AsyncTask task : tasks.keySet()) {
434
+ try {
435
+ task.get();
436
+ } catch (Exception e) {
437
+ Log.e(NAME, "Failed to wait for task", e);
438
+ }
439
+ }
440
+ tasks.clear();
441
+ for (LlamaContext context : contexts.values()) {
442
+ context.release();
443
+ }
444
+ contexts.clear();
445
+ }
446
+ }
@@ -0,0 +1,48 @@
1
+ package com.rnllama;
2
+
3
+ import androidx.annotation.NonNull;
4
+ import androidx.annotation.Nullable;
5
+
6
+ import com.facebook.react.bridge.NativeModule;
7
+ import com.facebook.react.bridge.ReactApplicationContext;
8
+ import com.facebook.react.module.model.ReactModuleInfo;
9
+ import com.facebook.react.module.model.ReactModuleInfoProvider;
10
+ import com.facebook.react.TurboReactPackage;
11
+
12
+ import java.util.List;
13
+ import java.util.HashMap;
14
+ import java.util.Map;
15
+
16
+ public class RNLlamaPackage extends TurboReactPackage {
17
+
18
+ @Nullable
19
+ @Override
20
+ public NativeModule getModule(String name, ReactApplicationContext reactContext) {
21
+ if (name.equals(RNLlamaModule.NAME)) {
22
+ return new com.rnllama.RNLlamaModule(reactContext);
23
+ } else {
24
+ return null;
25
+ }
26
+ }
27
+
28
+ @Override
29
+ public ReactModuleInfoProvider getReactModuleInfoProvider() {
30
+ return () -> {
31
+ final Map<String, ReactModuleInfo> moduleInfos = new HashMap<>();
32
+ boolean isTurboModule = BuildConfig.IS_NEW_ARCHITECTURE_ENABLED;
33
+ moduleInfos.put(
34
+ RNLlamaModule.NAME,
35
+ new ReactModuleInfo(
36
+ RNLlamaModule.NAME,
37
+ RNLlamaModule.NAME,
38
+ false, // canOverrideExistingModule
39
+ false, // needsEagerInit
40
+ true, // hasConstants
41
+ false, // isCxxModule
42
+ isTurboModule // isTurboModule
43
+ )
44
+ );
45
+ return moduleInfos;
46
+ };
47
+ }
48
+ }