cui-llama.rn 1.1.7 → 1.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/README.md +2 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +13 -5
  3. package/android/src/main/java/com/rnllama/RNLlama.java +39 -0
  4. package/android/src/main/jni.cpp +28 -2
  5. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  6. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/cpp/common.cpp +3 -0
  8. package/cpp/common.h +2 -0
  9. package/cpp/ggml-aarch64.c +1794 -1368
  10. package/cpp/ggml-alloc.c +6 -0
  11. package/cpp/ggml-backend-impl.h +10 -9
  12. package/cpp/ggml-backend.c +25 -0
  13. package/cpp/ggml-backend.h +2 -1
  14. package/cpp/ggml-cpu-impl.h +614 -0
  15. package/cpp/ggml-impl.h +13 -609
  16. package/cpp/ggml-metal.m +1 -0
  17. package/cpp/ggml-quants.c +1 -0
  18. package/cpp/ggml.c +457 -144
  19. package/cpp/ggml.h +37 -8
  20. package/cpp/llama-impl.h +2 -0
  21. package/cpp/llama-sampling.cpp +7 -5
  22. package/cpp/llama-vocab.cpp +1 -5
  23. package/cpp/llama-vocab.h +9 -5
  24. package/cpp/llama.cpp +202 -30
  25. package/cpp/llama.h +2 -0
  26. package/cpp/log.cpp +1 -1
  27. package/cpp/log.h +2 -0
  28. package/cpp/sampling.cpp +9 -1
  29. package/cpp/sgemm.cpp +1 -0
  30. package/cpp/unicode.cpp +1 -0
  31. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  32. package/lib/commonjs/index.js +12 -1
  33. package/lib/commonjs/index.js.map +1 -1
  34. package/lib/module/NativeRNLlama.js.map +1 -1
  35. package/lib/module/index.js +11 -1
  36. package/lib/module/index.js.map +1 -1
  37. package/lib/typescript/NativeRNLlama.d.ts +6 -0
  38. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  39. package/lib/typescript/index.d.ts +3 -2
  40. package/lib/typescript/index.d.ts.map +1 -1
  41. package/package.json +1 -1
  42. package/src/NativeRNLlama.ts +7 -0
  43. package/src/index.ts +23 -4
package/README.md CHANGED
@@ -12,6 +12,8 @@ The following features have been added for Android:
12
12
  - tokenizeSync: non-blocking, synchronous tokenizer function
13
13
  - Context Shift taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)
14
14
  - XTC sampling
15
+ - Progress callback
16
+ - Retrieving CPU Features to check for i8mm and dotprod flags
15
17
 
16
18
  Original repo README.md below.
17
19
 
@@ -71,6 +71,7 @@ public class LlamaContext {
71
71
  }
72
72
 
73
73
  this.id = id;
74
+ eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
74
75
  this.context = initContext(
75
76
  // String model,
76
77
  params.getString("model"),
@@ -97,11 +98,11 @@ public class LlamaContext {
97
98
  // float rope_freq_base,
98
99
  params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
99
100
  // float rope_freq_scale
100
- params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f
101
+ params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
102
+ this
101
103
  );
102
104
  this.modelDetails = loadModelDetails(this.context);
103
105
  this.reactContext = reactContext;
104
- eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
105
106
  }
106
107
 
107
108
  public long getContext() {
@@ -324,7 +325,7 @@ public class LlamaContext {
324
325
  }
325
326
  }
326
327
 
