@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -2,8 +2,8 @@
2
2
 
3
3
  #include "llama-batch.h"
4
4
  #include "llama-graph.h"
5
- #include "llama-kv-cache.h"
6
5
  #include "llama-kv-cells.h"
6
+ #include "llama-memory.h"
7
7
 
8
8
  #include <unordered_map>
9
9
  #include <vector>
@@ -17,13 +17,26 @@ struct llama_context;
17
17
  // llama_kv_cache_unified
18
18
  //
19
19
 
20
- class llama_kv_cache_unified : public llama_kv_cache {
20
+ class llama_kv_cache_unified : public llama_memory_i {
21
21
  public:
22
22
  static uint32_t get_padding(const llama_cparams & cparams);
23
23
 
24
24
  // this callback is used to filter out layers that should not be included in the cache
25
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
26
 
27
+ using ubatch_heads = std::vector<uint32_t>;
28
+
29
+ struct defrag_info {
30
+ bool empty() const {
31
+ return ids.empty();
32
+ }
33
+
34
+ // contains information about which cell moves where:
35
+ // - cell i moves to ids[i]
36
+ // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
37
+ std::vector<uint32_t> ids;
38
+ };
39
+
27
40
  llama_kv_cache_unified(
28
41
  const llama_model & model,
29
42
  layer_filter_cb && filter,
@@ -43,7 +56,18 @@ public:
43
56
  // llama_memory_i
44
57
  //
45
58
 
46
- void clear() override;
59
+ llama_memory_state_ptr init_batch(
60
+ const llama_batch & batch,
61
+ uint32_t n_ubatch,
62
+ bool embd_all) override;
63
+
64
+ llama_memory_state_ptr init_full() override;
65
+
66
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
67
+
68
+ bool get_can_shift() const override;
69
+
70
+ void clear(bool data) override;
47
71
 
48
72
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
49
73
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -54,24 +78,6 @@ public:
54
78
  llama_pos seq_pos_min(llama_seq_id seq_id) const override;
55
79
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
56
80
 
57
- //
58
- // llama_kv_cache
59
- //
60
-
61
- llama_memory_state_ptr init_batch(
62
- const llama_batch & batch,
63
- uint32_t n_ubatch,
64
- bool embd_pooled,
65
- bool logits_all) override;
66
-
67
- llama_memory_state_ptr init_full() override;
68
-
69
- bool update(llama_context & lctx) override;
70
-
71
- void defrag_sched(float thold) override;
72
-
73
- bool get_can_shift() const override;
74
-
75
81
  // state write/load
76
82
 
77
83
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -83,6 +89,8 @@ public:
83
89
 
84
90
  uint32_t get_size() const;
85
91
 
92
+ bool get_has_shift() const;
93
+
86
94
  //
87
95
  // graph_build API
88
96
  //
@@ -103,7 +111,9 @@ public:
103
111
 
104
112
  // find places for the provided ubatches in the cache, returns the head locations
105
113
  // return empty vector on failure
106
- std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
114
+ ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
115
+
116
+ bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
107
117
 
108
118
  // return the cell position where we can insert the ubatch
109
119
  // return -1 on failure to find a contiguous slot of kv cells
@@ -133,8 +143,7 @@ private:
133
143
  ggml_tensor * v;
134
144
  };
135
145
 
136
- bool do_defrag = false;
137
- bool v_trans = true; // the value tensor is transposed
146
+ bool v_trans = true; // the value tensor is transposed
138
147
 
139
148
  // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
140
149
  // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@@ -148,6 +157,8 @@ private:
148
157
  // SWA
149
158
  const uint32_t n_swa = 0;
150
159
 
160
+ int debug = 0;
161
+
151
162
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
152
163
 
153
164
  std::vector<ggml_context_ptr> ctxs;
@@ -160,13 +171,8 @@ private:
160
171
  // model layer id -> KV cache layer id
161
172
  std::unordered_map<int32_t, int32_t> map_layer_ids;
162
173
 
163
- // defrag
164
- struct {
165
- std::vector<uint32_t> ids;
166
- } defrag_info;
167
-
168
- // return true if cells have been moved
169
- bool defrag_prepare(int32_t n_max_nodes);
174
+ // return non-empty vector if cells have been moved
175
+ defrag_info defrag_prepare(int32_t n_max_nodes) const;
170
176
 
171
177
  size_t total_size() const;
172
178
 
@@ -192,7 +198,8 @@ private:
192
198
  llm_graph_result_ptr build_graph_defrag(
193
199
  const llama_cparams & cparams,
194
200
  ggml_context * ctx,
195
- ggml_cgraph * gf) const;
201
+ ggml_cgraph * gf,
202
+ const defrag_info & dinfo) const;
196
203
 
197
204
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
198
205
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -203,20 +210,29 @@ private:
203
210
 
204
211
  class llama_kv_cache_unified_state : public llama_memory_state_i {
205
212
  public:
213
+ // some shorthands
214
+ using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
+ using defrag_info = llama_kv_cache_unified::defrag_info;
216
+
206
217
  // used for errors
207
218
  llama_kv_cache_unified_state(llama_memory_status status);
208
219
 
209
220
  // used to create a full-cache state
210
221
  llama_kv_cache_unified_state(
211
- llama_memory_status status,
212
222
  llama_kv_cache_unified * kv);
213
223
 
214
- // used to create a state from a batch
224
+ // used to create an update state
225
+ llama_kv_cache_unified_state(
226
+ llama_kv_cache_unified * kv,
227
+ llama_context * lctx,
228
+ bool do_shift,
229
+ defrag_info dinfo);
230
+
231
+ // used to create a decode state from a batch
215
232
  llama_kv_cache_unified_state(
216
- llama_memory_status status,
217
233
  llama_kv_cache_unified * kv,
218
234
  llama_sbatch sbatch,
219
- std::vector<uint32_t> heads,
235
+ ubatch_heads heads,
220
236
  std::vector<llama_ubatch> ubatches);
221
237
 
222
238
  virtual ~llama_kv_cache_unified_state();
@@ -253,16 +269,30 @@ public:
253
269
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
254
270
 
255
271
  private:
256
- const llama_memory_status status;
272
+ llama_memory_status status;
257
273
 
258
274
  llama_kv_cache_unified * kv;
275
+ llama_context * lctx;
276
+
277
+ //
278
+ // update state
279
+ //
280
+
281
+ bool do_shift = false;
282
+
283
+ defrag_info dinfo;
284
+
285
+ //
286
+ // batch processing state
287
+ //
259
288
 
260
289
  llama_sbatch sbatch;
261
290
 
262
291
  // the index of the next ubatch to process
263
292
  size_t i_next = 0;
264
293
 
265
- std::vector<uint32_t> heads;
294
+ ubatch_heads heads;
295
+
266
296
  std::vector<llama_ubatch> ubatches;
267
297
 
268
298
  //
@@ -23,7 +23,7 @@ public:
23
23
 
24
24
  used.clear();
25
25
 
26
- for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
26
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
27
27
  seq_pos[s].clear();
28
28
  }
29
29
  }
