cui-llama.rn 1.0.3 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/README.md +35 -39
  2. package/android/src/main/CMakeLists.txt +12 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +29 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
  5. package/android/src/main/jni.cpp +62 -8
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  8. package/cpp/common.cpp +3237 -3231
  9. package/cpp/common.h +469 -468
  10. package/cpp/ggml-aarch64.c +2193 -2193
  11. package/cpp/ggml-aarch64.h +39 -39
  12. package/cpp/ggml-alloc.c +1036 -1042
  13. package/cpp/ggml-backend-impl.h +153 -153
  14. package/cpp/ggml-backend.c +2240 -2234
  15. package/cpp/ggml-backend.h +238 -238
  16. package/cpp/ggml-common.h +1833 -1829
  17. package/cpp/ggml-impl.h +755 -655
  18. package/cpp/ggml-metal.h +65 -65
  19. package/cpp/ggml-metal.m +3269 -3269
  20. package/cpp/ggml-quants.c +14872 -14860
  21. package/cpp/ggml-quants.h +132 -132
  22. package/cpp/ggml.c +22055 -22044
  23. package/cpp/ggml.h +2453 -2447
  24. package/cpp/llama-grammar.cpp +539 -0
  25. package/cpp/llama-grammar.h +39 -0
  26. package/cpp/llama-impl.h +26 -0
  27. package/cpp/llama-sampling.cpp +635 -0
  28. package/cpp/llama-sampling.h +56 -0
  29. package/cpp/llama-vocab.cpp +1721 -0
  30. package/cpp/llama-vocab.h +130 -0
  31. package/cpp/llama.cpp +19171 -21892
  32. package/cpp/llama.h +1240 -1217
  33. package/cpp/log.h +737 -737
  34. package/cpp/rn-llama.hpp +207 -29
  35. package/cpp/sampling.cpp +460 -460
  36. package/cpp/sgemm.cpp +1027 -1027
  37. package/cpp/sgemm.h +14 -14
  38. package/cpp/unicode.cpp +6 -0
  39. package/cpp/unicode.h +3 -0
  40. package/ios/RNLlama.mm +15 -6
  41. package/ios/RNLlamaContext.h +2 -8
  42. package/ios/RNLlamaContext.mm +41 -34
  43. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  44. package/lib/commonjs/chat.js +37 -0
  45. package/lib/commonjs/chat.js.map +1 -0
  46. package/lib/commonjs/index.js +14 -1
  47. package/lib/commonjs/index.js.map +1 -1
  48. package/lib/module/NativeRNLlama.js.map +1 -1
  49. package/lib/module/chat.js +31 -0
  50. package/lib/module/chat.js.map +1 -0
  51. package/lib/module/index.js +14 -1
  52. package/lib/module/index.js.map +1 -1
  53. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  54. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  55. package/lib/typescript/chat.d.ts +10 -0
  56. package/lib/typescript/chat.d.ts.map +1 -0
  57. package/lib/typescript/index.d.ts +9 -2
  58. package/lib/typescript/index.d.ts.map +1 -1
  59. package/package.json +1 -1
  60. package/src/NativeRNLlama.ts +10 -1
  61. package/src/chat.ts +44 -0
  62. package/src/index.ts +31 -4
package/cpp/rn-llama.hpp CHANGED
@@ -6,6 +6,13 @@
6
6
  #include "common.h"
7
7
  #include "llama.h"
8
8
 
