cui-llama.rn 1.4.3 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +93 -114
  2. package/android/src/main/CMakeLists.txt +5 -0
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +91 -17
  4. package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
  5. package/android/src/main/jni-utils.h +6 -0
  6. package/android/src/main/jni.cpp +289 -31
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
  16. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
  17. package/cpp/chat-template.hpp +529 -0
  18. package/cpp/chat.cpp +1779 -0
  19. package/cpp/chat.h +135 -0
  20. package/cpp/common.cpp +2064 -1873
  21. package/cpp/common.h +700 -699
  22. package/cpp/ggml-alloc.c +1039 -1042
  23. package/cpp/ggml-alloc.h +1 -1
  24. package/cpp/ggml-backend-impl.h +255 -255
  25. package/cpp/ggml-backend-reg.cpp +586 -582
  26. package/cpp/ggml-backend.cpp +2004 -2002
  27. package/cpp/ggml-backend.h +354 -354
  28. package/cpp/ggml-common.h +1851 -1853
  29. package/cpp/ggml-cpp.h +39 -39
  30. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  31. package/cpp/ggml-cpu-aarch64.h +8 -8
  32. package/cpp/ggml-cpu-impl.h +531 -386
  33. package/cpp/ggml-cpu-quants.c +12527 -10920
  34. package/cpp/ggml-cpu-traits.cpp +36 -36
  35. package/cpp/ggml-cpu-traits.h +38 -38
  36. package/cpp/ggml-cpu.c +15766 -14391
  37. package/cpp/ggml-cpu.cpp +655 -635
  38. package/cpp/ggml-cpu.h +138 -135
  39. package/cpp/ggml-impl.h +567 -567
  40. package/cpp/ggml-metal-impl.h +235 -0
  41. package/cpp/ggml-metal.h +1 -1
  42. package/cpp/ggml-metal.m +5146 -4884
  43. package/cpp/ggml-opt.cpp +854 -854
  44. package/cpp/ggml-opt.h +216 -216
  45. package/cpp/ggml-quants.c +5238 -5238
  46. package/cpp/ggml-threading.h +14 -14
  47. package/cpp/ggml.c +6529 -6514
  48. package/cpp/ggml.h +2198 -2194
  49. package/cpp/gguf.cpp +1329 -1329
  50. package/cpp/gguf.h +202 -202
  51. package/cpp/json-schema-to-grammar.cpp +1024 -1045
  52. package/cpp/json-schema-to-grammar.h +21 -8
  53. package/cpp/json.hpp +24766 -24766
  54. package/cpp/llama-adapter.cpp +347 -347
  55. package/cpp/llama-adapter.h +74 -74
  56. package/cpp/llama-arch.cpp +1513 -1487
  57. package/cpp/llama-arch.h +403 -400
  58. package/cpp/llama-batch.cpp +368 -368
  59. package/cpp/llama-batch.h +88 -88
  60. package/cpp/llama-chat.cpp +588 -578
  61. package/cpp/llama-chat.h +53 -52
  62. package/cpp/llama-context.cpp +1775 -1775
  63. package/cpp/llama-context.h +128 -128
  64. package/cpp/llama-cparams.cpp +1 -1
  65. package/cpp/llama-cparams.h +37 -37
  66. package/cpp/llama-cpp.h +30 -30
  67. package/cpp/llama-grammar.cpp +1219 -1139
  68. package/cpp/llama-grammar.h +173 -143
  69. package/cpp/llama-hparams.cpp +71 -71
  70. package/cpp/llama-hparams.h +139 -139
  71. package/cpp/llama-impl.cpp +167 -167
  72. package/cpp/llama-impl.h +61 -61
  73. package/cpp/llama-kv-cache.cpp +718 -718
  74. package/cpp/llama-kv-cache.h +219 -218
  75. package/cpp/llama-mmap.cpp +600 -590
  76. package/cpp/llama-mmap.h +68 -67
  77. package/cpp/llama-model-loader.cpp +1124 -1124
  78. package/cpp/llama-model-loader.h +167 -167
  79. package/cpp/llama-model.cpp +4087 -3997
  80. package/cpp/llama-model.h +370 -370
  81. package/cpp/llama-sampling.cpp +2558 -2408
  82. package/cpp/llama-sampling.h +32 -32
  83. package/cpp/llama-vocab.cpp +3264 -3247
  84. package/cpp/llama-vocab.h +125 -125
  85. package/cpp/llama.cpp +10284 -10077
  86. package/cpp/llama.h +1354 -1323
  87. package/cpp/log.cpp +393 -401
  88. package/cpp/log.h +132 -121
  89. package/cpp/minja/chat-template.hpp +529 -0
  90. package/cpp/minja/minja.hpp +2915 -0
  91. package/cpp/minja.hpp +2915 -0
  92. package/cpp/rn-llama.cpp +66 -6
  93. package/cpp/rn-llama.h +26 -1
  94. package/cpp/sampling.cpp +570 -505
  95. package/cpp/sampling.h +3 -0
  96. package/cpp/sgemm.cpp +2598 -2597
  97. package/cpp/sgemm.h +14 -14
  98. package/cpp/speculative.cpp +278 -277
  99. package/cpp/speculative.h +28 -28
  100. package/cpp/unicode.cpp +9 -2
  101. package/ios/CMakeLists.txt +6 -0
  102. package/ios/RNLlama.h +0 -8
  103. package/ios/RNLlama.mm +27 -3
  104. package/ios/RNLlamaContext.h +10 -1
  105. package/ios/RNLlamaContext.mm +269 -57
  106. package/jest/mock.js +21 -2
  107. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  108. package/lib/commonjs/grammar.js +3 -0
  109. package/lib/commonjs/grammar.js.map +1 -1
  110. package/lib/commonjs/index.js +87 -13
  111. package/lib/commonjs/index.js.map +1 -1
  112. package/lib/module/NativeRNLlama.js.map +1 -1
  113. package/lib/module/grammar.js +3 -0
  114. package/lib/module/grammar.js.map +1 -1
  115. package/lib/module/index.js +86 -13
  116. package/lib/module/index.js.map +1 -1
  117. package/lib/typescript/NativeRNLlama.d.ts +107 -2
  118. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  119. package/lib/typescript/grammar.d.ts.map +1 -1
  120. package/lib/typescript/index.d.ts +32 -7
  121. package/lib/typescript/index.d.ts.map +1 -1
  122. package/llama-rn.podspec +1 -1
  123. package/package.json +3 -2
  124. package/src/NativeRNLlama.ts +115 -3
  125. package/src/grammar.ts +3 -0
  126. package/src/index.ts +138 -21
  127. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  128. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  129. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  130. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  132. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -55
  134. package/cpp/rn-llama.hpp +0 -913
