cui-llama.rn 1.4.3 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +93 -114
  2. package/android/src/main/CMakeLists.txt +5 -0
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +91 -17
  4. package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
  5. package/android/src/main/jni-utils.h +6 -0
  6. package/android/src/main/jni.cpp +289 -31
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
  16. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
  17. package/cpp/chat-template.hpp +529 -0
  18. package/cpp/chat.cpp +1779 -0
  19. package/cpp/chat.h +135 -0
  20. package/cpp/common.cpp +2064 -1873
  21. package/cpp/common.h +700 -699
  22. package/cpp/ggml-alloc.c +1039 -1042
  23. package/cpp/ggml-alloc.h +1 -1
  24. package/cpp/ggml-backend-impl.h +255 -255
  25. package/cpp/ggml-backend-reg.cpp +586 -582
  26. package/cpp/ggml-backend.cpp +2004 -2002
  27. package/cpp/ggml-backend.h +354 -354
  28. package/cpp/ggml-common.h +1851 -1853
  29. package/cpp/ggml-cpp.h +39 -39
  30. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  31. package/cpp/ggml-cpu-aarch64.h +8 -8
  32. package/cpp/ggml-cpu-impl.h +531 -386
  33. package/cpp/ggml-cpu-quants.c +12527 -10920
  34. package/cpp/ggml-cpu-traits.cpp +36 -36
  35. package/cpp/ggml-cpu-traits.h +38 -38
  36. package/cpp/ggml-cpu.c +15766 -14391
  37. package/cpp/ggml-cpu.cpp +655 -635
  38. package/cpp/ggml-cpu.h +138 -135
  39. package/cpp/ggml-impl.h +567 -567
  40. package/cpp/ggml-metal-impl.h +235 -0
  41. package/cpp/ggml-metal.h +1 -1
  42. package/cpp/ggml-metal.m +5146 -4884
  43. package/cpp/ggml-opt.cpp +854 -854
  44. package/cpp/ggml-opt.h +216 -216
  45. package/cpp/ggml-quants.c +5238 -5238
  46. package/cpp/ggml-threading.h +14 -14
  47. package/cpp/ggml.c +6529 -6514
  48. package/cpp/ggml.h +2198 -2194
  49. package/cpp/gguf.cpp +1329 -1329
  50. package/cpp/gguf.h +202 -202
  51. package/cpp/json-schema-to-grammar.cpp +1024 -1045
  52. package/cpp/json-schema-to-grammar.h +21 -8
  53. package/cpp/json.hpp +24766 -24766
  54. package/cpp/llama-adapter.cpp +347 -347
  55. package/cpp/llama-adapter.h +74 -74
  56. package/cpp/llama-arch.cpp +1513 -1487
  57. package/cpp/llama-arch.h +403 -400
  58. package/cpp/llama-batch.cpp +368 -368
  59. package/cpp/llama-batch.h +88 -88
  60. package/cpp/llama-chat.cpp +588 -578
  61. package/cpp/llama-chat.h +53 -52
  62. package/cpp/llama-context.cpp +1775 -1775
  63. package/cpp/llama-context.h +128 -128
  64. package/cpp/llama-cparams.cpp +1 -1
  65. package/cpp/llama-cparams.h +37 -37
  66. package/cpp/llama-cpp.h +30 -30
  67. package/cpp/llama-grammar.cpp +1219 -1139
  68. package/cpp/llama-grammar.h +173 -143
  69. package/cpp/llama-hparams.cpp +71 -71
  70. package/cpp/llama-hparams.h +139 -139
  71. package/cpp/llama-impl.cpp +167 -167
  72. package/cpp/llama-impl.h +61 -61
  73. package/cpp/llama-kv-cache.cpp +718 -718
  74. package/cpp/llama-kv-cache.h +219 -218
  75. package/cpp/llama-mmap.cpp +600 -590
  76. package/cpp/llama-mmap.h +68 -67
  77. package/cpp/llama-model-loader.cpp +1124 -1124
  78. package/cpp/llama-model-loader.h +167 -167
  79. package/cpp/llama-model.cpp +4087 -3997
  80. package/cpp/llama-model.h +370 -370
  81. package/cpp/llama-sampling.cpp +2558 -2408
  82. package/cpp/llama-sampling.h +32 -32
  83. package/cpp/llama-vocab.cpp +3264 -3247
  84. package/cpp/llama-vocab.h +125 -125
  85. package/cpp/llama.cpp +10284 -10077
  86. package/cpp/llama.h +1354 -1323
  87. package/cpp/log.cpp +393 -401
  88. package/cpp/log.h +132 -121
  89. package/cpp/minja/chat-template.hpp +529 -0
  90. package/cpp/minja/minja.hpp +2915 -0
  91. package/cpp/minja.hpp +2915 -0
  92. package/cpp/rn-llama.cpp +66 -6
  93. package/cpp/rn-llama.h +26 -1
  94. package/cpp/sampling.cpp +570 -505
  95. package/cpp/sampling.h +3 -0
  96. package/cpp/sgemm.cpp +2598 -2597
  97. package/cpp/sgemm.h +14 -14
  98. package/cpp/speculative.cpp +278 -277
  99. package/cpp/speculative.h +28 -28
  100. package/cpp/unicode.cpp +9 -2
  101. package/ios/CMakeLists.txt +6 -0
  102. package/ios/RNLlama.h +0 -8
  103. package/ios/RNLlama.mm +27 -3
  104. package/ios/RNLlamaContext.h +10 -1
  105. package/ios/RNLlamaContext.mm +269 -57
  106. package/jest/mock.js +21 -2
  107. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  108. package/lib/commonjs/grammar.js +3 -0
  109. package/lib/commonjs/grammar.js.map +1 -1
  110. package/lib/commonjs/index.js +87 -13
  111. package/lib/commonjs/index.js.map +1 -1
  112. package/lib/module/NativeRNLlama.js.map +1 -1
  113. package/lib/module/grammar.js +3 -0
  114. package/lib/module/grammar.js.map +1 -1
  115. package/lib/module/index.js +86 -13
  116. package/lib/module/index.js.map +1 -1
  117. package/lib/typescript/NativeRNLlama.d.ts +107 -2
  118. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  119. package/lib/typescript/grammar.d.ts.map +1 -1
  120. package/lib/typescript/index.d.ts +32 -7
  121. package/lib/typescript/index.d.ts.map +1 -1
  122. package/llama-rn.podspec +1 -1
  123. package/package.json +3 -2
  124. package/src/NativeRNLlama.ts +115 -3
  125. package/src/grammar.ts +3 -0
  126. package/src/index.ts +138 -21
  127. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  128. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  129. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  130. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  132. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -55
  134. package/cpp/rn-llama.hpp +0 -913
