cui-llama.rn 1.2.6 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/README.md +3 -2
  2. package/android/src/main/CMakeLists.txt +26 -6
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +228 -40
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/amx/amx.cpp +196 -0
  9. package/cpp/amx/amx.h +20 -0
  10. package/cpp/amx/common.h +101 -0
  11. package/cpp/amx/mmq.cpp +2524 -0
  12. package/cpp/amx/mmq.h +16 -0
  13. package/cpp/common.cpp +118 -251
  14. package/cpp/common.h +53 -30
  15. package/cpp/ggml-aarch64.c +46 -3395
  16. package/cpp/ggml-aarch64.h +0 -20
  17. package/cpp/ggml-alloc.c +6 -8
  18. package/cpp/ggml-backend-impl.h +33 -11
  19. package/cpp/ggml-backend-reg.cpp +423 -0
  20. package/cpp/ggml-backend.cpp +14 -676
  21. package/cpp/ggml-backend.h +46 -9
  22. package/cpp/ggml-common.h +6 -0
  23. package/cpp/ggml-cpu-aarch64.c +3823 -0
  24. package/cpp/ggml-cpu-aarch64.h +32 -0
  25. package/cpp/ggml-cpu-impl.h +14 -242
  26. package/cpp/ggml-cpu-quants.c +10835 -0
  27. package/cpp/ggml-cpu-quants.h +63 -0
  28. package/cpp/ggml-cpu.c +13971 -13720
  29. package/cpp/ggml-cpu.cpp +715 -0
  30. package/cpp/ggml-cpu.h +65 -63
  31. package/cpp/ggml-impl.h +285 -25
  32. package/cpp/ggml-metal.h +8 -8
  33. package/cpp/ggml-metal.m +1221 -728
  34. package/cpp/ggml-quants.c +189 -10681
  35. package/cpp/ggml-quants.h +78 -125
  36. package/cpp/ggml-threading.cpp +12 -0
  37. package/cpp/ggml-threading.h +12 -0
  38. package/cpp/ggml.c +688 -1460
  39. package/cpp/ggml.h +58 -244
  40. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  41. package/cpp/json.hpp +24766 -24766
  42. package/cpp/llama-sampling.cpp +5 -2
  43. package/cpp/llama.cpp +409 -123
  44. package/cpp/llama.h +8 -4
  45. package/cpp/rn-llama.hpp +89 -25
  46. package/cpp/sampling.cpp +42 -3
  47. package/cpp/sampling.h +22 -1
  48. package/cpp/sgemm.cpp +608 -0
  49. package/cpp/speculative.cpp +270 -0
  50. package/cpp/speculative.h +28 -0
  51. package/cpp/unicode.cpp +11 -0
  52. package/ios/RNLlama.mm +43 -20
  53. package/ios/RNLlamaContext.h +9 -3
  54. package/ios/RNLlamaContext.mm +146 -33
  55. package/jest/mock.js +0 -1
  56. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  57. package/lib/commonjs/grammar.js +4 -2
  58. package/lib/commonjs/grammar.js.map +1 -1
  59. package/lib/commonjs/index.js +52 -15
  60. package/lib/commonjs/index.js.map +1 -1
  61. package/lib/module/NativeRNLlama.js.map +1 -1
  62. package/lib/module/grammar.js +2 -1
  63. package/lib/module/grammar.js.map +1 -1
  64. package/lib/module/index.js +51 -15
  65. package/lib/module/index.js.map +1 -1
  66. package/lib/typescript/NativeRNLlama.d.ts +122 -8
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  68. package/lib/typescript/grammar.d.ts +5 -6
  69. package/lib/typescript/grammar.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +15 -6
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +2 -1
  73. package/src/NativeRNLlama.ts +135 -13
  74. package/src/grammar.ts +10 -8
  75. package/src/index.ts +104 -28
@@ -4,13 +4,15 @@
4
4
  #include <android/log.h>
5
5
  #include <cstdlib>
6
6
  #include <ctime>
7
+ #include <ctime>
7
8
  #include <sys/sysinfo.h>
8
9
  #include <string>
9
10
  #include <thread>
10
11
  #include <unordered_map>
11
12
  #include "llama.h"
12
- #include "rn-llama.hpp"
13
+ #include "llama-impl.h"
13
14
  #include "ggml.h"
