cui-llama.rn 1.2.4 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/README.md +3 -4
  2. package/android/src/main/CMakeLists.txt +21 -5
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -30
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +222 -36
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/common.cpp +1682 -2122
  9. package/cpp/common.h +600 -594
  10. package/cpp/ggml-aarch64.c +129 -3209
  11. package/cpp/ggml-aarch64.h +19 -39
  12. package/cpp/ggml-alloc.c +1040 -1040
  13. package/cpp/ggml-alloc.h +76 -76
  14. package/cpp/ggml-backend-impl.h +216 -227
  15. package/cpp/ggml-backend-reg.cpp +195 -0
  16. package/cpp/ggml-backend.cpp +1997 -2625
  17. package/cpp/ggml-backend.h +328 -326
  18. package/cpp/ggml-common.h +1853 -1853
  19. package/cpp/ggml-cpp.h +38 -0
  20. package/cpp/ggml-cpu-aarch64.c +3560 -0
  21. package/cpp/ggml-cpu-aarch64.h +30 -0
  22. package/cpp/ggml-cpu-impl.h +371 -614
  23. package/cpp/ggml-cpu-quants.c +10822 -0
  24. package/cpp/ggml-cpu-quants.h +63 -0
  25. package/cpp/ggml-cpu.c +13975 -0
  26. package/cpp/ggml-cpu.cpp +663 -0
  27. package/cpp/ggml-cpu.h +177 -0
  28. package/cpp/ggml-impl.h +550 -209
  29. package/cpp/ggml-metal.h +66 -66
  30. package/cpp/ggml-metal.m +4294 -3819
  31. package/cpp/ggml-quants.c +5247 -15752
  32. package/cpp/ggml-quants.h +100 -147
  33. package/cpp/ggml-threading.cpp +12 -0
  34. package/cpp/ggml-threading.h +12 -0
  35. package/cpp/ggml.c +8180 -23464
  36. package/cpp/ggml.h +2411 -2562
  37. package/cpp/llama-grammar.cpp +1138 -1138
  38. package/cpp/llama-grammar.h +144 -144
  39. package/cpp/llama-impl.h +181 -181
  40. package/cpp/llama-sampling.cpp +2348 -2194
  41. package/cpp/llama-sampling.h +48 -30
  42. package/cpp/llama-vocab.cpp +1984 -1968
  43. package/cpp/llama-vocab.h +170 -165
  44. package/cpp/llama.cpp +22132 -21969
  45. package/cpp/llama.h +1253 -1253
  46. package/cpp/log.cpp +401 -401
  47. package/cpp/log.h +121 -121
  48. package/cpp/rn-llama.hpp +83 -19
  49. package/cpp/sampling.cpp +466 -458
  50. package/cpp/sgemm.cpp +1884 -1219
  51. package/ios/RNLlama.mm +43 -20
  52. package/ios/RNLlamaContext.h +9 -3
  53. package/ios/RNLlamaContext.mm +133 -33
  54. package/jest/mock.js +0 -1
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  56. package/lib/commonjs/index.js +52 -15
  57. package/lib/commonjs/index.js.map +1 -1
  58. package/lib/module/NativeRNLlama.js.map +1 -1
  59. package/lib/module/index.js +51 -15
  60. package/lib/module/index.js.map +1 -1
  61. package/lib/typescript/NativeRNLlama.d.ts +29 -6
  62. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  63. package/lib/typescript/index.d.ts +12 -5
  64. package/lib/typescript/index.d.ts.map +1 -1
  65. package/package.json +1 -1
  66. package/src/NativeRNLlama.ts +41 -7
  67. package/src/index.ts +82 -27
  68. package/cpp/json-schema-to-grammar.cpp +0 -1045
  69. package/cpp/json-schema-to-grammar.h +0 -8
  70. package/cpp/json.hpp +0 -24766
package/README.md CHANGED
@@ -6,15 +6,14 @@ This fork exists to update llama.cpp on a more frequent basis, plus adding usefu
6
6
 
7
7
  The following features have been added for Android:
8
8
 
9
- - Updated sync for llama.cpp
10
9
  - Added stopping prompt processing between batches, vital for mobile devices with very slow prompt processing
11
10
  - `vocab_only` mode: utilize the llama.cpp tokenizer
