@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -4,35 +4,41 @@
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
6
  #include "llama-model.h"
7
+ #include "llama-context.h"
7
8
 
8
9
  #include <algorithm>
9
10
  #include <cassert>
11
+ #include <cmath>
10
12
  #include <limits>
11
13
  #include <map>
12
14
  #include <stdexcept>
13
15
 
14
- static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
16
+ //
17
+ // llama_kv_cache_unified
18
+ //
15
19
 
16
- llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
20
+ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
21
+ // the FA kernels require padding to avoid extra runtime boundary checks
22
+ return cparams.flash_attn ? 256u : 32u;
17
23
  }
18
24
 
19
- bool llama_kv_cache_unified::init(
25
+ llama_kv_cache_unified::llama_kv_cache_unified(
20
26
  const llama_model & model,
21
- const llama_cparams & cparams,
22
27
  ggml_type type_k,
23
28
  ggml_type type_v,
29
+ bool v_trans,
30
+ bool offload,
24
31
  uint32_t kv_size,
25
- bool offload) {
32
+ uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
26
33
  const int32_t n_layer = hparams.n_layer;
27
34
 
28
35
  has_shift = false;
36
+ can_shift = true;
29
37
 
30
- recurrent = llama_model_is_recurrent(&model);
31
- v_trans = !recurrent && !cparams.flash_attn;
32
- can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
38
+ LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
39
+ __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
33
40
 
34
- LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
35
- __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
41
+ GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
36
42
 
37
43
  head = 0;
38
44
  size = kv_size;
@@ -78,23 +84,20 @@ bool llama_kv_cache_unified::init(
78
84
 
79
85
  const char * dev_name = "CPU";
80
86
 
81
- ggml_backend_buffer_type_t buft;
87
+ ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
88
+
82
89
  if (offload) {
83
90
  auto * dev = model.dev_layer(i);
84
91
  buft = ggml_backend_dev_buffer_type(dev);
85
92
 
86
93
  dev_name = ggml_backend_dev_name(dev);
87
- } else {
88
- buft = ggml_backend_cpu_buffer_type();
89
94
  }
90
95
 
91
- LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
92
- i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
96
+ LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name);
93
97
 
94
98
  ggml_context * ctx = ctx_for_buft(buft);
95
99
  if (!ctx) {
96
- LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
97
- return false;
100
+ throw std::runtime_error("failed to create ggml context for kv cache");
98
101
  }
99
102
 
100
103
  ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
@@ -112,55 +115,28 @@ bool llama_kv_cache_unified::init(
112
115
 
113
116
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
114
117
  if (!buf) {
115
- LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
116
- return false;
118
+ throw std::runtime_error("failed to allocate buffer for kv cache");
117
119
  }
118
120
  ggml_backend_buffer_clear(buf, 0);
119
121
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
120
122
  bufs.emplace_back(buf);
121
123
  }
122
124
 
123
- return true;
124
- }
125
-
126
- int32_t llama_kv_cache_unified::get_n_tokens() const {
127
- int32_t result = 0;
128
-
129
- for (uint32_t i = 0; i < size; i++) {
130
- result += cells[i].seq_id.size();
131
- }
132
-
133
- return result;
134
- }
135
-
136
- uint32_t llama_kv_cache_unified::get_used_cells() const {
137
- return used;
138
- }
139
-
140
- size_t llama_kv_cache_unified::total_size() const {
141
- size_t size = 0;
142
- for (const auto & buf : bufs) {
143
- size += ggml_backend_buffer_get_size(buf.get());
144
- }
145
-
146
- return size;
147
- }
125
+ {
126
+ const size_t memory_size_k = size_k_bytes();
127
+ const size_t memory_size_v = size_v_bytes();
148
128
 
149
- llama_pos llama_kv_cache_unified::pos_max() const {
150
- llama_pos pos_max = -1;
151
- for (const auto & cell : cells) {
152
- pos_max = std::max(pos_max, cell.pos);
129
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
130
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
131
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
132
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
153
133
  }
154
-
155
- return pos_max;
156
134
  }
157
135
 
158
136
  void llama_kv_cache_unified::clear() {
159
137
  for (int32_t i = 0; i < (int32_t) size; ++i) {
160
138
  cells[i].pos = -1;
161
139
  cells[i].seq_id.clear();
162
- cells[i].src = -1;
163
- cells[i].tail = -1;
164
140
  }
165
141
  head = 0;
166
142
  used = 0;
@@ -181,33 +157,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
181
157
  p1 = std::numeric_limits<llama_pos>::max();
182
158
  }
183
159
 
184
- // models like Mamba or RWKV can't have a state partially erased
185
- if (recurrent) {
186
- if (seq_id >= (int64_t) size) {
187
- // could be fatal
188
- return false;
189
- }
190
- if (0 <= seq_id) {
191
- int32_t & tail_id = cells[seq_id].tail;
192
- if (tail_id >= 0) {
193
- const llama_kv_cell & cell = cells[tail_id];
194
- // partial intersection is invalid
195
- if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
196
- return false;
197
- }
198
- // invalidate tails which will be cleared
199
- if (p0 <= cell.pos && cell.pos < p1) {
200
- tail_id = -1;
201
- }
202
- }
203
- } else {
204
- // seq_id is negative, then the range should include everything or nothing
205
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
206
- return false;
207
- }
208
- }
209
- }
210
-
211
160
  for (uint32_t i = 0; i < size; ++i) {
212
161
  if (cells[i].pos >= p0 && cells[i].pos < p1) {
213
162
  if (seq_id < 0) {
@@ -224,7 +173,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
224
173
  }
225
174
 
226
175
  cells[i].pos = -1;
227
- cells[i].src = -1;
228
176
 
229
177
  if (new_head == size) {
230
178
  new_head = i;
@@ -254,34 +202,6 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
254
202
  p1 = std::numeric_limits<llama_pos>::max();
255
203
  }
256
204
 
257
- if (recurrent) {
258
- if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
259
- llama_kv_cell & tail_src = cells[seq_id_src];
260
- llama_kv_cell & tail_dst = cells[seq_id_dst];
261
- if (tail_dst.tail >= 0) {
262
- // clear destination seq_id if it wasn't empty
263
- llama_kv_cell & cell_dst = cells[tail_dst.tail];
264
-
265
- cell_dst.seq_id.erase(seq_id_dst);
266
- tail_dst.tail = -1;
267
- if (cell_dst.seq_id.empty()) {
268
- cell_dst.pos = -1;
269
- cell_dst.delta = -1;
270
- cell_dst.src = -1;
271
- used -= 1;
272
- }
273
- }
274
- if (tail_src.tail >= 0) {
275
- llama_kv_cell & cell_src = cells[tail_src.tail];
276
-
277
- cell_src.seq_id.insert(seq_id_dst);
278
- tail_dst.tail = tail_src.tail;
279
- }
280
- }
281
-
282
- return;
283
- }
284
-
285
205
  // otherwise, this is the KV of a Transformer-like model
286
206
  head = 0;
287
207
 
@@ -296,17 +216,12 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
296
216
  uint32_t new_head = size;
297
217
 
298
218
  for (uint32_t i = 0; i < size; ++i) {
299
- if (recurrent && (llama_seq_id) i != seq_id) {
300
- cells[i].tail = -1;
301
- }
302
-
303
219
  if (!cells[i].has_seq_id(seq_id)) {
304
220
  if (cells[i].pos >= 0) {
305
221
  used--;
306
222
  }
307
223
 
308
224
  cells[i].pos = -1;
309
- cells[i].src = -1;
310
225
  cells[i].seq_id.clear();
311
226
 
312
227
  if (new_head == size){
@@ -344,20 +259,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
344
259
  return;
345
260
  }
346
261
 
347
- if (recurrent) {
348
- // for Mamba-like or RWKV models, only the pos needs to be shifted
349
- if (0 <= seq_id && seq_id < (int64_t) size) {
350
- const int32_t tail_id = cells[seq_id].tail;
351
- if (tail_id >= 0) {
352
- llama_kv_cell & cell = cells[tail_id];
353
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
354
- cell.pos += delta;
355
- }
356
- }
357
- }
358
- return;
359
- }
360
-
361
262
  for (uint32_t i = 0; i < size; ++i) {
362
263
  if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
363
264
  has_shift = true;
@@ -400,21 +301,6 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
400
301
  return;
401
302
  }
402
303
 
403
- if (recurrent) {
404
- // for Mamba-like or RWKV models, only the pos needs to be changed
405
- if (0 <= seq_id && seq_id < (int64_t) size) {
406
- const int32_t tail_id = cells[seq_id].tail;
407
- if (tail_id >= 0) {
408
- llama_kv_cell & cell = cells[tail_id];
409
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
410
- cell.pos /= d;
411
- }
412
- }
413
- }
414
-
415
- return;
416
- }
417
-
418
304
  for (uint32_t i = 0; i < size; ++i) {
419
305
  if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
420
306
  has_shift = true;
@@ -428,7 +314,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
428
314
  }
429
315
  }
430
316
 
431
- llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
317
+ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
432
318
  llama_pos result = 0;
433
319
 
434
320
  for (uint32_t i = 0; i < size; ++i) {
@@ -440,190 +326,161 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
440
326
  return result;
441
327
  }
442
328
 