@@ -9,12 +9,13 @@
9
9
  #include <string>
10
10
  #include <thread>
11
11
  #include <unordered_map>
12
+ #include "json-schema-to-grammar.h"
12
13
  #include "llama.h"
14
+ #include "chat.h"
13
15
  #include "llama-impl.h"
14
16
  #include "ggml.h"
15
17
  #include "rn-llama.h"
16
18
  #include "jni-utils.h"
17
-
18
19
  #define UNUSED(x) (void)(x)
19
20
  #define TAG "RNLLAMA_ANDROID_JNI"
20
21
 
@@ -25,7 +26,7 @@ static inline int min(int a, int b) {
25
26
  return (a < b) ? a : b;
26
27
  }
27
28
 
28
- static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
29
+ static void rnllama_log_callback_default(lm_ggml_log_level level, const char * fmt, void * data) {
29
30
  if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
30
31
  else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
31
32
  else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
@@ -230,6 +231,8 @@ Java_com_rnllama_LlamaContext_initContext(
230
231
  JNIEnv *env,
231
232
  jobject thiz,
232
233
  jstring model_path_str,
234
+ jstring chat_template,
235
+ jstring reasoning_format,
233
236
  jboolean embedding,
234
237
  jint embd_normalize,
235
238
  jint n_ctx,
@@ -262,7 +265,17 @@ Java_com_rnllama_LlamaContext_initContext(
262
265
 
263
266
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
264
267
  defaultParams.model = model_path_chars;
265
-
268
+
269
+ const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
270
+ defaultParams.chat_template = chat_template_chars;
271
+
272
+ const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
273
+ if (strcmp(reasoning_format_chars, "deepseek") == 0) {
274
+ defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
275
+ } else {
276
+ defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
277
+ }
278
+
266
279
  defaultParams.n_ctx = n_ctx;
267
280
  defaultParams.n_batch = n_batch;
268
281
  defaultParams.n_ubatch = n_ubatch;
@@ -329,6 +342,8 @@ Java_com_rnllama_LlamaContext_initContext(
329
342
  bool is_model_loaded = llama->loadModel(defaultParams);
330
343
 
331
344
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
345
+ env->ReleaseStringUTFChars(chat_template, chat_template_chars);
346
+ env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
332
347
  env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
333
348
  env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);
334
349
 
@@ -418,52 +433,137 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
418
433
 
419
434
  char desc[1024];
420
435
  llama_model_desc(llama->model, desc, sizeof(desc));
436
+
421
437
  putString(env, result, "desc", desc);
422
438
  putDouble(env, result, "size", llama_model_size(llama->model));
423
439
  putDouble(env, result, "nEmbd", llama_model_n_embd(llama->model));
424
440
  putDouble(env, result, "nParams", llama_model_n_params(llama->model));
425
- putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
441
+ auto chat_templates = createWriteableMap(env);
442
+ putBoolean(env, chat_templates, "llamaChat", llama->validateModelChatTemplate(false, nullptr));
443
+
444
+ auto minja = createWriteableMap(env);
445
+ putBoolean(env, minja, "default", llama->validateModelChatTemplate(true, nullptr));
446
+
447
+ auto default_caps = createWriteableMap(env);
448
+
449
+ auto default_tmpl = llama -> templates -> template_default.get();
450
+ auto default_tmpl_caps = default_tmpl->original_caps();
451
+ putBoolean(env, default_caps, "tools", default_tmpl_caps.supports_tools);
452
+ putBoolean(env, default_caps, "toolCalls", default_tmpl_caps.supports_tool_calls);
453
+ putBoolean(env, default_caps, "parallelToolCalls", default_tmpl_caps.supports_parallel_tool_calls);
454
+ putBoolean(env, default_caps, "toolResponses", default_tmpl_caps.supports_tool_responses);
455
+ putBoolean(env, default_caps, "systemRole", default_tmpl_caps.supports_system_role);
456
+ putBoolean(env, default_caps, "toolCallId", default_tmpl_caps.supports_tool_call_id);
457
+ putMap(env, minja, "defaultCaps", default_caps);
458
+
459
+ putBoolean(env, minja, "toolUse", llama->validateModelChatTemplate(true, "tool_use"));
460
+ auto tool_use_tmpl = llama-> templates -> template_tool_use.get();
461
+ if (tool_use_tmpl != nullptr) {
462
+ auto tool_use_caps = createWriteableMap(env);
463
+ auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
464
+ putBoolean(env, tool_use_caps, "tools", tool_use_tmpl_caps.supports_tools);
465
+ putBoolean(env, tool_use_caps, "toolCalls", tool_use_tmpl_caps.supports_tool_calls);
466
+ putBoolean(env, tool_use_caps, "parallelToolCalls", tool_use_tmpl_caps.supports_parallel_tool_calls);
467
+ putBoolean(env, tool_use_caps, "systemRole", tool_use_tmpl_caps.supports_system_role);
468
+ putBoolean(env, tool_use_caps, "toolResponses", tool_use_tmpl_caps.supports_tool_responses);
469
+ putBoolean(env, tool_use_caps, "toolCallId", tool_use_tmpl_caps.supports_tool_call_id);
470
+ putMap(env, minja, "toolUseCaps", tool_use_caps);
471
+ }
472
+
473
+ putMap(env, chat_templates, "minja", minja);
426
474
  putMap(env, result, "metadata", meta);
475
+ putMap(env, result, "chatTemplates", chat_templates);
476
+
477
+ // deprecated
478
+ putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate(false, nullptr));
427
479
 
428
480
  return reinterpret_cast<jobject>(result);
429
481
  }
430
482
 
431
483
  JNIEXPORT jobject JNICALL
432
- Java_com_rnllama_LlamaContext_getFormattedChat(
484
+ Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
433
485
  JNIEnv *env,
434
486
  jobject thiz,
435
487
  jlong context_ptr,
436
- jobjectArray messages,
437
- jstring chat_template
488
+ jstring messages,
489
+ jstring chat_template,
490
+ jstring json_schema,
491
+ jstring tools,
492
+ jboolean parallel_tool_calls,
493
+ jstring tool_choice
438
494
  ) {
439
495
  UNUSED(thiz);
440
496
  auto llama = context_map[(long) context_ptr];
441
497
 
442
- std::vector<common_chat_msg> chat;
443
-
444
- int messages_len = env->GetArrayLength(messages);
445
- for (int i = 0; i < messages_len; i++) {
446
- jobject msg = env->GetObjectArrayElement(messages, i);
447
- jclass msgClass = env->GetObjectClass(msg);
448
-
449
- jmethodID getRoleMethod = env->GetMethodID(msgClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
450
- jstring roleKey = env->NewStringUTF("role");
451
- jstring contentKey = env->NewStringUTF("content");
498
+ const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
499
+ const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
500
+ const char *json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
501
+ const char *tools_chars = env->GetStringUTFChars(tools, nullptr);
502
+ const char *tool_choice_chars = env->GetStringUTFChars(tool_choice, nullptr);
452
503
 
453
- jstring role_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, roleKey);
454
- jstring content_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, contentKey);
504
+ auto result = createWriteableMap(env);
505
+ try {
506
+ auto formatted = llama->getFormattedChatWithJinja(
507
+ messages_chars,
508
+ tmpl_chars,
509
+ json_schema_chars,
510
+ tools_chars,
511
+ parallel_tool_calls,
512
+ tool_choice_chars
513
+ );
514
+ putString(env, result, "prompt", formatted.prompt.c_str());
515
+ putInt(env, result, "chat_format", static_cast<int>(formatted.format));
516
+ putString(env, result, "grammar", formatted.grammar.c_str());
517
+ putBoolean(env, result, "grammar_lazy", formatted.grammar_lazy);
518
+ auto grammar_triggers = createWritableArray(env);
519
+ for (const auto &trigger : formatted.grammar_triggers) {
520
+ auto trigger_map = createWriteableMap(env);
521
+
522
+ putString(env, trigger_map, "word", trigger.value.c_str());
523
+ putBoolean(env, trigger_map, "at_start", trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START);
524
+ pushMap(env, grammar_triggers, trigger_map);
525
+ }
526
+ putArray(env, result, "grammar_triggers", grammar_triggers);
527
+ auto preserved_tokens = createWritableArray(env);
528
+ for (const auto &token : formatted.preserved_tokens) {
529
+ pushString(env, preserved_tokens, token.c_str());
530
+ }
531
+ putArray(env, result, "preserved_tokens", preserved_tokens);
532
+ auto additional_stops = createWritableArray(env);
533
+ for (const auto &stop : formatted.additional_stops) {
534
+ pushString(env, additional_stops, stop.c_str());
535
+ }
536
+ putArray(env, result, "additional_stops", additional_stops);
537
+ } catch (const std::runtime_error &e) {
538
+ LOGI("[RNLlama] Error: %s", e.what());
539
+ putString(env, result, "_error", e.what());
540
+ }
541
+ env->ReleaseStringUTFChars(tools, tools_chars);
542
+ env->ReleaseStringUTFChars(messages, messages_chars);
543
+ env->ReleaseStringUTFChars(chat_template, tmpl_chars);
544
+ env->ReleaseStringUTFChars(json_schema, json_schema_chars);
545
+ env->ReleaseStringUTFChars(tool_choice, tool_choice_chars);
546
+ return reinterpret_cast<jobject>(result);
547
+ }
455
548
 
456
- const char *role = env->GetStringUTFChars(role_str, nullptr);
457
- const char *content = env->GetStringUTFChars(content_str, nullptr);
549
+ JNIEXPORT jobject JNICALL
550
+ Java_com_rnllama_LlamaContext_getFormattedChat(
551
+ JNIEnv *env,
552
+ jobject thiz,
553
+ jlong context_ptr,
554
+ jstring messages,
555
+ jstring chat_template
556
+ ) {
557
+ UNUSED(thiz);
558
+ auto llama = context_map[(long) context_ptr];
458
559
 
459
- chat.push_back({ role, content });
560
+ const char *messages_chars = env->GetStringUTFChars(messages, nullptr);
561
+ const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
460
562
 
461
- env->ReleaseStringUTFChars(role_str, role);
462
- env->ReleaseStringUTFChars(content_str, content);
463
- }
563
+ std::string formatted_chat = llama->getFormattedChat(messages_chars, tmpl_chars);
464
564
 
465
- const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
466
- std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true);
565
+ env->ReleaseStringUTFChars(messages, messages_chars);
566
+ env->ReleaseStringUTFChars(chat_template, tmpl_chars);
467
567
 
