@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -1,6 +1,7 @@
1
1
  #include "llama-kv-cache-unified.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
4
5
  #include "llama-model.h"
5
6
  #include "llama-context.h"
6
7
 
@@ -67,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
67
68
  continue;
68
69
  }
69
70
 
70
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
71
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
71
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
72
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
72
73
 
73
74
  const char * dev_name = "CPU";
74
75
 
@@ -126,15 +127,20 @@ llama_kv_cache_unified::llama_kv_cache_unified(
126
127
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
127
128
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
128
129
  }
130
+
131
+ const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
132
+ debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
129
133
  }
130
134
 
131
- void llama_kv_cache_unified::clear() {
135
+ void llama_kv_cache_unified::clear(bool data) {
132
136
  cells.reset();
133
137
 
134
138
  head = 0;
135
139
 
136
- for (auto & buf : bufs) {
137
- ggml_backend_buffer_clear(buf.get(), 0);
140
+ if (data) {
141
+ for (auto & buf : bufs) {
142
+ ggml_backend_buffer_clear(buf.get(), 0);
143
+ }
138
144
  }
139
145
  }
140
146
 
@@ -149,12 +155,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
149
155
  p1 = std::numeric_limits<llama_pos>::max();
150
156
  }
151
157
 
152
- for (uint32_t i = 0; i < cells.size(); ++i) {
153
- if (!cells.pos_in(i, p0, p1)) {
154
- continue;
158
+ if (seq_id >= 0) {
159
+ for (uint32_t i = 0; i < cells.size(); ++i) {
160
+ if (!cells.pos_in(i, p0, p1)) {
161
+ continue;
162
+ }
163
+
164
+ if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
165
+ if (new_head == cells.size()) {
166
+ new_head = i;
167
+ }
168
+ }
155
169
  }
170
+ } else {
171
+ // match any sequence
172
+ for (uint32_t i = 0; i < cells.size(); ++i) {
173
+ if (!cells.pos_in(i, p0, p1)) {
174
+ continue;
175
+ }
176
+
177
+ cells.rm(i);
156
178
 
157
- if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
158
179
  if (new_head == cells.size()) {
159
180
  new_head = i;
160
181
  }
@@ -289,32 +310,68 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
289
310
  llama_memory_state_ptr llama_kv_cache_unified::init_batch(
290
311
  const llama_batch & batch,
291
312
  uint32_t n_ubatch,
292
- bool embd_pooled,
293
- bool logits_all) {
294
- GGML_UNUSED(embd_pooled);
313
+ bool embd_all) {
314
+ GGML_UNUSED(embd_all);
295
315
 
296
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
316
+ do {
317
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
297
318
 
298
- std::vector<llama_ubatch> ubatches;
299
- while (sbatch.n_tokens > 0) {
300
- ubatches.push_back(sbatch.split_simple(n_ubatch));
301
- }
319
+ std::vector<llama_ubatch> ubatches;
320
+ while (sbatch.n_tokens > 0) {
321
+ ubatches.push_back(sbatch.split_simple(n_ubatch));
322
+ }
302
323
 
303
- auto heads = prepare(ubatches);
304
- if (heads.empty()) {
305
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
306
- }
324
+ auto heads = prepare(ubatches);
325
+ if (heads.empty()) {
326
+ break;
327
+ }
307
328
 
308
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
309
- this, std::move(sbatch), std::move(heads), std::move(ubatches));
329
+ return std::make_unique<llama_kv_cache_unified_state>(
330
+ this, std::move(sbatch), std::move(heads), std::move(ubatches));
331
+ } while (false);
332
+
333
+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
310
334
  }
311
335
 
312
336
  llama_memory_state_ptr llama_kv_cache_unified::init_full() {
313
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
337
+ return std::make_unique<llama_kv_cache_unified_state>(this);
338
+ }
339
+
340
+ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
341
+ bool do_shift = get_has_shift();
342
+
343
+ defrag_info dinfo;
344
+
345
+ // see if we need to defrag
346
+ {
347
+ bool do_defrag = optimize;
348
+
349
+ const auto thold = lctx->get_cparams().defrag_thold;
350
+
351
+ if (!do_defrag && thold > 0.0f) {
352
+ const auto n_kv = cells.used_max_p1();
353
+
354
+ // - do not defrag small contexts (i.e. < 2048 tokens)
355
+ // - count the padding towards the number of used tokens
356
+ const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
357
+
358
+ if (fragmentation > thold) {
359
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
360
+
361
+ do_defrag = true;
362
+ }
363
+ }
364
+
365
+ if (do_defrag) {
366
+ dinfo = defrag_prepare(lctx->graph_max_nodes());
367
+ }
368
+ }
369
+
370
+ return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
314
371
  }
315
372
 
316
- std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
317
- std::vector<uint32_t> res;
373
+ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
374
+ llama_kv_cache_unified::ubatch_heads res;
318
375
 
319
376
  struct state {
320
377
  uint32_t head_old; // old position of the head, before placing the ubatch
@@ -359,12 +416,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
359
416
  return res;
360
417
  }
361
418
 
362
- bool llama_kv_cache_unified::update(llama_context & lctx) {
419
+ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
363
420
  bool updated = false;
364
421
 
365
- auto * sched = lctx.get_sched();
422
+ auto * sched = lctx->get_sched();
366
423
 
367
- if (cells.get_has_shift()) {
424
+ if (do_shift) {
368
425
  if (!get_can_shift()) {
369
426
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
370
427
  }
@@ -375,9 +432,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
375
432
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
376
433
  ggml_backend_sched_reset(sched);
377
434
 
378
- auto * gf = lctx.graph_init();
435
+ auto * gf = lctx->graph_init();
379
436
 
380
- auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
437
+ auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
381
438
  if (!res) {
382
439
  LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
383
440
  return updated;
@@ -390,7 +447,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
390
447
 
391
448
  res->set_inputs(nullptr);
392
449
 
393
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
450
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
394
451
  LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
395
452
  return updated;
396
453
  }
@@ -401,54 +458,53 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
401
458
  cells.reset_shift();
402
459
  }
403
460
 
404
- if (do_defrag) {
461
+ if (!dinfo.empty()) {
405
462
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
406
463
 
407
- if (defrag_prepare(lctx.graph_max_nodes())) {
408
- ggml_backend_sched_reset(sched);
464
+ // apply moves:
465
+ {
466
+ const auto n_kv = dinfo.ids.size();
409
467
 
410
- auto * gf = lctx.graph_init();
468
+ for (uint32_t i = 0; i < n_kv; ++i) {
469
+ assert(dinfo.ids[i] <= n_kv);
411
470
 
412
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
413
- if (!res) {
414
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
415
- return updated;
416
- }
417
-
418
- if (!ggml_backend_sched_alloc_graph(sched, gf)) {
419
- LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
420
- return updated;
421
- }
422
-
423
- res->set_inputs(nullptr);
471
+ if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
472
+ continue;
473
+ }
424
474
 
425
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
426
- LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
427
- return updated;
475
+ cells.mv(i, dinfo.ids[i]);
428
476
  }
429
477
 
430
- updated = true;
478
+ // reset the head so we can find the first free slot during the next ubatch
479
+ head = 0;
431
480
  }
432
481
 
433
- do_defrag = false;
434
- }
482
+ ggml_backend_sched_reset(sched);
435
483
 
436
- return updated;
437
- }
484
+ auto * gf = lctx->graph_init();
438
485
 
439
- void llama_kv_cache_unified::defrag_sched(float thold) {
440
- const auto n_kv = cells.used_max_p1();
486
+ auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
487
+ if (!res) {
488
+ LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
489
+ return updated;
490
+ }
491
+
492
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
493
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
494
+ return updated;
495
+ }
441
496
 
442
- // - do not defrag small contexts (i.e. < 2048 tokens)
443
- // - count the padding towards the number of used tokens
444
- const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
497
+ res->set_inputs(nullptr);
445
498
 
446
- // queue defragmentation for next llama_kv_cache_update
447
- if (fragmentation > thold) {
448
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
499
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
500
+ LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
501
+ return updated;
502
+ }
449
503
 
450
- do_defrag = true;
504
+ updated = true;
451
505
  }
506
+
507
+ return updated;
452
508
  }
453
509
 
454
510
  int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
@@ -462,43 +518,68 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
462
518
  head_cur = 0;
463
519
  }
464
520
 
465
- // otherwise, one cell per token.
466
-
467
521
  if (n_tokens > cells.size()) {
468
522
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
469
523
  return -1;
470
524
  }
471
525
 
472
- //#define FIND_SLOT_DEBUG 1
473
- #if FIND_SLOT_DEBUG
474
- LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
526
+ if (debug > 0) {
527
+ LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
475
528
 
476
- // for debugging
477
- {
478
- std::string ss;
479
- if (n_swa > 0) {
529
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
530
+ std::string ss;
480
531
  for (uint32_t i = 0; i < cells.size(); ++i) {
481
532
  if (cells.is_empty(i)) {
482
533
  ss += '.';
483
534
  } else {
484
- ss += std::to_string(cells.seq_get(i));
535
+ assert(cells.seq_count(i) >= 1);
536
+
537
+ if (cells.seq_count(i) == 1) {
538
+ ss += std::to_string(cells.seq_get(i));
539
+ } else {
540
+ ss += 'M';
541
+ }
485
542
  }
486
543
  if (i%256 == 255) {
544
+ ss += " *";
487
545
  ss += '\n';
488
546
  }
489
547
  }
548
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
490
549
  }
491
- LLAMA_LOG_WARN("\n%s\n", ss.c_str());
492
- }
493
550
 
494
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
495
- if (cells.seq_pos_min(s) < 0) {
496
- continue;
551
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
552
+ std::string ss;
553
+ for (uint32_t i = 0; i < cells.size(); ++i) {
554
+ std::string cur;
555
+ if (cells.is_empty(i)) {
556
+ cur = '.';
557
+ } else {
558
+ cur = std::to_string(cells.pos_get(i));
559
+ }
560
+ const int n = cur.size();
561
+ for (int j = 0; j < 5 - n; ++j) {
562
+ cur += ' ';
563
+ }
564
+ ss += cur;
565
+ if (i%256 == 255) {
566
+ ss += " *";
567
+ }
568
+ if (i%64 == 63) {
569
+ ss += '\n';
570
+ }
571
+ }
572
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
497
573
  }
498
574
 
499
- LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
575
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
576
+ if (cells.seq_pos_min(s) < 0) {
577
+ continue;
578
+ }
579
+
580
+ LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
581
+ }
500
582
  }