443
- void llama_kv_cache_unified::defrag() {
444
- if (!recurrent) {
445
- do_defrag = true;
329
+ void llama_kv_cache_unified::restore() {
330
+ if (pending.ranges.empty()) {
331
+ return;
332
+ }
333
+
334
+ uint32_t new_head = size;
335
+
336
+ for (auto & range : pending.ranges) {
337
+ for (uint32_t i = range.c0; i < range.c1; ++i) {
338
+ cells[i].seq_id.clear();
339
+
340
+ // keep count of the number of used cells
341
+ if (cells[i].pos >= 0) {
342
+ used--;
343
+ }
344
+
345
+ cells[i].pos = -1;
346
+ }
347
+
348
+ new_head = std::min(new_head, range.c0);
349
+ }
350
+
351
+ if (new_head != size && new_head < head) {
352
+ head = new_head;
446
353
  }
447
354
  }
448
355
 
449
- bool llama_kv_cache_unified::get_can_shift() const {
450
- return can_shift;
356
+ void llama_kv_cache_unified::commit() {
357
+ if (pending.ranges.empty()) {
358
+ LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
359
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
360
+ return;
361
+ }
362
+
363
+ pending.ranges.clear();
451
364
  }
452
365
 
453
- llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
454
- const llama_ubatch & ubatch) {
455
- const uint32_t n_tokens = ubatch.n_tokens;
456
- const uint32_t n_seqs = ubatch.n_seqs;
457
- const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
366
+ bool llama_kv_cache_unified::update(llama_context & lctx) {
367
+ bool need_reserve = false;
458
368
 
459
- if (recurrent) {
460
- // For recurrent state architectures (like Mamba or RWKV),
461
- // each cache cell can store the state for a whole sequence.
462
- // A slot should be always be contiguous.
369
+ auto * sched = lctx.get_sched();
463
370
 
464
- // can only process batches with an equal number of new tokens in each sequence
465
- GGML_ASSERT(ubatch.equal_seqs);
371
+ if (has_shift) {
372
+ if (!get_can_shift()) {
373
+ GGML_ABORT("The current KV cache / model configuration does not support K-shift");
374
+ }
466
375
 
467
- int32_t min = size - 1;
468
- int32_t max = 0;
376
+ LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
469
377
 
470
- // everything should fit if all seq_ids are smaller than the max
471
- for (uint32_t s = 0; s < n_seqs; ++s) {
472
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
473
- for (uint32_t j = 0; j < n_seq_id; ++j) {
474
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
378
+ // apply K-shift if needed
379
+ if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
380
+ ggml_backend_sched_reset(sched);
475
381
 
476
- if (seq_id < 0 || (uint32_t) seq_id >= size) {
477
- // too big seq_id
478
- // TODO: would it be possible to resize the cache instead?
479
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
480
- return llama_kv_cache_slot_info_failed;
481
- }
482
- if (j > 0) {
483
- llama_kv_cell & seq = cells[seq_id];
484
- if (seq.tail >= 0) {
485
- llama_kv_cell & cell = cells[seq.tail];
486
- // clear cells from seq_ids that become shared
487
- // (should not normally happen, but let's handle it anyway)
488
- cell.seq_id.erase(seq_id);
489
- seq.tail = -1;
490
- if (cell.seq_id.empty()) {
491
- cell.pos = -1;
492
- cell.src = -1;
493
- used -= 1;
494
- }
495
- }
496
- }
497
- }
382
+ auto * gf = lctx.graph_init();
383
+
384
+ auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
385
+
386
+ ggml_backend_sched_alloc_graph(sched, gf);
387
+
388
+ res->set_inputs(nullptr);
389
+
390
+ lctx.graph_compute(gf, false);
391
+
392
+ need_reserve = true;
498
393
  }
499
394
 
500
- #ifndef NDEBUG
501
395
  {
502
- std::vector<int32_t> tails_verif;
503
- tails_verif.assign(size, -1);
504
- for (uint32_t i = 0; i < size; ++i) {
505
- llama_kv_cell & cell = cells[i];
506
- for (llama_seq_id seq_id : cell.seq_id) {
507
- if (tails_verif[seq_id] != -1) {
508
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
509
- }
510
- tails_verif[seq_id] = i;
511
- }
512
- }
396
+ has_shift = false;
397
+
513
398
  for (uint32_t i = 0; i < size; ++i) {
514
- if (tails_verif[i] != cells[i].tail) {
515
- LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
516
- }
399
+ cells[i].delta = 0;
517
400
  }
518
401
  }
519
- #endif
402
+ }
520
403
 
521
- // find next empty cell
522
- uint32_t next_empty_cell = head;
404
+ if (do_defrag) {
405
+ LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
523
406
 
524
- for (uint32_t i = 0; i < size; ++i) {
525
- if (next_empty_cell >= size) { next_empty_cell -= size; }
526
- llama_kv_cell & cell = cells[next_empty_cell];
527
- if (cell.is_empty()) { break; }
528
- next_empty_cell += 1;
529
- }
407
+ if (defrag_prepare(lctx.graph_max_nodes())) {
408
+ ggml_backend_sched_reset(sched);
530
409
 
531
- // find usable cell range
532
- for (uint32_t s = 0; s < n_seqs; ++s) {
533
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
534
- llama_kv_cell & seq_meta = cells[seq_id];
535
- bool has_cell = false;
536
- if (seq_meta.tail >= 0) {
537
- llama_kv_cell & cell = cells[seq_meta.tail];
538
- GGML_ASSERT(cell.has_seq_id(seq_id));
539
- // does this seq_id "own" the cell?
540
- if (cell.seq_id.size() == 1) { has_cell = true; }
541
- }
542
- if (!has_cell) {
543
- llama_kv_cell & empty_cell = cells[next_empty_cell];
544
- GGML_ASSERT(empty_cell.is_empty());
545
- // copy old tail into the empty cell
546
- if (seq_meta.tail >= 0) {
547
- llama_kv_cell & orig_cell = cells[seq_meta.tail];
548
- empty_cell.pos = orig_cell.pos;
549
- empty_cell.src = orig_cell.src;
550
- orig_cell.seq_id.erase(seq_id);
551
- empty_cell.seq_id.insert(seq_id); // will be overwritten
552
- }
553
- seq_meta.tail = next_empty_cell;
554
- // find next empty cell
555
- if (s + 1 < n_seqs) {
556
- next_empty_cell += 1;
557
- for (uint32_t i = 0; i < size; ++i) {
558
- if (next_empty_cell >= size) { next_empty_cell -= size; }
559
- llama_kv_cell & cell = cells[next_empty_cell];
560
- if (cell.is_empty()) { break; }
561
- next_empty_cell += 1;
562
- }
563
- }
564
- }
565
- if (min > seq_meta.tail) { min = seq_meta.tail; }
566
- if (max < seq_meta.tail) { max = seq_meta.tail; }
567
- }
410
+ auto * gf = lctx.graph_init();
568
411
 
569
- // gather and re-order
570
- for (uint32_t s = 0; s < n_seqs; ++s) {
571
- int32_t dst_id = s + min;
572
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
573
- if (dst_id != src_id) {
574
- llama_kv_cell & dst_cell = cells[dst_id];
575
- llama_kv_cell & src_cell = cells[src_id];
412
+ auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
576
413
 
577
- std::swap(dst_cell.pos, src_cell.pos);
578
- std::swap(dst_cell.src, src_cell.src);
579
- std::swap(dst_cell.seq_id, src_cell.seq_id);
414
+ ggml_backend_sched_alloc_graph(sched, gf);
580
415
 
581
- // swap tails (assuming they NEVER overlap)
582
- for (const llama_seq_id seq_id : src_cell.seq_id) {
583
- cells[seq_id].tail = src_id;
584
- }
585
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
586
- cells[seq_id].tail = dst_id;
587
- }
588
- }
589
- }
416
+ res->set_inputs(nullptr);
590
417
 
591
- // update the pos of the used seqs
592
- for (uint32_t s = 0; s < n_seqs; ++s) {
593
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
594
- int32_t cell_id = s + min;
595
- llama_kv_cell & cell = cells[cell_id];
418
+ lctx.graph_compute(gf, false);
596
419
 
597
- if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
598
- // What should happen when the pos backtracks or skips a value?
599
- // Clearing the state mid-batch would require special-casing which isn't done.
600
- LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
601
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
602
- }
603
- cell.pos = last_pos;
604
- cell.seq_id.clear();
605
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
606
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
607
- cell.seq_id.insert(seq_id);
608
- cells[seq_id].tail = cell_id;
609
- }
420
+ need_reserve = true;
610
421
  }
611
422
 
612
- // allow getting the range of used cells, from head to head + n
613
- head = min;
614
- n = max - min + 1;
615
- used = std::count_if(cells.begin(), cells.end(),
616
- [](const llama_kv_cell& cell){ return !cell.is_empty(); });
423
+ do_defrag = false;
424
+ }
425
+
426
+ return need_reserve;
427
+ }
428
+
429
+ void llama_kv_cache_unified::defrag_sched(float thold) {
430
+ // - do not defrag small contexts (i.e. < 2048 tokens)
431
+ // - count the padding towards the number of used tokens
432
+ const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
433
+
434
+ // queue defragmentation for next llama_kv_cache_update
435
+ if (fragmentation > thold) {
436
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
437
+
438
+ do_defrag = true;
439
+ }
440
+ }
441
+
442
+ void llama_kv_cache_unified::set_full() {
443
+ n = size;
444
+
445
+ // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
446
+ // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
447
+ // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
448
+ // setting it to 0 is the simplest way to achieve that
449
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13359
450
+ head = 0;
451
+ }
452
+
453
+ llama_sbatch llama_kv_cache_unified::sbatch_init(
454
+ const llama_batch & batch,
455
+ bool logits_all) {
456
+ return llama_sbatch(batch, hparams.n_embd, true, logits_all);
457
+ }
458
+
459
+ llama_ubatch llama_kv_cache_unified::ubatch_next(
460
+ llama_sbatch & sbatch,
461
+ uint32_t n_ubatch,
462
+ bool embd_pooled) const {
463
+ GGML_UNUSED(embd_pooled);
464
+ return sbatch.split_simple(n_ubatch);
465
+ }
466
+
467
+ bool llama_kv_cache_unified::find_slot(
468
+ const llama_ubatch & ubatch) {
469
+ const uint32_t n_tokens = ubatch.n_tokens;
470
+ const uint32_t n_seqs = ubatch.n_seqs;
471
+ const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
617
472
 
618
- // sanity check
619
- return llama_kv_cache_slot_info(n >= n_seqs);
473
+ // if we have enough unused cells before the current head ->
474
+ // better to start searching from the beginning of the cache, hoping to fill it
475
+ if (head > used + 2*ubatch.n_tokens) {
476
+ head = 0;
620
477
  }
621
478
 
622
479
  // otherwise, one cell per token.
623
480
 
624
481
  if (n_tokens > size) {
625
482
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
626
- return llama_kv_cache_slot_info_failed;
483
+ return false;
627
484
  }
628
485
 
629
486
  uint32_t n_tested = 0;
@@ -651,7 +508,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
651
508
 
652
509
  if (n_tested >= size) {
653
510
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
654
- return llama_kv_cache_slot_info_failed;
511
+ return false;
655
512
  }
656
513
  }
657
514
 
@@ -668,184 +525,1502 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
668
525
 
669
526
  used += n_tokens;
670
527
 
671
- return llama_kv_cache_slot_info(head, head + n_tokens);
528
+ pending.ranges.push_back({head, head + n_tokens});
529
+
530
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
531
+ // after enough generations, the benefit from this heuristic disappears
532
+ // if we start defragmenting the cache, the benefit from this will be more important
533
+ n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
534
+
535
+ //printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
536
+
537
+ return true;
672
538
  }
673
539
 
674
- uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
675
- // the FA kernels require padding to avoid extra runtime boundary checks
676
- return cparams.flash_attn ? 256u : 32u;
540
+ int32_t llama_kv_cache_unified::get_n_tokens() const {
541
+ int32_t result = 0;
542
+
543
+ for (uint32_t i = 0; i < size; i++) {
544
+ result += cells[i].seq_id.size();
545
+ }
546
+
547
+ return result;
677
548
  }
678
549
 
679
- uint32_t llama_kv_cache_unified::cell_max() const {
680
- for (uint32_t i = size; i > 0; --i) {
681
- const llama_kv_cell & cell = cells[i - 1];
550
+ int32_t llama_kv_cache_unified::get_used_cells() const {
551
+ return used;
552
+ }
682
553
 
683
- if (cell.pos >= 0 && !cell.is_empty()) {
684
- return i;
554
+ bool llama_kv_cache_unified::get_can_shift() const {
555
+ return can_shift;
556
+ }
557
+
558
+ llama_pos llama_kv_cache_unified::get_pos_max() const {
559
+ llama_pos pos_max = -1;
560
+ for (const auto & cell : cells) {
561
+ pos_max = std::max(pos_max, cell.pos);
562
+ }
563
+
564
+ return pos_max;
565
+ }
566
+
567
+ size_t llama_kv_cache_unified::total_size() const {
568
+ size_t size = 0;
569
+ for (const auto & buf : bufs) {
570
+ size += ggml_backend_buffer_get_size(buf.get());
571
+ }
572
+
573
+ return size;
574
+ }
575
+
576
+ size_t llama_kv_cache_unified::size_k_bytes() const {
577
+ size_t size_k_bytes = 0;
578
+
579
+ for (const auto & k : k_l) {
580
+ size_k_bytes += ggml_nbytes(k);
581
+ }
582
+
583
+ return size_k_bytes;
584
+ }
585
+
586
+ size_t llama_kv_cache_unified::size_v_bytes() const {
587
+ size_t size_v_bytes = 0;
588
+
589
+ for (const auto & v : v_l) {
590
+ size_v_bytes += ggml_nbytes(v);
591
+ }
592
+
593
+ return size_v_bytes;
594
+ }
595
+
596
+ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
597
+ const llama_cparams & cparams,
598
+ ggml_context * ctx,
599
+ ggml_tensor * cur,
600
+ ggml_tensor * shift,
601
+ ggml_tensor * factors,
602
+ float freq_base,
603
+ float freq_scale) const {
604
+ const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
605
+
606
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor;
607
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast;
608
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow;
609
+
610
+ const auto & n_rot = hparams.n_rot;
611
+ const auto & rope_type = hparams.rope_type;
612
+
613
+ // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
614
+ // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
615
+ const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
616
+
617
+ ggml_tensor * tmp;
618
+
619
+ if (ggml_is_quantized(cur->type)) {
620
+ // dequantize to f32 -> RoPE -> quantize back
621
+ tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
622
+
623
+ tmp = ggml_rope_ext(ctx, tmp,
624
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
625
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
626
+
627
+ tmp = ggml_cpy(ctx, tmp, cur);
628
+ } else {
629
+ // we rotate only the first n_rot dimensions
630
+ tmp = ggml_rope_ext_inplace(ctx, cur,
631
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
632
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
633
+ }
634
+
635
+ return tmp;
636
+ }
637
+
638
+ class llm_graph_input_k_shift : public llm_graph_input_i {
639
+ public:
640
+ llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
641
+ virtual ~llm_graph_input_k_shift() = default;
642
+
643
+ void set_input(const llama_ubatch * ubatch) override;
644
+
645
+ ggml_tensor * k_shift; // I32 [kv_size]
646
+
647
+ const llama_kv_cache_unified * kv_self;
648
+ };
649
+
650
+ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
651
+ GGML_UNUSED(ubatch);
652
+
653
+ if (k_shift) {
654
+ assert(ggml_backend_buffer_is_host(k_shift->buffer));
655
+
656
+ int32_t * data = (int32_t *) k_shift->data;
657
+
658
+ for (uint32_t i = 0; i < kv_self->size; ++i) {
659
+ data[i] = kv_self->cells[i].delta;
685
660
  }
686
661
  }
662
+ }
663
+
664
+ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
665
+ const llama_cparams & cparams,
666
+ ggml_context * ctx,
667
+ ggml_cgraph * gf) const {
668
+ auto res = std::make_unique<llm_graph_result>();
669
+
670
+ const auto & n_layer = hparams.n_layer;
671
+
672
+ const auto & n_embd_head_k = hparams.n_embd_head_k;
673
+ //const auto & n_embd_head_v = hparams.n_embd_head_v;
674
+
675
+ const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
676
+
677
+ //GGML_ASSERT(kv_self->size == n_ctx);
678
+
679
+ auto inp = std::make_unique<llm_graph_input_k_shift>(this);
680
+
681
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
682
+ ggml_set_input(inp->k_shift);
683
+
684
+ for (uint32_t il = 0; il < n_layer; ++il) {
685
+ const int64_t n_head_kv = hparams.n_head_kv(il);
686
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
687
+
688
+ const bool is_swa = hparams.is_swa(il);
689
+
690
+ // note: the swa rope params could become part of the cparams in the future
691
+ // if we decide to make them configurable, like the non-sliding ones
692
+ const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
693
+ const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
694
+
695
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
696
+
697
+ ggml_tensor * k =
698
+ ggml_view_3d(ctx, k_l[il],
699
+ n_embd_head_k, n_head_kv, size,
700
+ ggml_row_size(k_l[il]->type, n_embd_head_k),
701
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa),
702
+ 0);
703
+
704
+ ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
705
+
706
+ ggml_build_forward_expand(gf, cur);
707
+ }
708
+
709
+ res->add_input(std::move(inp));
710
+
711
+ return res;
712
+ }
713
+
714
+ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
715
+ const llama_cparams & cparams,
716
+ ggml_context * ctx,
717
+ ggml_cgraph * gf) const {
718
+ auto res = std::make_unique<llm_graph_result>();
719
+
720
+ const auto & ids = defrag_info.ids;
721
+
722
+ #if 0
723
+ // CPU defrag
724
+ //
725
+ // TODO: optimizations are possible:
726
+ // - multiple threads
727
+ // - avoid copying to the host memory when already there
728
+ //
729
+ // likely not worth the effort, as we have ggml_graph based defrag
730
+ //
731
+
732
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
733
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
734
+
735
+ const uint32_t kv_size = size;
736
+
737
+ std::vector<uint8_t> buf_k;
738
+ std::vector<uint8_t> buf_v;
739
+
740
+ for (uint32_t il = 0; il < n_layer; ++il) {
741
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
742
+ const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
743
+
744
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
745
+ const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
746
+
747
+ buf_k.resize(k_size);
748
+ buf_v.resize(v_size);
749
+
750
+ ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
751
+ ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
752
+
753
+ // batch move [i, i+nm) to [id, id+nm)
754
+ // note: cells can move only to a lower index
755
+ for (uint32_t i = 0; i < n_kv; ++i) {
756
+ const uint32_t id = ids[i];
757
+
758
+ if (i == id || id == n_kv) {
759
+ continue;
760
+ }
761
+
762
+ uint32_t nm = 1;
763
+
764
+ while (i + nm < n_kv && ids[i + nm] == id + nm) {
765
+ nm++;
766
+ }
767
+
768
+ // move keys
769
+ {
770
+ const int64_t os = i*k_size_row;
771
+ const int64_t od = id*k_size_row;
772
+
773
+ memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
774
+ }
775
+
776
+ // move values (note: they are transposed)
777
+ {
778
+ const int64_t os = i;
779
+ const int64_t od = id;
780
+
781
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
782
+ memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
783
+ }
784
+ }
785
+
786
+ i += nm - 1;
787
+ }
788
+
789
+ ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
790
+ ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
791
+ }
792
+ #else
793
+ for (uint32_t i = 0; i < ids.size(); ++i) {
794
+ const uint32_t id = ids[i];
795
+
796
+ if (i == id || id == ids.size()) {
797
+ continue;
798
+ }
799
+
800
+ uint32_t nm = 1;
801
+
802
+ while (i + nm < ids.size() && ids[i + nm] == id + nm) {
803
+ nm++;
804
+ }
805
+
806
+ for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
807
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
808
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
809
+
810
+ ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
811
+ n_embd_k_gqa, nm,
812
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa),
813
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
814
+
815
+ ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
816
+ n_embd_k_gqa, nm,
817
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa),
818
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
819
+
820
+ ggml_tensor * view_v_src;
821
+ ggml_tensor * view_v_dst;
822
+
823
+ if (cparams.flash_attn) {
824
+ // NOTE: the V cache is not transposed when using flash attention
825
+ view_v_src = ggml_view_2d(ctx, v_l[il],
826
+ n_embd_v_gqa, nm,
827
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa),
828
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
829
+
830
+ view_v_dst = ggml_view_2d(ctx, v_l[il],
831
+ n_embd_v_gqa, nm,
832
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa),
833
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
834
+ } else {
835
+ view_v_src = ggml_view_2d(ctx, v_l[il],
836
+ nm, n_embd_v_gqa,
837
+ ggml_row_size(v_l[il]->type, size),
838
+ ggml_row_size(v_l[il]->type, i));
839
+
840
+ view_v_dst = ggml_view_2d(ctx, v_l[il],
841
+ nm, n_embd_v_gqa,
842
+ ggml_row_size(v_l[il]->type, size),
843
+ ggml_row_size(v_l[il]->type, id));
844
+ }
845
+
846
+ ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
847
+ ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
848
+ }
849
+
850
+ i += nm - 1;
851
+ }
852
+
853
+ //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
854
+ #endif
855
+
856
+ return res;
857
+ }
858
+
859
+ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
860
+ const uint32_t n_layer = hparams.n_layer;
861
+
862
+ const uint32_t n_kv = cell_max();
863
+ const uint32_t n_used = used;
864
+
865
+ assert(n_used <= n_kv);
866
+
867
+ //const int64_t t_start = ggml_time_us();
868
+
869
+ // number of cells moved
870
+ uint32_t n_moves = 0;
871
+
872
+ // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
873
+ // - source view, destination view, copy operation
874
+ // - x2 for keys and values
875
+ //const uint32_t max_moves = max_nodes()/(6*n_layer);
876
+ // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
877
+ const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
878
+
879
+ // determine which KV cells to move where
880
+ //
881
+ // cell i moves to ids[i]
882
+ //
883
+ // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
884
+ //
885
+ auto & ids = defrag_info.ids;
886
+
887
+ ids.clear();
888
+ ids.resize(n_kv, n_kv);
889
+
890
+ for (uint32_t i0 = 0; i0 < n_used; ++i0) {
891
+ const auto & cell0 = cells[i0];
892
+
893
+ if (!cell0.is_empty()) {
894
+ ids[i0] = i0;
895
+
896
+ continue;
897
+ }
898
+
899
+ // found a hole - fill it with data from the end of the cache
900
+
901
+ uint32_t nh = 1;
902
+
903
+ // determine the size of the hole
904
+ while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
905
+ nh++;
906
+ }
907
+
908
+ uint32_t nf = 0;
909
+ uint32_t is = n_kv - 1;
910
+
911
+ // starting from the end, find nh non-empty cells
912
+ for (; is > i0; --is) {
913
+ const auto & cell1 = cells[is];
914
+
915
+ if (cell1.is_empty() || ids[is] != n_kv) {
916
+ continue;
917
+ }
918
+
919
+ // non-empty cell which is not yet moved
920
+ nf++;
921
+
922
+ if (nf == nh) {
923
+ break;
924
+ }
925
+ }
926
+
927
+ // this can only happen if `n_used` is not accurate, which would be a bug
928
+ GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
929
+
930
+ nf = 0;
931
+
932
+ uint32_t i1 = is;
933
+
934
+ // are we moving a continuous block of memory?
935
+ bool cont = false;
936
+
937
+ // should we stop searching for the next move?
938
+ bool stop = false;
939
+
940
+ // go back and move the nf cells to the hole
941
+ for (; i1 < n_kv; ++i1) {
942
+ auto & cell1 = cells[i1];
943
+
944
+ if (cell1.is_empty() || ids[i1] != n_kv) {
945
+ if (n_moves == max_moves) {
946
+ stop = true;
947
+ break;
948
+ }
949
+
950
+ cont = false;
951
+ continue;
952
+ }
953
+
954
+ // this cell goes to (i0 + nf)
955
+ ids[i1] = i0 + nf;
956
+
957
+ // move the cell meta data
958
+ cells[i0 + nf] = cell1;
959
+
960
+ // clear the old cell and move the head there
961
+ cell1 = kv_cell();
962
+ head = n_used;
963
+
964
+ if (!cont) {
965
+ n_moves++;
966
+ cont = true;
967
+ }
968
+
969
+ nf++;
970
+
971
+ if (nf == nh) {
972
+ break;
973
+ }
974
+ }
975
+
976
+ if (stop || n_moves == max_moves) {
977
+ break;
978
+ }
979
+
980
+ //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
981
+
982
+ i0 += nh - 1;
983
+ }
984
+
985
+ if (n_moves == 0) {
986
+ return false;
987
+ }
988
+
989
+ LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
990
+
991
+ LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
992
+
993
+ return true;
994
+ }
995
+
996
+ uint32_t llama_kv_cache_unified::cell_max() const {
997
+ for (uint32_t i = size; i > 0; --i) {
998
+ const kv_cell & cell = cells[i - 1];
999
+
1000
+ if (cell.pos >= 0 && !cell.is_empty()) {
1001
+ return i;
1002
+ }
1003
+ }
1004
+
1005
+ return 0;
1006
+ }
1007
+
1008
+ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1009
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1010
+ uint32_t cell_count = 0;
1011
+
1012
+ // Count the number of cells with the specified seq_id
1013
+ // Find all the ranges of cells with this seq id (or all, when -1)
1014
+ uint32_t cell_range_begin = size;
1015
+ for (uint32_t i = 0; i < size; ++i) {
1016
+ const auto & cell = cells[i];
1017
+ if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
1018
+ ++cell_count;
1019
+ if (cell_range_begin == size) {
1020
+ cell_range_begin = i;
1021
+ }
1022
+ } else {
1023
+ if (cell_range_begin != size) {
1024
+ cell_ranges.emplace_back(cell_range_begin, i);
1025
+ cell_range_begin = size;
1026
+ }
1027
+ }
1028
+ }
1029
+ if (cell_range_begin != size) {
1030
+ cell_ranges.emplace_back(cell_range_begin, size);
1031
+ }
1032
+
1033
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1034
+ uint32_t cell_count_check = 0;
1035
+ for (const auto & range : cell_ranges) {
1036
+ cell_count_check += range.second - range.first;
1037
+ }
1038
+ GGML_ASSERT(cell_count == cell_count_check);
1039
+
1040
+ io.write(&cell_count, sizeof(cell_count));
1041
+
1042
+ state_write_meta(io, cell_ranges, seq_id);
1043
+ state_write_data(io, cell_ranges);
1044
+ }
1045
+
1046
+ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1047
+ uint32_t cell_count;
1048
+ io.read_to(&cell_count, sizeof(cell_count));
1049
+
1050
+ bool res = true;
1051
+ res = res && state_read_meta(io, cell_count, seq_id);
1052
+ res = res && state_read_data(io, cell_count);
1053
+
1054
+ if (!res) {
1055
+ if (seq_id == -1) {
1056
+ clear();
1057
+ } else {
1058
+ seq_rm(seq_id, -1, -1);
1059
+ }
1060
+ throw std::runtime_error("failed to restore kv cache");
1061
+ }
1062
+ }
1063
+
1064
+ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1065
+ for (const auto & range : cell_ranges) {
1066
+ for (uint32_t i = range.first; i < range.second; ++i) {
1067
+ const auto & cell = cells[i];
1068
+ const llama_pos pos = cell.pos;
1069
+ const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
1070
+
1071
+ io.write(&pos, sizeof(pos));
1072
+ io.write(&n_seq_id, sizeof(n_seq_id));
1073
+
1074
+ if (n_seq_id) {
1075
+ for (auto seq_id : cell.seq_id) {
1076
+ io.write(&seq_id, sizeof(seq_id));
1077
+ }
1078
+ }
1079
+ }
1080
+ }
1081
+ }
1082
+
1083
+ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1084
+ const uint32_t v_trans = this->v_trans ? 1 : 0;
1085
+ const uint32_t n_layer = hparams.n_layer;
1086
+
1087
+ io.write(&v_trans, sizeof(v_trans));
1088
+ io.write(&n_layer, sizeof(n_layer));
1089
+
1090
+ std::vector<uint8_t> tmp_buf;
1091
+
1092
+ // Iterate and write all the keys first, each row is a cell
1093
+ // Get whole range at a time
1094
+ for (uint32_t il = 0; il < n_layer; ++il) {
1095
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1096
+
1097
+ // Write key type
1098
+ const int32_t k_type_i = (int32_t)k_l[il]->type;
1099
+ io.write(&k_type_i, sizeof(k_type_i));
1100
+
1101
+ // Write row size of key
1102
+ const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1103
+ io.write(&k_size_row, sizeof(k_size_row));
1104
+
1105
+ // Read each range of cells of k_size length each into tmp_buf and write out
1106
+ for (const auto & range : cell_ranges) {
1107
+ const size_t range_size = range.second - range.first;
1108
+ const size_t buf_size = range_size * k_size_row;
1109
+ io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
1110
+ }
1111
+ }
1112
+
1113
+ if (!v_trans) {
1114
+ for (uint32_t il = 0; il < n_layer; ++il) {
1115
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1116
+
1117
+ // Write value type
1118
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1119
+ io.write(&v_type_i, sizeof(v_type_i));
1120
+
1121
+ // Write row size of value
1122
+ const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1123
+ io.write(&v_size_row, sizeof(v_size_row));
1124
+
1125
+ // Read each range of cells of v_size length each into tmp_buf and write out
1126
+ for (const auto & range : cell_ranges) {
1127
+ const size_t range_size = range.second - range.first;
1128
+ const size_t buf_size = range_size * v_size_row;
1129
+ io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
1130
+ }
1131
+ }
1132
+ } else {
1133
+ // When v is transposed, we also need the element size and get the element ranges from each row
1134
+ const uint32_t kv_size = size;
1135
+ for (uint32_t il = 0; il < n_layer; ++il) {
1136
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1137
+
1138
+ // Write value type
1139
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1140
+ io.write(&v_type_i, sizeof(v_type_i));
1141
+
1142
+ // Write element size
1143
+ const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
1144
+ io.write(&v_size_el, sizeof(v_size_el));
1145
+
1146
+ // Write GQA embedding size
1147
+ io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1148
+
1149
+ // For each row, we get the element values of each cell
1150
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1151
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
1152
+ for (const auto & range : cell_ranges) {
1153
+ const size_t range_size = range.second - range.first;
1154
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1155
+ const size_t buf_size = range_size * v_size_el;
1156
+ io.write_tensor(v_l[il], src_offset, buf_size);
1157
+ }
1158
+ }
1159
+ }
1160
+ }
1161
+ }
1162
+
1163
+ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1164
+ if (dest_seq_id != -1) {
1165
+ // single sequence
1166
+
1167
+ seq_rm(dest_seq_id, -1, -1);
1168
+
1169
+ llama_sbatch sbatch;
1170
+ llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1171
+
1172
+ batch.n_tokens = cell_count;
1173
+ batch.n_seq_tokens = cell_count;
1174
+ batch.n_seqs = 1;
1175
+
1176
+ for (uint32_t i = 0; i < cell_count; ++i) {
1177
+ llama_pos pos;
1178
+ uint32_t n_seq_id;
1179
+
1180
+ io.read_to(&pos, sizeof(pos));
1181
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1182
+
1183
+ if (n_seq_id != 0) {
1184
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1185
+ return false;
1186
+ }
1187
+
1188
+ batch.pos[i] = pos;
1189
+ }
1190
+ batch.n_seq_id[0] = 1;
1191
+ batch.seq_id[0] = &dest_seq_id;
1192
+ if (!find_slot(batch)) {
1193
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1194
+ return false;
1195
+ }
1196
+ commit();
1197
+
1198
+ // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1199
+ // Assume that this is one contiguous block of cells
1200
+ GGML_ASSERT(head + cell_count <= size);
1201
+ GGML_ASSERT(cells[head].pos == batch.pos[0]);
1202
+ GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
1203
+ GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
1204
+ GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
1205
+ } else {
1206
+ // whole KV cache restore
1207
+
1208
+ if (cell_count > size) {
1209
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1210
+ return false;
1211
+ }
1212
+
1213
+ clear();
1214
+
1215
+ for (uint32_t i = 0; i < cell_count; ++i) {
1216
+ kv_cell & cell = cells[i];
1217
+
1218
+ llama_pos pos;
1219
+ uint32_t n_seq_id;
1220
+
1221
+ io.read_to(&pos, sizeof(pos));
1222
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1223
+
1224
+ cell.pos = pos;
1225
+
1226
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
1227
+ llama_seq_id seq_id;
1228
+ io.read_to(&seq_id, sizeof(seq_id));
1229
+
1230
+ // TODO: llama_kv_cache_unified should have a notion of max sequences
1231
+ //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1232
+ if (seq_id < 0) {
1233
+ //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1234
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
1235
+ return false;
1236
+ }
1237
+
1238
+ cell.seq_id.insert(seq_id);
1239
+ }
1240
+ }
1241
+
1242
+ head = 0;
1243
+ used = cell_count;
1244
+ }
1245
+
1246
+ return true;
1247
+ }
1248
+
1249
+ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1250
+ uint32_t v_trans;
1251
+ uint32_t n_layer;
1252
+ io.read_to(&v_trans, sizeof(v_trans));
1253
+ io.read_to(&n_layer, sizeof(n_layer));
1254
+
1255
+ if (n_layer != hparams.n_layer) {
1256
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1257
+ return false;
1258
+ }
1259
+ if (cell_count > size) {
1260
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1261
+ return false;
1262
+ }
1263
+ if (this->v_trans != (bool) v_trans) {
1264
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1265
+ return false;
1266
+ }
1267
+
1268
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1269
+ for (uint32_t il = 0; il < n_layer; ++il) {
1270
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1271
+
1272
+ // Read type of key
1273
+ int32_t k_type_i_ref;
1274
+ io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1275
+ const int32_t k_type_i = (int32_t) k_l[il]->type;
1276
+ if (k_type_i != k_type_i_ref) {
1277
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1278
+ return false;
1279
+ }
1280
+
1281
+ // Read row size of key
1282
+ uint64_t k_size_row_ref;
1283
+ io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1284
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1285
+ if (k_size_row != k_size_row_ref) {
1286
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1287
+ return false;
1288
+ }
1289
+
1290
+ if (cell_count) {
1291
+ // Read and set the keys for the whole cell range
1292
+ ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1293
+ }
1294
+ }
1295
+
1296
+ if (!this->v_trans) {
1297
+ for (uint32_t il = 0; il < n_layer; ++il) {
1298
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1299
+
1300
+ // Read type of value
1301
+ int32_t v_type_i_ref;
1302
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1303
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1304
+ if (v_type_i != v_type_i_ref) {
1305
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1306
+ return false;
1307
+ }
1308
+
1309
+ // Read row size of value
1310
+ uint64_t v_size_row_ref;
1311
+ io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1312
+ const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1313
+ if (v_size_row != v_size_row_ref) {
1314
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1315
+ return false;
1316
+ }
1317
+
1318
+ if (cell_count) {
1319
+ // Read and set the values for the whole cell range
1320
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1321
+ }
1322
+ }
1323
+ } else {
1324
+ // For each layer, read the values for each cell (transposed)
1325
+ for (uint32_t il = 0; il < n_layer; ++il) {
1326
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1327
+
1328
+ // Read type of value
1329
+ int32_t v_type_i_ref;
1330
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1331
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1332
+ if (v_type_i != v_type_i_ref) {
1333
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1334
+ return false;
1335
+ }
1336
+
1337
+ // Read element size of value
1338
+ uint32_t v_size_el_ref;
1339
+ io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1340
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
1341
+ if (v_size_el != v_size_el_ref) {
1342
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1343
+ return false;
1344
+ }
1345
+
1346
+ // Read GQA embedding size
1347
+ uint32_t n_embd_v_gqa_ref;
1348
+ io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1349
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1350
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1351
+ return false;
1352
+ }
1353
+
1354
+ if (cell_count) {
1355
+ // For each row in the transposed matrix, read the values for the whole cell range
1356
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1357
+ const size_t dst_offset = (head + j * size) * v_size_el;
1358
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1359
+ }
1360
+ }
1361
+ }
1362
+ }
1363
+
1364
+ return true;
1365
+ }
1366
+
1367
+ //
1368
+ // llama_kv_cache_recurrent
1369
+ //
1370
+
1371
+ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1372
+ const llama_model & model,
1373
+ ggml_type type_k,
1374
+ ggml_type type_v,
1375
+ bool offload,
1376
+ uint32_t kv_size) : hparams(model.hparams) {
1377
+ const int32_t n_layer = hparams.n_layer;
1378
+
1379
+ LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
1380
+ __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1381
+
1382
+ head = 0;
1383
+ size = kv_size;
1384
+ used = 0;
1385
+
1386
+ this->type_k = type_k;
1387
+ this->type_v = type_v;
1388
+
1389
+ cells.clear();
1390
+ cells.resize(kv_size);
1391
+
1392
+ // create a context for each buffer type
1393
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1394
+ auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1395
+ auto it = ctx_map.find(buft);
1396
+ if (it == ctx_map.end()) {
1397
+ ggml_init_params params = {
1398
+ /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
1399
+ /*.mem_buffer =*/ NULL,
1400
+ /*.no_alloc =*/ true,
1401
+ };
1402
+
1403
+ ggml_context * ctx = ggml_init(params);
1404
+ if (!ctx) {
1405
+ return nullptr;
1406
+ }
1407
+
1408
+ ctx_map[buft] = ctx;
1409
+ ctxs.emplace_back(ctx);
1410
+
1411
+ return ctx;
1412
+ }
1413
+
1414
+ return it->second;
1415
+ };
1416
+
1417
+ k_l.reserve(n_layer);
1418
+ v_l.reserve(n_layer);
1419
+
1420
+ for (int i = 0; i < n_layer; i++) {
1421
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
1422
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
1423
+
1424
+ const char * dev_name = "CPU";
1425
+
1426
+ ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
1427
+
1428
+ if (offload) {
1429
+ auto * dev = model.dev_layer(i);
1430
+ buft = ggml_backend_dev_buffer_type(dev);
1431
+
1432
+ dev_name = ggml_backend_dev_name(dev);
1433
+ }
1434
+
1435
+ LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
1436
+
1437
+ ggml_context * ctx = ctx_for_buft(buft);
1438
+ if (!ctx) {
1439
+ throw std::runtime_error("failed to create ggml context for kv cache");
1440
+ }
1441
+
1442
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
1443
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
1444
+ ggml_format_name(k, "cache_k_l%d", i);
1445
+ ggml_format_name(v, "cache_v_l%d", i);
1446
+ k_l.push_back(k);
1447
+ v_l.push_back(v);
1448
+ }
1449
+
1450
+ // allocate tensors and initialize the buffers to avoid NaNs in the padding
1451
+ for (auto it : ctx_map) {
1452
+ auto * buft = it.first;
1453
+ auto * ctx = it.second;
1454
+
1455
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1456
+ if (!buf) {
1457
+ throw std::runtime_error("failed to allocate buffer for kv cache");
1458
+ }
1459
+ ggml_backend_buffer_clear(buf, 0);
1460
+ LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
1461
+ bufs.emplace_back(buf);
1462
+ }
1463
+
1464
+ {
1465
+ const size_t memory_size_k = size_k_bytes();
1466
+ const size_t memory_size_v = size_v_bytes();
1467
+
1468
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
1469
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
1470
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
1471
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1472
+ }
1473
+ }
1474
+
1475
+ void llama_kv_cache_recurrent::clear() {
1476
+ for (int32_t i = 0; i < (int32_t) size; ++i) {
1477
+ cells[i].pos = -1;
1478
+ cells[i].seq_id.clear();
1479
+ cells[i].src = -1;
1480
+ cells[i].tail = -1;
1481
+ }
1482
+ head = 0;
1483
+ used = 0;
1484
+
1485
+ for (auto & buf : bufs) {
1486
+ ggml_backend_buffer_clear(buf.get(), 0);
1487
+ }
1488
+ }
1489
+
1490
+ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1491
+ uint32_t new_head = size;
1492
+
1493
+ if (p0 < 0) {
1494
+ p0 = 0;
1495
+ }
1496
+
1497
+ if (p1 < 0) {
1498
+ p1 = std::numeric_limits<llama_pos>::max();
1499
+ }
1500
+
1501
+ // models like Mamba or RWKV can't have a state partially erased
1502
+ if (seq_id >= (int64_t) size) {
1503
+ // could be fatal
1504
+ return false;
1505
+ }
1506
+ if (0 <= seq_id) {
1507
+ int32_t & tail_id = cells[seq_id].tail;
1508
+ if (tail_id >= 0) {
1509
+ const kv_cell & cell = cells[tail_id];
1510
+ // partial intersection is invalid
1511
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1512
+ return false;
1513
+ }
1514
+ // invalidate tails which will be cleared
1515
+ if (p0 <= cell.pos && cell.pos < p1) {
1516
+ tail_id = -1;
1517
+ }
1518
+ }
1519
+ } else {
1520
+ // seq_id is negative, then the range should include everything or nothing
1521
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1522
+ return false;
1523
+ }
1524
+ }
1525
+
1526
+ for (uint32_t i = 0; i < size; ++i) {
1527
+ if (cells[i].pos >= p0 && cells[i].pos < p1) {
1528
+ if (seq_id < 0) {
1529
+ cells[i].seq_id.clear();
1530
+ } else if (cells[i].has_seq_id(seq_id)) {
1531
+ cells[i].seq_id.erase(seq_id);
1532
+ } else {
1533
+ continue;
1534
+ }
1535
+ if (cells[i].is_empty()) {
1536
+ // keep count of the number of used cells
1537
+ if (cells[i].pos >= 0) {
1538
+ used--;
1539
+ }
1540
+ cells[i].pos = -1;
1541
+ cells[i].src = -1;
1542
+ if (new_head == size) {
1543
+ new_head = i;
1544
+ }
1545
+ }
1546
+ }
1547
+ }
1548
+
1549
+ // If we freed up a slot, set head to it so searching can start there.
1550
+ if (new_head != size && new_head < head) {
1551
+ head = new_head;
1552
+ }
1553
+
1554
+ return true;
1555
+ }
1556
+
1557
+ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1558
+ if (seq_id_src == seq_id_dst) {
1559
+ return;
1560
+ }
1561
+
1562
+ if (p0 < 0) {
1563
+ p0 = 0;
1564
+ }
1565
+
1566
+ if (p1 < 0) {
1567
+ p1 = std::numeric_limits<llama_pos>::max();
1568
+ }
1569
+
1570
+ if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
1571
+ kv_cell & tail_src = cells[seq_id_src];
1572
+ kv_cell & tail_dst = cells[seq_id_dst];
1573
+ if (tail_dst.tail >= 0) {
1574
+ // clear destination seq_id if it wasn't empty
1575
+ kv_cell & cell_dst = cells[tail_dst.tail];
1576
+
1577
+ cell_dst.seq_id.erase(seq_id_dst);
1578
+ tail_dst.tail = -1;
1579
+ if (cell_dst.seq_id.empty()) {
1580
+ cell_dst.pos = -1;
1581
+ cell_dst.src = -1;
1582
+ used -= 1;
1583
+ }
1584
+ }
1585
+ if (tail_src.tail >= 0) {
1586
+ kv_cell & cell_src = cells[tail_src.tail];
1587
+
1588
+ cell_src.seq_id.insert(seq_id_dst);
1589
+ tail_dst.tail = tail_src.tail;
1590
+ }
1591
+ }
1592
+ }
1593
+
1594
+ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
1595
+ uint32_t new_head = size;
1596
+
1597
+ for (uint32_t i = 0; i < size; ++i) {
1598
+ if ((llama_seq_id) i != seq_id) {
1599
+ cells[i].tail = -1;
1600
+ }
1601
+
1602
+ if (!cells[i].has_seq_id(seq_id)) {
1603
+ if (cells[i].pos >= 0) {
1604
+ used--;
1605
+ }
1606
+
1607
+ cells[i].pos = -1;
1608
+ cells[i].src = -1;
1609
+ cells[i].seq_id.clear();
1610
+
1611
+ if (new_head == size){
1612
+ new_head = i;
1613
+ }
1614
+ } else {
1615
+ cells[i].seq_id.clear();
1616
+ cells[i].seq_id.insert(seq_id);
1617
+ }
1618
+ }
1619
+
1620
+ // If we freed up a slot, set head to it so searching can start there.
1621
+ if (new_head != size && new_head < head) {
1622
+ head = new_head;
1623
+ }
1624
+ }
1625
+
1626
+ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
1627
+ if (delta == 0) {
1628
+ return;
1629
+ }
1630
+
1631
+ if (p0 < 0) {
1632
+ p0 = 0;
1633
+ }
1634
+
1635
+ if (p1 < 0) {
1636
+ p1 = std::numeric_limits<llama_pos>::max();
1637
+ }
1638
+
1639
+ // If there is no range then return early to avoid looping over the
1640
+ if (p0 == p1) {
1641
+ return;
1642
+ }
1643
+
1644
+ // for Mamba-like or RWKV models, only the pos needs to be shifted
1645
+ if (0 <= seq_id && seq_id < (int64_t) size) {
1646
+ const int32_t tail_id = cells[seq_id].tail;
1647
+ if (tail_id >= 0) {
1648
+ kv_cell & cell = cells[tail_id];
1649
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
1650
+ cell.pos += delta;
1651
+ }
1652
+ }
1653
+ }
1654
+ }
1655
+
1656
+ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1657
+ if (d == 1) {
1658
+ return;
1659
+ }
1660
+
1661
+ if (p0 < 0) {
1662
+ p0 = 0;
1663
+ }
1664
+
1665
+ if (p1 < 0) {
1666
+ p1 = std::numeric_limits<llama_pos>::max();
1667
+ }
1668
+
1669
+ // If there is no range then return early to avoid looping over the cache.
1670
+ if (p0 == p1) {
1671
+ return;
1672
+ }
1673
+
1674
+ // for Mamba-like or RWKV models, only the pos needs to be changed
1675
+ if (0 <= seq_id && seq_id < (int64_t) size) {
1676
+ const int32_t tail_id = cells[seq_id].tail;
1677
+ if (tail_id >= 0) {
1678
+ kv_cell & cell = cells[tail_id];
1679
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
1680
+ cell.pos /= d;
1681
+ }
1682
+ }
1683
+ }
1684
+ }
1685
+
1686
+ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
1687
+ llama_pos result = 0;
1688
+
1689
+ for (uint32_t i = 0; i < size; ++i) {
1690
+ if (cells[i].has_seq_id(seq_id)) {
1691
+ result = std::max(result, cells[i].pos);
1692
+ }
1693
+ }
1694
+
1695
+ return result;
1696
+ }
1697
+
1698
+ void llama_kv_cache_recurrent::restore() {
1699
+ if (pending.ranges.empty()) {
1700
+ return;
1701
+ }
1702
+
1703
+ seq_rm(-1, -1, -1);
1704
+ }
687
1705
 
