cui-llama.rn 1.2.6 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/README.md +3 -2
  2. package/android/src/main/CMakeLists.txt +20 -5
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +222 -34
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/common.cpp +1682 -2114
  9. package/cpp/common.h +600 -613
  10. package/cpp/ggml-aarch64.c +129 -3478
  11. package/cpp/ggml-aarch64.h +19 -39
  12. package/cpp/ggml-alloc.c +1040 -1040
  13. package/cpp/ggml-alloc.h +76 -76
  14. package/cpp/ggml-backend-impl.h +216 -216
  15. package/cpp/ggml-backend-reg.cpp +195 -0
  16. package/cpp/ggml-backend.cpp +1997 -2661
  17. package/cpp/ggml-backend.h +328 -314
  18. package/cpp/ggml-common.h +1853 -1853
  19. package/cpp/ggml-cpp.h +38 -38
  20. package/cpp/ggml-cpu-aarch64.c +3560 -0
  21. package/cpp/ggml-cpu-aarch64.h +30 -0
  22. package/cpp/ggml-cpu-impl.h +371 -614
  23. package/cpp/ggml-cpu-quants.c +10822 -0
  24. package/cpp/ggml-cpu-quants.h +63 -0
  25. package/cpp/ggml-cpu.c +13975 -13720
  26. package/cpp/ggml-cpu.cpp +663 -0
  27. package/cpp/ggml-cpu.h +177 -150
  28. package/cpp/ggml-impl.h +550 -296
  29. package/cpp/ggml-metal.h +66 -66
  30. package/cpp/ggml-metal.m +4294 -3933
  31. package/cpp/ggml-quants.c +5247 -15739
  32. package/cpp/ggml-quants.h +100 -147
  33. package/cpp/ggml-threading.cpp +12 -0
  34. package/cpp/ggml-threading.h +12 -0
  35. package/cpp/ggml.c +8180 -8390
  36. package/cpp/ggml.h +2411 -2441
  37. package/cpp/llama-grammar.cpp +1138 -1138
  38. package/cpp/llama-grammar.h +144 -144
  39. package/cpp/llama-impl.h +181 -181
  40. package/cpp/llama-sampling.cpp +2348 -2345
  41. package/cpp/llama-sampling.h +48 -48
  42. package/cpp/llama-vocab.cpp +1984 -1984
  43. package/cpp/llama-vocab.h +170 -170
  44. package/cpp/llama.cpp +22132 -22046
  45. package/cpp/llama.h +1253 -1255
  46. package/cpp/log.cpp +401 -401
  47. package/cpp/log.h +121 -121
  48. package/cpp/rn-llama.hpp +83 -19
  49. package/cpp/sampling.cpp +466 -466
  50. package/cpp/sgemm.cpp +1884 -1276
  51. package/ios/RNLlama.mm +43 -20
  52. package/ios/RNLlamaContext.h +9 -3
  53. package/ios/RNLlamaContext.mm +133 -33
  54. package/jest/mock.js +0 -1
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  56. package/lib/commonjs/index.js +52 -15
  57. package/lib/commonjs/index.js.map +1 -1
  58. package/lib/module/NativeRNLlama.js.map +1 -1
  59. package/lib/module/index.js +51 -15
  60. package/lib/module/index.js.map +1 -1
  61. package/lib/typescript/NativeRNLlama.d.ts +29 -5
  62. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  63. package/lib/typescript/index.d.ts +12 -5
  64. package/lib/typescript/index.d.ts.map +1 -1
  65. package/package.json +1 -1
  66. package/src/NativeRNLlama.ts +41 -6
  67. package/src/index.ts +82 -27
  68. package/cpp/json-schema-to-grammar.cpp +0 -1045
  69. package/cpp/json-schema-to-grammar.h +0 -8
  70. package/cpp/json.hpp +0 -24766
package/README.md CHANGED
@@ -6,13 +6,14 @@ This fork exists to update llama.cpp on a more frequent basis, plus adding usefu
6
6
 
7
7
  The following features have been added for Android:
8
8
 
9
- - Updated sync for llama.cpp
10
9
  - Added stopping prompt processing between batches, vital for mobile devices with very slow prompt processing
