cui-llama.rn 1.4.4 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +54 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1085
  14. package/cpp/chat.h +143 -0
  15. package/cpp/common.cpp +1562 -1996
  16. package/cpp/common.h +677 -744
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-alloc.c +1039 -1030
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +255 -255
  21. package/cpp/ggml-backend-reg.cpp +586 -582
  22. package/cpp/ggml-backend.cpp +2004 -2002
  23. package/cpp/ggml-backend.h +354 -354
  24. package/cpp/ggml-common.h +1857 -1851
  25. package/cpp/ggml-cpp.h +39 -39
  26. package/cpp/ggml-cpu-aarch64.cpp +5725 -4247
  27. package/cpp/ggml-cpu-aarch64.h +8 -8
  28. package/cpp/ggml-cpu-impl.h +512 -380
  29. package/cpp/ggml-cpu-quants.c +13026 -11517
  30. package/cpp/ggml-cpu-traits.cpp +36 -36
  31. package/cpp/ggml-cpu-traits.h +38 -38
  32. package/cpp/ggml-cpu.c +3438 -14485
  33. package/cpp/ggml-cpu.cpp +655 -633
  34. package/cpp/ggml-cpu.h +138 -135
  35. package/cpp/ggml-impl.h +594 -567
  36. package/cpp/ggml-metal-impl.h +312 -3
  37. package/cpp/ggml-metal.h +66 -66
  38. package/cpp/ggml-metal.m +5360 -5002
  39. package/cpp/ggml-opt.cpp +854 -854
  40. package/cpp/ggml-opt.h +216 -216
  41. package/cpp/ggml-quants.c +5238 -5238
  42. package/cpp/ggml-threading.h +14 -14
  43. package/cpp/ggml.c +6618 -6524
  44. package/cpp/ggml.h +2222 -2194
  45. package/cpp/gguf.cpp +1330 -1329
  46. package/cpp/gguf.h +202 -202
  47. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  48. package/cpp/json-schema-to-grammar.h +21 -22
  49. package/cpp/json.hpp +24766 -24766
  50. package/cpp/llama-adapter.cpp +382 -347
  51. package/cpp/llama-adapter.h +76 -74
  52. package/cpp/llama-arch.cpp +1714 -1492
  53. package/cpp/llama-arch.h +428 -402
  54. package/cpp/llama-batch.cpp +368 -368
  55. package/cpp/llama-batch.h +88 -88
  56. package/cpp/llama-chat.cpp +640 -587
  57. package/cpp/llama-chat.h +56 -53
  58. package/cpp/llama-context.cpp +2831 -1775
  59. package/cpp/llama-context.h +265 -128
  60. package/cpp/llama-cparams.cpp +1 -1
  61. package/cpp/llama-cparams.h +38 -37
  62. package/cpp/llama-cpp.h +30 -30
  63. package/cpp/llama-grammar.cpp +1219 -1219
  64. package/cpp/llama-grammar.h +173 -164
  65. package/cpp/llama-graph.cpp +1695 -0
  66. package/cpp/llama-graph.h +592 -0
  67. package/cpp/llama-hparams.cpp +79 -71
  68. package/cpp/llama-hparams.h +156 -139
  69. package/cpp/llama-impl.cpp +167 -167
  70. package/cpp/llama-impl.h +61 -61
  71. package/cpp/llama-io.cpp +15 -0
  72. package/cpp/llama-io.h +35 -0
  73. package/cpp/llama-kv-cache.cpp +1380 -718
  74. package/cpp/llama-kv-cache.h +213 -218
  75. package/cpp/llama-memory.cpp +1 -0
  76. package/cpp/llama-memory.h +21 -0
  77. package/cpp/llama-mmap.cpp +600 -590
  78. package/cpp/llama-mmap.h +68 -68
  79. package/cpp/llama-model-loader.cpp +1129 -1124
  80. package/cpp/llama-model-loader.h +169 -167
  81. package/cpp/llama-model.cpp +13080 -4023
  82. package/cpp/llama-model.h +409 -370
  83. package/cpp/llama-sampling.cpp +2563 -2525
  84. package/cpp/llama-sampling.h +32 -32
  85. package/cpp/llama-vocab.cpp +3295 -3252
  86. package/cpp/llama-vocab.h +125 -125
  87. package/cpp/llama.cpp +351 -10137
  88. package/cpp/llama.h +1434 -1340
  89. package/cpp/log.cpp +427 -423
  90. package/cpp/log.h +132 -132
  91. package/cpp/{chat-template.hpp → minja/chat-template.hpp} +537 -529
  92. package/cpp/{minja.hpp → minja/minja.hpp} +2941 -2883
  93. package/cpp/ops.cpp +8723 -0
  94. package/cpp/ops.h +128 -0
  95. package/cpp/rn-llama.cpp +45 -71
  96. package/cpp/rn-llama.h +3 -3
  97. package/cpp/sampling.cpp +573 -532
  98. package/cpp/sgemm.cpp +3043 -2598
  99. package/cpp/sgemm.h +14 -14
  100. package/cpp/simd-mappings.h +888 -0
  101. package/cpp/speculative.cpp +278 -277
  102. package/cpp/speculative.h +28 -28
  103. package/cpp/unary-ops.cpp +186 -0
  104. package/cpp/unary-ops.h +28 -0
  105. package/cpp/vec.cpp +258 -0
  106. package/cpp/vec.h +802 -0
  107. package/ios/CMakeLists.txt +5 -2
  108. package/ios/RNLlama.mm +2 -2
  109. package/ios/RNLlamaContext.mm +40 -24
  110. package/package.json +1 -1
  111. package/src/NativeRNLlama.ts +6 -4
  112. package/src/index.ts +3 -1
  113. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  114. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  115. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  116. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  117. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  118. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  119. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  120. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  122. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  124. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  125. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  126. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  127. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  128. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  129. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  130. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  131. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  132. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  133. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  134. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  135. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  136. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  194. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  195. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  196. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  197. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  198. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  199. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  200. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  201. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  202. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  203. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  204. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  205. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  206. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  207. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  208. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  209. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  210. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  211. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  212. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  213. package/android/src/main/build-arm64/Makefile +0 -1862
  214. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  215. package/cpp/chat.hpp +0 -55
  216. package/cpp/rn-llama.hpp +0 -913