688
- return 0;
1706
+ void llama_kv_cache_recurrent::commit() {
1707
+ pending.ranges.clear();
689
1708
  }
690
1709
 
691
- size_t llama_kv_cache_unified::size_k_bytes() const {
692
- size_t size_k_bytes = 0;
1710
+ bool llama_kv_cache_recurrent::update(llama_context & lctx) {
1711
+ GGML_UNUSED(lctx);
1712
+ return false;
1713
+ }
693
1714
 
694
- for (const auto & k : k_l) {
695
- size_k_bytes += ggml_nbytes(k);
696
- }
1715
+ void llama_kv_cache_recurrent::defrag_sched(float thold) {
1716
+ GGML_UNUSED(thold);
1717
+ // noop
1718
+ }
697
1719
 
698
- return size_k_bytes;
1720
+ void llama_kv_cache_recurrent::set_full() {
1721
+ n = size;
1722
+ head = 0;
699
1723
  }
700
1724
 
701
- size_t llama_kv_cache_unified::size_v_bytes() const {
702
- size_t size_v_bytes = 0;
1725
+ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
1726
+ const llama_batch & batch,
1727
+ bool logits_all) {
1728
+ return llama_sbatch(batch, hparams.n_embd, false, logits_all);
1729
+ }
703
1730
 
704
- for (const auto & v : v_l) {
705
- size_v_bytes += ggml_nbytes(v);
1731
+ llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1732
+ if (embd_pooled) {
1733
+ // Pooled embeddings cannot be split across ubatches (yet)
1734
+ return sbatch.split_seq(n_ubatch);
706
1735
  }
707
1736
 
708
- return size_v_bytes;
1737
+ return sbatch.split_equal(n_ubatch);
709
1738
  }
710
1739
 
711
- bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
712
- const uint32_t n_layer = hparams.n_layer;
1740
+ bool llama_kv_cache_recurrent::find_slot(
1741
+ const llama_ubatch & ubatch) {
1742
+ const uint32_t n_tokens = ubatch.n_tokens;
1743
+ const uint32_t n_seqs = ubatch.n_seqs;
713
1744
 
714
- const uint32_t n_kv = cell_max();
715
- const uint32_t n_used = used;
1745
+ const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
716
1746
 
717
- assert(n_used <= n_kv);
1747
+ // if we have enough unused cells before the current head ->
1748
+ // better to start searching from the beginning of the cache, hoping to fill it
1749
+ if (head > used + 2*n_tokens) {
1750
+ head = 0;
1751
+ }
718
1752
 
719
- //const int64_t t_start = ggml_time_us();
1753
+ // For recurrent state architectures (like Mamba or RWKV),
1754
+ // each cache cell can store the state for a whole sequence.
1755
+ // A slot should be always be contiguous.
720
1756
 
721
- // number of cells moved
722
- uint32_t n_moves = 0;
1757
+ // can only process batches with an equal number of new tokens in each sequence
1758
+ GGML_ASSERT(ubatch.equal_seqs);
723
1759
 
724
- // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
725
- // - source view, destination view, copy operation
726
- // - x2 for keys and values
727
- //const uint32_t max_moves = max_nodes()/(6*n_layer);
728
- // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
729
- const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1760
+ int32_t min = size - 1;
1761
+ int32_t max = 0;
730
1762
 
731
- // determine which KV cells to move where
732
- //
733
- // cell i moves to ids[i]
734
- //
735
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
736
- //
737
- auto & ids = defrag_info.ids;
1763
+ // everything should fit if all seq_ids are smaller than the max
1764
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1765
+ const uint32_t n_seq_id = ubatch.n_seq_id[s];
1766
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
1767
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
738
1768
 
739
- ids.clear();
740
- ids.resize(n_kv, n_kv);
1769
+ if (seq_id < 0 || (uint32_t) seq_id >= size) {
1770
+ // too big seq_id
1771
+ // TODO: would it be possible to resize the cache instead?
1772
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
1773
+ return false;
1774
+ }
1775
+ if (j > 0) {
1776
+ kv_cell & seq = cells[seq_id];
1777
+ if (seq.tail >= 0) {
1778
+ kv_cell & cell = cells[seq.tail];
1779
+ // clear cells from seq_ids that become shared
1780
+ // (should not normally happen, but let's handle it anyway)
1781
+ cell.seq_id.erase(seq_id);
1782
+ seq.tail = -1;
1783
+ if (cell.seq_id.empty()) {
1784
+ cell.pos = -1;
1785
+ cell.src = -1;
1786
+ used -= 1;
1787
+ }
1788
+ }
1789
+ }
1790
+ }
1791
+ }
741
1792
 