12
11
  - tokenizeSync: non-blocking, synchronous tokenizer function
13
12
  - Context Shift taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)
14
- - XTC sampling
15
- - Progress callback
16
13
  - Retrieving CPU Features to check for i8mm and dotprod flags
17
14
 
15
+ There is no IOS implementation for these features.
16
+
18
17
  Original repo README.md below.
19
18
 
20
19
  # llama.rn
@@ -307,7 +306,7 @@ iOS:
307
306
 
308
307
  - The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
309
308
  - Metal:
310
- - We have tested to know some devices is not able to use Metal ('params.n_gpu_layers > 0') due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
309
+ - We have tested to know some devices is not able to use Metal (GPU) due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
311
310
  - It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.
312
311
 
313
312
  Android:
@@ -14,21 +14,28 @@ set(
14
14
  ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
15
15
  ${RNLLAMA_LIB_DIR}/log.cpp
16
16
 
17
+ ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
18
+ ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
19
+ ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
20
+ ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
21
+ ${RNLLAMA_LIB_DIR}/log.cpp
22
+
17
23
  ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
18
24
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
19
25
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
26
+ ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
20
27
  ${RNLLAMA_LIB_DIR}/ggml.c
28
+ ${RNLLAMA_LIB_DIR}/ggml-cpu.c
29
+ ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
30
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c
31
+ ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
32
+ ${RNLLAMA_LIB_DIR}/ggml-threading.cpp
21
33
  ${RNLLAMA_LIB_DIR}/ggml-quants.c
22
34
  ${RNLLAMA_LIB_DIR}/common.cpp
23
- ${RNLLAMA_LIB_DIR}/json.hpp
24
- ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
25
35
  ${RNLLAMA_LIB_DIR}/sampling.cpp
26
36
  ${RNLLAMA_LIB_DIR}/unicode-data.cpp
27
37
  ${RNLLAMA_LIB_DIR}/unicode.cpp
28
38
  ${RNLLAMA_LIB_DIR}/llama.cpp
29
- ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
30
- ${RNLLAMA_LIB_DIR}/llama-sampling.cpp
31
- ${RNLLAMA_LIB_DIR}/llama-grammar.cpp
32
39
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
33
40
  ${RNLLAMA_LIB_DIR}/ggml-aarch64.c
34
41
  ${RNLLAMA_LIB_DIR}/rn-llama.hpp
@@ -71,11 +78,20 @@ build_library("rnllama" "")
71
78
 
72
79
  if (${ANDROID_ABI} STREQUAL "arm64-v8a")
73
80
  # ARM64 targets
81
+ build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
82
+ build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
74
83
  build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
75
84
  build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
76
85
  build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
77
86
  build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
78
87
  build_library("rnllama_v8" "-march=armv8-a")
88
+
89
+ # https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
90
+ # llama.cpp will deal with the cpu features
91
+ # build_library("rnllama_v8_7" "-march=armv8.7-a")
92
+ # TODO: Add support runtime check for cpu features
93
+ # At the moment runtime check is failing.
94
+
79
95
  elseif (${ANDROID_ABI} STREQUAL "x86_64")
80
96
  # x86_64 target
81
97
  build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt")
@@ -93,7 +93,7 @@ public class LlamaContext {
93
93
  Log.e(NAME, "Failed to convert to FD!");
94
94
  }
95
95
  }
96
-
96
+ logToAndroid();
97
97
  // Check if file has GGUF magic numbers
98
98
  this.id = id;
99
99
  eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