468
568
  return env->NewStringUTF(formatted_chat.c_str());
469
569
  }
@@ -552,7 +652,12 @@ Java_com_rnllama_LlamaContext_doCompletion(
552
652
  jobject thiz,
553
653
  jlong context_ptr,
554
654
  jstring prompt,
655
+ jint chat_format,
555
656
  jstring grammar,
657
+ jstring json_schema,
658
+ jboolean grammar_lazy,
659
+ jobject grammar_triggers,
660
+ jobject preserved_tokens,
556
661
  jfloat temperature,
557
662
  jint n_threads,
558
663
  jint n_predict,
@@ -578,6 +683,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
578
683
  jfloat dry_base,
579
684
  jint dry_allowed_length,
580
685
  jint dry_penalty_last_n,
686
+ jfloat top_n_sigma,
581
687
  jobjectArray dry_sequence_breakers,
582
688
  jobject partial_completion_callback
583
689
  ) {
@@ -588,7 +694,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
588
694
 
589
695
  //llama_reset_timings(llama->ctx);
590
696
 
591
- llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
697
+ auto prompt_chars = env->GetStringUTFChars(prompt, nullptr);
698
+ llama->params.prompt = prompt_chars;
592
699
  llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;
593
700
 
594
701
  int max_threads = std::thread::hardware_concurrency();
@@ -613,13 +720,59 @@ Java_com_rnllama_LlamaContext_doCompletion(
613
720
  sparams.min_p = min_p;
614
721
  sparams.typ_p = typical_p;
615
722
  sparams.n_probs = n_probs;
616
- sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
617
723
  sparams.xtc_threshold = xtc_threshold;
618
724
  sparams.xtc_probability = xtc_probability;
619
725
  sparams.dry_multiplier = dry_multiplier;
620
726
  sparams.dry_base = dry_base;
621
727
  sparams.dry_allowed_length = dry_allowed_length;
622
728
  sparams.dry_penalty_last_n = dry_penalty_last_n;
729
+ sparams.top_n_sigma = top_n_sigma;
730
+
731
+ // grammar
732
+ auto grammar_chars = env->GetStringUTFChars(grammar, nullptr);
733
+ if (grammar_chars && grammar_chars[0] != '\0') {
734
+ sparams.grammar = grammar_chars;
735
+ }
736
+ sparams.grammar_lazy = grammar_lazy;
737
+ if (grammar_triggers != nullptr) {
738
+ int grammar_triggers_size = readablearray::size(env, grammar_triggers);
739
+ for (int i = 0; i < grammar_triggers_size; i++) {
740
+ common_grammar_trigger trigger;
741
+ auto trigger_map = readablearray::getMap(env, grammar_triggers, i);
742
+ jstring trigger_word = readablemap::getString(env, trigger_map, "word", nullptr);
743
+ jboolean trigger_at_start = readablemap::getBool(env, trigger_map, "at_start", false);
744
+ trigger.value = env->GetStringUTFChars(trigger_word, nullptr);
745
+ trigger.type = trigger_at_start ? COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START : COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
746
+
747
+ auto ids = common_tokenize(llama->ctx, trigger.value, /* add_special= */ false, /* parse_special= */ true);
748
+ if (ids.size() == 1) {
749
+ sparams.grammar_triggers.push_back(trigger);
750
+ sparams.preserved_tokens.insert(ids[0]);
751
+ continue;
752
+ }
753
+ sparams.grammar_triggers.push_back(trigger);
754
+ }
755
+ }
756
+
757
+ auto json_schema_chars = env->GetStringUTFChars(json_schema, nullptr);
758
+ if ((!grammar_chars || grammar_chars[0] == '\0') && json_schema_chars && json_schema_chars[0] != '\0') {
759
+ auto schema = json::parse(json_schema_chars);
760
+ sparams.grammar = json_schema_to_grammar(schema);
761
+ }
762
+ env->ReleaseStringUTFChars(json_schema, json_schema_chars);
763
+
764
+ if (preserved_tokens != nullptr) {
765
+ int preserved_tokens_size = readablearray::size(env, preserved_tokens);
766
+ for (int i = 0; i < preserved_tokens_size; i++) {
767
+ jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
768
+ auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
769
+ if (ids.size() == 1) {
770
+ sparams.preserved_tokens.insert(ids[0]);
771
+ } else {
772
+ LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
773
+ }
774
+ }
775
+ }
623
776
 
624
777
  const llama_model * model = llama_get_model(llama->ctx);
625
778
  const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -744,11 +897,51 @@ Java_com_rnllama_LlamaContext_doCompletion(
744
897
  }
745
898
  }
746
899
 
900
+ env->ReleaseStringUTFChars(grammar, grammar_chars);
901
+ env->ReleaseStringUTFChars(prompt, prompt_chars);
747
902
  llama_perf_context_print(llama->ctx);
748
903
  llama->is_predicting = false;
749
904
 
905
+ auto toolCalls = createWritableArray(env);
906
+ std::string reasoningContent = "";
907
+ std::string *content = nullptr;
908
+ auto toolCallsSize = 0;
909
+ if (!llama->is_interrupted) {
910
+ try {
911
+ common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
912
+ if (!message.reasoning_content.empty()) {
913
+ reasoningContent = message.reasoning_content;
914
+ }
915
+ content = &message.content;
916
+ for (const auto &tc : message.tool_calls) {
917
+ auto toolCall = createWriteableMap(env);
918
+ putString(env, toolCall, "type", "function");
919
+ auto functionMap = createWriteableMap(env);
920
+ putString(env, functionMap, "name", tc.name.c_str());
921
+ putString(env, functionMap, "arguments", tc.arguments.c_str());
922
+ putMap(env, toolCall, "function", functionMap);
923
+ if (!tc.id.empty()) {
924
+ putString(env, toolCall, "id", tc.id.c_str());
925
+ }
926
+ pushMap(env, toolCalls, toolCall);
927
+ toolCallsSize++;
928
+ }
929
+ } catch (const std::exception &e) {
930
+ // LOGI("Error parsing tool calls: %s", e.what());
931
+ }
932
+ }
933
+
750
934
  auto result = createWriteableMap(env);
751
935
  putString(env, result, "text", llama->generated_text.c_str());
936
+ if (content) {
937
+ putString(env, result, "content", content->c_str());
938
+ }
939
+ if (!reasoningContent.empty()) {
940
+ putString(env, result, "reasoning_content", reasoningContent.c_str());
941
+ }
942
+ if (toolCallsSize > 0) {
943
+ putArray(env, result, "tool_calls", toolCalls);
944
+ }
752
945
  putArray(env, result, "completion_probabilities", tokenProbsToMap(env, llama, llama->generated_token_probs));
753
946
  putInt(env, result, "tokens_predicted", llama->num_tokens_predicted);
754
947
  putInt(env, result, "tokens_evaluated", llama->num_prompt_tokens);
@@ -977,11 +1170,76 @@ Java_com_rnllama_LlamaContext_freeContext(
977
1170
  delete llama;
978
1171
  }
979
1172
 
1173
+ struct log_callback_context {
1174
+ JavaVM *jvm;
1175
+ jobject callback;
1176
+ };
1177
+
1178
+ static void rnllama_log_callback_to_j(lm_ggml_log_level level, const char * text, void * data) {
1179
+ auto level_c = "";
1180
+ if (level == LM_GGML_LOG_LEVEL_ERROR) {
1181
+ __android_log_print(ANDROID_LOG_ERROR, TAG, text, nullptr);
1182
+ level_c = "error";
1183
+ } else if (level == LM_GGML_LOG_LEVEL_INFO) {
1184
+ __android_log_print(ANDROID_LOG_INFO, TAG, text, nullptr);
1185
+ level_c = "info";
1186
+ } else if (level == LM_GGML_LOG_LEVEL_WARN) {
1187
+ __android_log_print(ANDROID_LOG_WARN, TAG, text, nullptr);
1188
+ level_c = "warn";
1189
+ } else {
1190
+ __android_log_print(ANDROID_LOG_DEFAULT, TAG, text, nullptr);
1191
+ }
1192
+
1193
+ log_callback_context *cb_ctx = (log_callback_context *) data;
1194
+
1195
+ JNIEnv *env;
1196
+ bool need_detach = false;
1197
+ int getEnvResult = cb_ctx->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
1198
+
1199
+ if (getEnvResult == JNI_EDETACHED) {
1200
+ if (cb_ctx->jvm->AttachCurrentThread(&env, nullptr) == JNI_OK) {
1201
+ need_detach = true;
1202
+ } else {
1203
+ return;
1204
+ }
1205
+ } else if (getEnvResult != JNI_OK) {
1206
+ return;
1207
+ }
1208
+
1209
+ jobject callback = cb_ctx->callback;
1210
+ jclass cb_class = env->GetObjectClass(callback);
1211
+ jmethodID emitNativeLog = env->GetMethodID(cb_class, "emitNativeLog", "(Ljava/lang/String;Ljava/lang/String;)V");
1212
+
1213
+ jstring level_str = env->NewStringUTF(level_c);
1214
+ jstring text_str = env->NewStringUTF(text);
1215
+ env->CallVoidMethod(callback, emitNativeLog, level_str, text_str);
1216
+ env->DeleteLocalRef(level_str);
1217
+ env->DeleteLocalRef(text_str);
1218
+
1219
+ if (need_detach) {
1220
+ cb_ctx->jvm->DetachCurrentThread();
1221
+ }
1222
+ }
1223
+
1224
+ JNIEXPORT void JNICALL
1225
+ Java_com_rnllama_LlamaContext_setupLog(JNIEnv *env, jobject thiz, jobject logCallback) {
1226
+ UNUSED(thiz);
1227
+
1228
+ log_callback_context *cb_ctx = new log_callback_context;
1229
+
1230
+ JavaVM *jvm;
1231
+ env->GetJavaVM(&jvm);
1232
+ cb_ctx->jvm = jvm;
1233
+ cb_ctx->callback = env->NewGlobalRef(logCallback);
1234
+
1235
+ llama_log_set(rnllama_log_callback_to_j, cb_ctx);
1236
+ }
1237
+
980
1238
  JNIEXPORT void JNICALL
981
- Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
1239
+ Java_com_rnllama_LlamaContext_unsetLog(JNIEnv *env, jobject thiz) {
982
1240
  UNUSED(env);
983
1241
  UNUSED(thiz);
984
- llama_log_set(log_callback, NULL);
1242
+ llama_log_set(rnllama_log_callback_default, NULL);
985
1243
  }
986
1244
 
987
1245
  } // extern "C"