742
- for (uint32_t i0 = 0; i0 < n_used; ++i0) {
743
- const auto & cell0 = cells[i0];
1793
+ #ifndef NDEBUG
1794
+ {
1795
+ std::vector<int32_t> tails_verif;
1796
+ tails_verif.assign(size, -1);
1797
+ for (uint32_t i = 0; i < size; ++i) {
1798
+ kv_cell & cell = cells[i];
1799
+ for (llama_seq_id seq_id : cell.seq_id) {
1800
+ if (tails_verif[seq_id] != -1) {
1801
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
1802
+ }
1803
+ tails_verif[seq_id] = i;
1804
+ }
1805
+ }
1806
+ for (uint32_t i = 0; i < size; ++i) {
1807
+ if (tails_verif[i] != cells[i].tail) {
1808
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
1809
+ }
1810
+ }
1811
+ }
1812
+ #endif
744
1813
 
745
- if (!cell0.is_empty()) {
746
- ids[i0] = i0;
1814
+ // find next empty cell
1815
+ uint32_t next_empty_cell = head;
747
1816
 
748
- continue;
1817
+ for (uint32_t i = 0; i < size; ++i) {
1818
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
1819
+ kv_cell & cell = cells[next_empty_cell];
1820
+ if (cell.is_empty()) { break; }
1821
+ next_empty_cell += 1;
1822
+ }
1823
+
1824
+ // find usable cell range
1825
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1826
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
1827
+ kv_cell & seq_meta = cells[seq_id];
1828
+ bool has_cell = false;
1829
+ if (seq_meta.tail >= 0) {
1830
+ kv_cell & cell = cells[seq_meta.tail];
1831
+ GGML_ASSERT(cell.has_seq_id(seq_id));
1832
+ // does this seq_id "own" the cell?
1833
+ if (cell.seq_id.size() == 1) { has_cell = true; }
1834
+ }
1835
+ if (!has_cell) {
1836
+ kv_cell & empty_cell = cells[next_empty_cell];
1837
+ GGML_ASSERT(empty_cell.is_empty());
1838
+ // copy old tail into the empty cell
1839
+ if (seq_meta.tail >= 0) {
1840
+ kv_cell & orig_cell = cells[seq_meta.tail];
1841
+ empty_cell.pos = orig_cell.pos;
1842
+ empty_cell.src = orig_cell.src;
1843
+ orig_cell.seq_id.erase(seq_id);
1844
+ empty_cell.seq_id.insert(seq_id); // will be overwritten
1845
+ }
1846
+ seq_meta.tail = next_empty_cell;
1847
+ // find next empty cell
1848
+ if (s + 1 < n_seqs) {
1849
+ next_empty_cell += 1;
1850
+ for (uint32_t i = 0; i < size; ++i) {
1851
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
1852
+ kv_cell & cell = cells[next_empty_cell];
1853
+ if (cell.is_empty()) { break; }
1854
+ next_empty_cell += 1;
1855
+ }
1856
+ }
749
1857
  }
1858
+ if (min > seq_meta.tail) { min = seq_meta.tail; }
1859
+ if (max < seq_meta.tail) { max = seq_meta.tail; }
1860
+ }
750
1861
 
751
- // found a hole - fill it with data from the end of the cache
1862
+ // gather and re-order
1863
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1864
+ int32_t dst_id = s + min;
1865
+ int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
1866
+ if (dst_id != src_id) {
1867
+ kv_cell & dst_cell = cells[dst_id];
1868
+ kv_cell & src_cell = cells[src_id];
752
1869
 
753
- uint32_t nh = 1;
1870
+ std::swap(dst_cell.pos, src_cell.pos);
1871
+ std::swap(dst_cell.src, src_cell.src);
1872
+ std::swap(dst_cell.seq_id, src_cell.seq_id);
754
1873
 
755
- // determine the size of the hole
756
- while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
757
- nh++;
1874
+ // swap tails (assuming they NEVER overlap)
1875
+ for (const llama_seq_id seq_id : src_cell.seq_id) {
1876
+ cells[seq_id].tail = src_id;
1877
+ }
1878
+ for (const llama_seq_id seq_id : dst_cell.seq_id) {
1879
+ cells[seq_id].tail = dst_id;
1880
+ }
758
1881
  }
1882
+ }
759
1883
 