@@ -80,6 +80,9 @@ public:
80
80
  assert(isrc < pos.size());
81
81
  assert(idst < pos.size());
82
82
 
83
+ assert(pos[idst] == -1);
84
+ assert(pos[isrc] != -1);
85
+
83
86
  pos [idst] = pos [isrc];
84
87
  shift[idst] = shift[isrc];
85
88
  seq [idst] = seq [isrc];
@@ -144,9 +147,10 @@ public:
144
147
  assert(pos[i] != -1);
145
148
 
146
149
  seq_pos_rm(i);
150
+ seq[i].reset();
147
151
 
148
152
  pos[i] = -1;
149
- seq[i].reset();
153
+ shift[i] = 0;
150
154
 
151
155
  used.erase(i);
152
156
  }
@@ -164,6 +168,7 @@ public:
164
168
 
165
169
  if (seq[i].none()) {
166
170
  pos[i] = -1;
171
+ shift[i] = 0;
167
172
 
168
173
  used.erase(i);
169
174
 
@@ -192,6 +197,7 @@ public:
192
197
  seq[i].reset();
193
198
 
194
199
  pos[i] = -1;
200
+ shift[i] = 0;
195
201
 
196
202
  used.erase(i);
197
203
 
@@ -234,7 +240,7 @@ public:
234
240
  llama_seq_id seq_get(uint32_t i) const {
235
241
  assert(seq[i].count() == 1);
236
242
 
237
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
243
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
238
244
  if (seq[i].test(s)) {
239
245
  return s;
240
246
  }
@@ -247,7 +253,7 @@ public:
247
253
  // return -1 if the sequence is not present
248
254
  llama_pos seq_pos_min(llama_seq_id seq_id) const {
249
255
  assert(seq_id >= 0);
250
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
256
+ assert(seq_id < LLAMA_MAX_SEQ);
251
257
 
252
258
  if (seq_pos[seq_id].empty()) {
253
259
  return -1;
@@ -260,7 +266,7 @@ public:
260
266
  // return -1 if the sequence is not present
261
267
  llama_pos seq_pos_max(llama_seq_id seq_id) const {
262
268
  assert(seq_id >= 0);
263
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
269
+ assert(seq_id < LLAMA_MAX_SEQ);
264
270
 
265
271
  if (seq_pos[seq_id].empty()) {
266
272
  return -1;
@@ -317,21 +323,20 @@ public:
317
323
  pos[i] += d;
318
324
  shift[i] += d;
319
325
 
320
- seq_pos_add(i);
321
-
322
326
  has_shift = true;
323
327
 
324
328
  if (pos[i] < 0) {
325
- seq_pos_rm(i);
326
-
327
329
  seq[i].reset();
328
330
  pos[i] = -1;
331
+ shift[i] = 0;
329
332
 
330
333
  used.erase(i);
331
334
 
332
335
  return true;
333
336
  }
334
337
 
338
+ seq_pos_add(i);
339
+
335
340
  return false;
336
341
  }
337
342
 
@@ -379,20 +384,20 @@ private:
379
384
  //
380
385
  std::vector<llama_pos> shift;
381
386
 
382
- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
387
+ using bits_t = std::bitset<LLAMA_MAX_SEQ>;
383
388
 
384
389
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
385
390
  std::vector<bits_t> seq;
386
391
 
387
392
  // the set seq_pos[s] tells us which positions are currently present for sequence s
388
393
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
389
- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
394
+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
390
395
 
391
396
  // helper functions for updating `seq_pos`, once cell at a time:
392
397
 
393
398
  // remove cell i
394
399
  void seq_pos_rm(uint32_t i) {
395
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
400
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
396
401
  if (seq[i].test(s)) {
397
402
  seq_pos[s].erase(pos[i]);
398
403
  }
@@ -401,7 +406,7 @@ private:
401
406
 
402
407
  // add cell i
403
408
  void seq_pos_add(uint32_t i) {
404
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
409
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
405
410
  if (seq[i].test(s)) {
406
411
  seq_pos[s].insert(pos[i]);
407
412
  }
@@ -0,0 +1,247 @@
1
+ #include "llama-memory-hybrid.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-model.h"
5
+ #include "llama-context.h"
6
+
7
+ //
8
+ // llama_memory_hybrid
9
+ //
10
+
11
+ llama_memory_hybrid::llama_memory_hybrid(
12
+ const llama_model & model,
13
+ /* attn */
14
+ ggml_type type_k,
15
+ ggml_type type_v,
16
+ bool v_trans,
17
+ uint32_t kv_size,
18
+ uint32_t n_pad,
19
+ uint32_t n_swa,
20
+ llama_swa_type swa_type,
21
+ /* recurrent */
22
+ ggml_type type_r,
23
+ ggml_type type_s,
24
+ uint32_t rs_size,
25
+ /* common */
26
+ uint32_t n_seq_max,
27
+ bool offload,
28
+ /* layer filters */
29
+ layer_filter_cb && filter_attn,
30
+ layer_filter_cb && filter_recr) :
31
+ hparams(model.hparams),
32
+ mem_attn(new llama_kv_cache_unified(
33
+ model,
34
+ filter_attn == nullptr ?
35
+ [&](int32_t il) { return !model.hparams.is_recurrent(il); }
36
+ : filter_attn,
37
+ type_k,
38
+ type_v,
39
+ v_trans,
40
+ offload,
41
+ kv_size,
42
+ n_seq_max,
43
+ n_pad,
44
+ n_swa,
45
+ swa_type
46
+ )),
47
+ mem_recr(new llama_memory_recurrent(
48
+ model,
49
+ filter_recr == nullptr ?
50
+ [&](int32_t il) { return model.hparams.is_recurrent(il); }
51
+ : filter_recr,
52
+ type_r,
53
+ type_s,
54
+ offload,
55
+ rs_size,
56
+ n_seq_max
57
+ )) {}
58
+
59
+ llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
60
+
61
+ // since this includes a recurrent cache, we cannot use split_simple
62
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
63
+
64
+ // follow the recurrent pattern for creating the ubatch splits
65
+ std::vector<llama_ubatch> ubatches;
66
+ while (sbatch.n_tokens > 0) {
67
+ llama_ubatch ubatch;
68
+
69
+ if (embd_pooled) {
70
+ // Pooled embeddings cannot be split across ubatches (yet)
71
+ ubatch = sbatch.split_seq(n_ubatch);
72
+ } else {
73
+ ubatch = sbatch.split_equal(n_ubatch);
74
+ }
75
+
76
+ ubatches.push_back(ubatch);
77
+ }
78
+
79
+ // prepare the recurrent batches first
80
+ if (!mem_recr->prepare(ubatches)) {
81
+ // TODO: will the recurrent cache be in an undefined state at this point?
82
+ LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
83
+ return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
84
+ }
85
+
86
+ // prepare the attention cache
87
+ auto heads_attn = mem_attn->prepare(ubatches);
88
+ if (heads_attn.empty()) {
89
+ LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
90
+ return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
91
+ }
92
+
93
+ return std::make_unique<llama_memory_hybrid_state>(
94
+ this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
95
+ }
96
+
97
+ llama_memory_state_ptr llama_memory_hybrid::init_full() {
98
+ return std::make_unique<llama_memory_hybrid_state>(this);
99
+ }
100
+
101
+ llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
102
+ return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
103
+ }
104
+
105
+ bool llama_memory_hybrid::get_can_shift() const {
106
+ // Shifting is trivially supported for recurrent
107
+ return mem_attn->get_can_shift();
108
+ }
109
+
110
+ void llama_memory_hybrid::clear(bool data) {
111
+ mem_attn->clear(data);
112
+ mem_recr->clear(data);
113
+ }
114
+
115
+ bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
116
+ // Try removing from the recurrent cache first since it may fail. If it does
117
+ // fail, the cache will not have been mutated.
118
+ if (!mem_recr->seq_rm(seq_id, p0, p1)) {
119
+ return false;
120
+ }
121
+ return mem_attn->seq_rm(seq_id, p0, p1);
122
+ }
123
+
124
+ void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
125
+ mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
126
+ mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
127
+ }
128
+
129
+ void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
130
+ mem_attn->seq_keep(seq_id);
131
+ mem_recr->seq_keep(seq_id);
132
+ }
133
+
134
+ void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
135
+ mem_attn->seq_add(seq_id, p0, p1, shift);
136
+ mem_recr->seq_add(seq_id, p0, p1, shift);
137
+ }
138
+
139
+ void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
140
+ mem_attn->seq_div(seq_id, p0, p1, d);
141
+ mem_recr->seq_div(seq_id, p0, p1, d);
142
+ }
143
+
144
+ llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
145
+ // the min of the total cache is the max of the two caches' min values
146
+ return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
147
+ }
148
+
149
+ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
150
+ // the max of the total cache is the min of the two caches' max values
151
+ return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
152
+ }
153
+
154
+ void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
155
+ mem_attn->state_write(io, seq_id);
156
+ mem_recr->state_write(io, seq_id);
157
+ }
158
+
159
+ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
160
+ mem_attn->state_read(io, seq_id);
161
+ mem_recr->state_read(io, seq_id);
162
+ }
163
+
164
+ llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
165
+ return mem_attn.get();
166
+ }
167
+
168
+ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
169
+ return mem_recr.get();
170
+ }
171
+
172
+ llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
173
+
174
+ llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
175
+ state_attn(mem->get_mem_attn()->init_full()),
176
+ state_recr(mem->get_mem_recr()->init_full()),
177
+ status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
178
+ }
179
+
180
+ llama_memory_hybrid_state::llama_memory_hybrid_state(
181
+ llama_memory_hybrid * mem,
182
+ llama_context * lctx,
183
+ bool optimize) :
184
+ state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
185
+ state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
186
+ status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
187
+ }
188
+
189
+ llama_memory_hybrid_state::llama_memory_hybrid_state(
190
+ llama_memory_hybrid * mem,
191
+ llama_sbatch sbatch,
192
+ std::vector<uint32_t> heads_attn,
193
+ std::vector<llama_ubatch> ubatches) :
194
+ sbatch(std::move(sbatch)),
195
+ ubatches(std::move(ubatches)),
196
+ // note: here we copy the ubatches. not sure if this is ideal
197
+ state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
198
+ state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)),
199
+ status(LLAMA_MEMORY_STATUS_SUCCESS) {
200
+ }
201
+
202
+ bool llama_memory_hybrid_state::next() {
203
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
204
+
205
+ state_attn->next();
206
+ state_recr->next();
207
+
208
+ if (++i_next >= ubatches.size()) {
209
+ return false;
210
+ }
211
+
212
+ return true;
213
+ }
214
+
215
+ bool llama_memory_hybrid_state::apply() {
216
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
217
+
218
+ bool res = true;
219
+
220
+ res = res & state_attn->apply();
221
+ res = res & state_recr->apply();
222
+
223
+ return res;
224
+ }
225
+
226
+ std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
227
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
228
+
229
+ return sbatch.out_ids;
230
+ }
231
+
232
+ llama_memory_status llama_memory_hybrid_state::get_status() const {
233
+ return status;
234
+ }
235
+
236
+ const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
237
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
238
+ return ubatches[i_next];
239
+ }
240
+
241
+ const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
242
+ return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
243
+ }
244
+
245
+ const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
246
+ return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
247
+ }