11
10
  - `vocab_only` mode: utilize the llama.cpp tokenizer
12
11
  - tokenizeSync: non-blocking, synchronous tokenizer function
13
12
  - Context Shift taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)
14
13
  - Retrieving CPU Features to check for i8mm and dotprod flags
15
14
 
15
+ There is no IOS implementation for these features.
16
+
16
17
  Original repo README.md below.
17
18
 
18
19
  # llama.rn
@@ -305,7 +306,7 @@ iOS:
305
306
 
306
307
  - The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
307
308
  - Metal:
308
- - We have tested to know some devices is not able to use Metal ('params.n_gpu_layers > 0') due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
309
+ - We have tested to know some devices is not able to use Metal (GPU) due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
309
310
  - It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.
310
311
 
311
312
  Android:
@@ -14,22 +14,28 @@ set(
14
14
  ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
15
15
  ${RNLLAMA_LIB_DIR}/log.cpp
16
16
 
17
+ ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
18
+ ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
19
+ ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
20
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
21
+ ${RNLLAMA_LIB_DIR}/log.cpp
22
+
17
23
  ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
18
24
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
19
25
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
26
+ ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
20
27
  ${RNLLAMA_LIB_DIR}/ggml.c
21
28
  ${RNLLAMA_LIB_DIR}/ggml-cpu.c
29
+ ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
30
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c
31
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
32
+ ${RNLLAMA_LIB_DIR}/ggml-threading.cpp
22
33
  ${RNLLAMA_LIB_DIR}/ggml-quants.c
23
34
  ${RNLLAMA_LIB_DIR}/common.cpp
24
- ${RNLLAMA_LIB_DIR}/json.hpp
25
- ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
26
35
  ${RNLLAMA_LIB_DIR}/sampling.cpp
27
36
  ${RNLLAMA_LIB_DIR}/unicode-data.cpp
28
37
  ${RNLLAMA_LIB_DIR}/unicode.cpp
29
38
  ${RNLLAMA_LIB_DIR}/llama.cpp
30
- ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
31
- ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
32
- ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
33
39
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
34
40
  ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
35
41
  ${RNLLAMA_LIB_DIR}/rn-llama.hpp
@@ -72,11 +78,20 @@ build_library("rnllama" "")
72
78
 
73
79
  if (${ANDROID_ABI} STREQUAL "arm64-v8a")
74
80
  # ARM64 targets
81
+ build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
82
+ build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
75
83
  build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
76
84
  build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
77
85
  build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
78
86
  build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
79
87
  build_library("rnllama_v8" "-march=armv8-a")
88
+
89
+ # https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
90
+ # llama.cpp will deal with the cpu features
91
+ # build_library("rnllama_v8_7" "-march=armv8.7-a")
92
+ # TODO: Add support runtime check for cpu features
93
+ # At the moment runtime check is failing.
94
+
80
95
  elseif (${ANDROID_ABI} STREQUAL "x86_64")
81
96
  # x86_64 target
82
97
  build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt")
@@ -93,7 +93,7 @@ public class LlamaContext {
93
93
  Log.e(NAME, "Failed to convert to FD!");
94
94
  }
95
95
  }
96
-
96
+ logToAndroid();
97
97
  // Check if file has GGUF magic numbers
98
98
  this.id = id;
99
99
  eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