760
- uint32_t nf = 0;
761
- uint32_t is = n_kv - 1;
1884
+ // update the pos of the used seqs
1885
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1886
+ const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
1887
+ int32_t cell_id = s + min;
1888
+ kv_cell & cell = cells[cell_id];
762
1889
 
763
- // starting from the end, find nh non-empty cells
764
- for (; is > i0; --is) {
765
- const auto & cell1 = cells[is];
1890
+ if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
1891
+ // What should happen when the pos backtracks or skips a value?
1892
+ // Clearing the state mid-batch would require special-casing which isn't done.
1893
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
1894
+ __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
1895
+ }
1896
+ cell.pos = last_pos;
1897
+ cell.seq_id.clear();
1898
+ for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
1899
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
1900
+ cell.seq_id.insert(seq_id);
1901
+ cells[seq_id].tail = cell_id;
1902
+ }
1903
+ }
766
1904
 
767
- if (cell1.is_empty() || ids[is] != n_kv) {
768
- continue;
769
- }
1905
+ // allow getting the range of used cells, from head to head + n
1906
+ head = min;
1907
+ n = max - min + 1;
1908
+ used = std::count_if(cells.begin(), cells.end(),
1909
+ [](const kv_cell & cell){ return !cell.is_empty(); });
770
1910
 