327
- private static boolean isArm64V8a() {
328
+ public static boolean isArm64V8a() {
328
329
  return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
329
330
  }
330
331
 
@@ -332,7 +333,7 @@ public class LlamaContext {
332
333
  return Build.SUPPORTED_ABIS[0].equals("x86_64");
333
334
  }
334
335
 
335
- private static String getCpuFeatures() {
336
+ public static String getCpuFeatures() {
336
337
  File file = new File("/proc/cpuinfo");
337
338
  StringBuilder stringBuilder = new StringBuilder();
338
339
  try {
@@ -352,6 +353,12 @@ public class LlamaContext {
352
353
  }
353
354
  }
354
355
 
356
+ public void emitModelProgressUpdate(int progress) {
357
+ WritableMap event = Arguments.createMap();
358
+ event.putInt("progress", progress);
359
+ eventEmitter.emit("@RNLlama_onModelProgress", event);
360
+ }
361
+
355
362
  protected static native long initContext(
356
363
  String model,
357
364
  boolean embedding,
@@ -365,7 +372,8 @@ public class LlamaContext {
365
372
  String lora,
366
373
  float lora_scaled,
367
374
  float rope_freq_base,
368
- float rope_freq_scale
375
+ float rope_freq_scale,
376
+ LlamaContext javaLlamaContext
369
377
  );
370
378
  protected static native WritableMap loadModelDetails(
371
379
  long contextPtr
@@ -294,7 +294,46 @@ public class RNLlama implements LifecycleEventListener {
294
294
  return context.tokenize(text);
295
295
  }
296
296
 
297
+ public void getCpuFeatures(Promise promise) {
298
+ AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
299
+ private Exception exception;
300
+ @Override
301
+ protected WritableMap doInBackground(Void... voids) {
302
+ try {
303
+ WritableMap result = Arguments.createMap();
304
+ boolean isV8 = LlamaContext.isArm64V8a();
305
+ result.putBoolean("armv8", isV8);
306
+
307
+ if(isV8) {
308
+ String cpuFeatures = LlamaContext.getCpuFeatures();
309
+ boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
310
+ boolean hasInt8Matmul = cpuFeatures.contains("i8mm");
311
+ result.putBoolean("i8mm", hasInt8Matmul);
312
+ result.putBoolean("dotprod", hasDotProd);
313
+ } else {
314
+ result.putBoolean("i8mm", false);
315
+ result.putBoolean("dotprod", false);
316
+ }
317
+ return result;
318
+ } catch (Exception e) {
319
+ exception = e;
320
+ return null;
321
+ }
322
+ }
297
323
 
324
+ @Override
325
+ protected void onPostExecute(WritableMap result) {
326
+ if (exception != null) {
327
+ promise.reject(exception);
328
+ return;
329
+ }
330
+ promise.resolve(result);
331
+ tasks.remove(this);
332
+ }
333
+ }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
334
+ tasks.put(task, "getCPUFeatures");
335
+ }
336
+
298
337
  public void detokenize(double id, final ReadableArray tokens, final Promise promise) {
299
338
  final int contextId = (int) id;
300
339
  AsyncTask task = new AsyncTask<Void, Void, String>() {
@@ -128,6 +128,13 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
128
128
 
129
129
  std::unordered_map<long, rnllama::llama_rn_context *> context_map;
130
130
 
131
+ struct CallbackContext {
132
+ JNIEnv * env;
133
+ jobject thiz;
134
+ jmethodID sendProgressMethod;
135
+ unsigned current;
136
+ };
137
+
131
138
  JNIEXPORT jlong JNICALL
132
139
  Java_com_rnllama_LlamaContext_initContext(
133
140
  JNIEnv *env,
@@ -144,7 +151,8 @@ Java_com_rnllama_LlamaContext_initContext(
144
151
  jstring lora_str,
145
152
  jfloat lora_scaled,
146
153
  jfloat rope_freq_base,
147
- jfloat rope_freq_scale
154
+ jfloat rope_freq_scale,
155
+ jobject javaLlamaContext
148
156
  ) {
149
157
  UNUSED(thiz);
150
158
 
@@ -169,7 +177,7 @@ Java_com_rnllama_LlamaContext_initContext(
169
177
  defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
170
178
 
171
179
  defaultParams.n_gpu_layers = n_gpu_layers;
172
-
180
+
173
181
  defaultParams.use_mlock = use_mlock;
174
182
  defaultParams.use_mmap = use_mmap;
175
183
 
@@ -182,6 +190,24 @@ Java_com_rnllama_LlamaContext_initContext(
182
190
  defaultParams.rope_freq_base = rope_freq_base;
183
191
  defaultParams.rope_freq_scale = rope_freq_scale;
184
192
 
193
+ // progress callback when loading
194
+ jclass llamaContextClass = env->GetObjectClass(javaLlamaContext);
195
+ jmethodID sendProgressMethod = env->GetMethodID(llamaContextClass, "emitModelProgressUpdate", "(I)V");
196
+
197
+ CallbackContext callbackctx = {env, javaLlamaContext, sendProgressMethod, 0};
198
+
199
+ defaultParams.progress_callback_user_data = &callbackctx;
200
+ defaultParams.progress_callback = [](float progress, void * ctx) {
201
+ unsigned percentage = (unsigned) (100 * progress);
202
+ CallbackContext * cbctx = static_cast<CallbackContext*>(ctx);
203
+ // reduce call frequency by only calling method when value changes
204
+ if (percentage <= cbctx->current) return true;
205
+ cbctx->current = percentage;
206
+ cbctx->env->CallVoidMethod(cbctx->thiz, cbctx->sendProgressMethod, percentage);
207
+ return true;
208
+ };
209
+
210
+
185
211
  auto llama = new rnllama::llama_rn_context();
186
212
  bool is_model_loaded = llama->loadModel(defaultParams);
187
213
 
@@ -78,6 +78,11 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
78
78
  return rnllama.tokenizeSync(id, text);
79
79
  }
80
80
 
81
+ @ReactMethod
82
+ public void getCpuFeatures(final Promise promise) {
83
+ rnllama.getCpuFeatures(promise);
84
+ }
85
+
81
86
  @ReactMethod
82
87
  public void detokenize(double id, final ReadableArray tokens, final Promise promise) {
83
88
  rnllama.detokenize(id, tokens, promise);
@@ -79,6 +79,11 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
79
79
  return rnllama.tokenizeSync(id, text);
80
80
  }
81
81
 
82
+ @ReactMethod
83
+ public void getCpuFeatures(final Promise promise) {
84
+ rnllama.getCpuFeatures(promise);
85
+ }
86
+
82
87
  @ReactMethod
83
88
  public void detokenize(double id, final ReadableArray tokens, final Promise promise) {
84
89
  rnllama.detokenize(id, tokens, promise);
package/cpp/common.cpp CHANGED
@@ -954,6 +954,9 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
954
954
  if (params.n_gpu_layers != -1) {
955
955
  mparams.n_gpu_layers = params.n_gpu_layers;
956
956
  }
957
+
958
+ mparams.progress_callback_user_data = params.progress_callback_user_data;
959
+ mparams.progress_callback = params.progress_callback;
957
960
  mparams.vocab_only = params.vocab_only;
958
961
  mparams.rpc_servers = params.rpc_servers.c_str();
959
962
  mparams.main_gpu = params.main_gpu;
package/cpp/common.h CHANGED
@@ -158,6 +158,8 @@ struct gpt_sampler_params {
158
158
 
159
159
  struct gpt_params {
160
160
 
161
+ void * progress_callback_user_data = nullptr;
162
+ llama_progress_callback progress_callback = nullptr;
161
163
  bool vocab_only = false;
162
164
  int32_t n_predict = -1; // new tokens to predict
163
165
  int32_t n_ctx = 0; // context size