package/cpp/unicode.cpp CHANGED
@@ -618,7 +618,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
618
618
  result.reserve(utf8.size());
619
619
  size_t offset = 0;
620
620
  while (offset < utf8.size()) {
621
- result.push_back(unicode_cpt_from_utf8(utf8, offset));
621
+ try {
622
+ result.push_back(unicode_cpt_from_utf8(utf8, offset));
623
+ }
624
+ catch (const std::invalid_argument & /*ex*/) {
625
+ // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
626
+ ++offset;
627
+ result.emplace_back(0xFFFD); // replacement character
628
+ }
622
629
  }
623
630
  return result;
624
631
  }
@@ -701,7 +708,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
701
708
  const auto cpts = unicode_cpts_from_utf8(text);
702
709
 
703
710
  // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
704
- // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
711
+ // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
705
712
  std::string text_collapsed;
706
713
  if (need_collapse) {
707
714
  // collapse all unicode categories
@@ -15,6 +15,7 @@ add_definitions(
15
15
  -DLM_GGML_USE_CPU
16
16
  -DLM_GGML_USE_ACCELERATE
17
17
  -DLM_GGML_USE_METAL
18
+ -DLM_GGML_METAL_USE_BF16
18
19
  )
19
20
 
20
21
  set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../cpp)