771
- // non-empty cell which is not yet moved
772
- nf++;
1911
+ // sanity check
1912
+ return n >= n_seqs;
1913
+ }
773
1914
 
774
- if (nf == nh) {
775
- break;
776
- }
777
- }
1915
+ int32_t llama_kv_cache_recurrent::get_n_tokens() const {
1916
+ int32_t result = 0;
778
1917
 
779
- // this can only happen if `n_used` is not accurate, which would be a bug
780
- GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
1918
+ for (uint32_t i = 0; i < size; i++) {
1919
+ result += cells[i].seq_id.size();
1920
+ }
781
1921
 
782
- nf = 0;
1922
+ return result;
1923
+ }
783
1924
 
784
- uint32_t i1 = is;
1925
+ int32_t llama_kv_cache_recurrent::get_used_cells() const {
1926
+ return used;
1927
+ }
785
1928
 
786
- // are we moving a continuous block of memory?
787
- bool cont = false;
1929
+ llama_pos llama_kv_cache_recurrent::get_pos_max() const {
1930
+ llama_pos pos_max = -1;
1931
+ for (const auto & cell : cells) {
1932
+ pos_max = std::max(pos_max, cell.pos);
1933
+ }
788
1934
 
789
- // should we stop searching for the next move?
790
- bool stop = false;
1935
+ return pos_max;
1936
+ }
791
1937
 