package/cpp/rn-llama.hpp DELETED
@@ -1,913 +0,0 @@
1
- #ifndef RNLLAMA_H
2
- #define RNLLAMA_H
3
-
4
- #include <sstream>
5
- #include <iostream>
6
- #include "common.h"
7
- #include "ggml.h"
8
- #include "gguf.h"
9
- #include "llama.h"
10
- #include "llama-impl.h"
11
- #include "sampling.h"
12
- #if defined(__ANDROID__)
13
- #include <android/log.h>
14
- #endif
15
-
16
- namespace rnllama {
17
-
18
- const std::vector<lm_ggml_type> kv_cache_types = {
19
- LM_GGML_TYPE_F32,
20
- LM_GGML_TYPE_F16,
21
- LM_GGML_TYPE_BF16,
22
- LM_GGML_TYPE_Q8_0,
23
- LM_GGML_TYPE_Q4_0,
24
- LM_GGML_TYPE_Q4_1,
25
- LM_GGML_TYPE_IQ4_NL,
26
- LM_GGML_TYPE_Q5_0,
27
- LM_GGML_TYPE_Q5_1,
28
- };
29
-
30
- static lm_ggml_type kv_cache_type_from_str(const std::string & s) {
31
- for (const auto & type : kv_cache_types) {
32
- if (lm_ggml_type_name(type) == s) {
33
- return type;
34
- }
35
- }
36
- throw std::runtime_error("Unsupported cache type: " + s);
37
- }
38
-
39
- static void llama_batch_clear(llama_batch *batch) {
40
- batch->n_tokens = 0;
41
- }
42
-
43
- static void llama_batch_add(llama_batch *batch, llama_token id, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
44
- batch->token [batch->n_tokens] = id;
45
- batch->pos [batch->n_tokens] = pos;
46
- batch->n_seq_id[batch->n_tokens] = seq_ids.size();
47
- for (size_t i = 0; i < seq_ids.size(); i++) {
48
- batch->seq_id[batch->n_tokens][i] = seq_ids[i];
49
- }
50
- batch->logits [batch->n_tokens] = logits ? 1 : 0;
51
- batch->n_tokens += 1;
52
- }
53
-
54
-
55
- // NOTE: Edit from https://github.com/ggerganov/llama.cpp/blob/master/examples/server/server.cpp
56
-
57
- static void log(const char *level, const char *function, int line,
58
- const char *format, ...)
59
- {
60
- va_list args;
61
- #if defined(__ANDROID__)
62
- char prefix[256];
63
- snprintf(prefix, sizeof(prefix), "%s:%d %s", function, line, format);
64
-
65
- va_start(args, format);
66
- android_LogPriority priority;
67
- if (strcmp(level, "ERROR") == 0) {
68
- priority = ANDROID_LOG_ERROR;
69
- } else if (strcmp(level, "WARNING") == 0) {
70
- priority = ANDROID_LOG_WARN;
71
- } else if (strcmp(level, "INFO") == 0) {
72
- priority = ANDROID_LOG_INFO;
73
- } else {
74
- priority = ANDROID_LOG_DEBUG;
75
- }
76
- __android_log_vprint(priority, "RNLlama", prefix, args);
77
- va_end(args);
78
- #else
79
- printf("[%s] %s:%d ", level, function, line);
80
- va_start(args, format);
81
- vprintf(format, args);
82
- va_end(args);
83
- printf("\n");
84
- #endif
85
- }
86
- static bool rnllama_verbose = false;
87
-
88
- #if RNLLAMA_VERBOSE != 1
89
- #define LOG_VERBOSE(MSG, ...)
90
- #else
91
- #define LOG_VERBOSE(MSG, ...) \
92
- do \
93
- { \
94
- if (rnllama_verbose) \
95
- { \
96
- log("VERBOSE", __func__, __LINE__, MSG, ##__VA_ARGS__); \
97
- } \
98
- } while (0)
99
- #endif
100
-
101
- #define LOG_ERROR(MSG, ...) log("ERROR", __func__, __LINE__, MSG, ##__VA_ARGS__)
102
- #define LOG_WARNING(MSG, ...) log("WARNING", __func__, __LINE__, MSG, ##__VA_ARGS__)
103
- #define LOG_INFO(MSG, ...) log("INFO", __func__, __LINE__, MSG, ##__VA_ARGS__)
104
-
105
- enum stop_type
106
- {
107
- STOP_FULL,
108
- STOP_PARTIAL,
109
- };
110
-
111
- // completion token output with probabilities
112
- struct completion_token_output
113
- {
114
- struct token_prob
115
- {
116
- llama_token tok;
117
- float prob;
118
- };
119
-
120
- std::vector<token_prob> probs;
121
- llama_token tok;
122
- };
123
-
124
- static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
125
- {
126
- size_t i;
127
- for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++)
128
- {
129
- }
130
- return i;
131
- }
132
-
133
- static bool ends_with(const std::string &str, const std::string &suffix)
134
- {
135
- return str.size() >= suffix.size() &&
136
- 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
137
- }
138
-
139
- static size_t find_partial_stop_string(const std::string &stop,
140
- const std::string &text)
141
- {
142
- if (!text.empty() && !stop.empty())
143
- {
144
- const char text_last_char = text.back();
145
- for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
146
- {
147
- if (stop[char_index] == text_last_char)
148
- {
149
- const std::string current_partial = stop.substr(0, char_index + 1);
150
- if (ends_with(text, current_partial))
151
- {
152
- return text.size() - char_index - 1;
153
- }
154
- }
155
- }
156
- }
157
- return std::string::npos;
158
- }
159
-
160
- // format incomplete utf-8 multibyte character for output
161
- static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
162
- {
163
- std::string out = token == -1 ? "" : common_token_to_piece(ctx, token);
164
- // if the size is 1 and first bit is 1, meaning it's a partial character
165
- // (size > 1 meaning it's already a known token)
166
- if (out.size() == 1 && (out[0] & 0x80) == 0x80)
167
- {
168
- std::stringstream ss;
169
- ss << std::hex << (out[0] & 0xff);
170
- std::string res(ss.str());
171
- out = "byte: \\x" + res;
172
- }
173
- return out;
174
- }
175
-
176
- template <class Iter>
177
- static std::string tokens_to_str(llama_context* ctx, Iter begin, Iter end)
178
- {
179
- std::string ret;
180
- for (; begin != end; ++begin)
181
- {
182
- ret += common_token_to_piece(ctx, *begin);
183
- }
184
- return ret;
185
- }
186
-
187
- struct llama_rn_context
188
- {
189
- bool is_predicting = false;
190
- bool is_interrupted = false;
191
- bool has_next_token = false;
192
- std::string generated_text;
193
- std::vector<completion_token_output> generated_token_probs;
194
-
195
- size_t num_prompt_tokens = 0;
196
- size_t num_tokens_predicted = 0;
197
- size_t n_past = 0;
198
- size_t n_remain = 0;
199
-
200
- std::vector<llama_token> embd;
201
-
202
- common_params params;
203
-
204
- common_init_result llama_init;
205
-
206
- llama_model *model = nullptr;
207
- float loading_progress = 0;
208
- bool is_load_interrupted = false;
209
-
210
- llama_context *ctx = nullptr;
211
- common_sampler *ctx_sampling = nullptr;
212
-
213
- int n_ctx;
214
-
215
- bool truncated = false;
216
- bool stopped_eos = false;
217
- bool stopped_word = false;
218
- bool stopped_limit = false;
219
- std::string stopping_word;
220
- bool incomplete = false;
221
-
222
- std::vector<common_adapter_lora_info> lora;
223
-
224
- ~llama_rn_context()
225
- {
226
- if (ctx_sampling != nullptr)
227
- {
228
- common_sampler_free(ctx_sampling);
229
- }
230
- }
231
-
232
- void rewind()
233
- {
234
- is_interrupted = false;
235
- params.antiprompt.clear();
236
- params.sampling.grammar.clear();
237
- num_prompt_tokens = 0;
238
- num_tokens_predicted = 0;
239
- generated_text = "";
240
- generated_text.reserve(params.n_ctx);
241
- generated_token_probs.clear();
242
- truncated = false;
243
- stopped_eos = false;
244
- stopped_word = false;
245
- stopped_limit = false;
246
- stopping_word = "";
247
- incomplete = false;
248
- n_remain = 0;
249
- n_past = 0;
250
- params.sampling.n_prev = n_ctx;
251
- }
252
-
253
- bool initSampling() {
254
- if (ctx_sampling != nullptr) {
255
- common_sampler_free(ctx_sampling);
256
- }
257
- ctx_sampling = common_sampler_init(model, params.sampling);
258
- return ctx_sampling != nullptr;
259
- }
260
-
261
- bool loadModel(common_params &params_)
262
- {
263
- params = params_;
264
- llama_init = common_init_from_params(params);
265
- model = llama_init.model.get();
266
- ctx = llama_init.context.get();
267
- if (model == nullptr)
268
- {
269
- LOG_ERROR("unable to load model: %s", params_.model.c_str());
270
- return false;
271
- }
272
- n_ctx = llama_n_ctx(ctx);
273
-
274
- // We can uncomment for debugging or after this fix: https://github.com/ggerganov/llama.cpp/pull/11101
275
- // LOG_INFO("%s\n", common_params_get_system_info(params).c_str());
276
-
277
- return true;
278
- }
279
-
280
- bool validateModelChatTemplate() const {
281
- llama_chat_message chat[] = {{"user", "test"}};
282
- int32_t chat_res = llama_chat_apply_template(llama_model_chat_template(model), chat, 1, true, nullptr, 0);
283
- return chat_res > 0;
284
- }
285
-
286
- void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
287
- const int n_left = n_ctx - params.n_keep;
288
- const int n_block_size = n_left / 2;
289
- const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size;
290
-
291
- // Keep n_keep tokens at start of prompt (at most n_ctx - 4)
292
- std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
293
-
294
- new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
295
-
296
- LOG_VERBOSE("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, new_tokens: %s, num_prompt_tokens: %d",
297
- n_ctx,
298
- params.n_keep,
299
- n_left,
300
- tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()).c_str(),
301
- new_tokens.size()
302
- );
303
-
304
- truncated = true;
305
- prompt_tokens = new_tokens;
306
- }
307
-
308
- void loadPrompt()
309
- {
310
- std::vector<llama_token> prompt_tokens = ::common_tokenize(llama_model_get_vocab(model), params.prompt, true, true);
311
- num_prompt_tokens = prompt_tokens.size();
312
-
313
- // LOG tokens
314
- std::stringstream ss;
315
- ss << "\n" << __func__ << ": prompt_tokens = ";
316
- for (auto& token : prompt_tokens) {
317
- ss << token << " ";
318
- }
319
- LOG_INFO("%s\n", ss.str().c_str());
320
-
321
- if (params.n_keep < 0)
322
- {
323
- params.n_keep = (int)num_prompt_tokens;
324
- }
325
- params.n_keep = std::min(n_ctx - 4, params.n_keep);
326
-
327
- // if input prompt is too big, truncate like normal
328
- if (num_prompt_tokens >= (size_t) n_ctx)
329
- {
330
- truncatePrompt(prompt_tokens);
331
- num_prompt_tokens = prompt_tokens.size();
332
-
333
- LM_GGML_ASSERT(num_prompt_tokens < (size_t) n_ctx);
334
- }
335
-
336
- // do Context Shift , may be buggy! TODO: Verify functionality
337
- if(!params.embedding){
338
- purge_missing_tokens(ctx, embd, prompt_tokens, params.n_predict, params.n_ctx);
339
- }
340
-
341
- // push the prompt into the sampling context (do not apply grammar)
342
- for (auto & token : prompt_tokens)
343
- {
344
- common_sampler_accept(ctx_sampling, token, false);
345
- }
346
- // compare the evaluated prompt with the new prompt
347
- n_past = params.embedding? 0 : common_part(embd, prompt_tokens);
348
- LOG_INFO("%s: n_past: %zu", __func__, n_past);
349
- LOG_INFO("%s: embd size: %zu", __func__, embd.size());
350
- LOG_INFO("%s: prompt_tokens size: %zu", __func__, prompt_tokens.size());
351
- embd = prompt_tokens;
352
- if (n_past == num_prompt_tokens)
353
- {
354
- // we have to evaluate at least 1 token to generate logits.
355
- n_past--;
356
- }
357
-
358
- // since #3228 we now have to manually manage the KV cache
359
- llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
360
-
361
- LOG_VERBOSE("prompt ingested, n_past: %d, cached: %s, to_eval: %s",
362
- n_past,
363
- tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past).c_str(),
364
- tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()).c_str()
365
- );
366
-
367
- has_next_token = true;
368
- }
369
-
370
- void beginCompletion()
371
- {
372
- // number of tokens to keep when resetting context
373
- n_remain = params.n_predict;
374
- llama_perf_context_reset(ctx);
375
- is_predicting = true;
376
- }
377
-
378
- completion_token_output nextToken()
379
- {
380
- completion_token_output result;
381
- result.tok = -1;
382
-
383
- // this truncation should never trigger with good context shifting
384
- if (embd.size() >= (size_t)params.n_ctx)
385
- {
386
-
387
- const int n_left = n_past - params.n_keep - 1;
388
- const int n_discard = n_left/2;
389
-
390
- llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
391
- llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
392
-
393
- for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
394
- {
395
- embd[i - n_discard] = embd[i];
396
- }
397
- embd.resize(embd.size() - n_discard);
398
-
399
- n_past -= n_discard;
400
-
401
- LOG_VERBOSE("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, new_tokens: %s",
402
- params.n_ctx,
403
- params.n_keep,
404
- n_left
405
- );
406
- }
407
-
408
- bool tg = true;
409
- while (n_past < embd.size())
410
- {
411
- int n_eval = (int)embd.size() - n_past;
412
- tg = n_eval == 1;
413
- if (n_eval > params.n_batch)
414
- {
415
- n_eval = params.n_batch;
416
- }
417
- if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval)))
418
- {
419
-
420
- LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s",
421
- n_eval,
422
- n_past,
423
- params.cpuparams.n_threads,
424
- tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()).c_str()
425
- );
426
- has_next_token = false;
427
- return result;
428
- }
429
- n_past += n_eval;
430
-
431
- if(is_interrupted) {
432
- LOG_INFO("Decoding Interrupted");
433
- embd.resize(n_past);
434
- has_next_token = false;
435
- return result;
436
- }
437
- }
438
-
439
- if (params.n_predict == 0)
440
- {
441
- has_next_token = false;
442
- result.tok = llama_vocab_eos(llama_model_get_vocab(model));
443
- return result;
444
- }
445
-
446
- {
447
- // out of user input, sample next token
448
- std::vector<llama_token_data> candidates;
449
- candidates.reserve(llama_vocab_n_tokens(llama_model_get_vocab(model)));
450
-
451
- result.tok = common_sampler_sample(ctx_sampling, ctx, -1);
452
-
453
- llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling);
454
-
455
- const int32_t n_probs = params.sampling.n_probs;
456
-
457
- // deprecated
458
- /*if (params.sampling.temp <= 0 && n_probs > 0)
459
- {
460
- // For llama_sample_token_greedy we need to sort candidates
461
- llama_sampler_init_softmax();
462
-
463
- }*/
464
-
465
-
466
- for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
467
- {
468
- result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
469
- }
470
-
471
- common_sampler_accept(ctx_sampling, result.tok, true);
472
- if (tg) {
473
- num_tokens_predicted++;
474
- }
475
- }
476
-
477
- // add it to the context
478
- embd.push_back(result.tok);
479
- // decrement remaining sampling budget
480
- --n_remain;
481
-
482
- if (!embd.empty() && embd.back() == llama_vocab_eos(llama_model_get_vocab(model)))
483
- {
484
- // stopping_word = llama_token_to_piece(ctx, embd.back());
485
- has_next_token = false;
486
- stopped_eos = true;
487
- LOG_VERBOSE("eos token found", "");
488
- return result;
489
- }
490
-
491
- has_next_token = params.n_predict == -1 || n_remain != 0;
492
- return result;
493
- }
494
-
495
- size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
496
- const stop_type type)
497
- {
498
- size_t stop_pos = std::string::npos;
499
- for (const std::string &word : params.antiprompt)
500
- {
501
- size_t pos;
502
- if (type == STOP_FULL)
503
- {
504
- const size_t tmp = word.size() + last_token_size;
505
- const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
506
- pos = text.find(word, from_pos);
507
- }
508
- else
509
- {
510
- pos = find_partial_stop_string(word, text);
511
- }
512
- if (pos != std::string::npos &&
513
- (stop_pos == std::string::npos || pos < stop_pos))
514
- {
515
- if (type == STOP_FULL)
516
- {
517
- stopping_word = word;
518
- stopped_word = true;
519
- has_next_token = false;
520
- }
521
- stop_pos = pos;
522
- }
523
- }
524
- return stop_pos;
525
- }
526
-
527
- completion_token_output doCompletion()
528
- {
529
- const completion_token_output token_with_probs = nextToken();
530
-
531
- const std::string token_text = token_with_probs.tok == -1 ? "" : common_token_to_piece(ctx, token_with_probs.tok);
532
- generated_text += token_text;
533
-
534
- if (params.sampling.n_probs > 0)
535
- {
536
- generated_token_probs.push_back(token_with_probs);
537
- }
538
-
539
- // check if there is incomplete UTF-8 character at the end
540
- for (unsigned i = 1; i < 5 && i <= generated_text.size(); ++i) {
541
- unsigned char c = generated_text[generated_text.size() - i];
542
- if ((c & 0xC0) == 0x80) {
543
- // continuation byte: 10xxxxxx
544
- continue;
545
- }
546
- if ((c & 0xE0) == 0xC0) {
547
- // 2-byte character: 110xxxxx ...
548
- incomplete = i < 2;
549
- } else if ((c & 0xF0) == 0xE0) {
550
- // 3-byte character: 1110xxxx ...
551
- incomplete = i < 3;
552
- } else if ((c & 0xF8) == 0xF0) {
553
- // 4-byte character: 11110xxx ...
554
- incomplete = i < 4;
555
- }
556
- // else 1-byte character or invalid byte
557
- break;
558
- }
559
-
560
- if (incomplete && !has_next_token)
561
- {
562
- has_next_token = true;
563
- n_remain++;
564
- }
565
-
566
- if (!has_next_token && n_remain == 0)
567
- {
568
- stopped_limit = true;
569
- }
570
-
571
- LOG_VERBOSE("next token, token: %s, token_text: %s, has_next_token: %d, n_remain: %d, num_tokens_predicted: %d, stopped_eos: %d, stopped_word: %d, stopped_limit: %d, stopping_word: %s",
572
- common_token_to_piece(ctx, token_with_probs.tok),
573
- tokens_to_output_formatted_string(ctx, token_with_probs.tok).c_str(),
574
- has_next_token,
575
- n_remain,
576
- num_tokens_predicted,
577
- stopped_eos,
578
- stopped_word,
579
- stopped_limit,
580
- stopping_word.c_str()
581
- );
582
- return token_with_probs;
583
- }
584
-
585
- std::vector<float> getEmbedding(common_params &embd_params)
586
- {
587
- static const int n_embd = llama_model_n_embd(llama_get_model(ctx));
588
- if (!embd_params.embedding)
589
- {
590
- LOG_WARNING("embedding disabled, embedding: %s", embd_params.embedding);
591
- return std::vector<float>(n_embd, 0.0f);
592
- }
593
- float *data;
594
-
595
- const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
596
- printf("pooling_type: %d\n", pooling_type);
597
- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
598
- data = llama_get_embeddings(ctx);
599
- } else {
600
- data = llama_get_embeddings_seq(ctx, 0);
601
- }
602
-
603
- if (!data) {
604
- return std::vector<float>(n_embd, 0.0f);
605
- }
606
-
607
- std::vector<float> embedding(data, data + n_embd), out(data, data + n_embd);
608
- common_embd_normalize(embedding.data(), out.data(), n_embd, params.embd_normalize);
609
- return out;
610
- }
611
-
612
- std::string bench(int pp, int tg, int pl, int nr)
613
- {
614
- if (is_predicting) {
615
- LOG_ERROR("cannot benchmark while predicting", "");
616
- return std::string("[]");
617
- }
618
-
619
- is_predicting = true;
620
-
621
- double pp_avg = 0;
622
- double tg_avg = 0;
623
-
624
- double pp_std = 0;
625
- double tg_std = 0;
626
-
627
- // TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30)
628
- llama_batch batch = llama_batch_init(
629
- std::min(pp, params.n_ubatch), // max n_tokens is limited by n_ubatch
630
- 0, // No embeddings
631
- 1 // Single sequence
632
- );
633
-
634
- for (int i = 0; i < nr; i++)
635
- {
636
- llama_batch_clear(&batch);
637
-
638
- const int n_tokens = pp;
639
-
640
- for (int i = 0; i < n_tokens; i++)
641
- {
642
- llama_batch_add(&batch, 0, i, {0}, false);
643
- }
644
- batch.logits[batch.n_tokens - 1] = 1; // true
645
-
646
- llama_kv_cache_clear(ctx);
647
-
648
- const int64_t t_pp_start = llama_time_us();
649
- if (llama_decode(ctx, batch) != 0)
650
- {
651
- LOG_ERROR("llama_decode() failed during prompt", "");
652
- }
653
- const int64_t t_pp_end = llama_time_us();
654
- llama_kv_cache_clear(ctx);
655
-
656
- if (is_interrupted) break;
657
-
658
- const int64_t t_tg_start = llama_time_us();
659
-
660
- for (int i = 0; i < tg; i++)
661
- {
662
- llama_batch_clear(&batch);
663
-
664
- for (int j = 0; j < pl; j++)
665
- {
666
- llama_batch_add(&batch, 0, i, {j}, true);
667
- }
668
-
669
- if (llama_decode(ctx, batch) != 0)
670
- {
671
- LOG_ERROR("llama_decode() failed during text generation", "");
672
- }
673
- if (is_interrupted) break;
674
- }
675
-
676
- const int64_t t_tg_end = llama_time_us();
677
-
678
- llama_kv_cache_clear(ctx);
679
-
680
- const double t_pp = (t_pp_end - t_pp_start) / 1000000.0;
681
- const double t_tg = (t_tg_end - t_tg_start) / 1000000.0;
682
-
683
- const double speed_pp = pp / t_pp;
684
- const double speed_tg = (pl * tg) / t_tg;
685
-
686
- pp_avg += speed_pp;
687
- tg_avg += speed_tg;
688
-
689
- pp_std += speed_pp * speed_pp;
690
- tg_std += speed_tg * speed_tg;
691
- }
692
-
693
- pp_avg /= nr;
694
- tg_avg /= nr;
695
-
696
- if (nr > 1) {
697
- pp_std = sqrt(pp_std / (nr - 1) - pp_avg * pp_avg * nr / (nr - 1));
698
- tg_std = sqrt(tg_std / (nr - 1) - tg_avg * tg_avg * nr / (nr - 1));
699
- } else {
700
- pp_std = 0;
701
- tg_std = 0;
702
- }
703
-
704
- if (is_interrupted) llama_kv_cache_clear(ctx);
705
- is_predicting = false;
706
-
707
- char model_desc[128];
708
- llama_model_desc(model, model_desc, sizeof(model_desc));
709
- return std::string("[\"") + model_desc + std::string("\",") +
710
- std::to_string(llama_model_size(model)) + std::string(",") +
711
- std::to_string(llama_model_n_params(model)) + std::string(",") +
712
- std::to_string(pp_avg) + std::string(",") +
713
- std::to_string(pp_std) + std::string(",") +
714
- std::to_string(tg_avg) + std::string(",") +
715
- std::to_string(tg_std) +
716
- std::string("]");
717
- }
718
-
719
- int applyLoraAdapters(std::vector<common_adapter_lora_info> lora) {
720
- for (auto &la : lora) {
721
- la.ptr = llama_adapter_lora_init(model, la.path.c_str());
722
- if (la.ptr == nullptr) {
723
- LOG_ERROR("failed to apply lora adapter '%s'\n", la.path.c_str());
724
- return -1;
725
- }
726
- }
727
- this->lora = lora;
728
- for (auto &la : lora) {
729
- llama_set_adapter_lora(ctx, la.ptr, 1);
730
- }
731
-
732
- return 0;
733
- }
734
-
735
- void removeLoraAdapters() {
736
- for (auto &la : this->lora) {
737
- llama_adapter_lora_free(la.ptr);
738
- }
739
- this->lora.clear();
740
- llama_clear_adapter_lora(ctx);
741
- }
742
-
743
- std::vector<common_adapter_lora_info> getLoadedLoraAdapters() {
744
- return this->lora;
745
- }
746
- // Context Shifting from KoboldCpp <https://github.com/LostRuins/koboldcpp>
747
- // Implementation obtained with special permission from @concedo
748
-
749
- std::vector<int> longest_common_subseq(const std::vector<int> x, const std::vector<int> y){
750
- int m = x.size(), n = y.size();
751
-
752
- //int LCSuff[m+1][n+1];
753
- std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
754
-
755
- for (int j = 0; j <= n; j++)
756
- LCSuff[0][j] = 0;
757
- for (int i = 0; i <= m; i++)
758
- LCSuff[i][0] = 0;
759
-
760
- for (int i = 1; i <= m; i++)
761
- {
762
- for (int j = 1; j <= n; j++)
763
- {
764
- if (x[i - 1] == y[j - 1])
765
- LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1;
766
- else
767
- LCSuff[i][j] = 0;
768
- }
769
- }
770
-
771
- std::vector<int> longest;
772
- for (int i = 1; i <= m; i++)
773
- {
774
- for (int j = 1; j <= n; j++)
775
- {
776
- if (LCSuff[i][j] > longest.size())
777
- {
778
- auto off1 = ((i - LCSuff[i][j] + 1) - 1);
779
- auto off2 = off1 + LCSuff[i][j];
780
- longest.clear();
781
- // std::vector<int>().swap(longest);
782
- longest = std::vector<int>(x.begin() + off1, x.begin() + off2);
783
- // x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]);
784
- }
785
- }
786
- }
787
- return longest;
788
- }
789
-
790
- bool arr_start_with(const std::vector<int> targetArray, const std::vector<int> searchSeq)
791
- {
792
- int ss = searchSeq.size();
793
- if(targetArray.size()<ss)
794
- {
795
- return false;
796
- }
797
- for(int i=0;i<ss;++i)
798
- {
799
- if(targetArray[i]!=searchSeq[i])
800
- {
801
- return false;
802
- }
803
- }
804
- return true;
805
- }
806
-
807
- int arr_find_index_of(const std::vector<int> targetArray, const std::vector<int> searchSeq)
808
- {
809
- int ss = searchSeq.size();
810
- int tas = targetArray.size();
811
- if(tas<ss)
812
- {
813
- return -1;
814
- }
815
- for(int i=0;i<tas;++i)
816
- {
817
- int srch = 0;
818
- bool fail = false;
819
- for(int srch=0;srch<ss;++srch)
820
- {
821
- if ((i + srch) >= tas || targetArray[i + srch] != searchSeq[srch])
822
- {
823
- fail = true;
824
- break;
825
- }
826
- }
827
- if(!fail)
828
- {
829
- return i;
830
- }
831
- }
832
- return -1;
833
- }
834
-
835
- void purge_missing_tokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
836
- {
837
- //scan from start old and new ctx, until first mismatch found, save as p0
838
- //check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
839
- //test: longest common subseq (LCQ) MUST start within 0 tokens from end of memory, otherwise purge fails
840
- //if passed, save beginning of LCQ from old ctx as p1
841
- //remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
842
-
843
- const int short_fall_threshold = 200 + (nctx/30); //dont trigger shifting if the distance between trimstart and currhead < this
844
- const int stack_allowance = 60 + (nctx/50); //in case the end text is slightly modified, be forgiving
845
-
846
- int trimstart = 0;
847
- int new_tokens_len = new_context_tokens.size();
848
- bool purge_needed = true;
849
-
850
- for (int i = 0; i < current_context_tokens.size(); ++i)
851
- {
852
- if (current_context_tokens[i] == new_context_tokens[i])
853
- {
854
- trimstart += 1;
855
- }
856
- else
857
- {
858
- break;
859
- }
860
- if ((i + 2) >= new_tokens_len)
861
- {
862
- purge_needed = false;
863
- break; //no surgery required
864
- }
865
- }
866
-
867
-
868
-
869
- if(!purge_needed || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < short_fall_threshold)
870
- {
871
- LOG_INFO("Fall Threshold: %d out of %d\n", new_tokens_len - trimstart, short_fall_threshold);
872
- return; //no purge is needed
873
- }
874
-
875
- //at least this many tokens need to match, otherwise don't bother trimming
876
- const int lc_tok_threshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+stack_allowance), (int)(nctx*0.45)), short_fall_threshold - stack_allowance);
877
-
878
- auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
879
- auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
880
-
881
- auto shared = longest_common_subseq(curr_ctx_without_memory, new_ctx_without_memory);
882
-
883
- if (shared.size() > lc_tok_threshold && arr_start_with(new_ctx_without_memory, shared)) // enough tokens in common
884
- {
885
- int found = arr_find_index_of(current_context_tokens,shared);
886
- if(found>=0 && found > trimstart)
887
- {
888
-
889
- //extract the unwanted tokens out from context and KV
890
- int diff = found - trimstart;
891
- llama_kv_cache_seq_rm(ctx, 0, trimstart, trimstart + diff);
892
- llama_kv_cache_seq_add(ctx, 0, trimstart + diff, -1, -diff);
893
-
894
- for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
895
- {
896
- current_context_tokens[i - diff] = current_context_tokens[i];
897
- }
898
-
899
- LOG_INFO("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart + 1);
900
-
901
- current_context_tokens.resize(current_context_tokens.size() - diff);
902
- }
903
- }
904
-
905
- }
906
-
907
- // End Context Shifting
908
-
909
- };
910
-
911
- }
912
-
913
- #endif /* LLAMA_H */