cui-llama.rn 1.1.6 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -71,6 +71,7 @@ public class LlamaContext {
71
71
  }
72
72
 
73
73
  this.id = id;
74
+ eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
74
75
  this.context = initContext(
75
76
  // String model,
76
77
  params.getString("model"),
@@ -97,11 +98,11 @@ public class LlamaContext {
97
98
  // float rope_freq_base,
98
99
  params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
99
100
  // float rope_freq_scale
100
- params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f
101
+ params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
102
+ this
101
103
  );
102
104
  this.modelDetails = loadModelDetails(this.context);
103
105
  this.reactContext = reactContext;
104
- eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
105
106
  }
106
107
 
107
108
  public long getContext() {
@@ -352,6 +353,12 @@ public class LlamaContext {
352
353
  }
353
354
  }
354
355
 
356
+ public void emitModelProgressUpdate(int progress) {
357
+ WritableMap event = Arguments.createMap();
358
+ event.putInt("progress", progress);
359
+ eventEmitter.emit("@RNLlama_onModelProgress", event);
360
+ }
361
+
355
362
  protected static native long initContext(
356
363
  String model,
357
364
  boolean embedding,
@@ -365,7 +372,8 @@ public class LlamaContext {
365
372
  String lora,
366
373
  float lora_scaled,
367
374
  float rope_freq_base,
368
- float rope_freq_scale
375
+ float rope_freq_scale,
376
+ LlamaContext javaLlamaContext
369
377
  );
370
378
  protected static native WritableMap loadModelDetails(
371
379
  long contextPtr
@@ -128,6 +128,13 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
128
128
 
129
129
  std::unordered_map<long, rnllama::llama_rn_context *> context_map;
130
130
 
131
+ struct CallbackContext {
132
+ JNIEnv * env;
133
+ jobject thiz;
134
+ jmethodID sendProgressMethod;
135
+ unsigned current;
136
+ };
137
+
131
138
  JNIEXPORT jlong JNICALL
132
139
  Java_com_rnllama_LlamaContext_initContext(
133
140
  JNIEnv *env,
@@ -144,7 +151,8 @@ Java_com_rnllama_LlamaContext_initContext(
144
151
  jstring lora_str,
145
152
  jfloat lora_scaled,
146
153
  jfloat rope_freq_base,
147
- jfloat rope_freq_scale
154
+ jfloat rope_freq_scale,
155
+ jobject javaLlamaContext
148
156
  ) {
149
157
  UNUSED(thiz);
150
158
 
@@ -169,7 +177,7 @@ Java_com_rnllama_LlamaContext_initContext(
169
177
  defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
170
178
 
171
179
  defaultParams.n_gpu_layers = n_gpu_layers;
172
-
180
+
173
181
  defaultParams.use_mlock = use_mlock;
174
182
  defaultParams.use_mmap = use_mmap;
175
183
 
@@ -182,6 +190,24 @@ Java_com_rnllama_LlamaContext_initContext(
182
190
  defaultParams.rope_freq_base = rope_freq_base;
183
191
  defaultParams.rope_freq_scale = rope_freq_scale;
184
192
 
193
+ // progress callback when loading
194
+ jclass llamaContextClass = env->GetObjectClass(javaLlamaContext);
195
+ jmethodID sendProgressMethod = env->GetMethodID(llamaContextClass, "emitModelProgressUpdate", "(I)V");
196
+
197
+ CallbackContext callbackctx = {env, javaLlamaContext, sendProgressMethod, 0};
198
+
199
+ defaultParams.progress_callback_user_data = &callbackctx;
200
+ defaultParams.progress_callback = [](float progress, void * ctx) {
201
+ unsigned percentage = (unsigned) (100 * progress);
202
+ CallbackContext * cbctx = static_cast<CallbackContext*>(ctx);
203
+ // reduce call frequency by only calling method when value changes
204
+ if (percentage <= cbctx->current) return true;
205
+ cbctx->current = percentage;
206
+ cbctx->env->CallVoidMethod(cbctx->thiz, cbctx->sendProgressMethod, percentage);
207
+ return true;
208
+ };
209
+
210
+
185
211
  auto llama = new rnllama::llama_rn_context();
186
212
  bool is_model_loaded = llama->loadModel(defaultParams);
187
213
 
@@ -636,9 +662,7 @@ Java_com_rnllama_LlamaContext_embedding(
636
662
  llama->rewind();
637
663
 
638
664
  llama_perf_context_reset(llama->ctx);
639
- gpt_sampler_reset(llama->ctx_sampling);
640
665
 
641
-
642
666
  llama->params.prompt = text_chars;
643
667
 
644
668
  llama->params.n_predict = 0;
package/cpp/common.cpp CHANGED
@@ -954,6 +954,9 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
954
954
  if (params.n_gpu_layers != -1) {
955
955
  mparams.n_gpu_layers = params.n_gpu_layers;
956
956
  }
957
+
958
+ mparams.progress_callback_user_data = params.progress_callback_user_data;
959
+ mparams.progress_callback = params.progress_callback;
957
960
  mparams.vocab_only = params.vocab_only;
958
961
  mparams.rpc_servers = params.rpc_servers.c_str();
959
962
  mparams.main_gpu = params.main_gpu;
package/cpp/common.h CHANGED
@@ -158,6 +158,8 @@ struct gpt_sampler_params {
158
158
 
159
159
  struct gpt_params {
160
160
 
161
+ void * progress_callback_user_data = nullptr;
162
+ llama_progress_callback progress_callback = nullptr;
161
163
  bool vocab_only = false;
162
164
  int32_t n_predict = -1; // new tokens to predict
163
165
  int32_t n_ctx = 0; // context size