@@ -33,6 +33,11 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
33
33
  return NAME;
34
34
  }
35
35
 
36
+ @ReactMethod
37
+ public void toggleNativeLog(boolean enabled, Promise promise) {
38
+ rnllama.toggleNativeLog(enabled, promise);
39
+ }
40
+
36
41
  @ReactMethod
37
42
  public void setContextLimit(double limit, Promise promise) {
38
43
  rnllama.setContextLimit(limit, promise);
@@ -49,8 +54,8 @@ public class RNLlamaModule extends NativeRNLlamaSpec {
49
54
  }
50
55
 
51
56
  @ReactMethod
52
- public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
53
- rnllama.getFormattedChat(id, messages, chatTemplate, promise);
57
+ public void getFormattedChat(double id, String messages, String chatTemplate, ReadableMap params, Promise promise) {
58
+ rnllama.getFormattedChat(id, messages, chatTemplate, params, promise);
54
59
  }
55
60
 
56
61
  @ReactMethod
@@ -34,6 +34,11 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
34
34
  return NAME;
35
35
  }
36
36
 
37
+ @ReactMethod
38
+ public void toggleNativeLog(boolean enabled, Promise promise) {
39
+ rnllama.toggleNativeLog(enabled, promise);
40
+ }
41
+
37
42
  @ReactMethod
38
43
  public void setContextLimit(double limit, Promise promise) {
39
44
  rnllama.setContextLimit(limit, promise);
@@ -50,8 +55,8 @@ public class RNLlamaModule extends ReactContextBaseJavaModule {
50
55
  }
51
56
 
52
57
  @ReactMethod
53
- public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
54
- rnllama.getFormattedChat(id, messages, chatTemplate, promise);
58
+ public void getFormattedChat(double id, String messages, String chatTemplate, ReadableMap params, Promise promise) {
59
+ rnllama.getFormattedChat(id, messages, chatTemplate, params, promise);
55
60
  }
56
61
 
57
62
  @ReactMethod