cui-llama.rn 1.0.4 → 1.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +36 -39
- package/android/src/main/CMakeLists.txt +11 -2
- package/android/src/main/java/com/rnllama/LlamaContext.java +24 -8
- package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
- package/android/src/main/jni.cpp +63 -9
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
- package/cpp/common.cpp +3237 -3231
- package/cpp/common.h +469 -468
- package/cpp/ggml-aarch64.c +2193 -2193
- package/cpp/ggml-aarch64.h +39 -39
- package/cpp/ggml-alloc.c +1036 -1042
- package/cpp/ggml-backend-impl.h +153 -153
- package/cpp/ggml-backend.c +2240 -2234
- package/cpp/ggml-backend.h +238 -238
- package/cpp/ggml-common.h +1833 -1829
- package/cpp/ggml-impl.h +755 -655
- package/cpp/ggml-metal.h +65 -65
- package/cpp/ggml-metal.m +3269 -3269
- package/cpp/ggml-quants.c +14872 -14860
- package/cpp/ggml-quants.h +132 -132
- package/cpp/ggml.c +22099 -22044
- package/cpp/ggml.h +2453 -2447
- package/cpp/llama-grammar.cpp +539 -0
- package/cpp/llama-grammar.h +39 -0
- package/cpp/llama-impl.h +26 -0
- package/cpp/llama-sampling.cpp +635 -0
- package/cpp/llama-sampling.h +56 -0
- package/cpp/llama-vocab.cpp +1721 -0
- package/cpp/llama-vocab.h +130 -0
- package/cpp/llama.cpp +19173 -21892
- package/cpp/llama.h +1240 -1217
- package/cpp/log.h +737 -737
- package/cpp/rn-llama.hpp +209 -29
- package/cpp/sampling.cpp +460 -460
- package/cpp/sgemm.cpp +1027 -1027
- package/cpp/sgemm.h +14 -14
- package/cpp/unicode.cpp +6 -0
- package/cpp/unicode.h +3 -0
- package/ios/RNLlama.mm +15 -6
- package/ios/RNLlamaContext.h +2 -8
- package/ios/RNLlamaContext.mm +41 -34
- package/jest/mock.js +3 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/chat.js +37 -0
- package/lib/commonjs/chat.js.map +1 -0
- package/lib/commonjs/index.js +14 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/chat.js +31 -0
- package/lib/module/chat.js.map +1 -0
- package/lib/module/index.js +14 -1
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +5 -1
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/chat.d.ts +10 -0
- package/lib/typescript/chat.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +9 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +10 -1
- package/src/chat.ts +44 -0
- package/src/index.ts +31 -4
package/cpp/rn-llama.hpp
CHANGED
@@ -6,6 +6,13 @@
|
|
6
6
|
#include "common.h"
|
7
7
|
#include "llama.h"
|
8
8
|
|
9
|
+
|
10
|
+
#include <android/log.h>
|
11
|
+
#define LLAMA_ANDROID_TAG "RNLLAMA_LOG_ANDROID"
|
12
|
+
#define LLAMA_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , LLAMA_ANDROID_TAG, __VA_ARGS__)
|
13
|
+
|
14
|
+
|
15
|
+
|
9
16
|
namespace rnllama {
|
10
17
|
|
11
18
|
static void llama_batch_clear(llama_batch *batch) {
|
@@ -139,6 +146,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
|
|
139
146
|
return ret;
|
140
147
|
}
|
141
148
|
|
149
|
+
|
142
150
|
struct llama_rn_context
|
143
151
|
{
|
144
152
|
bool is_predicting = false;
|
@@ -167,7 +175,7 @@ struct llama_rn_context
|
|
167
175
|
bool stopped_word = false;
|
168
176
|
bool stopped_limit = false;
|
169
177
|
std::string stopping_word;
|
170
|
-
|
178
|
+
bool incomplete = false;
|
171
179
|
|
172
180
|
~llama_rn_context()
|
173
181
|
{
|
@@ -202,7 +210,7 @@ struct llama_rn_context
|
|
202
210
|
stopped_word = false;
|
203
211
|
stopped_limit = false;
|
204
212
|
stopping_word = "";
|
205
|
-
|
213
|
+
incomplete = false;
|
206
214
|
n_remain = 0;
|
207
215
|
n_past = 0;
|
208
216
|
params.sparams.n_prev = n_ctx;
|
@@ -229,6 +237,16 @@ struct llama_rn_context
|
|
229
237
|
return true;
|
230
238
|
}
|
231
239
|
|
240
|
+
bool validateModelChatTemplate() const {
|
241
|
+
llama_chat_message chat[] = {{"user", "test"}};
|
242
|
+
|
243
|
+
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
244
|
+
std::string template_key = "tokenizer.chat_template";
|
245
|
+
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
246
|
+
|
247
|
+
return res >= 0;
|
248
|
+
}
|
249
|
+
|
232
250
|
void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
|
233
251
|
const int n_left = n_ctx - params.n_keep;
|
234
252
|
const int n_block_size = n_left / 2;
|
@@ -278,15 +296,20 @@ struct llama_rn_context
|
|
278
296
|
|
279
297
|
LM_GGML_ASSERT(num_prompt_tokens < (size_t) n_ctx);
|
280
298
|
}
|
299
|
+
|
300
|
+
// do Context Shift , may be buggy! TODO: Verify functionality
|
301
|
+
purge_missing_tokens(ctx, embd, prompt_tokens, params.n_predict, params.n_ctx);
|
302
|
+
|
281
303
|
// push the prompt into the sampling context (do not apply grammar)
|
282
304
|
for (auto & token : prompt_tokens)
|
283
305
|
{
|
284
306
|
llama_sampling_accept(ctx_sampling, ctx, token, false);
|
285
307
|
}
|
286
|
-
|
287
308
|
// compare the evaluated prompt with the new prompt
|
288
309
|
n_past = common_part(embd, prompt_tokens);
|
289
|
-
|
310
|
+
LLAMA_LOG_INFO("%s: n_past: %zu", __func__, n_past);
|
311
|
+
LLAMA_LOG_INFO("%s: embd size: %zu", __func__, embd.size());
|
312
|
+
LLAMA_LOG_INFO("%s: prompt_tokens size: %zu", __func__, prompt_tokens.size());
|
290
313
|
embd = prompt_tokens;
|
291
314
|
if (n_past == num_prompt_tokens)
|
292
315
|
{
|
@@ -470,35 +493,28 @@ struct llama_rn_context
|
|
470
493
|
generated_token_probs.push_back(token_with_probs);
|
471
494
|
}
|
472
495
|
|
473
|
-
if
|
474
|
-
{
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
const char c = token_text[0];
|
480
|
-
// 2-byte characters: 110xxxxx 10xxxxxx
|
481
|
-
if ((c & 0xE0) == 0xC0)
|
482
|
-
{
|
483
|
-
multibyte_pending = 1;
|
484
|
-
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
485
|
-
}
|
486
|
-
else if ((c & 0xF0) == 0xE0)
|
487
|
-
{
|
488
|
-
multibyte_pending = 2;
|
489
|
-
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
496
|
+
// check if there is incomplete UTF-8 character at the end
|
497
|
+
for (unsigned i = 1; i < 5 && i <= generated_text.size(); ++i) {
|
498
|
+
unsigned char c = generated_text[generated_text.size() - i];
|
499
|
+
if ((c & 0xC0) == 0x80) {
|
500
|
+
// continuation byte: 10xxxxxx
|
501
|
+
continue;
|
490
502
|
}
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
}
|
495
|
-
|
496
|
-
|
497
|
-
|
503
|
+
if ((c & 0xE0) == 0xC0) {
|
504
|
+
// 2-byte character: 110xxxxx ...
|
505
|
+
incomplete = i < 2;
|
506
|
+
} else if ((c & 0xF0) == 0xE0) {
|
507
|
+
// 3-byte character: 1110xxxx ...
|
508
|
+
incomplete = i < 3;
|
509
|
+
} else if ((c & 0xF8) == 0xF0) {
|
510
|
+
// 4-byte character: 11110xxx ...
|
511
|
+
incomplete = i < 4;
|
498
512
|
}
|
513
|
+
// else 1-byte character or invalid byte
|
514
|
+
break;
|
499
515
|
}
|
500
516
|
|
501
|
-
if (
|
517
|
+
if (incomplete && !has_next_token)
|
502
518
|
{
|
503
519
|
has_next_token = true;
|
504
520
|
n_remain++;
|
@@ -638,6 +654,170 @@ struct llama_rn_context
|
|
638
654
|
std::to_string(tg_std) +
|
639
655
|
std::string("]");
|
640
656
|
}
|
657
|
+
|
658
|
+
|
659
|
+
// Context Shifting from KoboldCpp <https://github.com/LostRuins/koboldcpp>
|
660
|
+
// Implementation obtained with special permission from @concedo
|
661
|
+
|
662
|
+
std::vector<int> longest_common_subseq(const std::vector<int> x, const std::vector<int> y){
|
663
|
+
int m = x.size(), n = y.size();
|
664
|
+
|
665
|
+
//int LCSuff[m+1][n+1];
|
666
|
+
std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
|
667
|
+
|
668
|
+
for (int j = 0; j <= n; j++)
|
669
|
+
LCSuff[0][j] = 0;
|
670
|
+
for (int i = 0; i <= m; i++)
|
671
|
+
LCSuff[i][0] = 0;
|
672
|
+
|
673
|
+
for (int i = 1; i <= m; i++)
|
674
|
+
{
|
675
|
+
for (int j = 1; j <= n; j++)
|
676
|
+
{
|
677
|
+
if (x[i - 1] == y[j - 1])
|
678
|
+
LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1;
|
679
|
+
else
|
680
|
+
LCSuff[i][j] = 0;
|
681
|
+
}
|
682
|
+
}
|
683
|
+
|
684
|
+
std::vector<int> longest;
|
685
|
+
for (int i = 1; i <= m; i++)
|
686
|
+
{
|
687
|
+
for (int j = 1; j <= n; j++)
|
688
|
+
{
|
689
|
+
if (LCSuff[i][j] > longest.size())
|
690
|
+
{
|
691
|
+
auto off1 = ((i - LCSuff[i][j] + 1) - 1);
|
692
|
+
auto off2 = off1 + LCSuff[i][j];
|
693
|
+
longest.clear();
|
694
|
+
// std::vector<int>().swap(longest);
|
695
|
+
longest = std::vector<int>(x.begin() + off1, x.begin() + off2);
|
696
|
+
// x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]);
|
697
|
+
}
|
698
|
+
}
|
699
|
+
}
|
700
|
+
return longest;
|
701
|
+
}
|
702
|
+
|
703
|
+
bool arr_start_with(const std::vector<int> targetArray, const std::vector<int> searchSeq)
|
704
|
+
{
|
705
|
+
int ss = searchSeq.size();
|
706
|
+
if(targetArray.size()<ss)
|
707
|
+
{
|
708
|
+
return false;
|
709
|
+
}
|
710
|
+
for(int i=0;i<ss;++i)
|
711
|
+
{
|
712
|
+
if(targetArray[i]!=searchSeq[i])
|
713
|
+
{
|
714
|
+
return false;
|
715
|
+
}
|
716
|
+
}
|
717
|
+
return true;
|
718
|
+
}
|
719
|
+
|
720
|
+
int arr_find_index_of(const std::vector<int> targetArray, const std::vector<int> searchSeq)
|
721
|
+
{
|
722
|
+
int ss = searchSeq.size();
|
723
|
+
int tas = targetArray.size();
|
724
|
+
if(tas<ss)
|
725
|
+
{
|
726
|
+
return -1;
|
727
|
+
}
|
728
|
+
for(int i=0;i<tas;++i)
|
729
|
+
{
|
730
|
+
int srch = 0;
|
731
|
+
bool fail = false;
|
732
|
+
for(int srch=0;srch<ss;++srch)
|
733
|
+
{
|
734
|
+
if ((i + srch) >= tas || targetArray[i + srch] != searchSeq[srch])
|
735
|
+
{
|
736
|
+
fail = true;
|
737
|
+
break;
|
738
|
+
}
|
739
|
+
}
|
740
|
+
if(!fail)
|
741
|
+
{
|
742
|
+
return i;
|
743
|
+
}
|
744
|
+
}
|
745
|
+
return -1;
|
746
|
+
}
|
747
|
+
|
748
|
+
void purge_missing_tokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
|
749
|
+
{
|
750
|
+
//scan from start old and new ctx, until first mismatch found, save as p0
|
751
|
+
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
|
752
|
+
//test: longest common subseq (LCQ) MUST start within 0 tokens from end of memory, otherwise purge fails
|
753
|
+
//if passed, save beginning of LCQ from old ctx as p1
|
754
|
+
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
|
755
|
+
|
756
|
+
const int short_fall_threshold = 200 + (nctx/30); //dont trigger shifting if the distance between trimstart and currhead < this
|
757
|
+
const int stack_allowance = 60 + (nctx/50); //in case the end text is slightly modified, be forgiving
|
758
|
+
|
759
|
+
int trimstart = 0;
|
760
|
+
int new_tokens_len = new_context_tokens.size();
|
761
|
+
bool purge_needed = true;
|
762
|
+
|
763
|
+
for (int i = 0; i < current_context_tokens.size(); ++i)
|
764
|
+
{
|
765
|
+
if (current_context_tokens[i] == new_context_tokens[i])
|
766
|
+
{
|
767
|
+
trimstart += 1;
|
768
|
+
}
|
769
|
+
else
|
770
|
+
{
|
771
|
+
break;
|
772
|
+
}
|
773
|
+
if ((i + 2) >= new_tokens_len)
|
774
|
+
{
|
775
|
+
purge_needed = false;
|
776
|
+
break; //no surgery required
|
777
|
+
}
|
778
|
+
}
|
779
|
+
|
780
|
+
|
781
|
+
|
782
|
+
if(!purge_needed || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < short_fall_threshold)
|
783
|
+
{
|
784
|
+
LLAMA_LOG_INFO("Fall Threshold: %d out of %d\n", new_tokens_len - trimstart, short_fall_threshold);
|
785
|
+
return; //no purge is needed
|
786
|
+
}
|
787
|
+
|
788
|
+
//at least this many tokens need to match, otherwise don't bother trimming
|
789
|
+
const int lc_tok_threshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+stack_allowance), (int)(nctx*0.45)), short_fall_threshold - stack_allowance);
|
790
|
+
|
791
|
+
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
|
792
|
+
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
|
793
|
+
|
794
|
+
auto shared = longest_common_subseq(curr_ctx_without_memory, new_ctx_without_memory);
|
795
|
+
|
796
|
+
if (shared.size() > lc_tok_threshold && arr_start_with(new_ctx_without_memory, shared)) // enough tokens in common
|
797
|
+
{
|
798
|
+
int found = arr_find_index_of(current_context_tokens,shared);
|
799
|
+
if(found>=0 && found > trimstart)
|
800
|
+
{
|
801
|
+
|
802
|
+
//extract the unwanted tokens out from context and KV
|
803
|
+
int diff = found - trimstart;
|
804
|
+
llama_kv_cache_seq_rm(ctx, 0, trimstart, trimstart + diff);
|
805
|
+
llama_kv_cache_seq_add(ctx, 0, trimstart + diff, -1, -diff);
|
806
|
+
|
807
|
+
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
|
808
|
+
{
|
809
|
+
current_context_tokens[i - diff] = current_context_tokens[i];
|
810
|
+
}
|
811
|
+
|
812
|
+
LLAMA_LOG_INFO("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart + 1);
|
813
|
+
|
814
|
+
current_context_tokens.resize(current_context_tokens.size() - diff);
|
815
|
+
}
|
816
|
+
}
|
817
|
+
|
818
|
+
}
|
819
|
+
|
820
|
+
// End Context Shifting
|
641
821
|
};
|
642
822
|
|
643
823
|
}
|