cui-llama.rn 1.0.4 → 1.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. package/README.md +36 -39
  2. package/android/src/main/CMakeLists.txt +11 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +24 -8
  4. package/android/src/main/java/com/rnllama/RNLlama.java +33 -1
  5. package/android/src/main/jni.cpp +63 -9
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  8. package/cpp/common.cpp +3237 -3231
  9. package/cpp/common.h +469 -468
  10. package/cpp/ggml-aarch64.c +2193 -2193
  11. package/cpp/ggml-aarch64.h +39 -39
  12. package/cpp/ggml-alloc.c +1036 -1042
  13. package/cpp/ggml-backend-impl.h +153 -153
  14. package/cpp/ggml-backend.c +2240 -2234
  15. package/cpp/ggml-backend.h +238 -238
  16. package/cpp/ggml-common.h +1833 -1829
  17. package/cpp/ggml-impl.h +755 -655
  18. package/cpp/ggml-metal.h +65 -65
  19. package/cpp/ggml-metal.m +3269 -3269
  20. package/cpp/ggml-quants.c +14872 -14860
  21. package/cpp/ggml-quants.h +132 -132
  22. package/cpp/ggml.c +22099 -22044
  23. package/cpp/ggml.h +2453 -2447
  24. package/cpp/llama-grammar.cpp +539 -0
  25. package/cpp/llama-grammar.h +39 -0
  26. package/cpp/llama-impl.h +26 -0
  27. package/cpp/llama-sampling.cpp +635 -0
  28. package/cpp/llama-sampling.h +56 -0
  29. package/cpp/llama-vocab.cpp +1721 -0
  30. package/cpp/llama-vocab.h +130 -0
  31. package/cpp/llama.cpp +19173 -21892
  32. package/cpp/llama.h +1240 -1217
  33. package/cpp/log.h +737 -737
  34. package/cpp/rn-llama.hpp +209 -29
  35. package/cpp/sampling.cpp +460 -460
  36. package/cpp/sgemm.cpp +1027 -1027
  37. package/cpp/sgemm.h +14 -14
  38. package/cpp/unicode.cpp +6 -0
  39. package/cpp/unicode.h +3 -0
  40. package/ios/RNLlama.mm +15 -6
  41. package/ios/RNLlamaContext.h +2 -8
  42. package/ios/RNLlamaContext.mm +41 -34
  43. package/jest/mock.js +3 -0
  44. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  45. package/lib/commonjs/chat.js +37 -0
  46. package/lib/commonjs/chat.js.map +1 -0
  47. package/lib/commonjs/index.js +14 -1
  48. package/lib/commonjs/index.js.map +1 -1
  49. package/lib/module/NativeRNLlama.js.map +1 -1
  50. package/lib/module/chat.js +31 -0
  51. package/lib/module/chat.js.map +1 -0
  52. package/lib/module/index.js +14 -1
  53. package/lib/module/index.js.map +1 -1
  54. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  55. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  56. package/lib/typescript/chat.d.ts +10 -0
  57. package/lib/typescript/chat.d.ts.map +1 -0
  58. package/lib/typescript/index.d.ts +9 -2
  59. package/lib/typescript/index.d.ts.map +1 -1
  60. package/package.json +1 -1
  61. package/src/NativeRNLlama.ts +10 -1
  62. package/src/chat.ts +44 -0
  63. package/src/index.ts +31 -4
package/cpp/rn-llama.hpp CHANGED
@@ -6,6 +6,13 @@
6
6
  #include "common.h"
7
7
  #include "llama.h"
8
8
 