15
+ #include "rn-llama.hpp"
14
16
 
15
17
  #define UNUSED(x) (void)(x)
16
18
  #define TAG "RNLLAMA_ANDROID_JNI"
@@ -22,6 +24,13 @@ static inline int min(int a, int b) {
22
24
  return (a < b) ? a : b;
23
25
  }
24
26
 
27
+ static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
28
+ if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
29
+ else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
30
+ else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
31
+ else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
32
+ }
33
+
25
34
  extern "C" {
26
35
 
27
36
  // Method to create WritableMap
@@ -107,6 +116,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
107
116
  env->CallVoidMethod(arr, pushDoubleMethod, value);
108
117
  }
109
118
 
119
+ // Method to push string into WritableArray
120
+ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
121
+ jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
122
+ jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");
123
+
124
+ jstring jValue = env->NewStringUTF(value);
125
+ env->CallVoidMethod(arr, pushStringMethod, jValue);
126
+ }
127
+
110
128
  // Method to push WritableMap into WritableArray
111
129
  static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
112
130
  jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
@@ -125,6 +143,77 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
125
143
  env->CallVoidMethod(map, putArrayMethod, jKey, value);
126
144
  }
127
145
 
146
+ JNIEXPORT jobject JNICALL
147
+ Java_com_rnllama_LlamaContext_modelInfo(
148
+ JNIEnv *env,
149
+ jobject thiz,
150
+ jstring model_path_str,
151
+ jobjectArray skip
152
+ ) {
153
+ UNUSED(thiz);
154
+
155
+ const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
156
+
157
+ std::vector<std::string> skip_vec;
158
+ int skip_len = env->GetArrayLength(skip);
159
+ for (int i = 0; i < skip_len; i++) {
160
+ jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
161
+ const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
162
+ skip_vec.push_back(skip_chars);
163
+ env->ReleaseStringUTFChars(skip_str, skip_chars);
164
+ }
165
+
166
+ struct lm_gguf_init_params params = {
167
+ /*.no_alloc = */ false,
168
+ /*.ctx = */ NULL,
169
+ };
170
+ struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);
171
+
172
+ if (!ctx) {
173
+ LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
174
+ return nullptr;
175
+ }
176
+
177
+ auto info = createWriteableMap(env);
178
+ putInt(env, info, "version", lm_gguf_get_version(ctx));
179
+ putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
180
+ putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
181
+ {
182
+ const int n_kv = lm_gguf_get_n_kv(ctx);
183
+
184
+ for (int i = 0; i < n_kv; ++i) {
185
+ const char * key = lm_gguf_get_key(ctx, i);
186
+
187
+ bool skipped = false;
188
+ if (skip_len > 0) {
189
+ for (int j = 0; j < skip_len; j++) {
190
+ if (skip_vec[j] == key) {
191
+ skipped = true;
192
+ break;
193
+ }
194
+ }
195
+ }
196
+
197
+ if (skipped) {
198
+ continue;
199
+ }
200
+
201
+ const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
202
+ putString(env, info, key, value.c_str());
203
+ }
204
+ }
205
+
206
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
207
+ lm_gguf_free(ctx);
208
+
209
+ return reinterpret_cast<jobject>(info);
210
+ }
211
+
212
+ struct callback_context {
213
+ JNIEnv *env;
214
+ rnllama::llama_rn_context *llama;
215
+ jobject callback;
216
+ };
128
217
 
129
218
  std::unordered_map<long, rnllama::llama_rn_context *> context_map;
130
219
 