501
- #endif
502
583
 
503
584
  uint32_t n_tested = 0;
504
585
 
@@ -509,21 +590,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
509
590
  continue;
510
591
  }
511
592
 
512
- // keep track of what the minimum sequence positions would be if we accept the ubatch
513
- llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
514
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
515
- seq_pos_min[s] = cells.seq_pos_min(s);
516
- }
517
-
518
593
  bool found = true;
519
594
  for (uint32_t i = 0; i < n_tokens; i++) {
520
- const llama_pos pos = ubatch.pos[i];
521
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
595
+ //const llama_pos pos = ubatch.pos[i];
596
+ //const llama_seq_id seq_id = ubatch.seq_id[i][0];
522
597
 
523
598
  // can we use this cell? either:
524
599
  // - the cell is empty
525
600
  // - the cell is occupied only by one sequence:
526
- // - mask causally, if the sequence is the same as the one we are inserting
601
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
527
602
  // - mask SWA, using current max pos for that sequence in the cache
528
603
  // always insert in the cell with minimum pos
529
604
  bool can_use = cells.is_empty(head_cur + i);
@@ -531,21 +606,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
531
606
  if (!can_use && cells.seq_count(head_cur + i) == 1) {
532
607
  const llama_pos pos_cell = cells.pos_get(head_cur + i);
533
608
 
534
- // causal mask
535
- if (cells.seq_has(head_cur + i, seq_id)) {
536
- can_use = pos_cell >= pos;
537
- }
609
+ // (disabled) causal mask
610
+ // note: it's better to purge any "future" tokens beforehand
611
+ //if (cells.seq_has(head_cur + i, seq_id)) {
612
+ // can_use = pos_cell >= pos;
613
+ //}
538
614
 
539
615
  if (!can_use) {
540
616
  const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
541
617
 
542
618
  // SWA mask
543
- // note: we insert only in the cell with minimum pos in order to preserve the invariant that
544
- // all positions between [pos_min, pos_max] for each sequence will be present in the cache
545
- // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
546
- if (pos_cell == seq_pos_min[seq_id_cell] &&
547
- is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
548
- seq_pos_min[seq_id_cell]++;
619
+ if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
549
620
  can_use = true;
550
621
  }
551
622
  }
@@ -573,18 +644,58 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
573
644
  }
574
645
 
575
646
  void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
576
- for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
577
- if (!cells.is_empty(head_cur + i)) {
578
- cells.rm(head_cur + i);
579
- }
647
+ if (debug > 0) {
648
+ LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
649
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
650
+ LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
651
+ }
652
+
653
+ // keep track of the max sequence position that we would overwrite with this ubatch
654
+ // for non-SWA cache, this would be always empty
655
+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
656
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
657
+ seq_pos_max_rm[s] = -1;
658
+ }
659
+
660
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
661
+ for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
662
+ const uint32_t idx = s*ubatch.n_seq_tokens + j;
663
+
664
+ if (!cells.is_empty(head_cur + idx)) {
665
+ assert(cells.seq_count(head_cur + idx) == 1);
580
666
 
581
- cells.pos_set(head_cur + i, ubatch.pos[i]);
667
+ const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
668
+ const llama_pos pos = cells.pos_get(head_cur + idx);
582
669
 
583
- for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
584
- cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
670
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
671
+
672
+ cells.rm(head_cur + idx);
673
+ }
674
+
675
+ cells.pos_set(head_cur + idx, ubatch.pos[idx]);
676
+
677
+ // TODO: fix indexing [UBATCH_IDX]
678
+ for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
679
+ cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
680
+ }
585
681
  }