792
- // go back and move the nf cells to the hole
793
- for (; i1 < n_kv; ++i1) {
794
- auto & cell1 = cells[i1];
1938
+ bool llama_kv_cache_recurrent::get_can_shift() const {
1939
+ return false;
1940
+ }
795
1941
 
796
- if (cell1.is_empty() || ids[i1] != n_kv) {
797
- if (n_moves == max_moves) {
798
- stop = true;
799
- break;
800
- }
1942
+ int32_t llama_kv_cache_recurrent::s_copy(int i) const {
1943
+ const uint32_t cell_id = i + head;
801
1944
 
802
- cont = false;
803
- continue;
804
- }
1945
+ //////////////////////////////////////////////
1946
+ // TODO: this should not mutate the KV cache !
1947
+ kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
805
1948
 
806
- // this cell goes to (i0 + nf)
807
- ids[i1] = i0 + nf;
1949
+ // prevent out-of-bound sources
1950
+ if (cell.src < 0 || (uint32_t) cell.src >= size) {
1951
+ cell.src = cell_id;
1952
+ }
808
1953
 
809
- // move the cell meta data
810
- cells[i0 + nf] = cell1;
1954
+ int32_t res = cell.src;
811
1955
 
812
- // clear the old cell and move the head there
813
- cell1 = llama_kv_cell();
814
- head = n_used;
1956
+ // TODO: do not mutate the KV cache
1957
+ // ensure copy only happens once
1958
+ if (cell.src != (int32_t) cell_id) {
1959
+ cell.src = cell_id;
1960
+ }
815
1961
 
816
- if (!cont) {
817
- n_moves++;
818
- cont = true;
819
- }
1962
+ return res;
1963
+ }
820
1964
 
821
- nf++;
1965
+ float llama_kv_cache_recurrent::s_mask(int i) const {
1966
+ const uint32_t cell_id = i + head;
822
1967
 
823
- if (nf == nh) {
824
- break;
825
- }
826
- }
1968
+ //////////////////////////////////////////////
1969
+ // TODO: this should not mutate the KV cache !
1970
+ kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
827
1971
 
828
- if (stop || n_moves == max_moves) {
829
- break;
1972
+ float res = (float) (cell.src >= 0);
1973
+
1974
+ // only clear once
1975
+ if (cell.src < 0) {
1976
+ cell.src = cell_id;
1977
+ }
1978
+
1979
+ return res;
1980
+ }
1981
+
1982
+ uint32_t llama_kv_cache_recurrent::cell_max() const {
1983
+ for (uint32_t i = size; i > 0; --i) {
1984
+ const kv_cell & cell = cells[i - 1];
1985
+
1986
+ if (cell.pos >= 0 && !cell.is_empty()) {
1987
+ return i;
830
1988
  }
1989
+ }
831
1990
 
832
- //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
1991
+ return 0;
1992
+ }
833
1993
 
834
- i0 += nh - 1;
1994
+ size_t llama_kv_cache_recurrent::total_size() const {
1995
+ size_t size = 0;
1996
+ for (const auto & buf : bufs) {
1997
+ size += ggml_backend_buffer_get_size(buf.get());
835
1998
  }
836
1999
 
837
- if (n_moves == 0) {
838
- return false;
2000
+ return size;
2001
+ }
2002
+
2003
+ size_t llama_kv_cache_recurrent::size_k_bytes() const {
2004
+ size_t size_k_bytes = 0;
2005
+
2006
+ for (const auto & k : k_l) {
2007
+ size_k_bytes += ggml_nbytes(k);
839
2008
  }
840
2009
 
841
- LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
2010
+ return size_k_bytes;
2011
+ }
2012
+
2013
+ size_t llama_kv_cache_recurrent::size_v_bytes() const {
2014
+ size_t size_v_bytes = 0;
842
2015
 
843
- LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
2016
+ for (const auto & v : v_l) {
2017
+ size_v_bytes += ggml_nbytes(v);
2018
+ }
844
2019
 
845
- return true;
2020
+ return size_v_bytes;
846
2021
  }
847
2022
 
848
- void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
2023
+ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
849
2024
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
850
2025
  uint32_t cell_count = 0;
851
2026
 
@@ -883,7 +2058,7 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
883
2058
  state_write_data(io, cell_ranges);
884
2059
  }
885
2060
 
886
- void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
2061
+ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
887
2062
  uint32_t cell_count;
888
2063
  io.read_to(&cell_count, sizeof(cell_count));
889
2064
 
@@ -901,7 +2076,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
901
2076
  }