@@ -141,10 +230,14 @@ Java_com_rnllama_LlamaContext_initContext(
141
230
  jobject thiz,
142
231
  jstring model_path_str,
143
232
  jboolean embedding,
233
+ jint embd_normalize,
144
234
  jint n_ctx,
145
235
  jint n_batch,
146
236
  jint n_threads,
147
237
  jint n_gpu_layers, // TODO: Support this
238
+ jboolean flash_attn,
239
+ jstring cache_type_k,
240
+ jstring cache_type_v,
148
241
  jboolean use_mlock,
149
242
  jboolean use_mmap,
150
243
  jboolean vocab_only,
@@ -152,7 +245,8 @@ Java_com_rnllama_LlamaContext_initContext(
152
245
  jfloat lora_scaled,
153
246
  jfloat rope_freq_base,
154
247
  jfloat rope_freq_scale,
155
- jobject javaLlamaContext
248
+ jint pooling_type,
249
+ jobject load_progress_callback
156
250
  ) {
157
251
  UNUSED(thiz);
158
252
 
@@ -165,65 +259,110 @@ Java_com_rnllama_LlamaContext_initContext(
165
259
 
166
260
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
167
261
  defaultParams.model = model_path_chars;
168
-
169
- defaultParams.embedding = embedding;
170
-
262
+
171
263
  defaultParams.n_ctx = n_ctx;
172
264
  defaultParams.n_batch = n_batch;
173
265
 
266
+ if (pooling_type != -1) {
267
+ defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
268
+ }
269
+
270
+ defaultParams.embedding = embedding;
271
+ if (embd_normalize != -1) {
272
+ defaultParams.embd_normalize = embd_normalize;
273
+ }
274
+ if (embedding) {
275
+ // For non-causal models, batch size must be equal to ubatch size
276
+ defaultParams.n_ubatch = defaultParams.n_batch;
277
+ }
278
+
174
279
  int max_threads = std::thread::hardware_concurrency();
175
280
  // Use 2 threads by default on 4-core devices, 4 threads on more cores
176
281
  int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
177
282
  defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
178
283
 
179
- defaultParams.n_gpu_layers = n_gpu_layers;
180
-
284
+ // defaultParams.n_gpu_layers = n_gpu_layers;
285
+ defaultParams.flash_attn = flash_attn;
286
+
287
+ const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
288
+ const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
289
+ defaultParams.cache_type_k = cache_type_k_chars;
290
+ defaultParams.cache_type_v = cache_type_v_chars;
291
+
181
292
  defaultParams.use_mlock = use_mlock;
182
293
  defaultParams.use_mmap = use_mmap;
183
294
 
184
295
  const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
185
296
  if (lora_chars != nullptr && lora_chars[0] != '\0') {
186
297
  defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
187
- defaultParams.use_mmap = false;
188
298
  }
189
299
 
190
300
  defaultParams.rope_freq_base = rope_freq_base;
191
301
  defaultParams.rope_freq_scale = rope_freq_scale;
192
302
 
193
- // progress callback when loading
194
- jclass llamaContextClass = env->GetObjectClass(javaLlamaContext);
195
- jmethodID sendProgressMethod = env->GetMethodID(llamaContextClass, "emitModelProgressUpdate", "(I)V");
196
-
197
- CallbackContext callbackctx = {env, javaLlamaContext, sendProgressMethod, 0};
198
-
199
- defaultParams.progress_callback_user_data = &callbackctx;
200
- defaultParams.progress_callback = [](float progress, void * ctx) {
201
- unsigned percentage = (unsigned) (100 * progress);
202
- CallbackContext * cbctx = static_cast<CallbackContext*>(ctx);
203
- // reduce call frequency by only calling method when value changes
204
- if (percentage <= cbctx->current) return true;
205
- cbctx->current = percentage;
206
- cbctx->env->CallVoidMethod(cbctx->thiz, cbctx->sendProgressMethod, percentage);
207
- return true;
208
- };
209
-
210
-
211
303
  auto llama = new rnllama::llama_rn_context();
304
+ llama->is_load_interrupted = false;
305
+ llama->loading_progress = 0;
306
+
307
+ if (load_progress_callback != nullptr) {
308
+ defaultParams.progress_callback = [](float progress, void * user_data) {
309
+ callback_context *cb_ctx = (callback_context *)user_data;
310
+ JNIEnv *env = cb_ctx->env;
311
+ auto llama = cb_ctx->llama;
312
+ jobject callback = cb_ctx->callback;
313
+ int percentage = (int) (100 * progress);
314
+ if (percentage > llama->loading_progress) {
315
+ llama->loading_progress = percentage;
316
+ jclass callback_class = env->GetObjectClass(callback);
317
+ jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
318
+ env->CallVoidMethod(callback, onLoadProgress, percentage);
319
+ }
320
+ return !llama->is_load_interrupted;
321
+ };
322
+
323
+ callback_context *cb_ctx = new callback_context;
324
+ cb_ctx->env = env;
325
+ cb_ctx->llama = llama;
326
+ cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
327
+ defaultParams.progress_callback_user_data = cb_ctx;
328
+ }
329
+
212
330
  bool is_model_loaded = llama->loadModel(defaultParams);
213
331
 
332
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
333
+ env->ReleaseStringUTFChars(lora_str, lora_chars);
334
+ env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
335
+ env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
336
+
214
337
  LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
215
338
  if (is_model_loaded) {
339
+ if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
340
+ LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
341
+ llama_free(llama->ctx);
342
+ return -1;
343
+ }
216
344
  context_map[(long) llama->ctx] = llama;
217
345
  } else {
218
346
  llama_free(llama->ctx);
219
347
  }
220
348
 
221
- env->ReleaseStringUTFChars(model_path_str, model_path_chars);
222
- env->ReleaseStringUTFChars(lora_str, lora_chars);
223
-
224
349
  return reinterpret_cast<jlong>(llama->ctx);
225
350
  }
226
351
 
352
+
353
+ JNIEXPORT void JNICALL
354
+ Java_com_rnllama_LlamaContext_interruptLoad(
355
+ JNIEnv *env,
356
+ jobject thiz,
357
+ jlong context_ptr
358
+ ) {
359
+ UNUSED(thiz);
360
+ auto llama = context_map[(long) context_ptr];
361
+ if (llama) {
362
+ llama->is_load_interrupted = true;
363
+ }
364
+ }
365
+
227
366
  JNIEXPORT jobject JNICALL
228
367
  Java_com_rnllama_LlamaContext_loadModelDetails(
229
368
  JNIEnv *env,
@@ -397,13 +536,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
397
536
  jint top_k,
398
537
  jfloat top_p,
399
538
  jfloat min_p,
400
- jfloat xtc_t,
401
- jfloat xtc_p,
539
+ jfloat xtc_threshold,
540
+ jfloat xtc_probability,
402
541
  jfloat typical_p,
403
542
  jint seed,
404
543
  jobjectArray stop,
405
544
  jboolean ignore_eos,
406
545
  jobjectArray logit_bias,
546
+ jfloat dry_multiplier,
547
+ jfloat dry_base,
548
+ jint dry_allowed_length,
549
+ jint dry_penalty_last_n,
550
+ jobjectArray dry_sequence_breakers,
407
551
  jobject partial_completion_callback
408
552
  ) {
409
553
  UNUSED(thiz);
@@ -414,7 +558,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
414
558
  //llama_reset_timings(llama->ctx);
415
559
 
416
560
  llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
417
- llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed;
561
+ llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
418
562
 
419
563
  int max_threads = std::thread::hardware_concurrency();
420
564
  // Use 2 threads by default on 4-core devices, 4 threads on more cores
@@ -422,9 +566,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
422
566
  llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
423
567
 
424
568
  llama->params.n_predict = n_predict;
425
- llama->params.sparams.ignore_eos = ignore_eos;
569
+ llama->params.sampling.ignore_eos = ignore_eos;
426
570
 
427
- auto & sparams = llama->params.sparams;
571
+ auto & sparams = llama->params.sampling;
428
572
  sparams.temp = temperature;
429
573
  sparams.penalty_last_n = penalty_last_n;
430
574
  sparams.penalty_repeat = penalty_repeat;
@@ -440,14 +584,34 @@ Java_com_rnllama_LlamaContext_doCompletion(
440
584
  sparams.typ_p = typical_p;
441
585
  sparams.n_probs = n_probs;
442
586
  sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
443
- sparams.xtc_threshold = xtc_t;
444
- sparams.xtc_probability = xtc_p;
587
+ sparams.xtc_threshold = xtc_threshold;
588
+ sparams.xtc_probability = xtc_probability;
589
+ sparams.dry_multiplier = dry_multiplier;
590
+ sparams.dry_base = dry_base;
591
+ sparams.dry_allowed_length = dry_allowed_length;
592
+ sparams.dry_penalty_last_n = dry_penalty_last_n;
445
593
 
446
594
  sparams.logit_bias.clear();
447
595
  if (ignore_eos) {
448
596
  sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
449
597
  }
450
598
 
599
+ // dry break seq
600
+
601
+ jint size = env->GetArrayLength(dry_sequence_breakers);
602
+ std::vector<std::string> dry_sequence_breakers_vector;
603
+
604
+ for (jint i = 0; i < size; i++) {
605
+ jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
606
+ const char *nativeString = env->GetStringUTFChars(javaString, 0);
607
+ dry_sequence_breakers_vector.push_back(std::string(nativeString));
608
+ env->ReleaseStringUTFChars(javaString, nativeString);
609
+ env->DeleteLocalRef(javaString);
610
+ }
611
+
612
+ sparams.dry_sequence_breakers = dry_sequence_breakers_vector;
613
+
614
+ // logit bias
451
615
  const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
452
616
  jsize logit_bias_len = env->GetArrayLength(logit_bias);
453
617
 
@@ -529,7 +693,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
529
693
  auto tokenResult = createWriteableMap(env);
530
694
  putString(env, tokenResult, "token", to_send.c_str());
531
695
 
532
- if (llama->params.sparams.n_probs > 0) {
696
+ if (llama->params.sampling.n_probs > 0) {
533
697
  const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
534
698
  size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
535
699
  size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
@@ -651,16 +815,27 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(
651
815
 
652
816
  JNIEXPORT jobject JNICALL
653
817
  Java_com_rnllama_LlamaContext_embedding(
654
- JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
818
+ JNIEnv *env, jobject thiz,
819
+ jlong context_ptr,
820
+ jstring text,
821
+ jint embd_normalize
822
+ ) {
655
823
  UNUSED(thiz);
656
824
  auto llama = context_map[(long) context_ptr];
657
825
 
826
+ common_params embdParams;
827
+ embdParams.embedding = true;
828
+ embdParams.embd_normalize = llama->params.embd_normalize;
829
+ if (embd_normalize != -1) {
830
+ embdParams.embd_normalize = embd_normalize;
831
+ }
832
+
658
833
  const char *text_chars = env->GetStringUTFChars(text, nullptr);
659
834
 
660
835
  llama->rewind();
661
836
 
662
837
  llama_perf_context_reset(llama->ctx);
663
-
838
+
664
839
  llama->params.prompt = text_chars;
665
840
 
666
841
  llama->params.n_predict = 0;
@@ -675,7 +850,7 @@ Java_com_rnllama_LlamaContext_embedding(
675
850
  llama->loadPrompt();
676
851
  llama->doCompletion();
677
852
 
678
- std::vector<float> embedding = llama->getEmbedding();
853
+ std::vector<float> embedding = llama->getEmbedding(embdParams);
679
854
 
680
855
  auto embeddings = createWritableArray(env);
681
856
  for (const auto &val : embedding) {
@@ -683,6 +858,12 @@ Java_com_rnllama_LlamaContext_embedding(
683
858
  }
684
859
  putArray(env, result, "embedding", embeddings);
685
860
 
861
+ auto promptTokens = createWritableArray(env);
862
+ for (const auto &tok : llama->embd) {
863
+ pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
864
+ }
865
+ putArray(env, result, "prompt_tokens", promptTokens);
866
+
686
867
  env->ReleaseStringUTFChars(text, text_chars);
687
868
  return result;
688
869
  }
@@ -722,4 +903,11 @@ Java_com_rnllama_LlamaContext_freeContext(
722
903
  context_map.erase((long) llama->ctx);
723
904
  }
724
905
 
906
+ JNIEXPORT void JNICALL
907
+ Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
908
+ UNUSED(env);
909
+ UNUSED(thiz);
910
+ llama_log_set(log_callback, NULL);
911
+ }
912
+
725
913
  } // extern "C"
@@ -39,8 +39,13 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
39
39
  }
40
40
 
41
41
  @ReactMethod
42
- public void initContext(final ReadableMap params, final Promise promise) {
43
- rnllama.initContext(params, promise);
42
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
43
+ rnllama.modelInfo(model, skip, promise);
44
+ }
45
+
46
+ @ReactMethod
47
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
48
+ rnllama.initContext(id, params, promise);
44
49
  }
45
50
 
46
51
  @ReactMethod
@@ -89,8 +94,8 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
89
94
  }
90
95
 
91
96
  @ReactMethod
92
- public void embedding(double id, final String text, final Promise promise) {
93
- rnllama.embedding(id, text, promise);
97
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
98
+ rnllama.embedding(id, text, params, promise);
94
99
  }
95
100
 
96
101
  @ReactMethod
@@ -40,8 +40,13 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
40
40
  }
41
41
 
42
42
  @ReactMethod
43
- public void initContext(final ReadableMap params, final Promise promise) {
44
- rnllama.initContext(params, promise);
43
+ public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
44
+ rnllama.modelInfo(model, skip, promise);
45
+ }
46
+
47
+ @ReactMethod
48
+ public void initContext(double id, final ReadableMap params, final Promise promise) {
49
+ rnllama.initContext(id, params, promise);
45
50
  }
46
51
 
47
52
  @ReactMethod
@@ -90,8 +95,8 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
90
95
  }
91
96
 
92
97
  @ReactMethod
93
- public void embedding(double id, final String text, final Promise promise) {
94
- rnllama.embedding(id, text, promise);
98
+ public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
99
+ rnllama.embedding(id, text, params, promise);
95
100
  }
96
101
 
97
102
  @ReactMethod
@@ -0,0 +1,196 @@
1
+ #include "amx.h"
2
+ #include "common.h"
3
+ #include "mmq.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-backend.h"
6
+ #include "ggml-impl.h"
7
+ #include "ggml-cpu.h"
8
+
9
+ #if defined(__gnu_linux__)
10
+ #include <sys/syscall.h>
11
+ #include <unistd.h>
12
+ #endif
13
+
14
+ #include <cstdlib>
15
+ #include <cstring>
16
+ #include <memory>
17
+
18
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
19
+
20
+ // AMX buffer interface
21
+ static void lm_ggml_backend_amx_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) {
22
+ free(buffer->context);
23
+ }
24
+
25
+ static void * lm_ggml_backend_amx_buffer_get_base(lm_ggml_backend_buffer_t buffer) {
26
+ return (void *)(buffer->context);
27
+ }
28
+
29
+ static void lm_ggml_backend_amx_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
30
+ memset((char *)tensor->data + offset, value, size);
31
+
32
+ LM_GGML_UNUSED(buffer);
33
+ }
34
+
35
+ static void lm_ggml_backend_amx_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
36
+ if (qtype_has_amx_kernels(tensor->type)) {
37
+ lm_ggml_backend_amx_convert_weight(tensor, data, offset, size);
38
+ } else {
39
+ memcpy((char *)tensor->data + offset, data, size);
40
+ }
41
+
42
+ LM_GGML_UNUSED(buffer);
43
+ }
44
+
45
+ static void lm_ggml_backend_amx_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
46
+ LM_GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
47
+ memcpy(data, (const char *)tensor->data + offset, size);
48
+
49
+ LM_GGML_UNUSED(buffer);
50
+ }
51
+
52
+ static bool lm_ggml_backend_amx_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) {
53
+ if (lm_ggml_backend_buffer_is_host(src->buffer)) {
54
+ if (qtype_has_amx_kernels(src->type)) {
55
+ lm_ggml_backend_amx_convert_weight(dst, src->data, 0, lm_ggml_nbytes(dst));
56
+ } else {
57
+ memcpy(dst->data, src->data, lm_ggml_nbytes(src));
58
+ }
59
+ return true;
60
+ }
61
+ return false;
62
+
63
+ LM_GGML_UNUSED(buffer);
64
+ }
65
+
66
+ static void lm_ggml_backend_amx_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) {
67
+ memset(buffer->context, value, buffer->size);
68
+ }
69
+
70
+ static lm_ggml_backend_buffer_i lm_ggml_backend_amx_buffer_interface = {
71
+ /* .free_buffer = */ lm_ggml_backend_amx_buffer_free_buffer,
72
+ /* .get_base = */ lm_ggml_backend_amx_buffer_get_base,
73
+ /* .init_tensor = */ NULL, // no initialization required
74
+ /* .memset_tensor = */ lm_ggml_backend_amx_buffer_memset_tensor,
75
+ /* .set_tensor = */ lm_ggml_backend_amx_buffer_set_tensor,
76
+ /* .get_tensor = */ lm_ggml_backend_amx_buffer_get_tensor,
77
+ /* .cpy_tensor = */ lm_ggml_backend_amx_buffer_cpy_tensor,
78
+ /* .clear = */ lm_ggml_backend_amx_buffer_clear,
79
+ /* .reset = */ NULL,
80
+ };
81
+
82
+ static const char * lm_ggml_backend_amx_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
83
+ return "AMX";
84
+
85
+ LM_GGML_UNUSED(buft);
86
+ }
87
+
88
+ static lm_ggml_backend_buffer_t lm_ggml_backend_amx_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
89
+ void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
90
+ if (data == NULL) {
91
+ fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
92
+ return NULL;
93
+ }
94
+
95
+ return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_amx_buffer_interface, data, size);
96
+ }
97
+
98
+ static size_t lm_ggml_backend_amx_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
99
+ return TENSOR_ALIGNMENT;
100
+
101
+ LM_GGML_UNUSED(buft);
102
+ }
103
+
104
+ static size_t lm_ggml_backend_amx_buffer_type_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const lm_ggml_tensor* tensor) {
105
+ return lm_ggml_backend_amx_get_alloc_size(tensor);
106
+
107
+ LM_GGML_UNUSED(buft);
108
+ }
109
+
110
+ static bool lm_ggml_backend_amx_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) {
111
+ return false;
112
+
113
+ LM_GGML_UNUSED(buft);
114
+ }
115
+
116
+ #define ARCH_GET_XCOMP_PERM 0x1022
117
+ #define ARCH_REQ_XCOMP_PERM 0x1023
118
+ #define XFEATURE_XTILECFG 17
119
+ #define XFEATURE_XTILEDATA 18
120
+
121
+ static bool lm_ggml_amx_init() {
122
+ #if defined(__gnu_linux__)
123
+ if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
124
+ fprintf(stderr, "AMX is not ready to be used!\n");
125
+ return false;
126
+ }
127
+ return true;
128
+ #elif defined(_WIN32)
129
+ return true;
130
+ #endif
131
+ }
132
+ lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type() {
133
+ static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_type_amx = {
134
+ /* .iface = */ {
135
+ /* .get_name = */ lm_ggml_backend_amx_buffer_type_get_name,
136
+ /* .alloc_buffer = */ lm_ggml_backend_amx_buffer_type_alloc_buffer,
137
+ /* .get_alignment = */ lm_ggml_backend_amx_buffer_type_get_alignment,
138
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
139
+ /* .get_alloc_size = */ lm_ggml_backend_amx_buffer_type_get_alloc_size,
140
+ /* .is_host = */ lm_ggml_backend_amx_buffer_type_is_host,
141
+ },
142
+ /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
143
+ /* .context = */ NULL,
144
+ };
145
+
146
+ if (!lm_ggml_amx_init()) {
147
+ return NULL;
148
+ }
149
+
150
+ return &lm_ggml_backend_buffer_type_amx;
151
+ }
152
+
153
+ bool lm_ggml_backend_amx_buft_is_amx(lm_ggml_backend_buffer_type_t buft) {
154
+ return buft->iface.get_name == lm_ggml_backend_amx_buffer_type_get_name;
155
+ }
156
+
157
+ bool lm_ggml_backend_amx_device_supports_op(const struct lm_ggml_tensor * op) {
158
+ // handle only 2d gemm for now
159
+ auto is_contiguous_2d = [](const struct lm_ggml_tensor * t) {
160
+ return lm_ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
161
+ };
162
+
163
+ switch (op->op) {
164
+ case LM_GGML_OP_NONE:
165
+ case LM_GGML_OP_RESHAPE:
166
+ case LM_GGML_OP_VIEW:
167
+ case LM_GGML_OP_PERMUTE:
168
+ case LM_GGML_OP_TRANSPOSE:
169
+ return true;
170
+
171
+ case LM_GGML_OP_MUL_MAT: {
172
+ const struct lm_ggml_tensor * src0 = op->src[0];
173
+ const struct lm_ggml_tensor * src1 = op->src[1];
174
+
175
+ const enum lm_ggml_type type = src0->type;
176
+ const int64_t ne0 = op->ne[0];
177
+
178
+ // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
179
+ // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
180
+ bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == LM_GGML_TYPE_F16);
181
+
182
+ bool can_use_amx =
183
+ is_contiguous_2d(src0) && // src0 must be contiguous
184
+ is_contiguous_2d(src1) && // src1 must be contiguous
185
+ src1->type == LM_GGML_TYPE_F32 && // src1 must be float32
186
+ has_amx_kernels && // with amx kernel impls
187
+ ne0 % (TILE_N * 2) == 0; // out_features is 32x
188
+
189
+ return can_use_amx;
190
+ }
191
+ default:
192
+ return false;
193
+ }
194
+ }
195
+
196
+ #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)