@@ -102,6 +102,8 @@ public class LlamaContext {
102
102
  modelName,
103
103
  // boolean embedding,
104
104
  params.hasKey("embedding") ? params.getBoolean("embedding") : false,
105
+ // int embd_normalize,
106
+ params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1,
105
107
  // int n_ctx,
106
108
  params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
107
109
  // int n_batch,
@@ -110,6 +112,12 @@ public class LlamaContext {
110
112
  params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
111
113
  // int n_gpu_layers, // TODO: Support this
112
114
  params.hasKey("n_gpu_layers") ? params.getInt("n_gpu_layers") : 0,
115
+ // boolean flash_attn,
116
+ params.hasKey("flash_attn") ? params.getBoolean("flash_attn") : false,
117
+ // String cache_type_k,
118
+ params.hasKey("cache_type_k") ? params.getString("cache_type_k") : "f16",
119
+ // String cache_type_v,
120
+ params.hasKey("cache_type_v") ? params.getString("cache_type_v") : "f16",
113
121
  // boolean use_mlock,
114
122
  params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
115
123
  // boolean use_mmap,
@@ -124,12 +132,22 @@ public class LlamaContext {
124
132
  params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
125
133
  // float rope_freq_scale
126
134
  params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
127
- this
135
+ // int pooling_type,
136
+ params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1,
137
+ // LoadProgressCallback load_progress_callback
138
+ params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
128
139
  );
140
+ if (this.context == -1) {
141
+ throw new IllegalStateException("Failed to initialize context");
142
+ }
129
143
  this.modelDetails = loadModelDetails(this.context);
130
144
  this.reactContext = reactContext;
131
145
  }
132
146
 
147
+ public void interruptLoad() {
148
+ interruptLoad(this.context);
149
+ }
150
+
133
151
  public long getContext() {
134
152
  return context;
135
153
  }
@@ -146,6 +164,25 @@ public class LlamaContext {
146
164
  return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
147
165
  }
148
166
 
167
+ private void emitLoadProgress(int progress) {
168
+ WritableMap event = Arguments.createMap();
169
+ event.putInt("contextId", LlamaContext.this.id);
170
+ event.putInt("progress", progress);
171
+ eventEmitter.emit("@RNLlama_onInitContextProgress", event);
172
+ }
173
+
174
+ private static class LoadProgressCallback {
175
+ LlamaContext context;
176
+
177
+ public LoadProgressCallback(LlamaContext context) {
178
+ this.context = context;
179
+ }
180
+
181
+ void onLoadProgress(int progress) {
182
+ context.emitLoadProgress(progress);
183
+ }
184
+ }
185
+
149
186
  private void emitPartialCompletion(WritableMap tokenResult) {
150
187
  WritableMap event = Arguments.createMap();
151
188
  event.putInt("contextId", LlamaContext.this.id);
@@ -244,12 +281,10 @@ public class LlamaContext {
244
281
  params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
245
282
  // float min_p,
246
283
  params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
247
- // float xtc_t,
248
- params.hasKey("xtc_t") ? (float) params.getDouble("xtc_t") : 0.00f,
249
- // float xtc_p,
250
- params.hasKey("xtc_p") ? (float) params.getDouble("xtc_p") : 0.00f,
251
- // float tfs_z,
252
- params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
284
+ // float xtc_threshold,
285
+ params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f,
286
+ // float xtc_probability,
287
+ params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f,
253
288
  // float typical_p,
254
289
  params.hasKey("typical_p") ? (float) params.getDouble("typical_p") : 1.00f,
255
290
  // int seed,
@@ -260,6 +295,16 @@ public class LlamaContext {
260
295
  params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false,
261
296
  // double[][] logit_bias,
262
297
  logit_bias,
298
+ // float dry_multiplier,
299
+ params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f,
300
+ // float dry_base,
301
+ params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f,
302
+ // int dry_allowed_length,
303
+ params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
304
+ // int dry_penalty_last_n,
305
+ params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
306
+ // String[] dry_sequence_breakers, when undef, we use the default definition from common.h
307
+ params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
263
308
  // PartialCompletionCallback partial_completion_callback
264
309
  new PartialCompletionCallback(
265
310
  this,
@@ -294,11 +339,16 @@ public class LlamaContext {
294
339
  return detokenize(this.context, toks);
295
340
  }
296
341
 
297
- public WritableMap getEmbedding(String text) {
342
+ public WritableMap getEmbedding(String text, ReadableMap params) {
298
343
  if (isEmbeddingEnabled(this.context) == false) {
299
344
  throw new IllegalStateException("Embedding is not enabled");
300
345
  }
301
- WritableMap result = embedding(this.context, text);
346
+ WritableMap result = embedding(
347
+ this.context,
348
+ text,
349
+ // int embd_normalize,
350
+ params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1
351
+ );
302
352
  if (result.hasKey("error")) {
303
353
  throw new IllegalStateException(result.getString("error"));
304
354
  }
@@ -315,17 +365,31 @@ public class LlamaContext {
315
365
 
316
366
  static {
317
367
  Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
318
- if (LlamaContext.isArm64V8a()) {
319
- String cpuFeatures = LlamaContext.getCpuFeatures();
320
- Log.d(NAME, "CPU features: " + cpuFeatures);
321
-
322
- boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
323
- boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
324
- boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
325
- boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
326
- boolean hasInt8Matmul = cpuFeatures.contains("i8mm");
327
368
 
328
- if (isAtLeastArmV84 && hasFp16 && hasDotProd && hasInt8Matmul) {
369
+ String cpuFeatures = LlamaContext.getCpuFeatures();
370
+ Log.d(NAME, "CPU features: " + cpuFeatures);
371
+ boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
372
+ boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
373
+ boolean hasSve = cpuFeatures.contains("sve");
374
+ boolean hasI8mm = cpuFeatures.contains("i8mm");
375
+ boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
376
+ boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
377
+ Log.d(NAME, "- hasFp16: " + hasFp16);
378
+ Log.d(NAME, "- hasDotProd: " + hasDotProd);
379
+ Log.d(NAME, "- hasSve: " + hasSve);
380
+ Log.d(NAME, "- hasI8mm: " + hasI8mm);
381
+ Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82);
382
+ Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84);
383
+
384
+ // TODO: Add runtime check for cpu features
385
+ if (LlamaContext.isArm64V8a()) {
386
+ if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
387
+ Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
388
+ System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
389
+ } else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
390
+ Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
391
+ System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
392
+ } else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
329
393
  Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
330
394
  System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
331
395
  } else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
@@ -341,14 +405,16 @@ public class LlamaContext {
341
405
  Log.d(NAME, "Loading librnllama_v8.so");
342
406
  System.loadLibrary("rnllama_v8");
343
407
  }
408
+ // Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
409
+ // System.loadLibrary("rnllama_v8_7");
344
410
  } else if (LlamaContext.isX86_64()) {
345
- Log.d(NAME, "Loading librnllama_x86_64.so");
346
- System.loadLibrary("rnllama_x86_64");
411
+ Log.d(NAME, "Loading librnllama_x86_64.so");
412
+ System.loadLibrary("rnllama_x86_64");
347
413
  } else {
348
- Log.d(NAME, "Loading default librnllama.so");
349
- System.loadLibrary("rnllama");
414
+ Log.d(NAME, "Loading default librnllama.so");
415
+ System.loadLibrary("rnllama");
350
416
  }
351
- }
417
+ }
352
418
 
353
419
  public static boolean isArm64V8a() {
354
420
  return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
@@ -384,13 +450,21 @@ public class LlamaContext {
384
450
  eventEmitter.emit("@RNLlama_onModelProgress", event);
385
451
  }
386
452
 
453
+ protected static native WritableMap modelInfo(
454
+ String model,
455
+ String[] skip
456
+ );
387
457
  protected static native long initContext(
388
458
  String model,
389
459
  boolean embedding,
460
+ int embd_normalize,
390
461
  int n_ctx,
391
462
  int n_batch,
392
463
  int n_threads,
393
464
  int n_gpu_layers, // TODO: Support this
465
+ boolean flash_attn,
466
+ String cache_type_k,
467
+ String cache_type_v,
394
468
  boolean use_mlock,
395
469
  boolean use_mmap,
396
470
  boolean vocab_only,
@@ -398,8 +472,10 @@ public class LlamaContext {
398
472
  float lora_scaled,
399
473
  float rope_freq_base,
400
474
  float rope_freq_scale,
401
- LlamaContext javaLlamaContext
475
+ int pooling_type,
476
+ LoadProgressCallback load_progress_callback
402
477
  );
478
+ protected static native void interruptLoad(long contextPtr);
403
479
  protected static native WritableMap loadModelDetails(
404
480
  long contextPtr
405
481
  );
@@ -436,14 +512,18 @@ public class LlamaContext {
436
512
  int top_k,
437
513
  float top_p,
438
514
  float min_p,
439
- float xtc_t,
440
- float xtc_p,
441
- float tfs_z,
515
+ float xtc_threshold,
516
+ float xtc_probability,
442
517
  float typical_p,
443
518
  int seed,
444
519
  String[] stop,
445
520
  boolean ignore_eos,
446
521
  double[][] logit_bias,
522
+ float dry_multiplier,
523
+ float dry_base,
524
+ int dry_allowed_length,
525
+ int dry_penalty_last_n,
526
+ String[] dry_sequence_breakers,
447
527
  PartialCompletionCallback partial_completion_callback
448
528
  );
449
529
  protected static native void stopCompletion(long contextPtr);
@@ -451,7 +531,12 @@ public class LlamaContext {
451
531
  protected static native WritableArray tokenize(long contextPtr, String text);
452
532
  protected static native String detokenize(long contextPtr, int[] tokens);
453
533
  protected static native boolean isEmbeddingEnabled(long contextPtr);
454
- protected static native WritableMap embedding(long contextPtr, String text);
534
+ protected static native WritableMap embedding(
535
+ long contextPtr,
536
+ String text,
537
+ int embd_normalize
538
+ );
455
539
  protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
456
540
  protected static native void freeContext(long contextPtr);
541
+ protected static native void logToAndroid();
457
542
  }
@@ -42,21 +42,53 @@ public class RNLlama implements LifecycleEventListener {
42
42
  promise.resolve(null);
43
43
  }
44
44
 
45
- public void initContext(final ReadableMap params, final Promise promise) {
45
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
46
+ new AsyncTask<Void, Void, WritableMap>() {
47
+ private Exception exception;
48
+
49
+ @Override
50
+ protected WritableMap doInBackground(Void... voids) {
51
+ try {
52
+ String[] skipArray = new String[skip.size()];
53
+ for (int i = 0; i < skip.size(); i++) {
54
+ skipArray[i] = skip.getString(i);
55
+ }
56
+ return LlamaContext.modelInfo(model, skipArray);
57
+ } catch (Exception e) {
58
+ exception = e;
59
+ }
60
+ return null;
61
+ }
62
+
63
+ @Override
64
+ protected void onPostExecute(WritableMap result) {
65
+ if (exception != null) {
66
+ promise.reject(exception);
67
+ return;
68
+ }
69
+ promise.resolve(result);
70
+ }
71
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
72
+ }
73
+
74
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
75
+ final int contextId = (int) id;
46
76
  AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
47
77
  private Exception exception;
48
78
 
49
79
  @Override
50
80
  protected WritableMap doInBackground(Void... voids) {
51
81
  try {
52
- int id = Math.abs(new Random().nextInt());
53
- LlamaContext llamaContext = new LlamaContext(id, reactContext, params);
82
+ LlamaContext context = contexts.get(contextId);
83
+ if (context != null) {
84
+ throw new Exception("Context already exists");
85
+ }
86
+ LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
54
87
  if (llamaContext.getContext() == 0) {
55
88
  throw new Exception("Failed to initialize context");
56
89
  }
57
- contexts.put(id, llamaContext);
90
+ contexts.put(contextId, llamaContext);
58
91
  WritableMap result = Arguments.createMap();
59
- result.putInt("contextId", id);
60
92
  result.putBoolean("gpu", false);
61
93
  result.putString("reasonNoGPU", "Currently not supported");
62
94
  result.putMap("model", llamaContext.getModelDetails());
@@ -366,7 +398,7 @@ public class RNLlama implements LifecycleEventListener {
366
398
  tasks.put(task, "detokenize-" + contextId);
367
399
  }
368
400
 
369
- public void embedding(double id, final String text, final Promise promise) {
401
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
370
402
  final int contextId = (int) id;
371
403
  AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
372
404
  private Exception exception;
@@ -378,7 +410,7 @@ public class RNLlama implements LifecycleEventListener {
378
410
  if (context == null) {
379
411
  throw new Exception("Context not found");
380
412
  }
381
- return context.getEmbedding(text);
413
+ return context.getEmbedding(text, params);
382
414
  } catch (Exception e) {
383
415
  exception = e;
384
416
  }
@@ -442,6 +474,7 @@ public class RNLlama implements LifecycleEventListener {
442
474
  if (context == null) {
443
475
  throw new Exception("Context " + id + " not found");
444
476
  }
477
+ context.interruptLoad();
445
478
  context.stopCompletion();
446
479
  AsyncTask completionTask = null;
447
480
  for (AsyncTask task : tasks.keySet()) {