@@ -66,6 +67,11 @@ add_library(rnllama SHARED
66
67
  ${SOURCE_DIR}/unicode.cpp
67
68
  ${SOURCE_DIR}/sgemm.cpp
68
69
  ${SOURCE_DIR}/common.cpp
70
+ ${SOURCE_DIR}/chat.cpp
71
+ ${SOURCE_DIR}/chat-template.hpp
72
+ ${SOURCE_DIR}/json-schema-to-grammar.cpp
73
+ ${SOURCE_DIR}/minja.hpp
74
+ ${SOURCE_DIR}/json.hpp
69
75
  ${SOURCE_DIR}/amx/amx.cpp
70
76
  ${SOURCE_DIR}/amx/mmq.cpp
71
77
  ${SOURCE_DIR}/rn-llama.cpp
package/ios/RNLlama.h CHANGED
@@ -1,11 +1,3 @@
1
- #ifdef __cplusplus
2
- #if RNLLAMA_BUILD_FROM_SOURCE
3
- #import "rn-llama.h"
4
- #else
5
- #import <rnllama/rn-llama.h>
6
- #endif
7
- #endif
8
-
9
1
  #import <React/RCTEventEmitter.h>
10
2
  #import <React/RCTBridgeModule.h>
11
3
 
package/ios/RNLlama.mm CHANGED
@@ -13,6 +13,16 @@ dispatch_queue_t llamaDQueue;
13
13
 
14
14
  RCT_EXPORT_MODULE()
15
15
 
16
+ RCT_EXPORT_METHOD(toggleNativeLog:(BOOL)enabled) {
17
+ void (^onEmitLog)(NSString *level, NSString *text) = nil;
18
+ if (enabled) {
19
+ onEmitLog = ^(NSString *level, NSString *text) {
20
+ [self sendEventWithName:@"@RNLlama_onNativeLog" body:@{ @"level": level, @"text": text }];
21
+ };
22
+ }
23
+ [RNLlamaContext toggleNativeLog:enabled onEmitLog:onEmitLog];
24
+ }
25
+
16
26
  RCT_EXPORT_METHOD(setContextLimit:(double)limit
17
27
  withResolver:(RCTPromiseResolveBlock)resolve
18
28
  withRejecter:(RCTPromiseRejectBlock)reject)
@@ -41,7 +51,7 @@ RCT_EXPORT_METHOD(initContext:(double)contextId
41
51
  }
42
52
 
43
53
  if (llamaDQueue == nil) {
44
- llamaDQueue = dispatch_queue_create("com.rnllama", DISPATCH_QUEUE_SERIAL);
54
+ llamaDQueue = dispatch_queue_create("com.rnllama", DISPATCH_QUEUE_SERIAL);
45
55
  }
46
56
 
47
57
  if (llamaContexts == nil) {
@@ -77,8 +87,9 @@ RCT_EXPORT_METHOD(initContext:(double)contextId
77
87
  }
78
88
 
79
89
  RCT_EXPORT_METHOD(getFormattedChat:(double)contextId
80
- withMessages:(NSArray *)messages
90
+ withMessages:(NSString *)messages
81
91
  withTemplate:(NSString *)chatTemplate
92
+ withParams:(NSDictionary *)params
82
93
  withResolver:(RCTPromiseResolveBlock)resolve
83
94
  withRejecter:(RCTPromiseRejectBlock)reject)
84
95
  {
@@ -87,7 +98,19 @@ RCT_EXPORT_METHOD(getFormattedChat:(double)contextId
87
98
  reject(@"llama_error", @"Context not found", nil);
88
99
  return;
89
100
  }
90
- resolve([context getFormattedChat:messages withTemplate:chatTemplate]);
101
+ try {
102
+ if ([params[@"jinja"] boolValue]) {
103
+ NSString *jsonSchema = params[@"json_schema"];
104
+ NSString *tools = params[@"tools"];
105
+ NSString *parallelToolCalls = params[@"parallel_tool_calls"];
106
+ NSString *toolChoice = params[@"tool_choice"];\
107
+ resolve([context getFormattedChatWithJinja:messages withChatTemplate:chatTemplate withJsonSchema:jsonSchema withTools:tools withParallelToolCalls:parallelToolCalls withToolChoice:toolChoice]);
108
+ } else {
109
+ resolve([context getFormattedChat:messages withChatTemplate:chatTemplate]);
110
+ }
111
+ } catch (const std::exception& e) { // catch cpp exceptions
112
+ reject(@"llama_error", [NSString stringWithUTF8String:e.what()], nil);
113
+ }
91
114
  }
92
115
 
93
116
  RCT_EXPORT_METHOD(loadSession:(double)contextId
@@ -146,6 +169,7 @@ RCT_EXPORT_METHOD(saveSession:(double)contextId
146
169
  return@[
147
170
  @"@RNLlama_onInitContextProgress",
148
171
  @"@RNLlama_onToken",
172
+ @"@RNLlama_onNativeLog",
149
173
  ];
150
174
  }
151
175
 
@@ -4,11 +4,13 @@
4
4
  #import "llama-impl.h"
5
5
  #import "ggml.h"
6
6
  #import "rn-llama.h"
7
+ #import "json-schema-to-grammar.h"
7
8
  #else
8
9
  #import <rnllama/llama.h>
9
10
  #import <rnllama/llama-impl.h>
10
11
  #import <rnllama/ggml.h>
11
12
  #import <rnllama/rn-llama.h>
13
+ #import <rnllama/json-schema-to-grammar.h>
12
14
  #endif
13
15
  #endif
14
16
 
@@ -23,6 +25,7 @@
23
25
  rnllama::llama_rn_context * llama;
24
26
  }
25
27
 
28
+ + (void)toggleNativeLog:(BOOL)enabled onEmitLog:(void (^)(NSString *level, NSString *text))onEmitLog;
26
29
  + (NSDictionary *)modelInfo:(NSString *)path skip:(NSArray *)skip;
27
30
  + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsigned int progress))onProgress;
28
31
  - (void)interruptLoad;
@@ -36,7 +39,13 @@
36
39
  - (NSArray *)tokenize:(NSString *)text;
37
40
  - (NSString *)detokenize:(NSArray *)tokens;
38
41
  - (NSDictionary *)embedding:(NSString *)text params:(NSDictionary *)params;
39
- - (NSString *)getFormattedChat:(NSArray *)messages withTemplate:(NSString *)chatTemplate;
42
+ - (NSDictionary *)getFormattedChatWithJinja:(NSString *)messages
43
+ withChatTemplate:(NSString *)chatTemplate
44
+ withJsonSchema:(NSString *)jsonSchema
45
+ withTools:(NSString *)tools
46
+ withParallelToolCalls:(BOOL)parallelToolCalls
47
+ withToolChoice:(NSString *)toolChoice;
48
+ - (NSString *)getFormattedChat:(NSString *)messages withChatTemplate:(NSString *)chatTemplate;
40
49
  - (NSDictionary *)loadSession:(NSString *)path;