9
+
10
+ #include <android/log.h>
11
+ #define LLAMA_ANDROID_TAG "RNLLAMA_LOG_ANDROID"
12
+ #define LLAMA_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , LLAMA_ANDROID_TAG, __VA_ARGS__)
13
+
14
+
15
+
9
16
  namespace rnllama {
10
17
 
11
18
  static void llama_batch_clear(llama_batch *batch) {
@@ -139,6 +146,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
139
146
  return ret;
140
147
  }
141
148
 
149
+
142
150
  struct llama_rn_context
143
151
  {
144
152
  bool is_predicting = false;
@@ -167,7 +175,7 @@ struct llama_rn_context
167
175
  bool stopped_word = false;
168
176
  bool stopped_limit = false;
169
177
  std::string stopping_word;
170
- int32_t multibyte_pending = 0;
178
+ bool incomplete = false;
171
179
 
172
180
  ~llama_rn_context()
173
181
  {
@@ -202,7 +210,7 @@ struct llama_rn_context
202
210
  stopped_word = false;
203
211
  stopped_limit = false;
204
212
  stopping_word = "";
205
- multibyte_pending = 0;
213
+ incomplete = false;
206
214
  n_remain = 0;
207
215
  n_past = 0;
208
216
  params.sparams.n_prev = n_ctx;
@@ -229,6 +237,14 @@ struct llama_rn_context
229
237
  return true;
230
238
  }
231
239
 
240
+ bool validateModelChatTemplate() const {
241
+ llama_chat_message chat[] = {{"user", "test"}};
242
+
243
+ const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
244
+
245
+ return res > 0;
246
+ }
247
+
232
248
  void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
233
249
  const int n_left = n_ctx - params.n_keep;
234
250
  const int n_block_size = n_left / 2;
@@ -278,15 +294,20 @@ struct llama_rn_context
278
294
 
279
295
  LM_GGML_ASSERT(num_prompt_tokens < (size_t) n_ctx);
280
296
  }
297
+
298
+ // do Context Shift , may be buggy! TODO: Verify functionality
299
+ purge_missing_tokens(ctx, embd, prompt_tokens, params.n_predict, params.n_ctx);
300
+
281
301
  // push the prompt into the sampling context (do not apply grammar)
282
302
  for (auto & token : prompt_tokens)
283
303
  {
284
304
  llama_sampling_accept(ctx_sampling, ctx, token, false);
285
305
  }
286
-
287
306
  // compare the evaluated prompt with the new prompt
288
307
  n_past = common_part(embd, prompt_tokens);
289
-
308
+ LLAMA_LOG_INFO("%s: n_past: %zu", __func__, n_past);
309
+ LLAMA_LOG_INFO("%s: embd size: %zu", __func__, embd.size());
310
+ LLAMA_LOG_INFO("%s: prompt_tokens size: %zu", __func__, prompt_tokens.size());
290
311
  embd = prompt_tokens;
291
312
  if (n_past == num_prompt_tokens)
292
313
  {
@@ -470,35 +491,28 @@ struct llama_rn_context
470
491
  generated_token_probs.push_back(token_with_probs);
471
492
  }
472
493
 
473
- if (multibyte_pending > 0)
474
- {
475
- multibyte_pending -= token_text.size();
476
- }
477
- else if (token_text.size() == 1)
478
- {
479
- const char c = token_text[0];
480
- // 2-byte characters: 110xxxxx 10xxxxxx
481
- if ((c & 0xE0) == 0xC0)
482
- {
483
- multibyte_pending = 1;
484
- // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
485
- }
486
- else if ((c & 0xF0) == 0xE0)
487
- {
488
- multibyte_pending = 2;
489
- // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
494
+ // check if there is incomplete UTF-8 character at the end
495
+ for (unsigned i = 1; i < 5 && i <= generated_text.size(); ++i) {
496
+ unsigned char c = generated_text[generated_text.size() - i];
497
+ if ((c & 0xC0) == 0x80) {
498
+ // continuation byte: 10xxxxxx
499
+ continue;
490
500
  }
491
- else if ((c & 0xF8) == 0xF0)
492
- {
493
- multibyte_pending = 3;
494
- }
495
- else
496
- {
497
- multibyte_pending = 0;
501
+ if ((c & 0xE0) == 0xC0) {
502
+ // 2-byte character: 110xxxxx ...
503
+ incomplete = i < 2;
504
+ } else if ((c & 0xF0) == 0xE0) {
505
+ // 3-byte character: 1110xxxx ...
506
+ incomplete = i < 3;
507
+ } else if ((c & 0xF8) == 0xF0) {
508
+ // 4-byte character: 11110xxx ...
509
+ incomplete = i < 4;
498
510
  }
511
+ // else 1-byte character or invalid byte
512
+ break;
499
513
  }
500
514
 
501
- if (multibyte_pending > 0 && !has_next_token)
515
+ if (incomplete && !has_next_token)
502
516
  {
503
517
  has_next_token = true;
504
518
  n_remain++;
@@ -638,6 +652,170 @@ struct llama_rn_context
638
652
  std::to_string(tg_std) +
639
653
  std::string("]");
640
654
  }
655
+
656
+
657
+ // Context Shifting from KoboldCpp <https://github.com/LostRuins/koboldcpp>
658
+ // Implementation obtained with special permission from @concedo
659
+
660
+ std::vector<int> longest_common_subseq(const std::vector<int> x, const std::vector<int> y){
661
+ int m = x.size(), n = y.size();
662
+
663
+ //int LCSuff[m+1][n+1];
664
+ std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
665
+
666
+ for (int j = 0; j <= n; j++)
667
+ LCSuff[0][j] = 0;
668
+ for (int i = 0; i <= m; i++)
669
+ LCSuff[i][0] = 0;
670
+
671
+ for (int i = 1; i <= m; i++)
672
+ {
673
+ for (int j = 1; j <= n; j++)
674
+ {
675
+ if (x[i - 1] == y[j - 1])
676
+ LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1;
677
+ else
678
+ LCSuff[i][j] = 0;
679
+ }
680
+ }
681
+
682
+ std::vector<int> longest;
683
+ for (int i = 1; i <= m; i++)
684
+ {
685
+ for (int j = 1; j <= n; j++)
686
+ {
687
+ if (LCSuff[i][j] > longest.size())
688
+ {
689
+ auto off1 = ((i - LCSuff[i][j] + 1) - 1);
690
+ auto off2 = off1 + LCSuff[i][j];
691
+ longest.clear();
692
+ // std::vector<int>().swap(longest);
693
+ longest = std::vector<int>(x.begin() + off1, x.begin() + off2);
694
+ // x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]);
695
+ }
696
+ }
697
+ }
698
+ return longest;
699
+ }
700
+
701
+ bool arr_start_with(const std::vector<int> targetArray, const std::vector<int> searchSeq)
702
+ {
703
+ int ss = searchSeq.size();
704
+ if(targetArray.size()<ss)
705
+ {
706
+ return false;
707
+ }
708
+ for(int i=0;i<ss;++i)
709
+ {
710
+ if(targetArray[i]!=searchSeq[i])
711
+ {
712
+ return false;
713
+ }
714
+ }
715
+ return true;
716
+ }
717
+
718
+ int arr_find_index_of(const std::vector<int> targetArray, const std::vector<int> searchSeq)
719
+ {
720
+ int ss = searchSeq.size();
721
+ int tas = targetArray.size();
722
+ if(tas<ss)
723
+ {
724
+ return -1;
725
+ }
726
+ for(int i=0;i<tas;++i)
727
+ {
728
+ int srch = 0;
729
+ bool fail = false;
730
+ for(int srch=0;srch<ss;++srch)
731
+ {
732
+ if ((i + srch) >= tas || targetArray[i + srch] != searchSeq[srch])
733
+ {
734
+ fail = true;
735
+ break;
736
+ }
737
+ }
738
+ if(!fail)
739
+ {
740
+ return i;
741
+ }
742
+ }
743
+ return -1;
744
+ }
745
+
746
+ void purge_missing_tokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
747
+ {
748
+ //scan from start old and new ctx, until first mismatch found, save as p0
749
+ //check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
750
+ //test: longest common subseq (LCQ) MUST start within 0 tokens from end of memory, otherwise purge fails
751
+ //if passed, save beginning of LCQ from old ctx as p1
752
+ //remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
753
+
754
+ const int short_fall_threshold = 200 + (nctx/30); //dont trigger shifting if the distance between trimstart and currhead < this
755
+ const int stack_allowance = 60 + (nctx/50); //in case the end text is slightly modified, be forgiving
756
+
757
+ int trimstart = 0;
758
+ int new_tokens_len = new_context_tokens.size();
759
+ bool purge_needed = true;
760
+
761
+ for (int i = 0; i < current_context_tokens.size(); ++i)
762
+ {
763
+ if (current_context_tokens[i] == new_context_tokens[i])
764
+ {
765
+ trimstart += 1;
766
+ }
767
+ else
768
+ {
769
+ break;
770
+ }
771
+ if ((i + 2) >= new_tokens_len)
772
+ {
773
+ purge_needed = false;
774
+ break; //no surgery required
775
+ }
776
+ }
777
+
778
+
779
+
780
+ if(!purge_needed || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < short_fall_threshold)
781
+ {
782
+ LLAMA_LOG_INFO("Fall Threshold: %d out of %d\n", new_tokens_len - trimstart, short_fall_threshold);
783
+ return; //no purge is needed
784
+ }
785
+
786
+ //at least this many tokens need to match, otherwise don't bother trimming
787
+ const int lc_tok_threshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+stack_allowance), (int)(nctx*0.45)), short_fall_threshold - stack_allowance);
788
+
789
+ auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
790
+ auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
791
+
792
+ auto shared = longest_common_subseq(curr_ctx_without_memory, new_ctx_without_memory);
793
+
794
+ if (shared.size() > lc_tok_threshold && arr_start_with(new_ctx_without_memory, shared)) // enough tokens in common
795
+ {
796
+ int found = arr_find_index_of(current_context_tokens,shared);
797
+ if(found>=0 && found > trimstart)
798
+ {
799
+
800
+ //extract the unwanted tokens out from context and KV
801
+ int diff = found - trimstart;
802
+ llama_kv_cache_seq_rm(ctx, 0, trimstart, trimstart + diff);
803
+ llama_kv_cache_seq_add(ctx, 0, trimstart + diff, -1, -diff);
804
+
805
+ for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
806
+ {
807
+ current_context_tokens[i - diff] = current_context_tokens[i];
808
+ }
809
+
810
+ LLAMA_LOG_INFO("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart + 1);
811
+
812
+ current_context_tokens.resize(current_context_tokens.size() - diff);
813
+ }
814
+ }
815
+
816
+ }
817
+
818
+ // End Context Shifting
641
819
  };
642
820
 
643
821
  }