586
682
  }
587
683
 
684
+ // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
685
+ // will be present in the cache. so we have to purge any position which is less than those we would overwrite
686
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
687
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
688
+ if (seq_pos_max_rm[s] == -1) {
689
+ continue;
690
+ }
691
+
692
+ if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
693
+ LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
694
+ __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
695
+
696
+ seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
697
+ }
698
+ }
588
699
  // move the head at the end of the slot
589
700
  head = head_cur + ubatch.n_tokens;
590
701
  }
@@ -597,6 +708,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
597
708
  return cells.size();
598
709
  }
599
710
 
711
+ bool llama_kv_cache_unified::get_has_shift() const {
712
+ return cells.get_has_shift();
713
+ }
714
+
600
715
  uint32_t llama_kv_cache_unified::get_n_kv() const {
601
716
  return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
602
717
  }
@@ -677,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
677
792
  }
678
793
 
679
794
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
680
- const int64_t n_tokens = ubatch->n_tokens;
681
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
682
- const int64_t n_seqs = ubatch->n_seqs;
795
+ const uint32_t n_tokens = ubatch->n_tokens;
796
+ const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
797
+ const uint32_t n_seqs = ubatch->n_seqs;
683
798
 
684
799
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
685
800
  float * data = (float *) dst->data;