41
50
  - (int)saveSession:(NSString *)path size:(int)size;
42
51
  - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr;
@@ -3,6 +3,33 @@
3
3
 
4
4
  @implementation RNLlamaContext
5
5
 
6
+ + (void)toggleNativeLog:(BOOL)enabled onEmitLog:(void (^)(NSString *level, NSString *text))onEmitLog {
7
+ if (enabled) {
8
+ void (^copiedBlock)(NSString *, NSString *) = [onEmitLog copy];
9
+ llama_log_set([](lm_ggml_log_level level, const char * text, void * data) {
10
+ llama_log_callback_default(level, text, data);
11
+ NSString *levelStr = @"";
12
+ if (level == LM_GGML_LOG_LEVEL_ERROR) {
13
+ levelStr = @"error";
14
+ } else if (level == LM_GGML_LOG_LEVEL_INFO) {
15
+ levelStr = @"info";
16
+ } else if (level == LM_GGML_LOG_LEVEL_WARN) {
17
+ levelStr = @"warn";
18
+ }
19
+
20
+ NSString *textStr = [NSString stringWithUTF8String:text];
21
+ // NOTE: Convert to UTF-8 string may fail
22
+ if (!textStr) {
23
+ return;
24
+ }
25
+ void (^block)(NSString *, NSString *) = (__bridge void (^)(NSString *, NSString *))(data);
26
+ block(levelStr, textStr);
27
+ }, copiedBlock);
28
+ } else {
29
+ llama_log_set(llama_log_callback_default, nullptr);
30
+ }
31
+ }
32
+
6
33
  + (NSDictionary *)modelInfo:(NSString *)path skip:(NSArray *)skip {
7
34
  struct lm_gguf_init_params params = {
8
35
  /*.no_alloc = */ false,
@@ -57,42 +84,83 @@
57
84
  if (isAsset) path = [[NSBundle mainBundle] pathForResource:modelPath ofType:nil];
58
85
  defaultParams.model = [path UTF8String];
59
86
 
87
+ NSString *chatTemplate = params[@"chat_template"];
88
+ if (chatTemplate) {
89
+ defaultParams.chat_template = [chatTemplate UTF8String];
90
+ NSLog(@"chatTemplate: %@", chatTemplate);
91
+ }
92
+
93
+ NSString *reasoningFormat = params[@"reasoning_format"];
94
+ if (reasoningFormat && [reasoningFormat isEqualToString:@"deepseek"]) {
95
+ defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
96
+ } else {
97
+ defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
98
+ }
99
+
60
100
  if (params[@"n_ctx"]) defaultParams.n_ctx = [params[@"n_ctx"] intValue];
61
101
  if (params[@"use_mlock"]) defaultParams.use_mlock = [params[@"use_mlock"]boolValue];
62
102
 
103
+ BOOL skipGpuDevices = params[@"no_gpu_devices"] && [params[@"no_gpu_devices"] boolValue];
104
+
63
105
  BOOL isMetalEnabled = false;
64
106
  NSString *reasonNoMetal = @"";
65
107
  defaultParams.n_gpu_layers = 0;
66
- if (params[@"n_gpu_layers"] && [params[@"n_gpu_layers"] intValue] > 0) {
67
108
  #ifdef LM_GGML_USE_METAL
68
- // Check ggml-metal availability
69
- NSError * error = nil;
70
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
71
- id<MTLLibrary> library = [device
72
- newLibraryWithSource:@"#include <metal_stdlib>\n"
73
- "using namespace metal;"
74
- "kernel void test() { simd_sum(0); }"
75
- options:nil
76
- error:&error
77
- ];
78
- if (error) {
109
+ // Check ggml-metal availability
110
+ NSError * error = nil;
111
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
112
+ id<MTLLibrary> library = [device
113
+ newLibraryWithSource:@"#include <metal_stdlib>\n"
114
+ "using namespace metal;"
115
+ "typedef matrix<bfloat, 4, 4> bfloat4x4;"
116
+ "kernel void test() { simd_sum(0); }"
117
+ options:nil
118
+ error:&error
119
+ ];
120
+ if (error) {
121
+ reasonNoMetal = [error localizedDescription];
122
+ skipGpuDevices = true;
123
+ } else {
124
+ id<MTLFunction> kernel = [library newFunctionWithName:@"test"];
125
+ id<MTLComputePipelineState> pipeline = [device newComputePipelineStateWithFunction:kernel error:&error];
126
+ if (pipeline == nil) {
79
127
  reasonNoMetal = [error localizedDescription];
128
+ skipGpuDevices = true;
80
129
  } else {
81
- id<MTLFunction> kernel = [library newFunctionWithName:@"test"];
82
- id<MTLComputePipelineState> pipeline = [device newComputePipelineStateWithFunction:kernel error:&error];
83
- if (pipeline == nil) {
84
- reasonNoMetal = [error localizedDescription];
85
- } else {
86
- defaultParams.n_gpu_layers = [params[@"n_gpu_layers"] intValue];
87
- isMetalEnabled = true;
88
- }
130
+ #if TARGET_OS_SIMULATOR
131
+ // Use the backend, but no layers because not supported fully on simulator
132
+ defaultParams.n_gpu_layers = 0;
133
+ isMetalEnabled = true;
134
+ #else
135
+ defaultParams.n_gpu_layers = [params[@"n_gpu_layers"] intValue];
136
+ isMetalEnabled = true;
137
+ #endif
89
138
  }
90
- device = nil;
139
+ }
140
+ device = nil;
91
141
  #else
92
- reasonNoMetal = @"Metal is not enabled in this build";
93
- isMetalEnabled = false;
142
+ reasonNoMetal = @"Metal is not enabled in this build";
143
+ isMetalEnabled = false;
94
144
  #endif
145
+
146
+ if (skipGpuDevices) {
147
+ std::vector<lm_ggml_backend_dev_t> cpu_devs;
148
+ for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) {
149
+ lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i);
150
+ switch (lm_ggml_backend_dev_type(dev)) {
151
+ case LM_GGML_BACKEND_DEVICE_TYPE_CPU:
152
+ case LM_GGML_BACKEND_DEVICE_TYPE_ACCEL:
153
+ cpu_devs.push_back(dev);
154
+ break;
155
+ case LM_GGML_BACKEND_DEVICE_TYPE_GPU:
156
+ break;
157
+ }
158
+ }
159
+ if (cpu_devs.size() > 0) {
160
+ defaultParams.devices = cpu_devs;
161
+ }
95
162
  }
163
+
96
164
  if (params[@"n_batch"]) defaultParams.n_batch = [params[@"n_batch"] intValue];
97
165
  if (params[@"n_ubatch"]) defaultParams.n_ubatch = [params[@"n_ubatch"] intValue];
98
166
  if (params[@"use_mmap"]) defaultParams.use_mmap = [params[@"use_mmap"] boolValue];
@@ -125,7 +193,6 @@
125
193
  const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads);
126
194
  defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads;
127
195
 
128
-
129
196
  RNLlamaContext *context = [[RNLlamaContext alloc] init];
130
197
  context->llama = new rnllama::llama_rn_context();
131
198
  context->llama->is_load_interrupted = false;
@@ -218,13 +285,48 @@
218
285
  [meta setValue:valStr forKey:keyStr];
219
286
  }