@@ -102,6 +102,8 @@ public class LlamaContext {
102
102
  modelName,
103
103
  // boolean embedding,
104
104
  params.hasKey("embedding") ? params.getBoolean("embedding") : false,
105
+ // int embd_normalize,
106
+ params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1,
105
107
  // int n_ctx,
106
108
  params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
107
109
  // int n_batch,
@@ -110,6 +112,12 @@ public class LlamaContext {
110
112
  params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
111
113
  // int n_gpu_layers, // TODO: Support this
112
114
  params.hasKey("n_gpu_layers") ? params.getInt("n_gpu_layers") : 0,
115
+ // boolean flash_attn,
116
+ params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
117
+ // String cache_type_k,
118
+ params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
119
+ // String cache_type_v,
120
+ params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
113
121
  // boolean use_mlock,
114
122
  params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
115
123
  // boolean use_mmap,
@@ -124,12 +132,22 @@ public class LlamaContext {
124
132
  params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
125
133
  // float rope_freq_scale
126
134
  params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
127
- this
135
+ // int pooling_type,
136
+ params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1,
137
+ // LoadProgressCallback load_progress_callback
138
+ params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
128
139
  );
140
+ if (this.context == -1) {
141
+ throw new IllegalStateException("Failed to initialize context");
142
+ }
129
143
  this.modelDetails = loadModelDetails(this.context);
130
144
  this.reactContext = reactContext;
131
145
  }
132
146
 
147
+ public void interruptLoad() {
148
+ interruptLoad(this.context);
149
+ }
150
+
133
151
  public long getContext() {
134
152
  return context;
135
153
  }
@@ -146,6 +164,25 @@ public class LlamaContext {
146
164
  return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
147
165
  }
148
166
 
167
+ private void emitLoadProgress(int progress) {
168
+ WritableMap event = Arguments.createMap();
169
+ event.putInt("contextId", LlamaContext.this.id);
170
+ event.putInt("progress", progress);
171
+ eventEmitter.emit("@RNLlama_onInitContextProgress", event);
172
+ }
173
+
174
+ private static class LoadProgressCallback {
175
+ LlamaContext context;
176
+
177
+ public LoadProgressCallback(LlamaContext context) {
178
+ this.context = context;
179
+ }
180
+
181
+ void onLoadProgress(int progress) {
182
+ context.emitLoadProgress(progress);
183
+ }
184
+ }
185
+
149
186
  private void emitPartialCompletion(WritableMap tokenResult) {
150
187
  WritableMap event = Arguments.createMap();
151
188
  event.putInt("contextId", LlamaContext.this.id);
@@ -244,10 +281,10 @@ public class LlamaContext {
244
281
  params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
245
282
  // float min_p,
246
283
  params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
247
- // float xtc_t,
248
- params.hasKey("xtc_t") ? (float) params.getDouble("xtc_t") : 0.00f,
249
- // float xtc_p,
250
- params.hasKey("xtc_p") ? (float) params.getDouble("xtc_p") : 0.00f,
284
+ // float xtc_threshold,
285
+ params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f,
286
+ // float xtc_probability,
287
+ params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f,
251
288
  // float typical_p,
252
289
  params.hasKey("typical_p") ? (float) params.getDouble("typical_p") : 1.00f,
253
290
  // int seed,
@@ -258,6 +295,16 @@ public class LlamaContext {
258
295
  params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false,
259
296
  // double[][] logit_bias,
260
297
  logit_bias,
298
+ // float dry_multiplier,
299
+ params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f,
300
+ // float dry_base,
301
+ params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f,
302
+ // int dry_allowed_length,
303
+ params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
304
+ // int dry_penalty_last_n,
305
+ params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
306
+ // String[] dry_sequence_breakers, when undef, we use the default definition from common.h
307
+ params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
261
308
  // PartialCompletionCallback partial_completion_callback
262
309
  new PartialCompletionCallback(
263
310
  this,
@@ -292,11 +339,16 @@ public class LlamaContext {
292
339
  return detokenize(this.context, toks);
293
340
  }
294
341
 
295
- public WritableMap getEmbedding(String text) {
342
+ public WritableMap getEmbedding(String text, ReadableMap params) {
296
343
  if (isEmbeddingEnabled(this.context) == false) {
297
344
  throw new IllegalStateException("Embedding is not enabled");
298
345
  }
299
- WritableMap result = embedding(this.context, text);
346
+ WritableMap result = embedding(
347
+ this.context,
348
+ text,
349
+ // int embd_normalize,
350
+ params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1
351
+ );
300
352
  if (result.hasKey("error")) {
301
353
  throw new IllegalStateException(result.getString("error"));
302
354
  }
@@ -313,17 +365,31 @@ public class LlamaContext {
313
365
 
314
366
  static {
315
367
  Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
316
- if (LlamaContext.isArm64V8a()) {
317
- String cpuFeatures = LlamaContext.getCpuFeatures();
318
- Log.d(NAME, "CPU features: " + cpuFeatures);
319
-
320
- boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
321
- boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
322
- boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
323
- boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
324
- boolean hasInt8Matmul = cpuFeatures.contains("i8mm");
325
368
 
326
- if (isAtLeastArmV84 && hasFp16 && hasDotProd && hasInt8Matmul) {
369
+ String cpuFeatures = LlamaContext.getCpuFeatures();
370
+ Log.d(NAME, "CPU features: " + cpuFeatures);
371
+ boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
372
+ boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
373
+ boolean hasSve = cpuFeatures.contains("sve");
374
+ boolean hasI8mm = cpuFeatures.contains("i8mm");
375
+ boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
376
+ boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
377
+ Log.d(NAME, "- hasFp16: " + hasFp16);
378
+ Log.d(NAME, "- hasDotProd: " + hasDotProd);
379
+ Log.d(NAME, "- hasSve: " + hasSve);
380
+ Log.d(NAME, "- hasI8mm: " + hasI8mm);
381
+ Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82);
382
+ Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84);
383
+
384
+ // TODO: Add runtime check for cpu features
385
+ if (LlamaContext.isArm64V8a()) {
386
+ if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
387
+ Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
388
+ System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
389
+ } else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
390
+ Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
391
+ System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
392
+ } else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
327
393
  Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
328
394
  System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
329
395
  } else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
@@ -339,14 +405,16 @@ public class LlamaContext {
339
405
  Log.d(NAME, "Loading librnllama_v8.so");
340
406
  System.loadLibrary("rnllama_v8");
341
407
  }
408
+ // Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
409
+ // System.loadLibrary("rnllama_v8_7");
342
410
  } else if (LlamaContext.isX86_64()) {
343
- Log.d(NAME, "Loading librnllama_x86_64.so");
344
- System.loadLibrary("rnllama_x86_64");
411
+ Log.d(NAME, "Loading librnllama_x86_64.so");
412
+ System.loadLibrary("rnllama_x86_64");
345
413
  } else {
346
- Log.d(NAME, "Loading default librnllama.so");
347
- System.loadLibrary("rnllama");
414
+ Log.d(NAME, "Loading default librnllama.so");
415
+ System.loadLibrary("rnllama");
348
416
  }
349
- }
417
+ }
350
418
 
351
419
  public static boolean isArm64V8a() {
352
420
  return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
@@ -382,13 +450,21 @@ public class LlamaContext {
382
450
  eventEmitter.emit("@RNLlama_onModelProgress", event);
383
451
  }
384
452
 
453
+ protected static native WritableMap modelInfo(
454
+ String model,
455
+ String[] skip
456
+ );
385
457
  protected static native long initContext(
386
458
  String model,
387
459
  boolean embedding,
460
+ int embd_normalize,
388
461
  int n_ctx,
389
462
  int n_batch,
390
463
  int n_threads,
391
464
  int n_gpu_layers, // TODO: Support this
465
+ boolean flash_attn,
466
+ String cache_type_k,
467
+ String cache_type_v,
392
468
  boolean use_mlock,
393
469
  boolean use_mmap,
394
470
  boolean vocab_only,
@@ -396,8 +472,10 @@ public class LlamaContext {
396
472
  float lora_scaled,
397
473
  float rope_freq_base,
398
474
  float rope_freq_scale,
399
- LlamaContext javaLlamaContext
475
+ int pooling_type,
476
+ LoadProgressCallback load_progress_callback
400
477
  );
478
+ protected static native void interruptLoad(long contextPtr);
401
479
  protected static native WritableMap loadModelDetails(
402
480
  long contextPtr
403
481
  );
@@ -434,13 +512,18 @@ public class LlamaContext {
434
512
  int top_k,
435
513
  float top_p,
436
514
  float min_p,
437
- float xtc_t,
438
- float xtc_p,
515
+ float xtc_threshold,
516
+ float xtc_probability,
439
517
  float typical_p,
440
518
  int seed,
441
519
  String[] stop,
442
520
  boolean ignore_eos,
443
521
  double[][] logit_bias,
522
+ float dry_multiplier,
523
+ float dry_base,
524
+ int dry_allowed_length,
525
+ int dry_penalty_last_n,
526
+ String[] dry_sequence_breakers,
444
527
  PartialCompletionCallback partial_completion_callback
445
528
  );
446
529
  protected static native void stopCompletion(long contextPtr);
@@ -448,7 +531,12 @@ public class LlamaContext {
448
531
  protected static native WritableArray tokenize(long contextPtr, String text);
449
532
  protected static native String detokenize(long contextPtr, int[] tokens);
450
533
  protected static native boolean isEmbeddingEnabled(long contextPtr);
451
- protected static native WritableMap embedding(long contextPtr, String text);
534
+ protected static native WritableMap embedding(
535
+ long contextPtr,
536
+ String text,
537
+ int embd_normalize
538
+ );
452
539
  protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
453
540
  protected static native void freeContext(long contextPtr);
541
+ protected static native void logToAndroid();
454
542
  }
@@ -42,21 +42,53 @@ public class RNLlama implements LifecycleEventListener {
42
42
  promise.resolve(null);
43
43
  }
44
44
 
45
- public void initContext(final ReadableMap params, final Promise promise) {
45
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
46
+ new AsyncTask<Void, Void, WritableMap>() {
47
+ private Exception exception;
48
+
49
+ @Override
50
+ protected WritableMap doInBackground(Void... voids) {
51
+ try {
52
+ String[] skipArray = new String[skip.size()];
53
+ for (int i = 0; i < skip.size(); i++) {
54
+ skipArray[i] = skip.getString(i);
55
+ }
56
+ return LlamaContext.modelInfo(model, skipArray);
57
+ } catch (Exception e) {
58
+ exception = e;
59
+ }
60
+ return null;
61
+ }
62
+
63
+ @Override
64
+ protected void onPostExecute(WritableMap result) {
65
+ if (exception != null) {
66
+ promise.reject(exception);
67
+ return;
68
+ }
69
+ promise.resolve(result);
70
+ }
71
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
72
+ }
73
+
74
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
75
+ final int contextId = (int) id;
46
76
  AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
47
77
  private Exception exception;
48
78
 
49
79
  @Override
50
80
  protected WritableMap doInBackground(Void... voids) {
51
81
  try {
52
- int id = Math.abs(new Random().nextInt());
53
- LlamaContext llamaContext = new LlamaContext(id, reactContext, params);
82
+ LlamaContext context = contexts.get(contextId);
83
+ if (context != null) {
84
+ throw new Exception("Context already exists");
85
+ }
86
+ LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
54
87
  if (llamaContext.getContext() == 0) {
55
88
  throw new Exception("Failed to initialize context");
56
89
  }
57
- contexts.put(id, llamaContext);
90
+ contexts.put(contextId, llamaContext);
58
91
  WritableMap result = Arguments.createMap();
59
- result.putInt("contextId", id);
60
92
  result.putBoolean("gpu", false);
61
93
  result.putString("reasonNoGPU", "Currently not supported");
62
94
  result.putMap("model", llamaContext.getModelDetails());
@@ -366,7 +398,7 @@ public class RNLlama implements LifecycleEventListener {
366
398
  tasks.put(task, "detokenize-" + contextId);
367
399
  }
368
400
 
369
- public void embedding(double id, final String text, final Promise promise) {
401
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
370
402
  final int contextId = (int) id;
371
403
  AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
372
404
  private Exception exception;
@@ -378,7 +410,7 @@ public class RNLlama implements LifecycleEventListener {
378
410
  if (context == null) {
379
411
  throw new Exception("Context not found");
380
412
  }
381
- return context.getEmbedding(text);
413
+ return context.getEmbedding(text, params);
382
414
  } catch (Exception e) {
383
415
  exception = e;
384
416
  }
@@ -442,6 +474,7 @@ public class RNLlama implements LifecycleEventListener {
442
474
  if (context == null) {
443
475
  throw new Exception("Context " + id + " not found");
444
476
  }
477
+ context.interruptLoad();
445
478
  context.stopCompletion();
446
479
  AsyncTask completionTask = null;
447
480
  for (AsyncTask task : tasks.keySet()) {