686
801
 
687
- const auto n_kv = dst->ne[0];
802
+ const int64_t n_kv = dst->ne[0];
688
803
 
689
804
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
690
805
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -698,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
698
813
  // xxxxx-----
699
814
  // xxxxx-----
700
815
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
701
- for (int h = 0; h < 1; ++h) {
702
- for (int s = 0; s < n_seqs; ++s) {
816
+ for (uint32_t h = 0; h < 1; ++h) {
817
+ for (uint32_t s = 0; s < n_seqs; ++s) {
703
818
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
704
819
 
705
- for (int j = 0; j < n_seq_tokens; ++j) {
706
- const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
820
+ for (uint32_t j = 0; j < n_seq_tokens; ++j) {
821
+ const uint32_t idx = s*n_seq_tokens + j;
822
+
823
+ const llama_pos p1 = ubatch->pos[idx];
707
824
 
708
825
  for (uint32_t i = 0; i < n_kv; ++i) {
709
826
  float f = 0.0f;
@@ -733,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
733
850
  f = -INFINITY;
734
851
  }
735
852
 
736
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
853
+ data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
737
854
  }
738
855
  }
739
856
  }
740
857
 
741
858
  // mask padded tokens
742
859
  if (data) {
743
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
744
- for (uint32_t j = 0; j < n_kv; ++j) {
745
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
860
+ for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
861
+ for (uint32_t i = 0; i < n_kv; ++i) {
862
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
746
863
  }
747
864
  }
748
865
  }
@@ -890,11 +1007,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
890
1007
  const auto & n_embd_head_k = hparams.n_embd_head_k;
891
1008
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
892
1009
 
893
- //GGML_ASSERT(kv_self->size == n_ctx);
894
-
895
1010
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
896
1011
 
897
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
1012
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
898
1013
  ggml_set_input(inp->k_shift);
899
1014
 
900
1015
  for (const auto & layer : layers) {
@@ -926,12 +1041,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
926
1041
  }
927
1042
 
928
1043
  llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
929
- const llama_cparams & cparams,
930
- ggml_context * ctx,
931
- ggml_cgraph * gf) const {
1044
+ const llama_cparams & cparams,
1045
+ ggml_context * ctx,
1046
+ ggml_cgraph * gf,
1047
+ const defrag_info & dinfo) const {
932
1048
  auto res = std::make_unique<llm_graph_result>();
933
1049
 
934
- const auto & ids = defrag_info.ids;
1050
+ const auto & ids = dinfo.ids;
935
1051
 
936
1052
  #if 0
937
1053
  // CPU defrag
@@ -1072,7 +1188,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1072
1188
  return res;
1073
1189
  }
1074
1190
 
1075
- bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1191
+ llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
1076
1192
  const uint32_t n_layer = layers.size();
1077
1193
 
1078
1194
  const uint32_t n_kv = cells.used_max_p1();
@@ -1093,14 +1209,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1093
1209
  const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1094
1210
 
1095
1211
  // determine which KV cells to move where
1096
- //
1097
- // cell i moves to ids[i]
1098
- //
1099
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
1100
- //
1101
- auto & ids = defrag_info.ids;
1212
+ defrag_info res;
1213
+ auto & ids = res.ids;
1102
1214
 
1103
- ids.clear();
1104
1215
  ids.resize(n_kv, n_kv);
1105
1216
 
1106
1217
  for (uint32_t i0 = 0; i0 < n_used; ++i0) {
@@ -1164,11 +1275,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1164
1275
  // this cell goes to (i0 + nf)
1165
1276
  ids[i1] = i0 + nf;
1166
1277
 
1167
- // move the cell meta data
1168
- cells.mv(i1, i0 + nf);
1169
-
1170
- head = n_used;
1171
-
1172
1278
  if (!cont) {
1173
1279
  n_moves++;
1174
1280
  cont = true;
@@ -1191,14 +1297,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1191
1297
  }
1192
1298
 
1193
1299
  if (n_moves == 0) {
1194
- return false;
1300
+ return {};
1195
1301
  }
1196
1302
 
1197
1303
  LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
1198
1304
 
1199
1305
  LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
1200
1306
 
1201
- return true;
1307
+ return res;
1202
1308
  }
1203
1309
 
1204
1310
  bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
@@ -1276,7 +1382,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
1276
1382
 
1277
1383
  if (!res) {
1278
1384
  if (seq_id == -1) {
1279
- clear();
1385
+ clear(true);
1280
1386
  } else {
1281
1387
  seq_rm(seq_id, -1, -1);
1282
1388
  }
@@ -1324,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1324
1430
  for (const auto & layer : layers) {
1325
1431
  const uint32_t il = layer.il;
1326
1432
 
1327
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1433
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1328
1434
 
1329
1435
  // Write key type
1330
1436
  const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1346,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1346
1452
  for (const auto & layer : layers) {
1347
1453
  const uint32_t il = layer.il;
1348
1454
 
1349
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1455
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1350
1456
 
1351
1457
  // Write value type
1352
1458
  const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1370,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1370
1476
  for (const auto & layer : layers) {
1371
1477
  const uint32_t il = layer.il;
1372
1478
 
1373
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1479
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1374
1480
 
1375
1481
  // Write value type
1376
1482
  const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1404,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1404
1510
  seq_rm(dest_seq_id, -1, -1);
1405
1511
 
1406
1512
  llama_sbatch sbatch;
1407
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1513
+ llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1408
1514
 
1409
- batch.n_tokens = cell_count;
1515
+ ubatch.n_tokens = cell_count;
1516
+ ubatch.n_seq_tokens = cell_count;
1517
+ ubatch.n_seqs = 1;
1410
1518
 
1411
1519
  for (uint32_t i = 0; i < cell_count; ++i) {
1412
1520
  llama_pos pos;
@@ -1426,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1426
1534
  io.read_to(&seq_id, sizeof(seq_id));
1427
1535
  }
1428
1536
 
1429
- batch.pos[i] = pos;
1430
- batch.n_seq_id[i] = n_seq_id;
1431
- batch.seq_id[i] = &dest_seq_id;
1537
+ ubatch.pos[i] = pos;
1538
+ ubatch.n_seq_id[i] = n_seq_id;
1539
+ ubatch.seq_id[i] = &dest_seq_id;
1432
1540
  }
1433
1541
 
1434
- const auto head_cur = find_slot(batch);
1542
+ const auto head_cur = find_slot(ubatch);
1435
1543
  if (head_cur < 0) {
1436
1544
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1437
1545
  return false;
1438
1546
  }
1439
1547
 
1440
- apply_ubatch(head_cur, batch);
1548
+ apply_ubatch(head_cur, ubatch);
1441
1549
 
1442
1550
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1443
1551
  head = head_cur;
@@ -1445,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1445
1553
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1446
1554
  // Assume that this is one contiguous block of cells
1447
1555
  GGML_ASSERT(head_cur + cell_count <= cells.size());
1448
- GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
1449
- GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
1556
+ GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
1557
+ GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
1450
1558
  GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1451
1559
  GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
1452
1560
  } else {
@@ -1457,7 +1565,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1457
1565
  return false;
1458
1566
  }
1459
1567
 
1460
- clear();
1568
+ clear(true);
1461
1569
 
1462
1570
  for (uint32_t i = 0; i < cell_count; ++i) {
1463
1571
  llama_pos pos;
@@ -1513,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1513
1621
  for (const auto & layer : layers) {
1514
1622
  const uint32_t il = layer.il;
1515
1623
 
1516
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1624
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1517
1625
 
1518
1626
  // Read type of key
1519
1627
  int32_t k_type_i_ref;
@@ -1543,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1543
1651
  for (const auto & layer : layers) {
1544
1652
  const uint32_t il = layer.il;
1545
1653
 
1546
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1654
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1547
1655
 
1548
1656
  // Read type of value
1549
1657
  int32_t v_type_i_ref;
@@ -1573,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1573
1681
  for (const auto & layer : layers) {
1574
1682
  const uint32_t il = layer.il;
1575
1683
 
1576
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1684
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1577
1685
 
1578
1686
  // Read type of value
1579
1687
  int32_t v_type_i_ref;
@@ -1621,24 +1729,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1621
1729
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1622
1730
 
1623
1731
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1624
- llama_memory_status status,
1625
- llama_kv_cache_unified * kv) : status(status), kv(kv) {
1626
- n_kv = kv->get_size();
1627
- head = 0;
1628
- }
1732
+ llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1733
+ n_kv = kv->get_size();
1734
+ head = 0;
1735
+ }
1629
1736
 
1630
1737
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1631
- llama_memory_status status,
1632
- llama_kv_cache_unified * kv,
1633
- llama_sbatch sbatch,
1634
- std::vector<uint32_t> heads,
1635
- std::vector<llama_ubatch> ubatches)
1636
- : status(status),
1637
- kv(kv),
1638
- sbatch(std::move(sbatch)),
1639
- heads(std::move(heads)),
1640
- ubatches(std::move(ubatches)) {
1738
+ llama_kv_cache_unified * kv,
1739
+ llama_context * lctx,
1740
+ bool do_shift,
1741
+ defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1742
+ if (!do_shift && this->dinfo.empty()) {
1743
+ status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1641
1744
  }
1745
+ }
1746
+
1747
+ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1748
+ llama_kv_cache_unified * kv,
1749
+ llama_sbatch sbatch,
1750
+ llama_kv_cache_unified::ubatch_heads heads,
1751
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1752
+ }
1642
1753
 
1643
1754
  llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1644
1755
 
@@ -1655,6 +1766,13 @@ bool llama_kv_cache_unified_state::next() {
1655
1766
  bool llama_kv_cache_unified_state::apply() {
1656
1767
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1657
1768
 
1769
+ // no ubatches -> this is a KV cache update
1770
+ if (ubatches.empty()) {
1771
+ kv->update(lctx, do_shift, dinfo);
1772
+
1773
+ return true;
1774
+ }
1775
+
1658
1776
  kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1659
1777
 
1660
1778
  n_kv = kv->get_n_kv();