220
287
 
288
+ auto template_tool_use = llama->templates.template_tool_use.get();
289
+ NSDictionary *tool_use_caps_dir = nil;
290
+ if (template_tool_use) {
291
+ auto tool_use_caps = template_tool_use->original_caps();
292
+ tool_use_caps_dir = @{
293
+ @"tools": @(tool_use_caps.supports_tools),
294
+ @"toolCalls": @(tool_use_caps.supports_tool_calls),
295
+ @"toolResponses": @(tool_use_caps.supports_tool_responses),
296
+ @"systemRole": @(tool_use_caps.supports_system_role),
297
+ @"parallelToolCalls": @(tool_use_caps.supports_parallel_tool_calls),
298
+ @"toolCallId": @(tool_use_caps.supports_tool_call_id)
299
+ };
300
+ }
301
+
302
+ auto default_tmpl = llama->templates.template_default.get();
303
+ auto default_tmpl_caps = default_tmpl->original_caps();
304
+
221
305
  return @{
222
306
  @"desc": [NSString stringWithUTF8String:desc],
223
307
  @"size": @(llama_model_size(llama->model)),
224
308
  @"nEmbd": @(llama_model_n_embd(llama->model)),
225
309
  @"nParams": @(llama_model_n_params(llama->model)),
226
- @"isChatTemplateSupported": @(llama->validateModelChatTemplate()),
227
- @"metadata": meta
310
+ @"chatTemplates": @{
311
+ @"llamaChat": @(llama->validateModelChatTemplate(false, nullptr)),
312
+ @"minja": @{
313
+ @"default": @(llama->validateModelChatTemplate(true, nullptr)),
314
+ @"defaultCaps": @{
315
+ @"tools": @(default_tmpl_caps.supports_tools),
316
+ @"toolCalls": @(default_tmpl_caps.supports_tool_calls),
317
+ @"toolResponses": @(default_tmpl_caps.supports_tool_responses),
318
+ @"systemRole": @(default_tmpl_caps.supports_system_role),
319
+ @"parallelToolCalls": @(default_tmpl_caps.supports_parallel_tool_calls),
320
+ @"toolCallId": @(default_tmpl_caps.supports_tool_call_id)
321
+ },
322
+ @"toolUse": @(llama->validateModelChatTemplate(true, "tool_use")),
323
+ @"toolUseCaps": tool_use_caps_dir ?: @{}
324
+ }
325
+ },
326
+ @"metadata": meta,
327
+
328
+ // deprecated
329
+ @"isChatTemplateSupported": @(llama->validateModelChatTemplate(false, nullptr))
228
330
  };