9
+
10
+ #include <android/log.h>
11
+ #define LLAMA_ANDROID_TAG "RNLLAMA_LOG_ANDROID"
12
+ #define LLAMA_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , LLAMA_ANDROID_TAG, __VA_ARGS__)
13
+
14
+
15
+
9
16
  namespace rnllama {
10
17
 
11
18
  static void llama_batch_clear(llama_batch *batch) {
@@ -139,6 +146,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
139
146
  return ret;
140
147
  }
141
148
 
149
+
142
150
  struct llama_rn_context
143
151
  {
144
152
  bool is_predicting = false;
@@ -167,7 +175,7 @@ struct llama_rn_context
167
175
  bool stopped_word = false;
168
176
  bool stopped_limit = false;
169
177
  std::string stopping_word;
170
- int32_t multibyte_pending = 0;
178
+ bool incomplete = false;
171
179
 
172
180
  ~llama_rn_context()
173
181
  {
@@ -202,7 +210,7 @@ struct llama_rn_context
202
210
  stopped_word = false;
203
211
  stopped_limit = false;
204
212
  stopping_word = "";
205
- multibyte_pending = 0;
213
+ incomplete = false;
206
214
  n_remain = 0;
207
215
  n_past = 0;
208
216
  params.sparams.n_prev = n_ctx;
@@ -229,6 +237,16 @@ struct llama_rn_context
229
237
  return true;
230
238
  }
231
239
 
240
+ bool validateModelChatTemplate() const {
241
+ llama_chat_message chat[] = {{"user", "test"}};
242
+
243
+ std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
244
+ std::string template_key = "tokenizer.chat_template";
245
+ int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
246
+
247
+ return res >= 0;
248
+ }
249
+
232
250
  void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
233
251
  const int n_left = n_ctx - params.n_keep;
234
252
  const int n_block_size = n_left / 2;
@@ -278,15 +296,20 @@ struct llama_rn_context
278
296
 
279
297
  LM_GGML_ASSERT(num_prompt_tokens < (size_t) n_ctx);
280
298
  }
299
+
300
+ // do Context Shift , may be buggy! TODO: Verify functionality
301
+ purge_missing_tokens(ctx, embd, prompt_tokens, params.n_predict, params.n_ctx);
302
+
281
303
  // push the prompt into the sampling context (do not apply grammar)
282
304
  for (auto & token : prompt_tokens)
283
305
  {
284
306
  llama_sampling_accept(ctx_sampling, ctx, token, false);
285
307
  }
286
-
287
308
  // compare the evaluated prompt with the new prompt
288
309
  n_past = common_part(embd, prompt_tokens);
289
-
310
+ LLAMA_LOG_INFO("%s: n_past: %zu", __func__, n_past);
311
+ LLAMA_LOG_INFO("%s: embd size: %zu", __func__, embd.size());
312
+ LLAMA_LOG_INFO("%s: prompt_tokens size: %zu", __func__, prompt_tokens.size());
290
313
  embd = prompt_tokens;
291
314
  if (n_past == num_prompt_tokens)
292
315
  {
@@ -470,35 +493,28 @@ struct llama_rn_context
470
493
  generated_token_probs.push_back(token_with_probs);
471
494
  }
472
495
 
473
- if (multibyte_pending > 0)
474
- {
475
- multibyte_pending -= token_text.size();
476
- }
477
- else if (token_text.size() == 1)
478
- {
479
- const char c = token_text[0];
480
- // 2-byte characters: 110xxxxx 10xxxxxx
481
- if ((c & 0xE0) == 0xC0)
482
- {
483
- multibyte_pending = 1;
484
- // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
485
- }
486
- else if ((c & 0xF0) == 0xE0)
487
- {
488
- multibyte_pending = 2;
489
- // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
496
+ // check if there is incomplete UTF-8 character at the end
497
+ for (unsigned i = 1; i < 5 && i <= generated_text.size(); ++i) {
498
+ unsigned char c = generated_text[generated_text.size() - i];
499
+ if ((c & 0xC0) == 0x80) {
500
+ // continuation byte: 10xxxxxx
501
+ continue;
490
502
  }
491
- else if ((c & 0xF8) == 0xF0)
492
- {
493
- multibyte_pending = 3;
494
- }
495
- else
496
- {
497
- multibyte_pending = 0;
503
+ if ((c & 0xE0) == 0xC0) {
504
+ // 2-byte character: 110xxxxx ...
505
+ incomplete = i < 2;
506
+ } else if ((c & 0xF0) == 0xE0) {
507
+ // 3-byte character: 1110xxxx ...
508
+ incomplete = i < 3;
509
+ } else if ((c & 0xF8) == 0xF0) {
510
+ // 4-byte character: 11110xxx ...
511
+ incomplete = i < 4;
498
512
  }
513
+ // else 1-byte character or invalid byte
514
+ break;
499
515
  }
500
516
 
501
- if (multibyte_pending > 0 && !has_next_token)
517
+ if (incomplete && !has_next_token)
502
518
  {
503
519
  has_next_token = true;
504
520
  n_remain++;
@@ -638,6 +654,170 @@ struct llama_rn_context
638
654
  std::to_string(tg_std) +
639
655
  std::string("]");
640
656
  }
657
+
658
+
659
+ // Context Shifting from KoboldCpp <https://github.com/LostRuins/koboldcpp>
660
+ // Implementation obtained with special permission from @concedo
661
+
662
+ std::vector<int> longest_common_subseq(const std::vector<int> x, const std::vector<int> y){
663
+ int m = x.size(), n = y.size();
664
+
665
+ //int LCSuff[m+1][n+1];
666
+ std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
667
+
668
+ for (int j = 0; j <= n; j++)
669
+ LCSuff[0][j] = 0;
670
+ for (int i = 0; i <= m; i++)
671
+ LCSuff[i][0] = 0;
672
+
673
+ for (int i = 1; i <= m; i++)
674
+ {
675
+ for (int j = 1; j <= n; j++)
676
+ {
677
+ if (x[i - 1] == y[j - 1])
678
+ LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1;
679
+ else
680
+ LCSuff[i][j] = 0;
681
+ }
682
+ }
683
+
684
+ std::vector<int> longest;
685
+ for (int i = 1; i <= m; i++)
686
+ {
687
+ for (int j = 1; j <= n; j++)
688
+ {
689
+ if (LCSuff[i][j] > longest.size())
690
+ {
691
+ auto off1 = ((i - LCSuff[i][j] + 1) - 1);
692
+ auto off2 = off1 + LCSuff[i][j];
693
+ longest.clear();
694
+ // std::vector<int>().swap(longest);
695
+ longest = std::vector<int>(x.begin() + off1, x.begin() + off2);
696
+ // x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]);
697
+ }
698
+ }
699
+ }
700
+ return longest;
701
+ }
702
+
703
+ bool arr_start_with(const std::vector<int> targetArray, const std::vector<int> searchSeq)
704
+ {
705
+ int ss = searchSeq.size();
706
+ if(targetArray.size()<ss)
707
+ {
708
+ return false;
709
+ }
710
+ for(int i=0;i<ss;++i)
711
+ {
712
+ if(targetArray[i]!=searchSeq[i])
713
+ {
714
+ return false;
715
+ }
716
+ }
717
+ return true;
718
+ }
719
+
720
+ int arr_find_index_of(const std::vector<int> targetArray, const std::vector<int> searchSeq)
721
+ {
722
+ int ss = searchSeq.size();
723
+ int tas = targetArray.size();
724
+ if(tas<ss)
725
+ {
726
+ return -1;
727
+ }
728
+ for(int i=0;i<tas;++i)
729
+ {
730
+ int srch = 0;
731
+ bool fail = false;
732
+ for(int srch=0;srch<ss;++srch)
733
+ {
734
+ if ((i + srch) >= tas || targetArray[i + srch] != searchSeq[srch])
735
+ {
736
+ fail = true;
737
+ break;
738
+ }
739
+ }
740
+ if(!fail)
741
+ {
742
+ return i;
743
+ }
744
+ }
745
+ return -1;
746
+ }
747
+
748
+ void purge_missing_tokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
749
+ {
750
+ //scan from start old and new ctx, until first mismatch found, save as p0
751
+ //check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
752
+ //test: longest common subseq (LCQ) MUST start within 0 tokens from end of memory, otherwise purge fails
753
+ //if passed, save beginning of LCQ from old ctx as p1
754
+ //remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
755
+
756
+ const int short_fall_threshold = 200 + (nctx/30); //dont trigger shifting if the distance between trimstart and currhead < this
757
+ const int stack_allowance = 60 + (nctx/50); //in case the end text is slightly modified, be forgiving
758
+
759
+ int trimstart = 0;
760
+ int new_tokens_len = new_context_tokens.size();
761
+ bool purge_needed = true;
762
+
763
+ for (int i = 0; i < current_context_tokens.size(); ++i)
764
+ {
765
+ if (current_context_tokens[i] == new_context_tokens[i])
766
+ {
767
+ trimstart += 1;
768
+ }
769
+ else
770
+ {
771
+ break;
772
+ }
773
+ if ((i + 2) >= new_tokens_len)
774
+ {
775
+ purge_needed = false;
776
+ break; //no surgery required
777
+ }
778
+ }
779
+
780
+
781
+
782
+ if(!purge_needed || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < short_fall_threshold)
783
+ {
784
+ LLAMA_LOG_INFO("Fall Threshold: %d out of %d\n", new_tokens_len - trimstart, short_fall_threshold);
785
+ return; //no purge is needed
786
+ }
787
+
788
+ //at least this many tokens need to match, otherwise don't bother trimming
789
+ const int lc_tok_threshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+stack_allowance), (int)(nctx*0.45)), short_fall_threshold - stack_allowance);
790
+
791
+ auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
792
+ auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
793
+
794
+ auto shared = longest_common_subseq(curr_ctx_without_memory, new_ctx_without_memory);
795
+
796
+ if (shared.size() > lc_tok_threshold && arr_start_with(new_ctx_without_memory, shared)) // enough tokens in common
797
+ {
798
+ int found = arr_find_index_of(current_context_tokens,shared);
799
+ if(found>=0 && found > trimstart)
800
+ {
801
+
802
+ //extract the unwanted tokens out from context and KV
803
+ int diff = found - trimstart;
804
+ llama_kv_cache_seq_rm(ctx, 0, trimstart, trimstart + diff);
805
+ llama_kv_cache_seq_add(ctx, 0, trimstart + diff, -1, -diff);
806
+
807
+ for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
808
+ {
809
+ current_context_tokens[i - diff] = current_context_tokens[i];
810
+ }
811
+
812
+ LLAMA_LOG_INFO("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart + 1);
813
+
814
+ current_context_tokens.resize(current_context_tokens.size() - diff);
815
+ }
816
+ }
817
+
818
+ }
819
+
820
+ // End Context Shifting
641
821
  };
642
822
 
643
823
  }