902
2077
  }
903
2078
 
904
- void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
2079
+ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
905
2080
  for (const auto & range : cell_ranges) {
906
2081
  for (uint32_t i = range.first; i < range.second; ++i) {
907
2082
  const auto & cell = cells[i];
@@ -920,8 +2095,8 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
920
2095
  }
921
2096
  }
922
2097
 
923
- void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
924
- const uint32_t v_trans = this->v_trans ? 1 : 0;
2098
+ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
2099
+ const uint32_t v_trans = 0;
925
2100
  const uint32_t n_layer = hparams.n_layer;
926
2101
 
927
2102
  io.write(&v_trans, sizeof(v_trans));
@@ -1000,7 +2175,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1000
2175
  }
1001
2176
  }
1002
2177
 
1003
- bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
2178
+ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1004
2179
  if (dest_seq_id != -1) {
1005
2180
  // single sequence
1006
2181
 
@@ -1033,6 +2208,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1033
2208
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1034
2209
  return false;
1035
2210
  }
2211
+ commit();
1036
2212
 
1037
2213
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1038
2214
  // Assume that this is one contiguous block of cells
@@ -1052,7 +2228,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1052
2228
  clear();
1053
2229
 
1054
2230
  for (uint32_t i = 0; i < cell_count; ++i) {
1055
- llama_kv_cell & cell = cells[i];
2231
+ kv_cell & cell = cells[i];
1056
2232
 
1057
2233
  llama_pos pos;
1058
2234
  uint32_t n_seq_id;
@@ -1066,7 +2242,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1066
2242
  llama_seq_id seq_id;
1067
2243
  io.read_to(&seq_id, sizeof(seq_id));
1068
2244
 
1069
- // TODO: llama_kv_cache_unified should have a notion of max sequences
2245
+ // TODO: llama_kv_cache_recurrent should have a notion of max sequences
1070
2246
  //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1071
2247
  if (seq_id < 0) {
1072
2248
  //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
@@ -1076,14 +2252,12 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1076
2252
 
1077
2253
  cell.seq_id.insert(seq_id);
1078
2254
 
1079
- if (recurrent) {
1080
- int32_t & tail = cells[seq_id].tail;
1081
- if (tail != -1) {
1082
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
1083
- return false;
1084
- }
1085
- tail = i;
2255
+ int32_t & tail = cells[seq_id].tail;
2256
+ if (tail != -1) {
2257
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
2258
+ return false;
1086
2259
  }
2260
+ tail = i;
1087
2261
  }
1088
2262
  }
1089
2263
 
@@ -1091,18 +2265,16 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1091
2265
  used = cell_count;
1092
2266
  }
1093
2267
 
1094
- if (recurrent) {
1095
- for (uint32_t i = 0; i < cell_count; ++i) {
1096
- uint32_t cell_id = head + i;
1097
- // make sure the recurrent states will keep their restored state
1098
- cells[cell_id].src = cell_id;
1099
- }
2268
+ for (uint32_t i = 0; i < cell_count; ++i) {
2269
+ uint32_t cell_id = head + i;
2270
+ // make sure the recurrent states will keep their restored state
2271
+ cells[cell_id].src = cell_id;
1100
2272
  }
1101
2273
 
1102
2274
  return true;
1103
2275
  }
1104
2276
 
1105
- bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
2277
+ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1106
2278
  uint32_t v_trans;
1107
2279
  uint32_t n_layer;
1108
2280
  io.read_to(&v_trans, sizeof(v_trans));
@@ -1116,7 +2288,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1116
2288
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1117
2289
  return false;
1118
2290
  }
1119
- if (v_trans != (bool) v_trans) {
2291
+ if (false != (bool) v_trans) {
1120
2292
  LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1121
2293
  return false;
1122
2294
  }
@@ -1220,117 +2392,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1220
2392
  return true;
1221
2393
  }
1222
2394
 
1223
- //
1224
- // interface implementation
1225
- //
1226
-
1227
- int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
1228
- if (!kv) {
1229
- return 0;
1230
- }
1231
-
1232
- return kv->get_n_tokens();
1233
- }
1234
-
1235
- int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
1236
- if (!kv) {
1237
- return 0;
1238
- }
1239
-
1240
- return kv->get_used_cells();
1241
- }
1242
-
1243
- void llama_kv_cache_clear(llama_kv_cache * kv) {
1244
- if (!kv) {
1245
- return;
1246
- }
1247
-
1248
- kv->clear();
1249
- }
1250
-
1251
- bool llama_kv_cache_seq_rm(
1252
- llama_kv_cache * kv,
1253
- llama_seq_id seq_id,
1254
- llama_pos p0,
1255
- llama_pos p1) {
1256
- if (!kv) {
1257
- return true;
1258
- }
1259
-
1260
- return kv->seq_rm(seq_id, p0, p1);
1261
- }
1262
-
1263
- void llama_kv_cache_seq_cp(
1264
- llama_kv_cache * kv,
1265
- llama_seq_id seq_id_src,
1266
- llama_seq_id seq_id_dst,
1267
- llama_pos p0,
1268
- llama_pos p1) {
1269
- if (!kv) {
1270
- return;
1271
- }
1272
-
1273
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1274
- }
1275
-
1276
- void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
1277
- if (!kv) {
1278
- return;
1279
- }
1280
-
1281
- kv->seq_keep(seq_id);
1282
- }
1283
-
1284
- void llama_kv_cache_seq_add(
1285
- llama_kv_cache * kv,
1286
- llama_seq_id seq_id,
1287
- llama_pos p0,
1288
- llama_pos p1,
1289
- llama_pos delta) {
1290
- if (!kv) {
1291
- return;
1292
- }
1293
-
1294
- kv->seq_add(seq_id, p0, p1, delta);
1295
- }
1296
-
1297
- void llama_kv_cache_seq_div(
1298
- llama_kv_cache * kv,
1299
- llama_seq_id seq_id,
1300
- llama_pos p0,
1301
- llama_pos p1,
1302
- int d) {
1303
- if (!kv) {
1304
- return;
1305
- }
1306
-
1307
- kv->seq_div(seq_id, p0, p1, d);
1308
- }
1309
-
1310
- llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
1311
- if (!kv) {
1312
- return 0;
1313
- }
1314
-
1315
- return kv->seq_pos_max(seq_id);
1316
- }
1317
-
1318
- void llama_kv_cache_defrag(llama_kv_cache * kv) {
1319
- if (!kv) {
1320
- return;
1321
- }
1322
-
1323
- kv->defrag();
1324
- }
1325
-
1326
- bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
1327
- if (!kv) {
1328
- return false;
1329
- }
1330
-
1331
- return kv->get_can_shift();
1332
- }
1333
-
1334
2395
  //
1335
2396
  // kv cache view
1336
2397
  //
@@ -1340,7 +2401,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
1340
2401
  /*.n_cells = */ 0,
1341
2402
  /*.n_seq_max = */ n_seq_max,
1342
2403
  /*.token_count = */ 0,
1343
- /*.used_cells = */ llama_kv_cache_used_cells(&kv),
2404
+ /*.used_cells = */ kv.get_used_cells(),
1344
2405
  /*.max_contiguous = */ 0,
1345
2406
  /*.max_contiguous_idx = */ -1,
1346
2407
  /*.cells = */ nullptr,
@@ -1379,7 +2440,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache
1379
2440
  view->cells_sequences = (llama_seq_id *)p;
1380
2441
  }
1381
2442
 
1382
- const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
2443
+ const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
1383
2444
  llama_kv_cache_view_cell * c_curr = view->cells;
1384
2445
  llama_seq_id * cs_curr = view->cells_sequences;
1385
2446
  int32_t used_cells = 0;