229
331
  }
230
332
 
@@ -236,18 +338,56 @@
236
338
  return llama->is_predicting;
237
339
  }
238
340
 
239
- - (NSString *)getFormattedChat:(NSArray *)messages withTemplate:(NSString *)chatTemplate {
240
- std::vector<common_chat_msg> chat;
341
+ - (NSDictionary *)getFormattedChatWithJinja:(NSString *)messages
342
+ withChatTemplate:(NSString *)chatTemplate
343
+ withJsonSchema:(NSString *)jsonSchema
344
+ withTools:(NSString *)tools
345
+ withParallelToolCalls:(BOOL)parallelToolCalls
346
+ withToolChoice:(NSString *)toolChoice
347
+ {
348
+ auto tmpl_str = chatTemplate == nil ? "" : [chatTemplate UTF8String];
349
+
350
+ NSMutableDictionary *result = [[NSMutableDictionary alloc] init];
351
+ auto chatParams = llama->getFormattedChatWithJinja(
352
+ [messages UTF8String],
353
+ tmpl_str,
354
+ jsonSchema == nil ? "" : [jsonSchema UTF8String],
355
+ tools == nil ? "" : [tools UTF8String],
356
+ parallelToolCalls,
357
+ toolChoice == nil ? "" : [toolChoice UTF8String]
358
+ );
359
+ result[@"prompt"] = [NSString stringWithUTF8String:chatParams.prompt.get<std::string>().c_str()];
360
+ result[@"chat_format"] = @(static_cast<int>(chatParams.format));
361
+ result[@"grammar"] = [NSString stringWithUTF8String:chatParams.grammar.c_str()];
362
+ result[@"grammar_lazy"] = @(chatParams.grammar_lazy);
363
+ NSMutableArray *grammar_triggers = [[NSMutableArray alloc] init];
364
+ for (const auto & trigger : chatParams.grammar_triggers) {
365
+ [grammar_triggers addObject:@{
366
+ @"word": [NSString stringWithUTF8String:trigger.word.c_str()],
367
+ @"at_start": @(trigger.at_start),
368
+ }];
369
+ }
370
+ result[@"grammar_triggers"] = grammar_triggers;
371
+ NSMutableArray *preserved_tokens = [[NSMutableArray alloc] init];
372
+ for (const auto & token : chatParams.preserved_tokens) {
373
+ [preserved_tokens addObject:[NSString stringWithUTF8String:token.c_str()]];
374
+ }
375
+ result[@"preserved_tokens"] = preserved_tokens;
376
+ NSMutableArray *additional_stops = [[NSMutableArray alloc] init];
377
+ for (const auto & stop : chatParams.additional_stops) {
378
+ [additional_stops addObject:[NSString stringWithUTF8String:stop.c_str()]];
379
+ }
380
+ result[@"additional_stops"] = additional_stops;
241
381
 
242
- for (NSDictionary *msg in messages) {
243
- std::string role = [[msg objectForKey:@"role"] UTF8String];
244
- std::string content = [[msg objectForKey:@"content"] UTF8String];
245
- chat.push_back({ role, content });
246
- }
382
+ return result;
383
+ }
247
384
 
248
- auto tmpl = chatTemplate == nil ? "" : [chatTemplate UTF8String];
249
- auto formatted_chat = common_chat_apply_template(llama->model, tmpl, chat, true);
250
- return [NSString stringWithUTF8String:formatted_chat.c_str()];
385
+ - (NSString *)getFormattedChat:(NSString *)messages withChatTemplate:(NSString *)chatTemplate {
386
+ auto tmpl_str = chatTemplate == nil ? "" : [chatTemplate UTF8String];
387
+ return [NSString stringWithUTF8String:llama->getFormattedChat(
388
+ [messages UTF8String],
389
+ tmpl_str
390
+ ).c_str()];;
251
391
  }
252
392
 
253
393
  - (NSArray *)tokenProbsToDict:(std::vector<rnllama::completion_token_output>)probs {
@@ -321,6 +461,8 @@
321
461
  if (params[@"dry_allowed_length"]) sparams.dry_allowed_length = [params[@"dry_allowed_length"] intValue];
322
462
  if (params[@"dry_penalty_last_n"]) sparams.dry_penalty_last_n = [params[@"dry_penalty_last_n"] intValue];
323
463
 
464
+ if (params[@"top_n_sigma"]) sparams.top_n_sigma = [params[@"top_n_sigma"] doubleValue];
465
+
324
466
  // dry break seq
325
467
  if (params[@"dry_sequence_breakers"] && [params[@"dry_sequence_breakers"] isKindOfClass:[NSArray class]]) {
326
468
  NSArray *dry_sequence_breakers = params[@"dry_sequence_breakers"];
@@ -333,6 +475,45 @@
333
475
  sparams.grammar = [params[@"grammar"] UTF8String];
334
476
  }
335
477
 
478
+ if (params[@"json_schema"] && !params[@"grammar"]) {
479
+ sparams.grammar = json_schema_to_grammar(json::parse([params[@"json_schema"] UTF8String]));
480
+ }
481
+
482
+ if (params[@"grammar_lazy"]) {
483
+ sparams.grammar_lazy = [params[@"grammar_lazy"] boolValue];
484
+ }
485
+
486
+ if (params[@"grammar_triggers"] && [params[@"grammar_triggers"] isKindOfClass:[NSArray class]]) {
487
+ NSArray *grammar_triggers = params[@"grammar_triggers"];
488
+ for (NSDictionary *grammar_trigger in grammar_triggers) {
489
+ common_grammar_trigger trigger;
490
+ trigger.word = [grammar_trigger[@"word"] UTF8String];
491
+ trigger.at_start = [grammar_trigger[@"at_start"] boolValue];
492
+
493
+ auto ids = common_tokenize(llama->ctx, trigger.word, /* add_special= */ false, /* parse_special= */ true);
494
+ if (ids.size() == 1) {
495
+ // LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
496
+ sparams.grammar_trigger_tokens.push_back(ids[0]);
497
+ sparams.preserved_tokens.insert(ids[0]);
498
+ continue;
499
+ }
500
+ // LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
501
+ sparams.grammar_trigger_words.push_back(trigger);
502
+ }
503
+ }
504
+
505
+ if (params[@"preserved_tokens"] && [params[@"preserved_tokens"] isKindOfClass:[NSArray class]]) {
506
+ NSArray *preserved_tokens = params[@"preserved_tokens"];
507
+ for (NSString *token in preserved_tokens) {
508
+ auto ids = common_tokenize(llama->ctx, [token UTF8String], /* add_special= */ false, /* parse_special= */ true);
509
+ if (ids.size() == 1) {
510
+ sparams.preserved_tokens.insert(ids[0]);
511
+ } else {
512
+ // LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", [token UTF8String]);
513
+ }
514
+ }
515
+ }
516
+
336
517
  llama->params.antiprompt.clear();
337
518
  if (params[@"stop"]) {
338
519
  NSArray *stop = params[@"stop"];
@@ -434,29 +615,60 @@
434
615
  llama->is_predicting = false;
435
616
 
436
617
  const auto timings = llama_perf_context(llama->ctx);
437
- return @{
438
- @"text": [NSString stringWithUTF8String:llama->generated_text.c_str()],
439
- @"completion_probabilities": [self tokenProbsToDict:llama->generated_token_probs],
440
- @"tokens_predicted": @(llama->num_tokens_predicted),
441
- @"tokens_evaluated": @(llama->num_prompt_tokens),
442
- @"truncated": @(llama->truncated),
443
- @"stopped_eos": @(llama->stopped_eos),
444
- @"stopped_word": @(llama->stopped_word),
445
- @"stopped_limit": @(llama->stopped_limit),
446
- @"stopping_word": [NSString stringWithUTF8String:llama->stopping_word.c_str()],
447
- @"tokens_cached": @(llama->n_past),
448
- @"timings": @{
449
- @"prompt_n": @(timings.n_p_eval),
450
- @"prompt_ms": @(timings.t_p_eval_ms),
451
- @"prompt_per_token_ms": @(timings.t_p_eval_ms / timings.n_p_eval),
452
- @"prompt_per_second": @(1e3 / timings.t_p_eval_ms * timings.n_p_eval),
453
-
454
- @"predicted_n": @(timings.n_eval),
455
- @"predicted_ms": @(timings.t_eval_ms),
456
- @"predicted_per_token_ms": @(timings.t_eval_ms / timings.n_eval),
457
- @"predicted_per_second": @(1e3 / timings.t_eval_ms * timings.n_eval),
618
+
619
+ NSMutableArray *toolCalls = nil;
620
+ NSString *reasoningContent = nil;
621
+ NSString *content = nil;
622
+ if (!llama->is_interrupted) {
623
+ try {
624
+ auto chat_format = params[@"chat_format"] ? [params[@"chat_format"] intValue] : COMMON_CHAT_FORMAT_CONTENT_ONLY;
625
+ common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
626
+ if (!message.reasoning_content.empty()) {
627
+ reasoningContent = [NSString stringWithUTF8String:message.reasoning_content.c_str()];
628
+ }
629
+ content = [NSString stringWithUTF8String:message.content.c_str()];
630
+ toolCalls = [[NSMutableArray alloc] init];
631
+ for (const auto &tc : message.tool_calls) {
632
+ [toolCalls addObject:@{
633
+ @"type": @"function",
634
+ @"function": @{
635
+ @"name": [NSString stringWithUTF8String:tc.name.c_str()],
636
+ @"arguments": [NSString stringWithUTF8String:tc.arguments.c_str()],
637
+ },
638
+ @"id": tc.id.empty() ? [NSNull null] : [NSString stringWithUTF8String:tc.id.c_str()],
639
+ }];
640
+ }
641
+ } catch (const std::exception &e) {
642
+ // NSLog(@"Error parsing tool calls: %s", e.what());
458
643
  }
644
+ }
645
+
646
+ NSMutableDictionary *result = [[NSMutableDictionary alloc] init];
647
+ result[@"text"] = [NSString stringWithUTF8String:llama->generated_text.c_str()]; // Original text
648
+ if (content) result[@"content"] = content;
649
+ if (reasoningContent) result[@"reasoning_content"] = reasoningContent;
650
+ if (toolCalls && toolCalls.count > 0) result[@"tool_calls"] = toolCalls;
651
+ result[@"completion_probabilities"] = [self tokenProbsToDict:llama->generated_token_probs];
652
+ result[@"tokens_predicted"] = @(llama->num_tokens_predicted);
653
+ result[@"tokens_evaluated"] = @(llama->num_prompt_tokens);
654
+ result[@"truncated"] = @(llama->truncated);
655
+ result[@"stopped_eos"] = @(llama->stopped_eos);
656
+ result[@"stopped_word"] = @(llama->stopped_word);
657
+ result[@"stopped_limit"] = @(llama->stopped_limit);
658
+ result[@"stopping_word"] = [NSString stringWithUTF8String:llama->stopping_word.c_str()];
659
+ result[@"tokens_cached"] = @(llama->n_past);
660
+ result[@"timings"] = @{
661
+ @"prompt_n": @(timings.n_p_eval),
662
+ @"prompt_ms": @(timings.t_p_eval_ms),
663
+ @"prompt_per_token_ms": @(timings.t_p_eval_ms / timings.n_p_eval),
664
+ @"prompt_per_second": @(1e3 / timings.t_p_eval_ms * timings.n_p_eval),
665
+ @"predicted_n": @(timings.n_eval),
666
+ @"predicted_n": @(timings.n_eval),
667
+ @"predicted_ms": @(timings.t_eval_ms),
668
+ @"predicted_per_token_ms": @(timings.t_eval_ms / timings.n_eval),
669
+ @"predicted_per_second": @(1e3 / timings.t_eval_ms * timings.n_eval),
459
670
  };
671
+ return result;
460
672
  }
461
673
 
462
674
  - (void)stopCompletion {
package/jest/mock.js CHANGED
@@ -18,12 +18,31 @@ if (!NativeModules.RNLlama) {
18
18
  'general.architecture': 'llama',
19
19
  'llama.embedding_length': 768,
20
20
  },
21
+ chatTemplates: {
22
+ llamaChat: true,
23
+ minja: {
24
+ default: true,
25
+ defaultCaps: {
26
+ parallelToolCalls: false,
27
+ systemRole: true,
28
+ toolCallId: false,
29
+ toolCalls: false,
30
+ toolResponses: false,
31
+ tools: false,
32
+ },
33
+ toolUse: false,
34
+ },
35
+ },
21
36
  },
22
37
  }),
23
38
  ),
24
39
 
25
- // TODO: Use jinja parser
26
- getFormattedChat: jest.fn(() => ''),
40
+ getFormattedChat: jest.fn(async (messages, chatTemplate, options) => {
41
+ if (options.jinja) {
42
+ return { prompt: '', chat_format: 0 }
43
+ }
44
+ return ''
45
+ }),
27
46
 
28
47
  completion: jest.fn(async (contextId, jobId) => {
29
48
  const testResult = {
@@ -1 +1 @@
1
- {"version":3,"names":["_reactNative","require","_default","TurboModuleRegistry","get","exports","default"],"sourceRoot":"..\\..\\src","sources":["NativeRNLlama.ts"],"mappings":";;;;;;AACA,IAAAA,YAAA,GAAAC,OAAA;AAAkD,IAAAC,QAAA,GAsTnCC,gCAAmB,CAACC,GAAG,CAAO,SAAS,CAAC;AAAAC,OAAA,CAAAC,OAAA,GAAAJ,QAAA"}
1
+ {"version":3,"names":["_reactNative","require","_default","TurboModuleRegistry","get","exports","default"],"sourceRoot":"..\\..\\src","sources":["NativeRNLlama.ts"],"mappings":";;;;;;AACA,IAAAA,YAAA,GAAAC,OAAA;AAAkD,IAAAC,QAAA,GAsanCC,gCAAmB,CAACC,GAAG,CAAO,SAAS,CAAC;AAAAC,OAAA,CAAAC,OAAA,GAAAJ,QAAA"}
@@ -6,6 +6,9 @@ Object.defineProperty(exports, "__esModule", {
6
6
  exports.convertJsonSchemaToGrammar = exports.SchemaGrammarConverterBuiltinRule = exports.SchemaGrammarConverter = void 0;
7
7
  /* eslint-disable no-restricted-syntax */
8
8
  /* eslint-disable no-underscore-dangle */
9
+
10
+ // NOTE: Deprecated, please use tools or response_format with json_schema instead
11
+
9
12
  const SPACE_RULE = '" "?';
10
13
  function buildRepetition(itemRule, minItems, maxItems) {
11
14
  let opts = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : {};