cui-llama.rn 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/LICENSE +20 -0
  2. package/README.md +330 -0
  3. package/android/build.gradle +107 -0
  4. package/android/gradle.properties +5 -0
  5. package/android/src/main/AndroidManifest.xml +4 -0
  6. package/android/src/main/CMakeLists.txt +69 -0
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
  8. package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
  10. package/android/src/main/jni.cpp +635 -0
  11. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
  12. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
  13. package/cpp/README.md +4 -0
  14. package/cpp/common.cpp +3237 -0
  15. package/cpp/common.h +467 -0
  16. package/cpp/ggml-aarch64.c +2193 -0
  17. package/cpp/ggml-aarch64.h +39 -0
  18. package/cpp/ggml-alloc.c +1041 -0
  19. package/cpp/ggml-alloc.h +76 -0
  20. package/cpp/ggml-backend-impl.h +153 -0
  21. package/cpp/ggml-backend.c +2225 -0
  22. package/cpp/ggml-backend.h +236 -0
  23. package/cpp/ggml-common.h +1829 -0
  24. package/cpp/ggml-impl.h +655 -0
  25. package/cpp/ggml-metal.h +65 -0
  26. package/cpp/ggml-metal.m +3273 -0
  27. package/cpp/ggml-quants.c +15022 -0
  28. package/cpp/ggml-quants.h +132 -0
  29. package/cpp/ggml.c +22034 -0
  30. package/cpp/ggml.h +2444 -0
  31. package/cpp/grammar-parser.cpp +536 -0
  32. package/cpp/grammar-parser.h +29 -0
  33. package/cpp/json-schema-to-grammar.cpp +1045 -0
  34. package/cpp/json-schema-to-grammar.h +8 -0
  35. package/cpp/json.hpp +24766 -0
  36. package/cpp/llama.cpp +21789 -0
  37. package/cpp/llama.h +1201 -0
  38. package/cpp/log.h +737 -0
  39. package/cpp/rn-llama.hpp +630 -0
  40. package/cpp/sampling.cpp +460 -0
  41. package/cpp/sampling.h +160 -0
  42. package/cpp/sgemm.cpp +1027 -0
  43. package/cpp/sgemm.h +14 -0
  44. package/cpp/unicode-data.cpp +7032 -0
  45. package/cpp/unicode-data.h +20 -0
  46. package/cpp/unicode.cpp +812 -0
  47. package/cpp/unicode.h +64 -0
  48. package/ios/RNLlama.h +11 -0
  49. package/ios/RNLlama.mm +302 -0
  50. package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
  51. package/ios/RNLlamaContext.h +39 -0
  52. package/ios/RNLlamaContext.mm +426 -0
  53. package/jest/mock.js +169 -0
  54. package/lib/commonjs/NativeRNLlama.js +10 -0
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -0
  56. package/lib/commonjs/grammar.js +574 -0
  57. package/lib/commonjs/grammar.js.map +1 -0
  58. package/lib/commonjs/index.js +151 -0
  59. package/lib/commonjs/index.js.map +1 -0
  60. package/lib/module/NativeRNLlama.js +3 -0
  61. package/lib/module/NativeRNLlama.js.map +1 -0
  62. package/lib/module/grammar.js +566 -0
  63. package/lib/module/grammar.js.map +1 -0
  64. package/lib/module/index.js +129 -0
  65. package/lib/module/index.js.map +1 -0
  66. package/lib/typescript/NativeRNLlama.d.ts +107 -0
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
  68. package/lib/typescript/grammar.d.ts +38 -0
  69. package/lib/typescript/grammar.d.ts.map +1 -0
  70. package/lib/typescript/index.d.ts +46 -0
  71. package/lib/typescript/index.d.ts.map +1 -0
  72. package/llama-rn.podspec +56 -0
  73. package/package.json +230 -0
  74. package/src/NativeRNLlama.ts +132 -0
  75. package/src/grammar.ts +849 -0
  76. package/src/index.ts +182 -0
@@ -0,0 +1,426 @@
1
+ #import "RNLlamaContext.h"
2
+ #import <Metal/Metal.h>
3
+
4
+ @implementation RNLlamaContext
5
+
6
+ + (instancetype)initWithParams:(NSDictionary *)params {
7
+ // llama_backend_init(false);
8
+ gpt_params defaultParams;
9
+
10
+ NSString *modelPath = params[@"model"];
11
+ BOOL isAsset = [params[@"is_model_asset"] boolValue];
12
+ NSString *path = modelPath;
13
+ if (isAsset) path = [[NSBundle mainBundle] pathForResource:modelPath ofType:nil];
14
+ defaultParams.model = [path UTF8String];
15
+
16
+ if (params[@"embedding"] && [params[@"embedding"] boolValue]) {
17
+ defaultParams.embedding = true;
18
+ }
19
+
20
+ if (params[@"n_ctx"]) defaultParams.n_ctx = [params[@"n_ctx"] intValue];
21
+ if (params[@"use_mlock"]) defaultParams.use_mlock = [params[@"use_mlock"]boolValue];
22
+
23
+ BOOL isMetalEnabled = false;
24
+ NSString *reasonNoMetal = @"";
25
+ defaultParams.n_gpu_layers = 0;
26
+ if (params[@"n_gpu_layers"] && [params[@"n_gpu_layers"] intValue] > 0) {
27
+ #ifdef LM_GGML_USE_METAL
28
+ // Check ggml-metal availability
29
+ NSError * error = nil;
30
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
31
+ id<MTLLibrary> library = [device
32
+ newLibraryWithSource:@"#include <metal_stdlib>\n"
33
+ "using namespace metal;"
34
+ "kernel void test() { simd_sum(0); }"
35
+ options:nil
36
+ error:&error
37
+ ];
38
+ if (error) {
39
+ reasonNoMetal = [error localizedDescription];
40
+ } else {
41
+ id<MTLFunction> kernel = [library newFunctionWithName:@"test"];
42
+ id<MTLComputePipelineState> pipeline = [device newComputePipelineStateWithFunction:kernel error:&error];
43
+ if (pipeline == nil) {
44
+ reasonNoMetal = [error localizedDescription];
45
+ } else {
46
+ defaultParams.n_gpu_layers = [params[@"n_gpu_layers"] intValue];
47
+ isMetalEnabled = true;
48
+ }
49
+ }
50
+ device = nil;
51
+ #else
52
+ reasonNoMetal = @"Metal is not enabled in this build";
53
+ isMetalEnabled = false;
54
+ #endif
55
+ }
56
+ if (params[@"n_batch"]) defaultParams.n_batch = [params[@"n_batch"] intValue];
57
+ if (params[@"use_mmap"]) defaultParams.use_mmap = [params[@"use_mmap"] boolValue];
58
+
59
+ if (params[@"lora"]) {
60
+ float lora_scaled = 1.0f;
61
+ if (params[@"lora_scaled"]) lora_scaled = [params[@"lora_scaled"] floatValue];
62
+ defaultParams.lora_adapter.push_back({[params[@"lora"] UTF8String], lora_scaled});
63
+ defaultParams.use_mmap = false;
64
+ }
65
+ if (params[@"lora_base"]) defaultParams.lora_base = [params[@"lora_base"] UTF8String];
66
+
67
+ if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue];
68
+ if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue];
69
+
70
+ if (params[@"seed"]) defaultParams.seed = [params[@"seed"] intValue];
71
+
72
+ int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : 0;
73
+ const int maxThreads = (int) [[NSProcessInfo processInfo] processorCount];
74
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
75
+ const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads);
76
+ defaultParams.n_threads = nThreads > 0 ? nThreads : defaultNThreads;
77
+
78
+ RNLlamaContext *context = [[RNLlamaContext alloc] init];
79
+ if (context->llama == nullptr) {
80
+ context->llama = new rnllama::llama_rn_context();
81
+ }
82
+ context->is_model_loaded = context->llama->loadModel(defaultParams);
83
+ context->is_metal_enabled = isMetalEnabled;
84
+ context->reason_no_metal = reasonNoMetal;
85
+
86
+ int count = llama_model_meta_count(context->llama->model);
87
+ NSDictionary *meta = [[NSMutableDictionary alloc] init];
88
+ for (int i = 0; i < count; i++) {
89
+ char key[256];
90
+ llama_model_meta_key_by_index(context->llama->model, i, key, sizeof(key));
91
+ char val[256];
92
+ llama_model_meta_val_str_by_index(context->llama->model, i, val, sizeof(val));
93
+
94
+ NSString *keyStr = [NSString stringWithUTF8String:key];
95
+ NSString *valStr = [NSString stringWithUTF8String:val];
96
+ [meta setValue:valStr forKey:keyStr];
97
+ }
98
+ context->metadata = meta;
99
+
100
+ char desc[1024];
101
+ llama_model_desc(context->llama->model, desc, sizeof(desc));
102
+ context->model_desc = [NSString stringWithUTF8String:desc];
103
+ context->model_size = llama_model_size(context->llama->model);
104
+ context->model_n_params = llama_model_n_params(context->llama->model);
105
+
106
+ return context;
107
+ }
108
+
109
+ - (bool)isMetalEnabled {
110
+ return is_metal_enabled;
111
+ }
112
+
113
+ - (NSString *)reasonNoMetal {
114
+ return reason_no_metal;
115
+ }
116
+
117
+ - (NSDictionary *)metadata {
118
+ return metadata;
119
+ }
120
+
121
+ - (NSString *)modelDesc {
122
+ return model_desc;
123
+ }
124
+
125
+ - (uint64_t)modelSize {
126
+ return model_size;
127
+ }
128
+
129
+ - (uint64_t)modelNParams {
130
+ return model_n_params;
131
+ }
132
+
133
+ - (bool)isModelLoaded {
134
+ return is_model_loaded;
135
+ }
136
+
137
+ - (bool)isPredicting {
138
+ return llama->is_predicting;
139
+ }
140
+
141
+ - (NSArray *)tokenProbsToDict:(std::vector<rnllama::completion_token_output>)probs {
142
+ NSMutableArray *out = [[NSMutableArray alloc] init];
143
+ for (const auto &prob : probs)
144
+ {
145
+ NSMutableArray *probsForToken = [[NSMutableArray alloc] init];
146
+ for (const auto &p : prob.probs)
147
+ {
148
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, p.tok);
149
+ [probsForToken addObject:@{
150
+ @"tok_str": [NSString stringWithUTF8String:tokStr.c_str()],
151
+ @"prob": [NSNumber numberWithDouble:p.prob]
152
+ }];
153
+ }
154
+ std::string tokStr = rnllama::tokens_to_output_formatted_string(llama->ctx, prob.tok);
155
+ [out addObject:@{
156
+ @"content": [NSString stringWithUTF8String:tokStr.c_str()],
157
+ @"probs": probsForToken
158
+ }];
159
+ }
160
+ return out;
161
+ }
162
+
163
+ - (NSDictionary *)completion:(NSDictionary *)params
164
+ onToken:(void (^)(NSMutableDictionary * tokenResult))onToken
165
+ {
166
+ llama->rewind();
167
+
168
+ llama_reset_timings(llama->ctx);
169
+
170
+ NSString *prompt = [params objectForKey:@"prompt"];
171
+
172
+ llama->params.prompt = [prompt UTF8String];
173
+ llama->params.seed = params[@"seed"] ? [params[@"seed"] intValue] : -1;
174
+
175
+ if (params[@"n_threads"]) {
176
+ int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : llama->params.n_threads;
177
+ const int maxThreads = (int) [[NSProcessInfo processInfo] processorCount];
178
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
179
+ const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads);
180
+ llama->params.n_threads = nThreads > 0 ? nThreads : defaultNThreads;
181
+ }
182
+ if (params[@"n_predict"]) llama->params.n_predict = [params[@"n_predict"] intValue];
183
+
184
+ auto & sparams = llama->params.sparams;
185
+
186
+ if (params[@"temperature"]) sparams.temp = [params[@"temperature"] doubleValue];
187
+
188
+ if (params[@"n_probs"]) sparams.n_probs = [params[@"n_probs"] intValue];
189
+
190
+ if (params[@"penalty_last_n"]) sparams.penalty_last_n = [params[@"penalty_last_n"] intValue];
191
+ if (params[@"penalty_repeat"]) sparams.penalty_repeat = [params[@"penalty_repeat"] doubleValue];
192
+ if (params[@"penalty_freq"]) sparams.penalty_freq = [params[@"penalty_freq"] doubleValue];
193
+ if (params[@"penalty_present"]) sparams.penalty_present = [params[@"penalty_present"] doubleValue];
194
+
195
+ if (params[@"mirostat"]) sparams.mirostat = [params[@"mirostat"] intValue];
196
+ if (params[@"mirostat_tau"]) sparams.mirostat_tau = [params[@"mirostat_tau"] doubleValue];
197
+ if (params[@"mirostat_eta"]) sparams.mirostat_eta = [params[@"mirostat_eta"] doubleValue];
198
+ if (params[@"penalize_nl"]) sparams.penalize_nl = [params[@"penalize_nl"] boolValue];
199
+
200
+ if (params[@"top_k"]) sparams.top_k = [params[@"top_k"] intValue];
201
+ if (params[@"top_p"]) sparams.top_p = [params[@"top_p"] doubleValue];
202
+ if (params[@"min_p"]) sparams.min_p = [params[@"min_p"] doubleValue];
203
+ if (params[@"tfs_z"]) sparams.tfs_z = [params[@"tfs_z"] doubleValue];
204
+
205
+ if (params[@"typical_p"]) sparams.typical_p = [params[@"typical_p"] doubleValue];
206
+
207
+ if (params[@"grammar"]) {
208
+ sparams.grammar = [params[@"grammar"] UTF8String];
209
+ }
210
+
211
+ llama->params.antiprompt.clear();
212
+ if (params[@"stop"]) {
213
+ NSArray *stop = params[@"stop"];
214
+ for (NSString *s in stop) {
215
+ llama->params.antiprompt.push_back([s UTF8String]);
216
+ }
217
+ }
218
+
219
+ sparams.logit_bias.clear();
220
+ if (params[@"ignore_eos"] && [params[@"ignore_eos"] boolValue]) {
221
+ sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
222
+ }
223
+
224
+ if (params[@"logit_bias"] && [params[@"logit_bias"] isKindOfClass:[NSArray class]]) {
225
+ const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
226
+ NSArray *logit_bias = params[@"logit_bias"];
227
+ for (NSArray *el in logit_bias) {
228
+ if ([el isKindOfClass:[NSArray class]] && [el count] == 2) {
229
+ llama_token tok = [el[0] intValue];
230
+ if (tok >= 0 && tok < n_vocab) {
231
+ if ([el[1] isKindOfClass:[NSNumber class]]) {
232
+ sparams.logit_bias[tok] = [el[1] doubleValue];
233
+ } else if ([el[1] isKindOfClass:[NSNumber class]] && ![el[1] boolValue]) {
234
+ sparams.logit_bias[tok] = -INFINITY;
235
+ }
236
+ }
237
+ }
238
+ }
239
+ }
240
+
241
+ if (!llama->initSampling()) {
242
+ @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to initialize sampling" userInfo:nil];
243
+ }
244
+ llama->beginCompletion();
245
+ llama->loadPrompt();
246
+
247
+ size_t sent_count = 0;
248
+ size_t sent_token_probs_index = 0;
249
+
250
+ while (llama->has_next_token && !llama->is_interrupted) {
251
+ const rnllama::completion_token_output token_with_probs = llama->doCompletion();
252
+ if (token_with_probs.tok == -1 || llama->multibyte_pending > 0) {
253
+ continue;
254
+ }
255
+ const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
256
+
257
+ size_t pos = std::min(sent_count, llama->generated_text.size());
258
+
259
+ const std::string str_test = llama->generated_text.substr(pos);
260
+ bool is_stop_full = false;
261
+ size_t stop_pos =
262
+ llama->findStoppingStrings(str_test, token_text.size(), rnllama::STOP_FULL);
263
+ if (stop_pos != std::string::npos) {
264
+ is_stop_full = true;
265
+ llama->generated_text.erase(
266
+ llama->generated_text.begin() + pos + stop_pos,
267
+ llama->generated_text.end());
268
+ pos = std::min(sent_count, llama->generated_text.size());
269
+ } else {
270
+ is_stop_full = false;
271
+ stop_pos = llama->findStoppingStrings(str_test, token_text.size(),
272
+ rnllama::STOP_PARTIAL);
273
+ }
274
+
275
+ if (
276
+ stop_pos == std::string::npos ||
277
+ // Send rest of the text if we are at the end of the generation
278
+ (!llama->has_next_token && !is_stop_full && stop_pos > 0)
279
+ ) {
280
+ const std::string to_send = llama->generated_text.substr(pos, std::string::npos);
281
+
282
+ sent_count += to_send.size();
283
+
284
+ std::vector<rnllama::completion_token_output> probs_output = {};
285
+
286
+ NSMutableDictionary *tokenResult = [[NSMutableDictionary alloc] init];
287
+ tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()];
288
+
289
+ if (llama->params.sparams.n_probs > 0) {
290
+ const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
291
+ size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
292
+ size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
293
+ if (probs_pos < probs_stop_pos) {
294
+ probs_output = std::vector<rnllama::completion_token_output>(llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
295
+ }
296
+ sent_token_probs_index = probs_stop_pos;
297
+
298
+ tokenResult[@"completion_probabilities"] = [self tokenProbsToDict:probs_output];
299
+ }
300
+
301
+ onToken(tokenResult);
302
+ }
303
+ }
304
+
305
+ llama_print_timings(llama->ctx);
306
+ llama->is_predicting = false;
307
+
308
+ const auto timings = llama_get_timings(llama->ctx);
309
+ return @{
310
+ @"text": [NSString stringWithUTF8String:llama->generated_text.c_str()],
311
+ @"completion_probabilities": [self tokenProbsToDict:llama->generated_token_probs],
312
+ @"tokens_predicted": @(llama->num_tokens_predicted),
313
+ @"tokens_evaluated": @(llama->num_prompt_tokens),
314
+ @"truncated": @(llama->truncated),
315
+ @"stopped_eos": @(llama->stopped_eos),
316
+ @"stopped_word": @(llama->stopped_word),
317
+ @"stopped_limit": @(llama->stopped_limit),
318
+ @"stopping_word": [NSString stringWithUTF8String:llama->stopping_word.c_str()],
319
+ @"tokens_cached": @(llama->n_past),
320
+ @"timings": @{
321
+ @"prompt_n": @(timings.n_p_eval),
322
+ @"prompt_ms": @(timings.t_p_eval_ms),
323
+ @"prompt_per_token_ms": @(timings.t_p_eval_ms / timings.n_p_eval),
324
+ @"prompt_per_second": @(1e3 / timings.t_p_eval_ms * timings.n_p_eval),
325
+
326
+ @"predicted_n": @(timings.n_eval),
327
+ @"predicted_ms": @(timings.t_eval_ms),
328
+ @"predicted_per_token_ms": @(timings.t_eval_ms / timings.n_eval),
329
+ @"predicted_per_second": @(1e3 / timings.t_eval_ms * timings.n_eval),
330
+ }
331
+ };
332
+ }
333
+
334
+ - (void)stopCompletion {
335
+ llama->is_interrupted = true;
336
+ }
337
+
338
+ - (NSArray *)tokenize:(NSString *)text {
339
+ const std::vector<llama_token> toks = llama_tokenize(llama->ctx, [text UTF8String], false);
340
+ NSMutableArray *result = [[NSMutableArray alloc] init];
341
+ for (llama_token tok : toks) {
342
+ [result addObject:@(tok)];
343
+ }
344
+ return result;
345
+ }
346
+
347
+ - (NSString *)detokenize:(NSArray *)tokens {
348
+ std::vector<llama_token> toks;
349
+ for (NSNumber *tok in tokens) {
350
+ toks.push_back([tok intValue]);
351
+ }
352
+ const std::string text = rnllama::tokens_to_str(llama->ctx, toks.cbegin(), toks.cend());
353
+ return [NSString stringWithUTF8String:text.c_str()];
354
+ }
355
+
356
+ - (NSArray *)embedding:(NSString *)text {
357
+ if (llama->params.embedding != true) {
358
+ @throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not enabled" userInfo:nil];
359
+ }
360
+
361
+ llama->rewind();
362
+
363
+ llama_reset_timings(llama->ctx);
364
+
365
+ llama->params.prompt = [text UTF8String];
366
+
367
+ llama->params.n_predict = 0;
368
+ llama->loadPrompt();
369
+ llama->beginCompletion();
370
+ llama->doCompletion();
371
+
372
+ std::vector<float> result = llama->getEmbedding();
373
+
374
+ NSMutableArray *embeddingResult = [[NSMutableArray alloc] init];
375
+ for (float f : result) {
376
+ [embeddingResult addObject:@(f)];
377
+ }
378
+
379
+ llama->is_predicting = false;
380
+ return embeddingResult;
381
+ }
382
+
383
+ - (NSDictionary *)loadSession:(NSString *)path {
384
+ if (!path || [path length] == 0) {
385
+ @throw [NSException exceptionWithName:@"LlamaException" reason:@"Session path is empty" userInfo:nil];
386
+ }
387
+ if (![[NSFileManager defaultManager] fileExistsAtPath:path]) {
388
+ @throw [NSException exceptionWithName:@"LlamaException" reason:@"Session file does not exist" userInfo:nil];
389
+ }
390
+
391
+ size_t n_token_count_out = 0;
392
+ llama->embd.resize(llama->params.n_ctx);
393
+ if (!llama_state_load_file(llama->ctx, [path UTF8String], llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
394
+ @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to load session" userInfo:nil];
395
+ }
396
+ llama->embd.resize(n_token_count_out);
397
+ const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
398
+ return @{
399
+ @"tokens_loaded": @(n_token_count_out),
400
+ @"prompt": [NSString stringWithUTF8String:text.c_str()]
401
+ };
402
+ }
403
+
404
+ - (int)saveSession:(NSString *)path size:(int)size {
405
+ if (!path || [path length] == 0) {
406
+ @throw [NSException exceptionWithName:@"LlamaException" reason:@"Session path is empty" userInfo:nil];
407
+ }
408
+ std::vector<llama_token> session_tokens = llama->embd;
409
+ int default_size = session_tokens.size();
410
+ int save_size = size > 0 && size <= default_size ? size : default_size;
411
+ if (!llama_state_save_file(llama->ctx, [path UTF8String], session_tokens.data(), save_size)) {
412
+ @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to save session" userInfo:nil];
413
+ }
414
+ return session_tokens.size();
415
+ }
416
+
417
+ - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr {
418
+ return [NSString stringWithUTF8String:llama->bench(pp, tg, pl, nr).c_str()];
419
+ }
420
+
421
+ - (void)invalidate {
422
+ delete llama;
423
+ // llama_backend_free();
424
+ }
425
+
426
+ @end
package/jest/mock.js ADDED
@@ -0,0 +1,169 @@
1
+ const { NativeModules, DeviceEventEmitter } = require('react-native')
2
+
3
+ if (!NativeModules.RNLlama) {
4
+ NativeModules.RNLlama = {
5
+ initContext: jest.fn(() =>
6
+ Promise.resolve({
7
+ contextId: 1,
8
+ gpu: false,
9
+ reasonNoGPU: 'Test',
10
+ }),
11
+ ),
12
+
13
+ completion: jest.fn(async (contextId, jobId) => {
14
+ const testResult = {
15
+ text: '*giggles*',
16
+ completion_probabilities: [
17
+ {
18
+ content: ' *',
19
+ probs: [
20
+ {
21
+ prob: 0.9658700227737427,
22
+ tok_str: ' *',
23
+ },
24
+ {
25
+ prob: 0.021654844284057617,
26
+ tok_str: ' Hi',
27
+ },
28
+ {
29
+ prob: 0.012475099414587021,
30
+ tok_str: ' Hello',
31
+ },
32
+ ],
33
+ },
34
+ {
35
+ content: 'g',
36
+ probs: [
37
+ {
38
+ prob: 0.5133139491081238,
39
+ tok_str: 'g',
40
+ },
41
+ {
42
+ prob: 0.3046242296695709,
43
+ tok_str: 'ch',
44
+ },
45
+ {
46
+ prob: 0.18206188082695007,
47
+ tok_str: 'bl',
48
+ },
49
+ ],
50
+ },
51
+ {
52
+ content: 'igg',
53
+ probs: [
54
+ {
55
+ prob: 0.9886618852615356,
56
+ tok_str: 'igg',
57
+ },
58
+ {
59
+ prob: 0.008458126336336136,
60
+ tok_str: 'ig',
61
+ },
62
+ {
63
+ prob: 0.002879939740523696,
64
+ tok_str: 'reet',
65
+ },
66
+ ],
67
+ },
68
+ {
69
+ content: 'les',
70
+ probs: [
71
+ {
72
+ prob: 1,
73
+ tok_str: 'les',
74
+ },
75
+ {
76
+ prob: 1.8753286923356427e-8,
77
+ tok_str: 'ling',
78
+ },
79
+ {
80
+ prob: 3.312444318837038e-9,
81
+ tok_str: 'LES',
82
+ },
83
+ ],
84
+ },
85
+ {
86
+ content: '*',
87
+ probs: [
88
+ {
89
+ prob: 1,
90
+ tok_str: '*',
91
+ },
92
+ {
93
+ prob: 4.459857905203535e-8,
94
+ tok_str: '*.',
95
+ },
96
+ {
97
+ prob: 3.274198334679568e-8,
98
+ tok_str: '**',
99
+ },
100
+ ],
101
+ },
102
+ ],
103
+ stopped_eos: true,
104
+ stopped_limit: false,
105
+ stopped_word: false,
106
+ stopping_word: '',
107
+ timings: {
108
+ predicted_ms: 1330.6290000000001,
109
+ predicted_n: 5,
110
+ predicted_per_second: 16.533534140620713,
111
+ predicted_per_token_ms: 60.48313636363637,
112
+ prompt_ms: 3805.6730000000002,
113
+ prompt_n: 5,
114
+ prompt_per_second: 8.408499626741445,
115
+ prompt_per_token_ms: 118.92728125000001,
116
+ },
117
+ tokens_cached: 54,
118
+ tokens_evaluated: 15,
119
+ tokens_predicted: 6,
120
+ truncated: false,
121
+ }
122
+ const emitEvent = async (data) => {
123
+ await new Promise((resolve) => setTimeout(resolve))
124
+ DeviceEventEmitter.emit('@RNLlama_onToken', data)
125
+ }
126
+ await testResult.completion_probabilities.reduce(
127
+ (promise, item) =>
128
+ promise.then(() =>
129
+ emitEvent({
130
+ contextId,
131
+ jobId,
132
+ tokenResult: {
133
+ token: item.content,
134
+ completion_probabilities: item.probs,
135
+ },
136
+ }),
137
+ ),
138
+ Promise.resolve(),
139
+ )
140
+ return Promise.resolve(testResult)
141
+ }),
142
+
143
+ stopCompletion: jest.fn(),
144
+
145
+ tokenize: jest.fn(async () => []),
146
+ detokenize: jest.fn(async () => ''),
147
+ embedding: jest.fn(async () => []),
148
+
149
+ loadSession: jest.fn(async () => ({
150
+ tokens_loaded: 1,
151
+ prompt: 'Hello',
152
+ })),
153
+ saveSession: jest.fn(async () => 1),
154
+
155
+ bench: jest.fn(
156
+ async () =>
157
+ '["test 3B Q4_0",1600655360,2779683840,16.211304,0.021748,38.570646,1.195800]',
158
+ ),
159
+
160
+ releaseContext: jest.fn(() => Promise.resolve()),
161
+ releaseAllContexts: jest.fn(() => Promise.resolve()),
162
+
163
+ // For NativeEventEmitter
164
+ addListener: jest.fn(),
165
+ removeListeners: jest.fn(),
166
+ }
167
+ }
168
+
169
+ module.exports = jest.requireActual('llama.rn')
@@ -0,0 +1,10 @@
1
+ "use strict";
2
+
3
+ Object.defineProperty(exports, "__esModule", {
4
+ value: true
5
+ });
6
+ exports.default = void 0;
7
+ var _reactNative = require("react-native");
8
+ var _default = _reactNative.TurboModuleRegistry.get('RNLlama');
9
+ exports.default = _default;
10
+ //# sourceMappingURL=NativeRNLlama.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"names":["_reactNative","require","_default","TurboModuleRegistry","get","exports","default"],"sourceRoot":"..\\..\\src","sources":["NativeRNLlama.ts"],"mappings":";;;;;;AACA,IAAAA,YAAA,GAAAC,OAAA;AAAmD,IAAAC,QAAA,GAkIpCC,gCAAmB,CAACC,GAAG,CAAO,SAAS,CAAC;AAAAC,OAAA,CAAAC,OAAA,GAAAJ,QAAA"}