whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,8 +1,7 @@
1
1
  #include "llama-kv-cache.h"
2
2
 
3
3
  #include "llama-impl.h"
4
- #include "llama-batch.h"
5
- #include "llama-cparams.h"
4
+ #include "llama-io.h"
6
5
  #include "llama-model.h"
7
6
  #include "llama-context.h"
8
7
 
@@ -14,38 +13,37 @@
14
13
  #include <stdexcept>
15
14
 
16
15
  //
17
- // llama_kv_cache_unified
16
+ // llama_kv_cache
18
17
  //
19
18
 
20
- uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
21
- // the FA kernels require padding to avoid extra runtime boundary checks
22
- return cparams.flash_attn ? 256u : 32u;
23
- }
24
-
25
- llama_kv_cache_unified::llama_kv_cache_unified(
26
- const llama_model & model,
27
- layer_filter_cb && filter,
28
- ggml_type type_k,
29
- ggml_type type_v,
30
- bool v_trans,
31
- bool offload,
32
- uint32_t kv_size,
33
- uint32_t n_seq_max,
34
- uint32_t n_pad,
35
- uint32_t n_swa,
36
- llama_swa_type swa_type) :
19
+ llama_kv_cache::llama_kv_cache(
20
+ const llama_model & model,
21
+ ggml_type type_k,
22
+ ggml_type type_v,
23
+ bool v_trans,
24
+ bool offload,
25
+ bool unified,
26
+ uint32_t kv_size,
27
+ uint32_t n_seq_max,
28
+ uint32_t n_pad,
29
+ uint32_t n_swa,
30
+ llama_swa_type swa_type,
31
+ const layer_filter_cb & filter,
32
+ const layer_reuse_cb & reuse) :
37
33
  model(model), hparams(model.hparams), v_trans(v_trans),
38
- n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
34
+ n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
39
35
 
40
36
  GGML_ASSERT(kv_size % n_pad == 0);
41
37
 
38
+ const uint32_t n_layer_kv = hparams.n_layer_kv();
39
+
42
40
  // create a context for each buffer type
43
41
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
44
42
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
45
43
  auto it = ctx_map.find(buft);
46
44
  if (it == ctx_map.end()) {
47
45
  ggml_init_params params = {
48
- /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
46
+ /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
49
47
  /*.mem_buffer =*/ NULL,
50
48
  /*.no_alloc =*/ true,
51
49
  };
@@ -64,18 +62,48 @@ llama_kv_cache_unified::llama_kv_cache_unified(
64
62
  return it->second;
65
63
  };
66
64
 
67
- head = 0;
65
+ GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
66
+
67
+ v_heads.resize(n_stream);
68
+ for (uint32_t s = 0; s < n_stream; ++s) {
69
+ v_heads[s] = 0;
70
+ }
71
+
72
+ v_cells.resize(n_stream);
73
+ for (uint32_t s = 0; s < n_stream; ++s) {
74
+ v_cells[s].resize(kv_size);
75
+ }
76
+
77
+ // by default, all sequence ids are mapped to the 0th stream
78
+ seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
79
+
80
+ if (n_stream > 1) {
81
+ seq_to_stream.resize(n_stream, 0);
82
+ for (uint32_t s = 0; s < n_stream; ++s) {
83
+ seq_to_stream[s] = s;
84
+ }
85
+ }
68
86
 
69
- cells.resize(kv_size);
87
+ // [TAG_V_CACHE_VARIABLE]
88
+ if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
89
+ LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
90
+ __func__, hparams.n_embd_v_gqa_max());
91
+ }
70
92
 
71
93
  for (uint32_t il = 0; il < hparams.n_layer; il++) {
94
+ if (!hparams.has_kv(il)) {
95
+ LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
96
+ continue;
97
+ }
98
+
72
99
  if (filter && !filter(il)) {
73
- LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
100
+ LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
74
101
  continue;
75
102
  }
76
103
 
77
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
78
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
104
+ // [TAG_V_CACHE_VARIABLE]
105
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
106
+ const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
79
107
 
80
108
  const char * dev_name = "CPU";
81
109
 
@@ -98,14 +126,47 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98
126
  ggml_tensor * k;
99
127
  ggml_tensor * v;
100
128
 
101
- k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
129
+ k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
130
+ v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
103
131
 
104
132
  ggml_format_name(k, "cache_k_l%d", il);
105
133
  ggml_format_name(v, "cache_v_l%d", il);
106
134
 
135
+ std::vector<ggml_tensor *> k_stream;
136
+ std::vector<ggml_tensor *> v_stream;
137
+
138
+ for (uint32_t s = 0; s < n_stream; ++s) {
139
+ k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
140
+ v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
141
+ }
142
+
107
143
  map_layer_ids[il] = layers.size();
108
- layers.push_back({ il, k, v });
144
+
145
+ layers.push_back({ il, k, v, k_stream, v_stream, });
146
+ }
147
+
148
+ if (reuse) {
149
+ LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
150
+
151
+ for (uint32_t il = 0; il < hparams.n_layer; il++) {
152
+ const int32_t il_reuse = reuse(il);
153
+
154
+ if (il_reuse < 0) {
155
+ LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
156
+ continue;
157
+ }
158
+
159
+ if (filter && !filter(il)) {
160
+ LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
161
+ continue;
162
+ }
163
+
164
+ GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
165
+
166
+ map_layer_ids[il] = map_layer_ids[il_reuse];
167
+
168
+ LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
169
+ }
109
170
  }
110
171
 
111
172
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -128,25 +189,31 @@ llama_kv_cache_unified::llama_kv_cache_unified(
128
189
  const size_t memory_size_k = size_k_bytes();
129
190
  const size_t memory_size_v = size_v_bytes();
130
191
 
131
- LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
132
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
192
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
193
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
133
194
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
134
195
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
135
196
  }
136
- }
137
197
 
138
- void llama_kv_cache_unified::clear() {
139
- cells.reset();
198
+ const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
199
+ debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
200
+ }
140
201
 
141
- head = 0;
202
+ void llama_kv_cache::clear(bool data) {
203
+ for (uint32_t s = 0; s < n_stream; ++s) {
204
+ v_cells[s].reset();
205
+ v_heads[s] = 0;
206
+ }
142
207
 
143
- for (auto & buf : bufs) {
144
- ggml_backend_buffer_clear(buf.get(), 0);
208
+ if (data) {
209
+ for (auto & buf : bufs) {
210
+ ggml_backend_buffer_clear(buf.get(), 0);
211
+ }
145
212
  }
146
213
  }
147
214
 
148
- bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
149
- uint32_t new_head = cells.size();
215
+ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
216
+ GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
150
217
 
151
218
  if (p0 < 0) {
152
219
  p0 = 0;
@@ -156,51 +223,147 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
156
223
  p1 = std::numeric_limits<llama_pos>::max();
157
224
  }
158
225
 
159
- for (uint32_t i = 0; i < cells.size(); ++i) {
160
- if (!cells.pos_in(i, p0, p1)) {
161
- continue;
162
- }
226
+ if (seq_id >= 0) {
227
+ auto & cells = v_cells[seq_to_stream[seq_id]];
228
+ auto & head = v_heads[seq_to_stream[seq_id]];
163
229
 
164
- if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
165
- if (new_head == cells.size()) {
166
- new_head = i;
230
+ uint32_t new_head = cells.size();
231
+
232
+ for (uint32_t i = 0; i < cells.size(); ++i) {
233
+ if (!cells.pos_in(i, p0, p1)) {
234
+ continue;
235
+ }
236
+
237
+ if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
238
+ if (new_head == cells.size()) {
239
+ new_head = i;
240
+ }
167
241
  }
168
242
  }
169
- }
170
243
 
171
- // If we freed up a slot, set head to it so searching can start there.
172
- if (new_head != cells.size() && new_head < head) {
173
- head = new_head;
244
+ // If we freed up a slot, set head to it so searching can start there.
245
+ if (new_head != cells.size() && new_head < head) {
246
+ head = new_head;
247
+ }
248
+ } else {
249
+ // match any sequence
250
+ for (uint32_t s = 0; s < n_stream; ++s) {
251
+ auto & cells = v_cells[s];
252
+ auto & head = v_heads[s];
253
+
254
+ uint32_t new_head = cells.size();
255
+
256
+ for (uint32_t i = 0; i < cells.size(); ++i) {
257
+ if (!cells.pos_in(i, p0, p1)) {
258
+ continue;
259
+ }
260
+
261
+ cells.rm(i);
262
+
263
+ if (new_head == cells.size()) {
264
+ new_head = i;
265
+ }
266
+ }
267
+
268
+ // If we freed up a slot, set head to it so searching can start there.
269
+ if (new_head != cells.size() && new_head < head) {
270
+ head = new_head;
271
+ }
272
+ }
174
273
  }
175
274
 
176
275
  return true;
177
276
  }
178
277
 
179
- void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
180
- if (seq_id_src == seq_id_dst) {
278
+ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
279
+ GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
280
+ GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
281
+
282
+ const auto s0 = seq_to_stream[seq_id_src];
283
+ const auto s1 = seq_to_stream[seq_id_dst];
284
+
285
+ if (s0 == s1) {
286
+ // since both sequences are in the same stream, no data copy is necessary
287
+ // we just have to update the cells meta data
288
+
289
+ auto & cells = v_cells[s0];
290
+
291
+ if (seq_id_src == seq_id_dst) {
292
+ return;
293
+ }
294
+
295
+ if (p0 < 0) {
296
+ p0 = 0;
297
+ }
298
+
299
+ if (p1 < 0) {
300
+ p1 = std::numeric_limits<llama_pos>::max();
301
+ }
302
+
303
+ for (uint32_t i = 0; i < cells.size(); ++i) {
304
+ if (!cells.pos_in(i, p0, p1)) {
305
+ continue;
306
+ }
307
+
308
+ if (cells.seq_has(i, seq_id_src)) {
309
+ cells.seq_add(i, seq_id_dst);
310
+ }
311
+ }
312
+
181
313
  return;
182
314
  }
183
315
 
184
- if (p0 < 0) {
185
- p0 = 0;
316
+ // cross-stream sequence copies require to copy the actual buffer data
317
+
318
+ bool is_full = true;
319
+
320
+ if (p0 > 0 && p0 + 1 < (int) get_size()) {
321
+ is_full = false;
186
322
  }
187
323
 
188
- if (p1 < 0) {
189
- p1 = std::numeric_limits<llama_pos>::max();
324
+ if (p1 > 0 && p1 + 1 < (int) get_size()) {
325
+ is_full = false;
190
326
  }
191
327
 
192
- for (uint32_t i = 0; i < cells.size(); ++i) {
193
- if (!cells.pos_in(i, p0, p1)) {
194
- continue;
195
- }
328
+ GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
329
+
330
+ // enqueue the copy operation - the buffer copy will be performed during the next update
331
+ sc_info.ssrc.push_back(s0);
332
+ sc_info.sdst.push_back(s1);
196
333
 
197
- if (cells.seq_has(i, seq_id_src)) {
198
- cells.seq_add(i, seq_id_dst);
334
+ v_cells[s1].reset();
335
+ for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
336
+ if (v_cells[s0].seq_has(i, seq_id_src)) {
337
+ llama_pos pos = v_cells[s0].pos_get(i);
338
+ llama_pos shift = v_cells[s0].get_shift(i);
339
+
340
+ if (shift != 0) {
341
+ pos -= shift;
342
+ assert(pos >= 0);
343
+ }
344
+
345
+ v_cells[s1].pos_set(i, pos);
346
+ v_cells[s1].seq_add(i, seq_id_dst);
347
+
348
+ if (shift != 0) {
349
+ v_cells[s1].pos_add(i, shift);
350
+ }
199
351
  }
200
352
  }
353
+
354
+ v_heads[s1] = v_heads[s0];
355
+
356
+ //for (uint32_t s = 0; s < n_stream; ++s) {
357
+ // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
358
+ //}
201
359
  }
202
360
 
203
- void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
361
+ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
362
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
363
+
364
+ auto & cells = v_cells[seq_to_stream[seq_id]];
365
+ auto & head = v_heads[seq_to_stream[seq_id]];
366
+
204
367
  uint32_t new_head = cells.size();
205
368
 
206
369
  for (uint32_t i = 0; i < cells.size(); ++i) {
@@ -217,7 +380,12 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
217
380
  }
218
381
  }
219
382
 
220
- void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
383
+ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
384
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
385
+
386
+ auto & cells = v_cells[seq_to_stream[seq_id]];
387
+ auto & head = v_heads[seq_to_stream[seq_id]];
388
+
221
389
  if (shift == 0) {
222
390
  return;
223
391
  }
@@ -256,7 +424,11 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
256
424
  head = new_head != cells.size() ? new_head : 0;
257
425
  }
258
426
 
259
- void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
427
+ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
428
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
429
+
430
+ auto & cells = v_cells[seq_to_stream[seq_id]];
431
+
260
432
  if (d == 1) {
261
433
  return;
262
434
  }
@@ -285,2165 +457,1115 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
285
457
  }
286
458
  }
287
459
 
288
- llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
460
+ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
461
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
462
+
463
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
464
+
289
465
  return cells.seq_pos_min(seq_id);
290
466
  }
291
467
 
292
- llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
468
+ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
469
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
470
+
471
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
472
+
293
473
  return cells.seq_pos_max(seq_id);
294
474
  }
295
475
 
296
- void llama_kv_cache_unified::restore() {
297
- for (auto & state : recovery.states) {
298
- cells.set(state.i, state.cells);
476
+ std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
477
+ std::map<ggml_backend_buffer_type_t, size_t> ret;
478
+ for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
479
+ ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
299
480
  }
300
-
301
- recovery.clear();
481
+ return ret;
302
482
  }
303
483
 
304
- void llama_kv_cache_unified::commit() {
305
- if (recovery.states.empty()) {
306
- LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
307
- __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
308
- return;
309
- }
484
+ llama_memory_context_ptr llama_kv_cache::init_batch(
485
+ llama_batch_allocr & balloc,
486
+ uint32_t n_ubatch,
487
+ bool embd_all) {
488
+ GGML_UNUSED(embd_all);
310
489
 
311
- recovery.clear();
312
- }
490
+ do {
491
+ balloc.split_reset();
313
492
 
314
- bool llama_kv_cache_unified::update(llama_context & lctx) {
315
- bool need_reserve = false;
493
+ std::vector<llama_ubatch> ubatches;
494
+ while (true) {
495
+ auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
316
496
 
317
- auto * sched = lctx.get_sched();
497
+ if (ubatch.n_tokens == 0) {
498
+ break;
499
+ }
318
500
 
319
- if (cells.get_has_shift()) {
320
- if (!get_can_shift()) {
321
- GGML_ABORT("The current KV cache / model configuration does not support K-shift");
501
+ ubatches.push_back(std::move(ubatch)); // NOLINT
322
502
  }
323
503
 
324
- LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
504
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
505
+ // failed to find a suitable split
506
+ break;
507
+ }
325
508
 
326
- // apply K-shift if needed
327
- if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
328
- ggml_backend_sched_reset(sched);
509
+ auto sinfos = prepare(ubatches);
510
+ if (sinfos.empty()) {
511
+ break;
512
+ }
329
513
 
330
- auto * gf = lctx.graph_init();
514
+ return std::make_unique<llama_kv_cache_context>(
515
+ this, std::move(sinfos), std::move(ubatches));
516
+ } while (false);
331
517
 
332
- auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
518
+ return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
519
+ }
333
520
 
334
- ggml_backend_sched_alloc_graph(sched, gf);
521
+ llama_memory_context_ptr llama_kv_cache::init_full() {
522
+ return std::make_unique<llama_kv_cache_context>(this);
523
+ }
335
524
 
336
- res->set_inputs(nullptr);
525
+ llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
526
+ GGML_UNUSED(optimize);
337
527
 
338
- lctx.graph_compute(gf, false);
528
+ bool do_shift = get_has_shift();
339
529
 
340
- need_reserve = true;
341
- }
530
+ return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
531
+ }
342
532
 
343
- cells.reset_shift();
344
- }
533
+ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
534
+ llama_kv_cache::slot_info_vec_t res;
345
535
 
346
- if (do_defrag) {
347
- LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
536
+ struct state_t {
537
+ slot_info sinfo; // slot info for the ubatch
348
538
 
349
- if (defrag_prepare(lctx.graph_max_nodes())) {
350
- ggml_backend_sched_reset(sched);
539
+ std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
351
540
 
352
- auto * gf = lctx.graph_init();
541
+ std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
542
+ };
353
543
 
354
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
544
+ // remember the old state of the cells so we can restore it in the end
545
+ std::vector<state_t> states;
355
546
 
356
- ggml_backend_sched_alloc_graph(sched, gf);
547
+ bool success = true;
357
548
 
358
- res->set_inputs(nullptr);
549
+ for (const auto & ubatch : ubatches) {
550
+ // only find a suitable slot for the ubatch. don't modify the cells yet
551
+ const auto sinfo_new = find_slot(ubatch, false);
552
+ if (sinfo_new.empty()) {
553
+ success = false;
554
+ break;
555
+ }
359
556
 
360
- lctx.graph_compute(gf, false);
557
+ // remeber the position that we found
558
+ res.push_back(sinfo_new);
559
+
560
+ // store the old state of the cells in the recovery stack
561
+ {
562
+ state_t state = { sinfo_new, v_heads, {} };
563
+
564
+ for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
565
+ auto & cells = v_cells[sinfo_new.strm[s]];
566
+
567
+ state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
568
+ }
361
569
 
362
- need_reserve = true;
570
+ states.push_back(std::move(state));
363
571
  }
364
572
 
365
- do_defrag = false;
573
+ // now emplace the ubatch
574
+ apply_ubatch(sinfo_new, ubatch);
366
575
  }
367
576
 
368
- return need_reserve;
369
- }
577
+ GGML_ASSERT(!states.empty() || !success);
370
578
 
371
- void llama_kv_cache_unified::defrag_sched(float thold) {
372
- // - do not defrag small contexts (i.e. < 2048 tokens)
373
- // - count the padding towards the number of used tokens
374
- const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
579
+ // iterate backwards and restore the cells to their original state
580
+ for (auto it = states.rbegin(); it != states.rend(); ++it) {
581
+ const auto & sinfo = it->sinfo;
375
582
 
376
- // queue defragmentation for next llama_kv_cache_update
377
- if (fragmentation > thold) {
378
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
583
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
584
+ auto & cells = v_cells[sinfo.strm[s]];
585
+ auto & head = v_heads[sinfo.strm[s]];
379
586
 
380
- do_defrag = true;
587
+ cells.set(sinfo.idxs[s], it->v_cells[s]);
588
+ head = it->v_heads_old[s];
589
+ }
381
590
  }
382
- }
383
591
 
384
- void llama_kv_cache_unified::set_full() {
385
- n = cells.size();
592
+ if (!success) {
593
+ return {};
594
+ }
386
595
 
387
- // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
388
- // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
389
- // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
390
- // setting it to 0 is the simplest way to achieve that
391
- // ref: https://github.com/ggml-org/llama.cpp/issues/13359
392
- head = 0;
596
+ return res;
393
597
  }
394
598
 
395
- llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
396
- return llama_sbatch(batch, hparams.n_embd, true, logits_all);
397
- }
599
+ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
600
+ bool updated = false;
398
601
 
399
- llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
400
- GGML_UNUSED(embd_pooled);
401
- return sbatch.split_simple(n_ubatch);
402
- }
602
+ auto * sched = lctx->get_sched();
403
603
 
404
- bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
405
- const uint32_t n_tokens = ubatch.n_tokens;
604
+ if (!sc_info.empty()) {
605
+ assert(n_stream > 1 && "stream copy should never happen with a single stream");
406
606
 
407
- // if we have enough unused cells before the current head ->
408
- // better to start searching from the beginning of the cache, hoping to fill it
409
- if (head > cells.get_used() + 2*ubatch.n_tokens) {
410
- head = 0;
411
- }
607
+ llama_synchronize(lctx);
412
608
 
413
- // otherwise, one cell per token.
609
+ const size_t n_copy = sc_info.ssrc.size();
414
610
 
415
- if (n_tokens > cells.size()) {
416
- LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
417
- return false;
418
- }
611
+ for (size_t i = 0; i < n_copy; ++i) {
612
+ const auto ssrc = sc_info.ssrc[i];
613
+ const auto sdst = sc_info.sdst[i];
419
614
 
420
- //#define FIND_SLOT_DEBUG 1
421
- #if FIND_SLOT_DEBUG
422
- LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
615
+ assert(ssrc < n_stream);
616
+ assert(sdst < n_stream);
423
617
 
424
- // for debugging
425
- {
426
- std::string ss;
427
- if (n_swa > 0) {
428
- for (uint32_t i = 0; i < size; ++i) {
429
- if (cells.is_empty(i)) {
430
- ss += '.';
431
- } else {
432
- ss += 'x';
433
- }
434
- if (i%256 == 255) {
435
- ss += '\n';
436
- }
437
- }
438
- }
439
- LLAMA_LOG_WARN("\n%s\n", ss.c_str());
440
- }
441
- #endif
618
+ LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
442
619
 
443
- uint32_t n_tested = 0;
620
+ assert(ssrc != sdst);
444
621
 
445
- while (true) {
446
- if (head + n_tokens > cells.size()) {
447
- n_tested += cells.size() - head;
448
- head = 0;
449
- continue;
450
- }
622
+ for (uint32_t il = 0; il < layers.size(); ++il) {
623
+ const auto & layer = layers[il];
451
624
 
452
- bool found = true;
453
- for (uint32_t i = 0; i < n_tokens; i++) {
454
- // TODO: improve to accept cells that are masked by the SWA
455
- if (!cells.is_empty(head + i)) {
456
- found = false;
457
- head += i + 1;
458
- n_tested += i + 1;
459
- break;
625
+ ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
626
+ ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
460
627
  }
461
628
  }
462
-
463
- if (found) {
464
- break;
465
- }
466
-
467
- if (n_tested >= cells.size()) {
468
- //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
469
- return false;
470
- }
471
629
  }
472
630
 
473
- // store the old state of the cells in the recovery stack
474
- recovery.states.push_back({head, cells.cp(head, n_tokens)});
475
-
476
- for (uint32_t i = 0; i < n_tokens; ++i) {
477
- cells.pos_set(head + i, ubatch.pos[i]);
478
-
479
- for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
480
- cells.seq_add(head + i, ubatch.seq_id[i][j]);
631
+ if (do_shift) {
632
+ if (!get_can_shift()) {
633
+ GGML_ABORT("The current KV cache / model configuration does not support K-shift");
481
634
  }
482
- }
483
-
484
- // a heuristic, to avoid attending the full cache if it is not yet utilized
485
- // after enough generations, the benefit from this heuristic disappears
486
- // if we start defragmenting the cache, the benefit from this will be more important
487
- n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
488
635
 
489
- #ifdef FIND_SLOT_DEBUG
490
- LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
491
- #endif
492
-
493
- return true;
494
- }
636
+ LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
495
637
 
496
- bool llama_kv_cache_unified::get_can_shift() const {
497
- return true;
498
- }
638
+ // apply K-shift if needed
639
+ if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
640
+ ggml_backend_sched_reset(sched);
499
641
 
500
- uint32_t llama_kv_cache_unified::get_n() const {
501
- return n;
502
- }
642
+ auto * res = lctx->get_gf_res_reserve();
503
643
 
504
- uint32_t llama_kv_cache_unified::get_size() const {
505
- return cells.size();
506
- }
644
+ res->reset();
507
645
 
508
- ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
509
- const int32_t ikv = map_layer_ids.at(il);
646
+ auto * gf = build_graph_shift(res, lctx);
647
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
648
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
649
+ return updated;
650
+ }
510
651
 
511
- auto * k = layers[ikv].k;
652
+ res->set_inputs(nullptr);
512
653
 
513
- return ggml_view_3d(ctx, k,
514
- hparams.n_embd_head_k, hparams.n_head_kv(il), n,
515
- ggml_row_size(k->type, hparams.n_embd_head_k),
516
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
517
- 0);
518
- }
654
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
655
+ LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
656
+ return updated;
657
+ }
519
658
 
520
- ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
521
- const int32_t ikv = map_layer_ids.at(il);
659
+ updated = true;
660
+ }
522
661
 
523
- auto * v = layers[ikv].v;
662
+ for (uint32_t s = 0; s < n_stream; ++s) {
663
+ auto & cells = v_cells[s];
524
664
 
525
- if (!v_trans) {
526
- // note: v->nb[1] <= v->nb[2]
527
- return ggml_view_3d(ctx, v,
528
- hparams.n_embd_head_v, hparams.n_head_kv(il), n,
529
- ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
530
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
531
- 0);
665
+ cells.reset_shift();
666
+ }
532
667
  }
533
668
 
534
- // note: v->nb[1] > v->nb[2]
535
- return ggml_view_3d(ctx, v,
536
- n, hparams.n_head_kv(il), hparams.n_embd_head_v,
537
- ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
538
- ggml_row_size(v->type, v->ne[1]), // v->nb[2]
539
- 0);
669
+ return updated;
540
670
  }
541
671
 
542
- ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
543
- const int32_t ikv = map_layer_ids.at(il);
544
-
545
- auto * k = layers[ikv].k;
672
+ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
546
673
 
547
- const int64_t n_tokens = k_cur->ne[2];
674
+ if (debug > 0) {
675
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
676
+ const auto seq_id = ubatch.seq_id_unq[s];
677
+ const auto stream_id = seq_to_stream[seq_id];
678
+ const auto & cells = v_cells[stream_id];
679
+ const uint32_t head_cur = v_heads[stream_id];
548
680
 
549
- ggml_tensor * k_view = ggml_view_1d(ctx, k,
550
- n_tokens*hparams.n_embd_k_gqa(il),
551
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
681
+ LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
682
+ __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
552
683
 
553
- return ggml_cpy(ctx, k_cur, k_view);
554
- }
684
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
685
+ std::string ss;
686
+ for (uint32_t i = 0; i < cells.size(); ++i) {
687
+ if (cells.is_empty(i)) {
688
+ ss += '.';
689
+ } else {
690
+ assert(cells.seq_count(i) >= 1);
555
691
 
556
- ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
557
- const int32_t ikv = map_layer_ids.at(il);
692
+ if (cells.seq_count(i) == 1) {
693
+ ss += std::to_string(cells.seq_get(i));
694
+ } else {
695
+ ss += 'M';
696
+ }
697
+ }
698
+ if (i%256 == 255) {
699
+ ss += " *";
700
+ ss += '\n';
701
+ }
702
+ }
703
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
704
+ }
558
705
 
559
- auto * v = layers[ikv].v;
706
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
707
+ std::string ss;
708
+ for (uint32_t i = 0; i < cells.size(); ++i) {
709
+ std::string cur;
710
+ if (cells.is_empty(i)) {
711
+ cur = '.';
712
+ } else {
713
+ cur = std::to_string(cells.pos_get(i));
714
+ }
715
+ const int n = cur.size();
716
+ for (int j = 0; j < 5 - n; ++j) {
717
+ cur += ' ';
718
+ }
719
+ ss += cur;
720
+ if (i%256 == 255) {
721
+ ss += " *";
722
+ }
723
+ if (i%64 == 63) {
724
+ ss += '\n';
725
+ }
726
+ }
727
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
728
+ }
560
729
 
561
- const int64_t n_tokens = v_cur->ne[2];
730
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
731
+ if (cells.seq_pos_min(s) < 0) {
732
+ continue;
733
+ }
562
734
 
563
- v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
735
+ LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
736
+ }
737
+ }
738
+ }
564
739
 
565
- ggml_tensor * v_view = nullptr;
740
+ uint32_t n_tokens = ubatch.n_tokens;
741
+ uint32_t n_seqs = 1;
566
742
 
567
- if (!v_trans) {
568
- v_view = ggml_view_1d(ctx, v,
569
- n_tokens*hparams.n_embd_v_gqa(il),
570
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
571
- } else {
572
- // note: the V cache is transposed when not using flash attention
573
- v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
574
- (v->ne[1])*ggml_element_size(v),
575
- ( head)*ggml_element_size(v));
743
+ if (n_stream > 1) {
744
+ GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
576
745
 
577
- v_cur = ggml_transpose(ctx, v_cur);
746
+ n_seqs = ubatch.n_seqs_unq;
747
+ n_tokens = n_tokens / n_seqs;
578
748
  }
579
749
 
580
- return ggml_cpy(ctx, v_cur, v_view);
581
- }
750
+ slot_info res = {
751
+ /*.s0 =*/ LLAMA_MAX_SEQ,
752
+ /*.s1 =*/ 0,
753
+ /*.strm =*/ { },
754
+ /*.idxs =*/ { },
755
+ };
582
756
 
583
- void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
584
- // no pruning is needed when the cache does not use SWA
585
- GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
757
+ res.resize(n_seqs);
586
758
 
587
- int n_attended = 0;
759
+ for (uint32_t s = 0; s < n_seqs; ++s) {
760
+ const auto seq_id = ubatch.seq_id_unq[s];
588
761
 
589
- for (uint32_t i = 0; i < cells.size(); ++i) {
590
- if (!cells.seq_has(i, seq_id)) {
591
- continue;
762
+ if (n_stream > 1) {
763
+ GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
764
+ GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
592
765
  }
593
766
 
594
- const llama_pos p0 = cells.pos_get(i);
767
+ res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
768
+ res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
769
+
770
+ res.strm[s] = seq_to_stream[seq_id];
771
+ res.idxs[s].reserve(n_tokens);
772
+
773
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
595
774
 
596
- if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
597
- n_attended++;
775
+ uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
776
+
777
+ // if we have enough unused cells before the current head ->
778
+ // better to start searching from the beginning of the cache, hoping to fill it
779
+ if (head_cur > cells.get_used() + 2*n_tokens) {
780
+ head_cur = 0;
598
781
  }
599
782
 
600
- if (is_masked_swa(p0, pmax)) {
601
- cells.seq_rm(i, seq_id);
783
+ if (n_tokens > cells.size()) {
784
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
785
+ return { };
602
786
  }
603
- }
604
787
 
605
- if (n_attended < std::min<int>(n_swa, pmin)) {
606
- LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
607
- }
608
- }
788
+ uint32_t n_tested = 0;
609
789
 
610
- void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
611
- const int64_t n_tokens = ubatch->n_tokens;
612
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
613
- const int64_t n_seqs = ubatch->n_seqs;
790
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
791
+ // for non-continuous slots, we test the tokens one by one
792
+ const uint32_t n_test = cont ? n_tokens : 1;
614
793
 
615
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
616
- float * data = (float *) dst->data;
794
+ while (true) {
795
+ if (head_cur + n_test > cells.size()) {
796
+ n_tested += cells.size() - head_cur;
797
+ head_cur = 0;
798
+ continue;
799
+ }
617
800
 
618
- const int64_t n_kv = n;
801
+ for (uint32_t i = 0; i < n_test; i++) {
802
+ const auto idx = head_cur;
619
803
 
620
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
621
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
622
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
623
- // Causal mask:
624
- // xxx-------
625
- // xxxx------
626
- // xxxxx-----
627
- // Non-causal mask:
628
- // xxxxx-----
629
- // xxxxx-----
630
- // xxxxx-----
631
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
632
- for (int h = 0; h < 1; ++h) {
633
- for (int s = 0; s < n_seqs; ++s) {
634
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
635
-
636
- for (int j = 0; j < n_seq_tokens; ++j) {
637
- const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
804
+ head_cur++;
805
+ n_tested++;
638
806
 
639
- for (int i = 0; i < n_kv; ++i) {
640
- float f = 0.0f;
807
+ //const llama_pos pos = ubatch.pos[i];
808
+ //const llama_seq_id seq_id = ubatch.seq_id[i][0];
641
809
 
642
- bool masked = false;
643
-
644
- if (cells.is_empty(i)) {
645
- masked = true;
646
- } else {
647
- const llama_pos p0 = cells.pos_get(i);
810
+ // can we use this cell? either:
811
+ // - the cell is empty
812
+ // - the cell is occupied only by one sequence:
813
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
814
+ // - mask SWA, using current max pos for that sequence in the cache
815
+ // always insert in the cell with minimum pos
816
+ bool can_use = cells.is_empty(idx);
648
817
 
649
- // mask the token if not the same sequence
650
- masked = masked || (!cells.seq_has(i, seq_id));
818
+ if (!can_use && cells.seq_count(idx) == 1) {
819
+ const llama_pos pos_cell = cells.pos_get(idx);
651
820
 
652
- // mask future tokens
653
- masked = masked || (causal_attn && p0 > p1);
821
+ // (disabled) causal mask
822
+ // note: it's better to purge any "future" tokens beforehand
823
+ //if (cells.seq_has(idx, seq_id)) {
824
+ // can_use = pos_cell >= pos;
825
+ //}
654
826
 
655
- // apply SWA if any
656
- masked = masked || (is_masked_swa(p0, p1));
827
+ if (!can_use) {
828
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
657
829
 
658
- if (!masked && hparams.use_alibi) {
659
- f = -std::abs(p0 - p1);
830
+ // SWA mask
831
+ if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
832
+ can_use = true;
660
833
  }
661
834
  }
835
+ }
662
836
 
663
- if (masked) {
664
- f = -INFINITY;
837
+ if (can_use) {
838
+ res.idxs[s].push_back(idx);
839
+ } else {
840
+ if (cont) {
841
+ break;
665
842
  }
666
-
667
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
668
843
  }
669
844
  }
670
- }
671
845
 
672
- // mask padded tokens
673
- if (data) {
674
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
675
- for (int j = 0; j < n_kv; ++j) {
676
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
677
- }
846
+ if (res.idxs[s].size() == n_tokens) {
847
+ break;
848
+ }
849
+
850
+ if (cont) {
851
+ res.idxs[s].clear();
852
+ }
853
+
854
+ if (n_tested >= cells.size()) {
855
+ //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
856
+ return { };
678
857
  }
679
858
  }
859
+
860
+ // we didn't find a suitable slot - return empty result
861
+ if (res.idxs[s].size() < n_tokens) {
862
+ return { };
863
+ }
680
864
  }
681
- }
682
865
 
683
- void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
684
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
866
+ assert(res.s1 >= res.s0);
685
867
 
686
- int32_t * data = (int32_t *) dst->data;
868
+ return res;
869
+ }
687
870
 
688
- for (uint32_t i = 0; i < cells.size(); ++i) {
689
- data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
871
+ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
872
+ // keep track of the max sequence position that we would overwrite with this ubatch
873
+ // for non-SWA cache, this would be always empty
874
+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
875
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
876
+ seq_pos_max_rm[s] = -1;
690
877
  }
691
- }
692
878
 
693
- void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
694
- const int64_t n_tokens = ubatch->n_tokens;
879
+ assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
695
880
 
696
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
697
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
881
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
882
+ for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
883
+ const uint32_t i = s*sinfo.size() + ii;
698
884
 
699
- int32_t * data = (int32_t *) dst->data;
885
+ auto & cells = v_cells[sinfo.strm[s]];
700
886
 
701
- const int64_t n_kv = n;
887
+ const auto idx = sinfo.idxs[s][ii];
702
888
 
703
- for (int h = 0; h < 1; ++h) {
704
- for (int j = 0; j < n_tokens; ++j) {
705
- for (int i = 0; i < n_kv; ++i) {
706
- // the position when the cells is empty is irrelevant - it will be masked out later in the attention
707
- const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
889
+ if (!cells.is_empty(idx)) {
890
+ assert(cells.seq_count(idx) == 1);
891
+
892
+ const llama_seq_id seq_id = cells.seq_get(idx);
893
+ const llama_pos pos = cells.pos_get(idx);
894
+
895
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
896
+
897
+ cells.rm(idx);
898
+ }
899
+
900
+ cells.pos_set(idx, ubatch.pos[i]);
708
901
 
709
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
902
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
903
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
710
904
  }
711
905
  }
712
906
  }
713
- }
714
907
 
715
- size_t llama_kv_cache_unified::total_size() const {
716
- size_t size = 0;
908
+ // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
909
+ // will be present in the cache. so we have to purge any position which is less than those we would overwrite
910
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
911
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
912
+ if (seq_pos_max_rm[s] == -1) {
913
+ continue;
914
+ }
717
915
 
718
- for (const auto & buf : bufs) {
719
- size += ggml_backend_buffer_get_size(buf.get());
720
- }
916
+ GGML_ASSERT(s < seq_to_stream.size());
721
917
 
722
- return size;
723
- }
918
+ auto & cells = v_cells[seq_to_stream[s]];
724
919
 
725
- size_t llama_kv_cache_unified::size_k_bytes() const {
726
- size_t size_k_bytes = 0;
920
+ if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
921
+ LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
922
+ __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
727
923
 
728
- for (const auto & layer : layers) {
729
- size_k_bytes += ggml_nbytes(layer.k);
924
+ seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
925
+ }
730
926
  }
731
927
 
732
- return size_k_bytes;
733
- }
734
-
735
- size_t llama_kv_cache_unified::size_v_bytes() const {
736
- size_t size_v_bytes = 0;
928
+ // move the head at the end of the slot
929
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
930
+ auto & head = v_heads[sinfo.strm[s]];
737
931
 
738
- for (const auto & layer : layers) {
739
- size_v_bytes += ggml_nbytes(layer.v);
932
+ head = sinfo.idxs[s].back() + 1;
740
933
  }
741
-
742
- return size_v_bytes;
743
934
  }
744
935
 
745
- ggml_tensor * llama_kv_cache_unified::build_rope_shift(
746
- const llama_cparams & cparams,
747
- ggml_context * ctx,
748
- ggml_tensor * cur,
749
- ggml_tensor * shift,
750
- ggml_tensor * factors,
751
- float freq_base,
752
- float freq_scale) const {
753
- const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
754
-
755
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
756
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
757
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
758
-
759
- const auto & n_rot = hparams.n_rot;
760
- const auto & rope_type = hparams.rope_type;
936
+ bool llama_kv_cache::get_can_shift() const {
937
+ return true;
938
+ }
761
939
 
762
- // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
763
- // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
764
- const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
940
+ uint32_t llama_kv_cache::get_size() const {
941
+ const auto & cells = v_cells[seq_to_stream[0]];
765
942
 
766
- ggml_tensor * tmp;
943
+ return cells.size();
944
+ }
767
945
 
768
- if (ggml_is_quantized(cur->type)) {
769
- // dequantize to f32 -> RoPE -> quantize back
770
- tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
946
+ uint32_t llama_kv_cache::get_n_stream() const {
947
+ return n_stream;
948
+ }
771
949
 
772
- tmp = ggml_rope_ext(ctx, tmp,
773
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
774
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
950
+ bool llama_kv_cache::get_has_shift() const {
951
+ bool result = false;
775
952
 
776
- tmp = ggml_cpy(ctx, tmp, cur);
777
- } else {
778
- // we rotate only the first n_rot dimensions
779
- tmp = ggml_rope_ext_inplace(ctx, cur,
780
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
781
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
953
+ for (uint32_t s = 0; s < n_stream; ++s) {
954
+ result |= v_cells[s].get_has_shift();
782
955
  }
783
956
 
784
- return tmp;
957
+ return result;
785
958
  }
786
959
 
787
- class llm_graph_input_k_shift : public llm_graph_input_i {
788
- public:
789
- llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
790
- virtual ~llm_graph_input_k_shift() = default;
791
-
792
- void set_input(const llama_ubatch * ubatch) override;
793
-
794
- ggml_tensor * k_shift; // I32 [kv_size]
795
-
796
- const llama_kv_cache_unified * kv_self;
797
- };
960
+ uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
961
+ uint32_t result = 0;
798
962
 
799
- void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
800
- GGML_UNUSED(ubatch);
963
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
964
+ const auto & cells = v_cells[sinfo.strm[s]];
801
965
 
802
- if (k_shift) {
803
- kv_self->set_input_k_shift(k_shift);
966
+ result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
804
967
  }
968
+
969
+ return result;
805
970
  }
806
971
 
807
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
808
- const llama_cparams & cparams,
809
- ggml_context * ctx,
810
- ggml_cgraph * gf) const {
811
- auto res = std::make_unique<llm_graph_result>();
972
+ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
973
+ const int32_t ikv = map_layer_ids.at(il);
812
974
 
813
- const auto & n_embd_head_k = hparams.n_embd_head_k;
814
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
975
+ auto * k = layers[ikv].k;
815
976
 
816
- //GGML_ASSERT(kv_self->size == n_ctx);
977
+ const uint64_t kv_size = get_size();
978
+ const uint64_t n_embd_k_gqa = k->ne[0];
817
979
 
818
- auto inp = std::make_unique<llm_graph_input_k_shift>(this);
980
+ assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
819
981
 
820
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
821
- ggml_set_input(inp->k_shift);
982
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
822
983
 
823
- for (const auto & layer : layers) {
824
- const uint32_t il = layer.il;
984
+ return ggml_view_4d(ctx, k,
985
+ hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
986
+ ggml_row_size(k->type, hparams.n_embd_head_k),
987
+ ggml_row_size(k->type, n_embd_k_gqa),
988
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size),
989
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
990
+ }
825
991
 
826
- const int64_t n_head_kv = hparams.n_head_kv(il);
827
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
992
+ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
993
+ const int32_t ikv = map_layer_ids.at(il);
828
994
 
829
- const float freq_base_l = model.get_rope_freq_base (cparams, il);
830
- const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
995
+ auto * v = layers[ikv].v;
831
996
 
832
- ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
997
+ const uint64_t kv_size = get_size();
998
+ const uint64_t n_embd_v_gqa = v->ne[0];
833
999
 
834
- ggml_tensor * k =
835
- ggml_view_3d(ctx, layer.k,
836
- n_embd_head_k, n_head_kv, cells.size(),
837
- ggml_row_size(layer.k->type, n_embd_head_k),
838
- ggml_row_size(layer.k->type, n_embd_k_gqa),
839
- 0);
1000
+ // [TAG_V_CACHE_VARIABLE]
1001
+ assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
840
1002
 
841
- ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
1003
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
842
1004
 
843
- ggml_build_forward_expand(gf, cur);
1005
+ if (!v_trans) {
1006
+ // note: v->nb[1] <= v->nb[2]
1007
+ return ggml_view_4d(ctx, v,
1008
+ hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1009
+ ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1010
+ ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1011
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
1012
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
844
1013
  }
845
1014
 
846
- res->add_input(std::move(inp));
847
-
848
- return res;
1015
+ // note: v->nb[1] > v->nb[2]
1016
+ return ggml_view_4d(ctx, v,
1017
+ n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1018
+ ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1019
+ ggml_row_size(v->type, kv_size), // v->nb[2]
1020
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
1021
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
849
1022
  }
850
1023
 
851
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
852
- const llama_cparams & cparams,
853
- ggml_context * ctx,
854
- ggml_cgraph * gf) const {
855
- auto res = std::make_unique<llm_graph_result>();
1024
+ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
1025
+ GGML_UNUSED(sinfo);
856
1026
 
857
- const auto & ids = defrag_info.ids;
1027
+ const int32_t ikv = map_layer_ids.at(il);
858
1028
 
859
- #if 0
860
- // CPU defrag
861
- //
862
- // TODO: optimizations are possible:
863
- // - multiple threads
864
- // - avoid copying to the host memory when already there
865
- //
866
- // likely not worth the effort, as we have ggml_graph based defrag
867
- //
1029
+ ggml_tensor * k = layers[ikv].k;
868
1030
 
869
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
870
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
1031
+ const int64_t n_embd_head = k_cur->ne[0];
1032
+ const int64_t n_head = k_cur->ne[1];
1033
+ const int64_t n_tokens = k_cur->ne[2];
871
1034
 
872
- const uint32_t kv_size = size;
1035
+ const int64_t n_embd_gqa = n_embd_head*n_head;
873
1036
 
874
- std::vector<uint8_t> buf_k;
875
- std::vector<uint8_t> buf_v;
1037
+ // we can merge dims 0 and 1
1038
+ // TODO: add ggml helper function for this?
1039
+ GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
876
1040
 
877
- for (uint32_t il = 0; il < n_layer; ++il) {
878
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
879
- const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
1041
+ k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
880
1042
 
881
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
882
- const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
1043
+ const int64_t n_stream = k->ne[2];
883
1044
 
884
- buf_k.resize(k_size);
885
- buf_v.resize(v_size);
1045
+ if (n_stream > 1) {
1046
+ const int64_t kv_size = get_size();
886
1047
 
887
- ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
888
- ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
1048
+ assert(n_embd_gqa == k->ne[0]);
1049
+ assert(kv_size == k->ne[1]);
889
1050
 
890
- // batch move [i, i+nm) to [id, id+nm)
891
- // note: cells can move only to a lower index
892
- for (uint32_t i = 0; i < n_kv; ++i) {
893
- const uint32_t id = ids[i];
1051
+ // merge the buffer across all streams because the idxs are global
1052
+ k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
1053
+ }
894
1054
 
895
- if (i == id || id == n_kv) {
896
- continue;
897
- }
1055
+ // store the current K values into the cache
1056
+ return ggml_set_rows(ctx, k, k_cur, k_idxs);
1057
+ }
898
1058
 
899
- uint32_t nm = 1;
1059
+ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
1060
+ GGML_UNUSED(sinfo);
900
1061
 
901
- while (i + nm < n_kv && ids[i + nm] == id + nm) {
902
- nm++;
903
- }
1062
+ const int32_t ikv = map_layer_ids.at(il);
904
1063
 
905
- // move keys
906
- {
907
- const int64_t os = i*k_size_row;
908
- const int64_t od = id*k_size_row;
1064
+ auto * v = layers[ikv].v;
909
1065
 
910
- memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
911
- }
1066
+ const int64_t n_embd_head = v_cur->ne[0];
1067
+ const int64_t n_head = v_cur->ne[1];
1068
+ const int64_t n_tokens = v_cur->ne[2];
912
1069
 
913
- // move values (note: they are transposed)
914
- {
915
- const int64_t os = i;
916
- const int64_t od = id;
1070
+ const int64_t n_embd_gqa = n_embd_head*n_head;
917
1071
 
918
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
919
- memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
920
- }
921
- }
1072
+ // we can merge dims 0 and 1
1073
+ GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
922
1074
 
923
- i += nm - 1;
924
- }
1075
+ const int64_t n_stream = v->ne[2];
925
1076
 
926
- ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
927
- ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
928
- }
929
- #else
930
- for (uint32_t i = 0; i < ids.size(); ++i) {
931
- const uint32_t id = ids[i];
1077
+ // take this branch when FA is enabled (the V cache is not transposed)
1078
+ if (!v_trans) {
1079
+ v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
932
1080
 
933
- if (i == id || id == ids.size()) {
934
- continue;
935
- }
1081
+ if (n_stream > 1) {
1082
+ const int64_t kv_size = get_size();
936
1083
 
937
- uint32_t nm = 1;
1084
+ assert(n_embd_gqa == v->ne[0]);
1085
+ assert(kv_size == v->ne[1]);
938
1086
 
939
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
940
- nm++;
1087
+ // merge the buffer across all streams because the idxs are global
1088
+ v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
941
1089
  }
942
1090
 
943
- for (const auto & layer : layers) {
944
- const uint32_t il = layer.il;
945
-
946
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
947
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
948
-
949
- ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
950
- n_embd_k_gqa, nm,
951
- ggml_row_size(layer.k->type, n_embd_k_gqa),
952
- ggml_row_size(layer.k->type, n_embd_k_gqa*i));
953
-
954
- ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
955
- n_embd_k_gqa, nm,
956
- ggml_row_size(layer.k->type, n_embd_k_gqa),
957
- ggml_row_size(layer.k->type, n_embd_k_gqa*id));
958
-
959
- ggml_tensor * view_v_src;
960
- ggml_tensor * view_v_dst;
961
-
962
- if (cparams.flash_attn) {
963
- // NOTE: the V cache is not transposed when using flash attention
964
- view_v_src = ggml_view_2d(ctx, layer.v,
965
- n_embd_v_gqa, nm,
966
- ggml_row_size(layer.v->type, n_embd_v_gqa),
967
- ggml_row_size(layer.v->type, n_embd_v_gqa*i));
968
-
969
- view_v_dst = ggml_view_2d(ctx, layer.v,
970
- n_embd_v_gqa, nm,
971
- ggml_row_size(layer.v->type, n_embd_v_gqa),
972
- ggml_row_size(layer.v->type, n_embd_v_gqa*id));
973
- } else {
974
- view_v_src = ggml_view_2d(ctx, layer.v,
975
- nm, n_embd_v_gqa,
976
- ggml_row_size(layer.v->type, cells.size()),
977
- ggml_row_size(layer.v->type, i));
978
-
979
- view_v_dst = ggml_view_2d(ctx, layer.v,
980
- nm, n_embd_v_gqa,
981
- ggml_row_size(layer.v->type, cells.size()),
982
- ggml_row_size(layer.v->type, id));
983
- }
1091
+ return ggml_set_rows(ctx, v, v_cur, v_idxs);
1092
+ }
984
1093
 
985
- ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
986
- ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
987
- }
1094
+ if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
1095
+ // we can merge dims 0, 1 and 2
1096
+ v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
1097
+ } else {
1098
+ // otherwise -> make a copy to get contiguous data
1099
+ v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
1100
+ }
988
1101
 
989
- i += nm - 1;
1102
+ // [TAG_V_CACHE_VARIABLE]
1103
+ if (n_embd_gqa < v->ne[0]) {
1104
+ v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
990
1105
  }
991
1106
 
992
- //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
993
- #endif
1107
+ // in this branch the v_idxs are constructed in such a way that each row is a single head element
1108
+ ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
994
1109
 
995
- return res;
1110
+ v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
1111
+
1112
+ return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
996
1113
  }
997
1114
 
998
- bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
999
- const uint32_t n_layer = layers.size();
1115
+ ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1116
+ const uint32_t n_tokens = ubatch.n_tokens;
1000
1117
 
1001
- const uint32_t n_kv = cells.used_max_p1();
1002
- const uint32_t n_used = cells.get_used();
1118
+ ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1003
1119
 
1004
- assert(n_used <= n_kv);
1120
+ ggml_set_input(k_idxs);
1005
1121
 
1006
- //const int64_t t_start = ggml_time_us();
1122
+ return k_idxs;
1123
+ }
1007
1124
 
1008
- // number of cells moved
1009
- uint32_t n_moves = 0;
1125
+ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1126
+ const uint32_t n_tokens = ubatch.n_tokens;
1010
1127
 
1011
- // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
1012
- // - source view, destination view, copy operation
1013
- // - x2 for keys and values
1014
- //const uint32_t max_moves = max_nodes()/(6*n_layer);
1015
- // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
1016
- const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1128
+ ggml_tensor * v_idxs;
1017
1129
 
1018
- // determine which KV cells to move where
1019
- //
1020
- // cell i moves to ids[i]
1021
- //
1022
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
1023
- //
1024
- auto & ids = defrag_info.ids;
1130
+ if (!v_trans) {
1131
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1132
+ } else {
1133
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
1134
+ }
1025
1135
 
1026
- ids.clear();
1027
- ids.resize(n_kv, n_kv);
1136
+ ggml_set_input(v_idxs);
1028
1137
 
1029
- for (uint32_t i0 = 0; i0 < n_used; ++i0) {
1030
- if (!cells.is_empty(i0)) {
1031
- ids[i0] = i0;
1138
+ return v_idxs;
1139
+ }
1032
1140
 
1033
- continue;
1034
- }
1141
+ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1142
+ const uint32_t n_tokens = ubatch->n_tokens;
1143
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1035
1144
 
1036
- // found a hole - fill it with data from the end of the cache
1145
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1146
+ int64_t * data = (int64_t *) dst->data;
1037
1147
 
1038
- uint32_t nh = 1;
1148
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1149
+ const int64_t offs = sinfo.strm[s]*get_size();
1039
1150
 
1040
- // determine the size of the hole
1041
- while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
1042
- nh++;
1151
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1152
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1043
1153
  }
1154
+ }
1155
+ }
1044
1156
 
1045
- uint32_t nf = 0;
1046
- uint32_t is = n_kv - 1;
1157
+ void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1158
+ const uint32_t n_tokens = ubatch->n_tokens;
1159
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1047
1160
 
1048
- // starting from the end, find nh non-empty cells
1049
- for (; is > i0; --is) {
1050
- if (cells.is_empty(is) || ids[is] != n_kv) {
1051
- continue;
1052
- }
1161
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1162
+ int64_t * data = (int64_t *) dst->data;
1053
1163
 
1054
- // non-empty cell which is not yet moved
1055
- nf++;
1164
+ if (!v_trans) {
1165
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1166
+ const int64_t offs = sinfo.strm[s]*get_size();
1056
1167
 
1057
- if (nf == nh) {
1058
- break;
1168
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1169
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1059
1170
  }
1060
1171
  }
1172
+ } else {
1173
+ // note: the V cache is transposed when not using flash attention
1174
+ const int64_t kv_size = get_size();
1061
1175
 
1062
- // this can only happen if `n_used` is not accurate, which would be a bug
1063
- GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
1064
-
1065
- nf = 0;
1066
-
1067
- uint32_t i1 = is;
1068
-
1069
- // are we moving a continuous block of memory?
1070
- bool cont = false;
1176
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
1071
1177
 
1072
- // should we stop searching for the next move?
1073
- bool stop = false;
1178
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1179
+ const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
1074
1180
 
1075
- // go back and move the nf cells to the hole
1076
- for (; i1 < n_kv; ++i1) {
1077
- if (cells.is_empty(i1) || ids[i1] != n_kv) {
1078
- if (n_moves == max_moves) {
1079
- stop = true;
1080
- break;
1181
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1182
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1183
+ data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
1081
1184
  }
1082
-
1083
- cont = false;
1084
- continue;
1085
1185
  }
1186
+ }
1187
+ }
1188
+ }
1086
1189
 
1087
- // this cell goes to (i0 + nf)
1088
- ids[i1] = i0 + nf;
1190
+ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
1191
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1089
1192
 
1090
- // move the cell meta data
1091
- cells.mv(i1, i0 + nf);
1193
+ int32_t * data = (int32_t *) dst->data;
1092
1194
 
1093
- head = n_used;
1195
+ for (uint32_t s = 0; s < n_stream; ++s) {
1196
+ const auto & cells = v_cells[s];
1094
1197
 
1095
- if (!cont) {
1096
- n_moves++;
1097
- cont = true;
1098
- }
1198
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1199
+ data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1200
+ }
1201
+ }
1202
+ }
1099
1203
 
1100
- nf++;
1204
+ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1205
+ const uint32_t n_tokens = ubatch->n_tokens;
1101
1206
 
1102
- if (nf == nh) {
1103
- break;
1104
- }
1105
- }
1207
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1208
+ float * data = (float *) dst->data;
1106
1209
 
1107
- if (stop || n_moves == max_moves) {
1108
- break;
1109
- }
1210
+ const int64_t n_kv = dst->ne[0];
1211
+ const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1110
1212
 
1111
- //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
1213
+ GGML_ASSERT(n_tokens%n_stream == 0);
1112
1214
 
1113
- i0 += nh - 1;
1114
- }
1215
+ // n_tps == n_tokens_per_stream
1216
+ const int64_t n_tps = n_tokens/n_stream;
1217
+ const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
1115
1218
 
1116
- if (n_moves == 0) {
1117
- return false;
1118
- }
1219
+ std::fill(data, data + ggml_nelements(dst), -INFINITY);
1119
1220
 
1120
- LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
1221
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
1222
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
1223
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
1224
+ // Causal mask:
1225
+ // xxx-------
1226
+ // xxxx------
1227
+ // xxxxx-----
1228
+ // Non-causal mask:
1229
+ // xxxxx-----
1230
+ // xxxxx-----
1231
+ // xxxxx-----
1232
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
1233
+ // TODO: optimize this section
1234
+ for (uint32_t h = 0; h < 1; ++h) {
1235
+ for (uint32_t s = 0; s < n_stream; ++s) {
1236
+ for (uint32_t ii = 0; ii < n_tps; ++ii) {
1237
+ const uint32_t i = s*n_tps + ii;
1121
1238
 
1122
- LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
1239
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
1123
1240
 
1124
- return true;
1125
- }
1241
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
1126
1242
 
1127
- bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1128
- assert(p0 >= 0 && p1 >= 0);
1243
+ const llama_pos p1 = ubatch->pos[i];
1129
1244
 
1130
- switch (swa_type) {
1131
- case LLAMA_SWA_TYPE_NONE:
1132
- {
1133
- } break;
1134
- case LLAMA_SWA_TYPE_STANDARD:
1135
- {
1136
- if (p1 - p0 >= (int32_t) n_swa) {
1137
- return true;
1138
- }
1139
- } break;
1140
- case LLAMA_SWA_TYPE_CHUNKED:
1141
- {
1142
- const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1245
+ const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1143
1246
 
1144
- if (p0 < pos_chunk_start) {
1145
- return true;
1146
- }
1147
- } break;
1148
- }
1247
+ for (uint32_t j = 0; j < n_kv; ++j) {
1248
+ if (cells.is_empty(j)) {
1249
+ continue;
1250
+ }
1149
1251
 
1150
- return false;
1151
- }
1252
+ // mask the token if not the same sequence
1253
+ if (!cells.seq_has(j, seq_id)) {
1254
+ continue;
1255
+ }
1152
1256
 
1153
- void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1154
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1155
- uint32_t cell_count = 0;
1257
+ const llama_pos p0 = cells.pos_get(j);
1156
1258
 
1157
- // Count the number of cells with the specified seq_id
1158
- // Find all the ranges of cells with this seq id (or all, when -1)
1159
- uint32_t cell_range_begin = cells.size();
1259
+ // mask future tokens
1260
+ if (causal_attn && p0 > p1) {
1261
+ continue;
1262
+ }
1160
1263
 
1161
- for (uint32_t i = 0; i < cells.size(); ++i) {
1162
- if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1163
- ++cell_count;
1164
- if (cell_range_begin == cells.size()) {
1165
- cell_range_begin = i;
1166
- }
1167
- } else {
1168
- if (cell_range_begin != cells.size()) {
1169
- cell_ranges.emplace_back(cell_range_begin, i);
1170
- cell_range_begin = cells.size();
1264
+ // apply SWA if any
1265
+ if (is_masked_swa(p0, p1)) {
1266
+ continue;
1267
+ }
1268
+
1269
+ data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
1270
+ }
1171
1271
  }
1172
1272
  }
1173
1273
  }
1274
+ }
1174
1275
 
1175
- if (cell_range_begin != cells.size()) {
1176
- cell_ranges.emplace_back(cell_range_begin, cells.size());
1177
- }
1178
-
1179
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1180
- uint32_t cell_count_check = 0;
1181
- for (const auto & range : cell_ranges) {
1182
- cell_count_check += range.second - range.first;
1183
- }
1184
- GGML_ASSERT(cell_count == cell_count_check);
1185
-
1186
- io.write(&cell_count, sizeof(cell_count));
1187
-
1188
- state_write_meta(io, cell_ranges, seq_id);
1189
- state_write_data(io, cell_ranges);
1190
- }
1191
-
1192
- void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1193
- uint32_t cell_count;
1194
- io.read_to(&cell_count, sizeof(cell_count));
1195
-
1196
- bool res = true;
1197
- res = res && state_read_meta(io, cell_count, seq_id);
1198
- res = res && state_read_data(io, cell_count);
1199
-
1200
- if (!res) {
1201
- if (seq_id == -1) {
1202
- clear();
1203
- } else {
1204
- seq_rm(seq_id, -1, -1);
1205
- }
1206
- throw std::runtime_error("failed to restore kv cache");
1207
- }
1208
- }
1209
-
1210
- void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1211
- for (const auto & range : cell_ranges) {
1212
- for (uint32_t i = range.first; i < range.second; ++i) {
1213
- std::vector<llama_seq_id> seq_ids;
1214
-
1215
- for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1216
- if (cur == seq_id || seq_id == -1) {
1217
- if (cells.seq_has(i, cur)) {
1218
- seq_ids.push_back(cur);
1219
- }
1220
- }
1221
- }
1222
-
1223
- const llama_pos pos = cells.pos_get(i);
1224
- const uint32_t n_seq_id = seq_ids.size();
1225
-
1226
- io.write(&pos, sizeof(pos));
1227
- io.write(&n_seq_id, sizeof(n_seq_id));
1228
-
1229
- for (const auto & seq_id : seq_ids) {
1230
- io.write(&seq_id, sizeof(seq_id));
1231
- }
1232
- }
1233
- }
1234
- }
1235
-
1236
- void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1237
- const uint32_t v_trans = this->v_trans ? 1 : 0;
1238
- const uint32_t n_layer = layers.size();
1239
-
1240
- io.write(&v_trans, sizeof(v_trans));
1241
- io.write(&n_layer, sizeof(n_layer));
1242
-
1243
- std::vector<uint8_t> tmp_buf;
1244
-
1245
- // Iterate and write all the keys first, each row is a cell
1246
- // Get whole range at a time
1247
- for (const auto & layer : layers) {
1248
- const uint32_t il = layer.il;
1249
-
1250
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1251
-
1252
- // Write key type
1253
- const int32_t k_type_i = (int32_t)layer.k->type;
1254
- io.write(&k_type_i, sizeof(k_type_i));
1255
-
1256
- // Write row size of key
1257
- const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1258
- io.write(&k_size_row, sizeof(k_size_row));
1259
-
1260
- // Read each range of cells of k_size length each into tmp_buf and write out
1261
- for (const auto & range : cell_ranges) {
1262
- const size_t range_size = range.second - range.first;
1263
- const size_t buf_size = range_size * k_size_row;
1264
- io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1265
- }
1266
- }
1267
-
1268
- if (!v_trans) {
1269
- for (const auto & layer : layers) {
1270
- const uint32_t il = layer.il;
1271
-
1272
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1273
-
1274
- // Write value type
1275
- const int32_t v_type_i = (int32_t)layer.v->type;
1276
- io.write(&v_type_i, sizeof(v_type_i));
1277
-
1278
- // Write row size of value
1279
- const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1280
- io.write(&v_size_row, sizeof(v_size_row));
1281
-
1282
- // Read each range of cells of v_size length each into tmp_buf and write out
1283
- for (const auto & range : cell_ranges) {
1284
- const size_t range_size = range.second - range.first;
1285
- const size_t buf_size = range_size * v_size_row;
1286
- io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1287
- }
1288
- }
1289
- } else {
1290
- // When v is transposed, we also need the element size and get the element ranges from each row
1291
- const uint32_t kv_size = cells.size();
1292
-
1293
- for (const auto & layer : layers) {
1294
- const uint32_t il = layer.il;
1295
-
1296
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1297
-
1298
- // Write value type
1299
- const int32_t v_type_i = (int32_t)layer.v->type;
1300
- io.write(&v_type_i, sizeof(v_type_i));
1301
-
1302
- // Write element size
1303
- const uint32_t v_size_el = ggml_type_size(layer.v->type);
1304
- io.write(&v_size_el, sizeof(v_size_el));
1305
-
1306
- // Write GQA embedding size
1307
- io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1308
-
1309
- // For each row, we get the element values of each cell
1310
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1311
- // Read each range of cells of v_size_el length each into tmp_buf and write out
1312
- for (const auto & range : cell_ranges) {
1313
- const size_t range_size = range.second - range.first;
1314
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1315
- const size_t buf_size = range_size * v_size_el;
1316
- io.write_tensor(layer.v, src_offset, buf_size);
1317
- }
1318
- }
1319
- }
1320
- }
1321
- }
1322
-
1323
- bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1324
- if (dest_seq_id != -1) {
1325
- // single sequence
1326
-
1327
- seq_rm(dest_seq_id, -1, -1);
1328
-
1329
- llama_sbatch sbatch;
1330
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1331
-
1332
- batch.n_tokens = cell_count;
1333
-
1334
- for (uint32_t i = 0; i < cell_count; ++i) {
1335
- llama_pos pos;
1336
- uint32_t n_seq_id;
1337
-
1338
- io.read_to(&pos, sizeof(pos));
1339
- io.read_to(&n_seq_id, sizeof(n_seq_id));
1340
-
1341
- if (n_seq_id != 1) {
1342
- LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1343
- return false;
1344
- }
1345
-
1346
- // read the sequence id, but directly discard it - we will use dest_seq_id instead
1347
- {
1348
- llama_seq_id seq_id;
1349
- io.read_to(&seq_id, sizeof(seq_id));
1350
- }
1351
-
1352
- batch.pos[i] = pos;
1353
- batch.n_seq_id[i] = n_seq_id;
1354
- batch.seq_id[i] = &dest_seq_id;
1355
- }
1356
-
1357
- if (!find_slot(batch)) {
1358
- LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1359
- return false;
1360
- }
1361
-
1362
- commit();
1363
-
1364
- // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1365
- // Assume that this is one contiguous block of cells
1366
- GGML_ASSERT(head + cell_count <= cells.size());
1367
- GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
1368
- GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
1369
- GGML_ASSERT(cells.seq_has(head, dest_seq_id));
1370
- GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
1371
- } else {
1372
- // whole KV cache restore
1373
-
1374
- if (cell_count > cells.size()) {
1375
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1376
- return false;
1377
- }
1378
-
1379
- clear();
1380
-
1381
- for (uint32_t i = 0; i < cell_count; ++i) {
1382
- llama_pos pos;
1383
- uint32_t n_seq_id;
1384
-
1385
- io.read_to(&pos, sizeof(pos));
1386
- io.read_to(&n_seq_id, sizeof(n_seq_id));
1387
-
1388
- cells.pos_set(i, pos);
1389
-
1390
- for (uint32_t j = 0; j < n_seq_id; ++j) {
1391
- llama_seq_id seq_id;
1392
- io.read_to(&seq_id, sizeof(seq_id));
1393
-
1394
- if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1395
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
1396
- return false;
1397
- }
1398
-
1399
- cells.seq_add(i, seq_id);
1400
- }
1401
- }
1402
-
1403
- head = 0;
1404
- }
1405
-
1406
- return true;
1407
- }
1408
-
1409
- bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1410
- uint32_t v_trans;
1411
- uint32_t n_layer;
1412
-
1413
- io.read_to(&v_trans, sizeof(v_trans));
1414
- io.read_to(&n_layer, sizeof(n_layer));
1415
-
1416
- if (n_layer != layers.size()) {
1417
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1418
- return false;
1419
- }
1420
- if (cell_count > cells.size()) {
1421
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
1422
- return false;
1423
- }
1424
- if (this->v_trans != (bool) v_trans) {
1425
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1426
- return false;
1427
- }
1428
-
1429
- // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1430
- for (const auto & layer : layers) {
1431
- const uint32_t il = layer.il;
1432
-
1433
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1434
-
1435
- // Read type of key
1436
- int32_t k_type_i_ref;
1437
- io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1438
- const int32_t k_type_i = (int32_t) layer.k->type;
1439
- if (k_type_i != k_type_i_ref) {
1440
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1441
- return false;
1442
- }
1443
-
1444
- // Read row size of key
1445
- uint64_t k_size_row_ref;
1446
- io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1447
- const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1448
- if (k_size_row != k_size_row_ref) {
1449
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1450
- return false;
1451
- }
1452
-
1453
- if (cell_count) {
1454
- // Read and set the keys for the whole cell range
1455
- ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1456
- }
1457
- }
1458
-
1459
- if (!this->v_trans) {
1460
- for (const auto & layer : layers) {
1461
- const uint32_t il = layer.il;
1462
-
1463
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1464
-
1465
- // Read type of value
1466
- int32_t v_type_i_ref;
1467
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1468
- const int32_t v_type_i = (int32_t)layer.v->type;
1469
- if (v_type_i != v_type_i_ref) {
1470
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1471
- return false;
1472
- }
1473
-
1474
- // Read row size of value
1475
- uint64_t v_size_row_ref;
1476
- io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1477
- const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1478
- if (v_size_row != v_size_row_ref) {
1479
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1480
- return false;
1481
- }
1482
-
1483
- if (cell_count) {
1484
- // Read and set the values for the whole cell range
1485
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1486
- }
1487
- }
1488
- } else {
1489
- // For each layer, read the values for each cell (transposed)
1490
- for (const auto & layer : layers) {
1491
- const uint32_t il = layer.il;
1492
-
1493
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1494
-
1495
- // Read type of value
1496
- int32_t v_type_i_ref;
1497
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1498
- const int32_t v_type_i = (int32_t)layer.v->type;
1499
- if (v_type_i != v_type_i_ref) {
1500
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1501
- return false;
1502
- }
1503
-
1504
- // Read element size of value
1505
- uint32_t v_size_el_ref;
1506
- io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1507
- const size_t v_size_el = ggml_type_size(layer.v->type);
1508
- if (v_size_el != v_size_el_ref) {
1509
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1510
- return false;
1511
- }
1512
-
1513
- // Read GQA embedding size
1514
- uint32_t n_embd_v_gqa_ref;
1515
- io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1516
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1517
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1518
- return false;
1519
- }
1520
-
1521
- if (cell_count) {
1522
- // For each row in the transposed matrix, read the values for the whole cell range
1523
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1524
- const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1525
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1526
- }
1527
- }
1528
- }
1529
- }
1530
-
1531
- return true;
1532
- }
1533
-
1534
- //
1535
- // llama_kv_cache_unified_iswa
1536
- //
1537
-
1538
- llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1539
- const llama_model & model,
1540
- ggml_type type_k,
1541
- ggml_type type_v,
1542
- bool v_trans,
1543
- bool offload,
1544
- bool swa_full,
1545
- uint32_t kv_size,
1546
- uint32_t n_seq_max,
1547
- uint32_t n_batch,
1548
- uint32_t n_pad) : hparams(model.hparams) {
1549
- llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
1550
- llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
1551
-
1552
- const uint32_t size_base = kv_size;
1553
-
1554
- uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
1555
-
1556
- // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
1557
- if (swa_full) {
1558
- LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
1559
- __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
1560
-
1561
- size_swa = size_base;
1562
- do_prune = false;
1563
- }
1564
-
1565
- LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
1566
-
1567
- kv_base = std::make_unique<llama_kv_cache_unified>(
1568
- model, std::move(filter_base), type_k, type_v,
1569
- v_trans, offload, size_base, n_seq_max, n_pad,
1570
- 0, LLAMA_SWA_TYPE_NONE);
1571
-
1572
- LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
1573
-
1574
- kv_swa = std::make_unique<llama_kv_cache_unified>(
1575
- model, std::move(filter_swa), type_k, type_v,
1576
- v_trans, offload, size_swa, n_seq_max, n_pad,
1577
- hparams.n_swa, hparams.swa_type);
1578
- }
1579
-
1580
- void llama_kv_cache_unified_iswa::clear() {
1581
- kv_base->clear();
1582
- kv_swa ->clear();
1583
- }
1584
-
1585
- bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1586
- bool res = true;
1587
-
1588
- res = res & kv_base->seq_rm(seq_id, p0, p1);
1589
- res = res & kv_swa ->seq_rm(seq_id, p0, p1);
1590
-
1591
- return res;
1592
- }
1593
-
1594
- void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1595
- kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1596
- kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1597
- }
1598
-
1599
- void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
1600
- kv_base->seq_keep(seq_id);
1601
- kv_swa ->seq_keep(seq_id);
1602
- }
1603
-
1604
- void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
1605
- kv_base->seq_add(seq_id, p0, p1, shift);
1606
- kv_swa ->seq_add(seq_id, p0, p1, shift);
1607
- }
1608
-
1609
- void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1610
- kv_base->seq_div(seq_id, p0, p1, d);
1611
- kv_swa ->seq_div(seq_id, p0, p1, d);
1612
- }
1613
-
1614
- llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
1615
- // the base cache is a superset of the SWA cache, so we can just check the SWA cache
1616
- return kv_swa->seq_pos_min(seq_id);
1617
- }
1618
-
1619
- llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
1620
- return kv_swa->seq_pos_max(seq_id);
1621
- }
1622
-
1623
- void llama_kv_cache_unified_iswa::restore() {
1624
- kv_base->restore();
1625
- kv_swa ->restore();
1626
- }
1627
-
1628
- void llama_kv_cache_unified_iswa::commit() {
1629
- kv_base->commit();
1630
- kv_swa ->commit();
1631
-
1632
- // slide the attention window, forgetting/pruning old tokens that are outside the window
1633
- if (do_prune) {
1634
- for (const auto & [seq_id, entry] : pending.pos) {
1635
- kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
1636
- }
1637
-
1638
- }
1639
-
1640
- pending.clear();
1641
- }
1642
-
1643
- bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
1644
- bool res = true;
1645
-
1646
- res = res & kv_base->update(lctx);
1647
- res = res & kv_swa ->update(lctx);
1648
-
1649
- return res;
1650
- }
1651
-
1652
- void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
1653
- kv_base->defrag_sched(thold);
1654
- kv_swa ->defrag_sched(thold);
1655
- }
1656
-
1657
- void llama_kv_cache_unified_iswa::set_full() {
1658
- kv_base->set_full();
1659
- kv_swa ->set_full();
1660
- }
1661
-
1662
- llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1663
- pending.clear();
1664
-
1665
- if (do_prune) {
1666
- for (int i = 0; i < batch.n_tokens; ++i) {
1667
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1668
- const llama_seq_id seq_id = batch.seq_id[i][s];
1669
- const llama_pos pos = batch.pos[i];
1670
-
1671
- if (pending.pos.find(seq_id) == pending.pos.end()) {
1672
- pending.pos[seq_id].pmin = pos;
1673
- pending.pos[seq_id].pmax = pos;
1674
- } else {
1675
- pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
1676
- pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
1677
- }
1678
- }
1679
- }
1680
- }
1681
-
1682
- return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1683
- }
1684
-
1685
- llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1686
- GGML_UNUSED(embd_pooled);
1687
- return sbatch.split_simple(n_ubatch);
1688
- }
1689
-
1690
- bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
1691
- bool res = true;
1692
-
1693
- res = res & kv_base->find_slot(batch);
1694
- res = res & kv_swa ->find_slot(batch);
1695
-
1696
- return res;
1697
- }
1698
-
1699
- bool llama_kv_cache_unified_iswa::get_can_shift() const {
1700
- return kv_base->get_size() == kv_swa->get_size();
1701
- }
1702
-
1703
- void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1704
- kv_base->state_write(io, seq_id);
1705
- kv_swa ->state_write(io, seq_id);
1706
- }
1707
-
1708
- void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1709
- kv_base->state_read(io, seq_id);
1710
- kv_swa ->state_read(io, seq_id);
1711
- }
1712
-
1713
- llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
1714
- return kv_base.get();
1715
- }
1716
-
1717
- llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
1718
- return kv_swa.get();
1719
- }
1720
-
1721
- //
1722
- // llama_kv_cache_recurrent
1723
- //
1724
-
1725
- llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1726
- const llama_model & model,
1727
- ggml_type type_k,
1728
- ggml_type type_v,
1729
- bool offload,
1730
- uint32_t kv_size,
1731
- uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
1732
- const int32_t n_layer = hparams.n_layer;
1733
-
1734
- LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
1735
- __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1736
-
1737
- head = 0;
1738
- size = kv_size;
1739
- used = 0;
1740
-
1741
- cells.clear();
1742
- cells.resize(kv_size);
1743
-
1744
- // create a context for each buffer type
1745
- std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1746
- auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1747
- auto it = ctx_map.find(buft);
1748
- if (it == ctx_map.end()) {
1749
- ggml_init_params params = {
1750
- /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
1751
- /*.mem_buffer =*/ NULL,
1752
- /*.no_alloc =*/ true,
1753
- };
1754
-
1755
- ggml_context * ctx = ggml_init(params);
1756
- if (!ctx) {
1757
- return nullptr;
1758
- }
1759
-
1760
- ctx_map[buft] = ctx;
1761
- ctxs.emplace_back(ctx);
1762
-
1763
- return ctx;
1764
- }
1765
-
1766
- return it->second;
1767
- };
1768
-
1769
- k_l.reserve(n_layer);
1770
- v_l.reserve(n_layer);
1771
-
1772
- for (int i = 0; i < n_layer; i++) {
1773
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
1774
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
1775
-
1776
- const char * dev_name = "CPU";
1777
-
1778
- ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
1779
-
1780
- if (offload) {
1781
- auto * dev = model.dev_layer(i);
1782
- buft = ggml_backend_dev_buffer_type(dev);
1783
-
1784
- dev_name = ggml_backend_dev_name(dev);
1785
- }
1786
-
1787
- LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
1788
-
1789
- ggml_context * ctx = ctx_for_buft(buft);
1790
- if (!ctx) {
1791
- throw std::runtime_error("failed to create ggml context for kv cache");
1792
- }
1793
-
1794
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
1795
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
1796
- ggml_format_name(k, "cache_k_l%d", i);
1797
- ggml_format_name(v, "cache_v_l%d", i);
1798
- k_l.push_back(k);
1799
- v_l.push_back(v);
1800
- }
1801
-
1802
- // allocate tensors and initialize the buffers to avoid NaNs in the padding
1803
- for (auto it : ctx_map) {
1804
- auto * buft = it.first;
1805
- auto * ctx = it.second;
1806
-
1807
- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1808
- if (!buf) {
1809
- throw std::runtime_error("failed to allocate buffer for kv cache");
1810
- }
1811
- ggml_backend_buffer_clear(buf, 0);
1812
- LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
1813
- bufs.emplace_back(buf);
1814
- }
1815
-
1816
- {
1817
- const size_t memory_size_k = size_k_bytes();
1818
- const size_t memory_size_v = size_v_bytes();
1819
-
1820
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
1821
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
1822
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
1823
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1824
- }
1825
- }
1826
-
1827
- void llama_kv_cache_recurrent::clear() {
1828
- for (int32_t i = 0; i < (int32_t) size; ++i) {
1829
- cells[i].pos = -1;
1830
- cells[i].seq_id.clear();
1831
- cells[i].src = -1;
1832
- cells[i].tail = -1;
1833
- }
1834
- head = 0;
1835
- used = 0;
1836
-
1837
- for (auto & buf : bufs) {
1838
- ggml_backend_buffer_clear(buf.get(), 0);
1839
- }
1840
- }
1841
-
1842
- bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1843
- uint32_t new_head = size;
1844
-
1845
- if (p0 < 0) {
1846
- p0 = 0;
1847
- }
1848
-
1849
- if (p1 < 0) {
1850
- p1 = std::numeric_limits<llama_pos>::max();
1851
- }
1852
-
1853
- // models like Mamba or RWKV can't have a state partially erased
1854
- if (seq_id >= (int64_t) size) {
1855
- // could be fatal
1856
- return false;
1857
- }
1858
- if (0 <= seq_id) {
1859
- int32_t & tail_id = cells[seq_id].tail;
1860
- if (tail_id >= 0) {
1861
- const kv_cell & cell = cells[tail_id];
1862
- // partial intersection is invalid
1863
- if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1864
- return false;
1865
- }
1866
- // invalidate tails which will be cleared
1867
- if (p0 <= cell.pos && cell.pos < p1) {
1868
- tail_id = -1;
1869
- }
1870
- }
1871
- } else {
1872
- // seq_id is negative, then the range should include everything or nothing
1873
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1874
- return false;
1875
- }
1876
- }
1877
-
1878
- for (uint32_t i = 0; i < size; ++i) {
1879
- if (cells[i].pos >= p0 && cells[i].pos < p1) {
1880
- if (seq_id < 0) {
1881
- cells[i].seq_id.clear();
1882
- } else if (cells[i].has_seq_id(seq_id)) {
1883
- cells[i].seq_id.erase(seq_id);
1884
- } else {
1885
- continue;
1886
- }
1887
- if (cells[i].is_empty()) {
1888
- // keep count of the number of used cells
1889
- if (cells[i].pos >= 0) {
1890
- used--;
1891
- }
1892
- cells[i].pos = -1;
1893
- cells[i].src = -1;
1894
- if (new_head == size) {
1895
- new_head = i;
1896
- }
1897
- }
1898
- }
1899
- }
1900
-
1901
- // If we freed up a slot, set head to it so searching can start there.
1902
- if (new_head != size && new_head < head) {
1903
- head = new_head;
1904
- }
1905
-
1906
- return true;
1907
- }
1908
-
1909
- void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1910
- if (seq_id_src == seq_id_dst) {
1911
- return;
1912
- }
1913
-
1914
- if (p0 < 0) {
1915
- p0 = 0;
1916
- }
1917
-
1918
- if (p1 < 0) {
1919
- p1 = std::numeric_limits<llama_pos>::max();
1920
- }
1921
-
1922
- if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
1923
- kv_cell & tail_src = cells[seq_id_src];
1924
- kv_cell & tail_dst = cells[seq_id_dst];
1925
- if (tail_dst.tail >= 0) {
1926
- // clear destination seq_id if it wasn't empty
1927
- kv_cell & cell_dst = cells[tail_dst.tail];
1928
-
1929
- cell_dst.seq_id.erase(seq_id_dst);
1930
- tail_dst.tail = -1;
1931
- if (cell_dst.seq_id.empty()) {
1932
- cell_dst.pos = -1;
1933
- cell_dst.src = -1;
1934
- used -= 1;
1935
- }
1936
- }
1937
- if (tail_src.tail >= 0) {
1938
- kv_cell & cell_src = cells[tail_src.tail];
1939
-
1940
- cell_src.seq_id.insert(seq_id_dst);
1941
- tail_dst.tail = tail_src.tail;
1942
- }
1943
- }
1944
- }
1945
-
1946
- void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
1947
- uint32_t new_head = size;
1948
-
1949
- for (uint32_t i = 0; i < size; ++i) {
1950
- if ((llama_seq_id) i != seq_id) {
1951
- cells[i].tail = -1;
1952
- }
1953
-
1954
- if (!cells[i].has_seq_id(seq_id)) {
1955
- if (cells[i].pos >= 0) {
1956
- used--;
1957
- }
1958
-
1959
- cells[i].pos = -1;
1960
- cells[i].src = -1;
1961
- cells[i].seq_id.clear();
1962
-
1963
- if (new_head == size){
1964
- new_head = i;
1965
- }
1966
- } else {
1967
- cells[i].seq_id.clear();
1968
- cells[i].seq_id.insert(seq_id);
1969
- }
1970
- }
1276
+ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1277
+ const int64_t n_tokens = ubatch->n_tokens;
1971
1278
 
1972
- // If we freed up a slot, set head to it so searching can start there.
1973
- if (new_head != size && new_head < head) {
1974
- head = new_head;
1975
- }
1976
- }
1279
+ GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
1280
+ const auto & cells = v_cells[0];
1977
1281
 
1978
- void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
1979
- if (shift == 0) {
1980
- return;
1981
- }
1282
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1283
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
1982
1284
 
1983
- if (p0 < 0) {
1984
- p0 = 0;
1985
- }
1285
+ int32_t * data = (int32_t *) dst->data;
1986
1286
 
1987
- if (p1 < 0) {
1988
- p1 = std::numeric_limits<llama_pos>::max();
1989
- }
1287
+ const int32_t n_kv = dst->ne[0];
1990
1288
 
1991
- // If there is no range then return early to avoid looping over the
1992
- if (p0 == p1) {
1993
- return;
1994
- }
1289
+ for (int h = 0; h < 1; ++h) {
1290
+ for (int i = 0; i < n_tokens; ++i) {
1291
+ for (int j = 0; j < n_kv; ++j) {
1292
+ // the position when the cells is empty is irrelevant - it will be masked out later in the attention
1293
+ const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
1995
1294
 
1996
- // for Mamba-like or RWKV models, only the pos needs to be shifted
1997
- if (0 <= seq_id && seq_id < (int64_t) size) {
1998
- const int32_t tail_id = cells[seq_id].tail;
1999
- if (tail_id >= 0) {
2000
- kv_cell & cell = cells[tail_id];
2001
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2002
- cell.pos += shift;
1295
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
2003
1296
  }
2004
1297
  }
2005
1298
  }
2006
1299
  }
2007
1300
 
2008
- void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2009
- if (d == 1) {
2010
- return;
2011
- }
2012
-
2013
- if (p0 < 0) {
2014
- p0 = 0;
2015
- }
2016
-
2017
- if (p1 < 0) {
2018
- p1 = std::numeric_limits<llama_pos>::max();
2019
- }
1301
+ size_t llama_kv_cache::total_size() const {
1302
+ size_t size = 0;
2020
1303
 
2021
- // If there is no range then return early to avoid looping over the cache.
2022
- if (p0 == p1) {
2023
- return;
1304
+ for (const auto & buf : bufs) {
1305
+ size += ggml_backend_buffer_get_size(buf.get());
2024
1306
  }
2025
1307
 
2026
- // for Mamba-like or RWKV models, only the pos needs to be changed
2027
- if (0 <= seq_id && seq_id < (int64_t) size) {
2028
- const int32_t tail_id = cells[seq_id].tail;
2029
- if (tail_id >= 0) {
2030
- kv_cell & cell = cells[tail_id];
2031
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2032
- cell.pos /= d;
2033
- }
2034
- }
2035
- }
1308
+ return size;
2036
1309
  }
2037
1310
 
2038
- llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
2039
- llama_pos result = std::numeric_limits<llama_pos>::max();
2040
-
2041
- for (uint32_t i = 0; i < size; ++i) {
2042
- if (cells[i].has_seq_id(seq_id)) {
2043
- result = std::min(result, cells[i].pos);
2044
- }
2045
- }
1311
+ size_t llama_kv_cache::size_k_bytes() const {
1312
+ size_t size_k_bytes = 0;
2046
1313
 
2047
- if (result == std::numeric_limits<llama_pos>::max()) {
2048
- result = -1;
1314
+ for (const auto & layer : layers) {
1315
+ size_k_bytes += ggml_nbytes(layer.k);
2049
1316
  }
2050
1317
 
2051
- return result;
1318
+ return size_k_bytes;
2052
1319
  }
2053
1320
 
2054
- llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
2055
- llama_pos result = -1;
1321
+ size_t llama_kv_cache::size_v_bytes() const {
1322
+ size_t size_v_bytes = 0;
2056
1323
 
2057
- for (uint32_t i = 0; i < size; ++i) {
2058
- if (cells[i].has_seq_id(seq_id)) {
2059
- result = std::max(result, cells[i].pos);
2060
- }
1324
+ for (const auto & layer : layers) {
1325
+ size_v_bytes += ggml_nbytes(layer.v);
2061
1326
  }
2062
1327
 
2063
- return result;
1328
+ return size_v_bytes;
2064
1329
  }
2065
1330
 
2066
- void llama_kv_cache_recurrent::restore() {
2067
- if (pending.ranges.empty()) {
2068
- return;
2069
- }
1331
+ ggml_tensor * llama_kv_cache::build_rope_shift(
1332
+ const llama_cparams & cparams,
1333
+ ggml_context * ctx,
1334
+ ggml_tensor * cur,
1335
+ ggml_tensor * shift,
1336
+ ggml_tensor * factors,
1337
+ float freq_base,
1338
+ float freq_scale) const {
1339
+ const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
2070
1340
 
2071
- seq_rm(-1, -1, -1);
2072
- }
1341
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor;
1342
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast;
1343
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow;
2073
1344
 
2074
- void llama_kv_cache_recurrent::commit() {
2075
- pending.ranges.clear();
2076
- }
1345
+ const auto & n_rot = hparams.n_rot;
1346
+ const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
1347
+ // @ngxson : this is a workaround
1348
+ // for M-RoPE, we want to rotate the whole vector when doing KV shift
1349
+ // a normal RoPE should work, we just need to use the correct ordering
1350
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13870
1351
+ ? LLAMA_ROPE_TYPE_NEOX
1352
+ : hparams.rope_type;
2077
1353
 
2078
- bool llama_kv_cache_recurrent::update(llama_context & ctx) {
2079
- GGML_UNUSED(ctx);
2080
- return false;
2081
- }
1354
+ // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
1355
+ // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
1356
+ const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
1357
+ ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
1358
+ : cparams.yarn_attn_factor;
2082
1359
 
2083
- void llama_kv_cache_recurrent::defrag_sched(float thold) {
2084
- GGML_UNUSED(thold);
2085
- // noop
2086
- }
1360
+ ggml_tensor * tmp;
2087
1361
 
2088
- void llama_kv_cache_recurrent::set_full() {
2089
- n = size;
2090
- head = 0;
2091
- }
1362
+ if (ggml_is_quantized(cur->type)) {
1363
+ // dequantize to f32 -> RoPE -> quantize back
1364
+ tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
2092
1365
 
2093
- llama_sbatch llama_kv_cache_recurrent::sbatch_init(
2094
- const llama_batch & batch,
2095
- bool logits_all) {
2096
- return llama_sbatch(batch, hparams.n_embd, false, logits_all);
2097
- }
1366
+ tmp = ggml_rope_ext(ctx, tmp,
1367
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1368
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
2098
1369
 
2099
- llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2100
- if (embd_pooled) {
2101
- // Pooled embeddings cannot be split across ubatches (yet)
2102
- return sbatch.split_seq(n_ubatch);
1370
+ tmp = ggml_cpy(ctx, tmp, cur);
1371
+ } else {
1372
+ // we rotate only the first n_rot dimensions
1373
+ tmp = ggml_rope_ext_inplace(ctx, cur,
1374
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1375
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
2103
1376
  }
2104
1377
 
2105
- return sbatch.split_equal(n_ubatch);
1378
+ return tmp;
2106
1379
  }
2107
1380
 
2108
- bool llama_kv_cache_recurrent::find_slot(
2109
- const llama_ubatch & ubatch) {
2110
- const uint32_t n_tokens = ubatch.n_tokens;
2111
- const uint32_t n_seqs = ubatch.n_seqs;
2112
-
2113
- const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
2114
-
2115
- // if we have enough unused cells before the current head ->
2116
- // better to start searching from the beginning of the cache, hoping to fill it
2117
- if (head > used + 2*n_tokens) {
2118
- head = 0;
2119
- }
1381
+ class llm_graph_input_k_shift : public llm_graph_input_i {
1382
+ public:
1383
+ llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
1384
+ virtual ~llm_graph_input_k_shift() = default;
2120
1385
 
2121
- // For recurrent state architectures (like Mamba or RWKV),
2122
- // each cache cell can store the state for a whole sequence.
2123
- // A slot should be always be contiguous.
1386
+ void set_input(const llama_ubatch * ubatch) override;
2124
1387
 
2125
- // can only process batches with an equal number of new tokens in each sequence
2126
- GGML_ASSERT(ubatch.equal_seqs);
1388
+ ggml_tensor * k_shift; // I32 [kv_size*n_stream]
2127
1389
 
2128
- int32_t min = size - 1;
2129
- int32_t max = 0;
1390
+ const llama_kv_cache * kv_self;
1391
+ };
2130
1392
 
2131
- // everything should fit if all seq_ids are smaller than the max
2132
- for (uint32_t s = 0; s < n_seqs; ++s) {
2133
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
2134
- for (uint32_t j = 0; j < n_seq_id; ++j) {
2135
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
2136
-
2137
- if (seq_id < 0 || (uint32_t) seq_id >= size) {
2138
- // too big seq_id
2139
- // TODO: would it be possible to resize the cache instead?
2140
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
2141
- return false;
2142
- }
2143
- if (j > 0) {
2144
- kv_cell & seq = cells[seq_id];
2145
- if (seq.tail >= 0) {
2146
- kv_cell & cell = cells[seq.tail];
2147
- // clear cells from seq_ids that become shared
2148
- // (should not normally happen, but let's handle it anyway)
2149
- cell.seq_id.erase(seq_id);
2150
- seq.tail = -1;
2151
- if (cell.seq_id.empty()) {
2152
- cell.pos = -1;
2153
- cell.src = -1;
2154
- used -= 1;
2155
- }
2156
- }
2157
- }
2158
- }
2159
- }
1393
+ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1394
+ GGML_UNUSED(ubatch);
2160
1395
 
2161
- #ifndef NDEBUG
2162
- {
2163
- std::vector<int32_t> tails_verif;
2164
- tails_verif.assign(size, -1);
2165
- for (uint32_t i = 0; i < size; ++i) {
2166
- kv_cell & cell = cells[i];
2167
- for (llama_seq_id seq_id : cell.seq_id) {
2168
- if (tails_verif[seq_id] != -1) {
2169
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
2170
- }
2171
- tails_verif[seq_id] = i;
2172
- }
2173
- }
2174
- for (uint32_t i = 0; i < size; ++i) {
2175
- if (tails_verif[i] != cells[i].tail) {
2176
- LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
2177
- }
2178
- }
1396
+ if (k_shift) {
1397
+ kv_self->set_input_k_shift(k_shift);
2179
1398
  }
2180
- #endif
1399
+ }
2181
1400
 
2182
- // find next empty cell
2183
- uint32_t next_empty_cell = head;
1401
+ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
1402
+ auto * ctx = res->get_ctx();
1403
+ auto * gf = res->get_gf();
2184
1404
 
2185
- for (uint32_t i = 0; i < size; ++i) {
2186
- if (next_empty_cell >= size) { next_empty_cell -= size; }
2187
- kv_cell & cell = cells[next_empty_cell];
2188
- if (cell.is_empty()) { break; }
2189
- next_empty_cell += 1;
2190
- }
1405
+ const auto & n_embd_head_k = hparams.n_embd_head_k;
1406
+ //const auto & n_embd_head_v = hparams.n_embd_head_v;
2191
1407
 
2192
- // find usable cell range
2193
- for (uint32_t s = 0; s < n_seqs; ++s) {
2194
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
2195
- kv_cell & seq_meta = cells[seq_id];
2196
- bool has_cell = false;
2197
- if (seq_meta.tail >= 0) {
2198
- kv_cell & cell = cells[seq_meta.tail];
2199
- GGML_ASSERT(cell.has_seq_id(seq_id));
2200
- // does this seq_id "own" the cell?
2201
- if (cell.seq_id.size() == 1) { has_cell = true; }
2202
- }
2203
- if (!has_cell) {
2204
- kv_cell & empty_cell = cells[next_empty_cell];
2205
- GGML_ASSERT(empty_cell.is_empty());
2206
- // copy old tail into the empty cell
2207
- if (seq_meta.tail >= 0) {
2208
- kv_cell & orig_cell = cells[seq_meta.tail];
2209
- empty_cell.pos = orig_cell.pos;
2210
- empty_cell.src = orig_cell.src;
2211
- orig_cell.seq_id.erase(seq_id);
2212
- empty_cell.seq_id.insert(seq_id); // will be overwritten
2213
- }
2214
- seq_meta.tail = next_empty_cell;
2215
- // find next empty cell
2216
- if (s + 1 < n_seqs) {
2217
- next_empty_cell += 1;
2218
- for (uint32_t i = 0; i < size; ++i) {
2219
- if (next_empty_cell >= size) { next_empty_cell -= size; }
2220
- kv_cell & cell = cells[next_empty_cell];
2221
- if (cell.is_empty()) { break; }
2222
- next_empty_cell += 1;
2223
- }
2224
- }
2225
- }
2226
- if (min > seq_meta.tail) { min = seq_meta.tail; }
2227
- if (max < seq_meta.tail) { max = seq_meta.tail; }
2228
- }
1408
+ auto inp = std::make_unique<llm_graph_input_k_shift>(this);
2229
1409
 
2230
- // gather and re-order
2231
- for (uint32_t s = 0; s < n_seqs; ++s) {
2232
- int32_t dst_id = s + min;
2233
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
2234
- if (dst_id != src_id) {
2235
- kv_cell & dst_cell = cells[dst_id];
2236
- kv_cell & src_cell = cells[src_id];
2237
-
2238
- std::swap(dst_cell.pos, src_cell.pos);
2239
- std::swap(dst_cell.src, src_cell.src);
2240
- std::swap(dst_cell.seq_id, src_cell.seq_id);
2241
-
2242
- // swap tails (assuming they NEVER overlap)
2243
- for (const llama_seq_id seq_id : src_cell.seq_id) {
2244
- cells[seq_id].tail = src_id;
2245
- }
2246
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
2247
- cells[seq_id].tail = dst_id;
2248
- }
2249
- }
2250
- }
1410
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
1411
+ ggml_set_input(inp->k_shift);
2251
1412
 
2252
- // update the pos of the used seqs
2253
- for (uint32_t s = 0; s < n_seqs; ++s) {
2254
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
2255
- int32_t cell_id = s + min;
2256
- kv_cell & cell = cells[cell_id];
1413
+ const auto & cparams = lctx->get_cparams();
2257
1414
 
2258
- if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
2259
- // What should happen when the pos backtracks or skips a value?
2260
- // Clearing the state mid-batch would require special-casing which isn't done.
2261
- LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
2262
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
2263
- }
2264
- cell.pos = last_pos;
2265
- cell.seq_id.clear();
2266
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
2267
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
2268
- cell.seq_id.insert(seq_id);
2269
- cells[seq_id].tail = cell_id;
2270
- }
2271
- }
1415
+ for (const auto & layer : layers) {
1416
+ const uint32_t il = layer.il;
2272
1417
 
2273
- // allow getting the range of used cells, from head to head + n
2274
- head = min;
2275
- n = max - min + 1;
2276
- used = std::count_if(cells.begin(), cells.end(),
2277
- [](const kv_cell & cell){ return !cell.is_empty(); });
1418
+ const int64_t n_head_kv = hparams.n_head_kv(il);
1419
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
2278
1420
 
2279
- // sanity check
2280
- return n >= n_seqs;
2281
- }
1421
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
1422
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
2282
1423
 
2283
- bool llama_kv_cache_recurrent::get_can_shift() const {
2284
- return false;
2285
- }
1424
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
2286
1425
 
2287
- int32_t llama_kv_cache_recurrent::s_copy(int i) const {
2288
- const uint32_t cell_id = i + head;
1426
+ ggml_tensor * k =
1427
+ ggml_view_3d(ctx, layer.k,
1428
+ n_embd_head_k, n_head_kv, get_size()*n_stream,
1429
+ ggml_row_size(layer.k->type, n_embd_head_k),
1430
+ ggml_row_size(layer.k->type, n_embd_k_gqa),
1431
+ 0);
2289
1432
 
2290
- //////////////////////////////////////////////
2291
- // TODO: this should not mutate the KV cache !
2292
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
1433
+ ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
2293
1434
 
2294
- // prevent out-of-bound sources
2295
- if (cell.src < 0 || (uint32_t) cell.src >= size) {
2296
- cell.src = cell_id;
1435
+ ggml_build_forward_expand(gf, cur);
2297
1436
  }
2298
1437
 
2299
- int32_t res = cell.src;
1438
+ res->add_input(std::move(inp));
2300
1439
 
2301
- // TODO: do not mutate the KV cache
2302
- // ensure copy only happens once
2303
- if (cell.src != (int32_t) cell_id) {
2304
- cell.src = cell_id;
2305
- }
1440
+ return gf;
1441
+ }
2306
1442
 
2307
- return res;
1443
+ bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
1444
+ return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
2308
1445
  }
2309
1446
 
2310
- float llama_kv_cache_recurrent::s_mask(int i) const {
2311
- const uint32_t cell_id = i + head;
1447
+ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
1448
+ GGML_UNUSED(flags);
2312
1449
 
2313
- //////////////////////////////////////////////
2314
- // TODO: this should not mutate the KV cache !
2315
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
1450
+ io.write(&n_stream, sizeof(n_stream));
2316
1451
 
2317
- float res = (float) (cell.src >= 0);
1452
+ for (uint32_t s = 0; s < n_stream; ++s) {
1453
+ cell_ranges_t cr { s, {} };
2318
1454
 
2319
- // only clear once
2320
- if (cell.src < 0) {
2321
- cell.src = cell_id;
2322
- }
1455
+ uint32_t cell_count = 0;
2323
1456
 
2324
- return res;
2325
- }
1457
+ const auto & cells = v_cells[s];
2326
1458
 
2327
- uint32_t llama_kv_cache_recurrent::cell_max() const {
2328
- for (uint32_t i = size; i > 0; --i) {
2329
- const kv_cell & cell = cells[i - 1];
1459
+ // Count the number of cells with the specified seq_id
1460
+ // Find all the ranges of cells with this seq id (or all, when -1)
1461
+ uint32_t cell_range_begin = cells.size();
2330
1462
 
2331
- if (cell.pos >= 0 && !cell.is_empty()) {
2332
- return i;
1463
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1464
+ if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1465
+ ++cell_count;
1466
+ if (cell_range_begin == cells.size()) {
1467
+ cell_range_begin = i;
1468
+ }
1469
+ } else {
1470
+ if (cell_range_begin != cells.size()) {
1471
+ cr.data.emplace_back(cell_range_begin, i);
1472
+ cell_range_begin = cells.size();
1473
+ }
1474
+ }
2333
1475
  }
2334
- }
2335
-
2336
- return 0;
2337
- }
2338
1476
 
2339
- size_t llama_kv_cache_recurrent::total_size() const {
2340
- size_t size = 0;
2341
- for (const auto & buf : bufs) {
2342
- size += ggml_backend_buffer_get_size(buf.get());
2343
- }
2344
-
2345
- return size;
2346
- }
2347
-
2348
- size_t llama_kv_cache_recurrent::size_k_bytes() const {
2349
- size_t size_k_bytes = 0;
1477
+ if (cell_range_begin != cells.size()) {
1478
+ cr.data.emplace_back(cell_range_begin, cells.size());
1479
+ }
2350
1480
 
2351
- for (const auto & k : k_l) {
2352
- size_k_bytes += ggml_nbytes(k);
2353
- }
1481
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1482
+ uint32_t cell_count_check = 0;
1483
+ for (const auto & range : cr.data) {
1484
+ cell_count_check += range.second - range.first;
1485
+ }
1486
+ GGML_ASSERT(cell_count == cell_count_check);
2354
1487
 
2355
- return size_k_bytes;
2356
- }
1488
+ io.write(&cell_count, sizeof(cell_count));
2357
1489
 
2358
- size_t llama_kv_cache_recurrent::size_v_bytes() const {
2359
- size_t size_v_bytes = 0;
1490
+ // skip empty streams
1491
+ if (cell_count == 0) {
1492
+ continue;
1493
+ }
2360
1494
 
2361
- for (const auto & v : v_l) {
2362
- size_v_bytes += ggml_nbytes(v);
1495
+ state_write_meta(io, cr, seq_id);
1496
+ state_write_data(io, cr);
2363
1497
  }
2364
-
2365
- return size_v_bytes;
2366
1498
  }
2367
1499
 
2368
- void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
2369
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
2370
- uint32_t cell_count = 0;
1500
+ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1501
+ GGML_UNUSED(flags);
2371
1502
 
2372
- // Count the number of cells with the specified seq_id
2373
- // Find all the ranges of cells with this seq id (or all, when -1)
2374
- uint32_t cell_range_begin = size;
2375
- for (uint32_t i = 0; i < size; ++i) {
2376
- const auto & cell = cells[i];
2377
- if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
2378
- ++cell_count;
2379
- if (cell_range_begin == size) {
2380
- cell_range_begin = i;
2381
- }
2382
- } else {
2383
- if (cell_range_begin != size) {
2384
- cell_ranges.emplace_back(cell_range_begin, i);
2385
- cell_range_begin = size;
2386
- }
2387
- }
2388
- }
2389
- if (cell_range_begin != size) {
2390
- cell_ranges.emplace_back(cell_range_begin, size);
2391
- }
1503
+ GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
2392
1504
 
2393
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
2394
- uint32_t cell_count_check = 0;
2395
- for (const auto & range : cell_ranges) {
2396
- cell_count_check += range.second - range.first;
1505
+ uint32_t n_stream_cur;
1506
+ io.read_to(&n_stream_cur, sizeof(n_stream_cur));
1507
+ if (n_stream_cur != n_stream) {
1508
+ throw std::runtime_error("n_stream mismatch");
2397
1509
  }
2398
- GGML_ASSERT(cell_count == cell_count_check);
2399
1510
 
2400
- io.write(&cell_count, sizeof(cell_count));
1511
+ for (uint32_t s = 0; s < n_stream; ++s) {
1512
+ uint32_t cell_count;
1513
+ io.read_to(&cell_count, sizeof(cell_count));
2401
1514
 
2402
- state_write_meta(io, cell_ranges, seq_id);
2403
- state_write_data(io, cell_ranges);
2404
- }
2405
-
2406
- void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
2407
- uint32_t cell_count;
2408
- io.read_to(&cell_count, sizeof(cell_count));
1515
+ if (cell_count == 0) {
1516
+ continue;
1517
+ }
2409
1518
 
2410
- bool res = true;
1519
+ const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
2411
1520
 
2412
- res = res && state_read_meta(io, cell_count, seq_id);
2413
- res = res && state_read_data(io, cell_count);
1521
+ bool res = true;
1522
+ res = res && state_read_meta(io, strm, cell_count, seq_id);
1523
+ res = res && state_read_data(io, strm, cell_count);
2414
1524
 
2415
- if (!res) {
2416
- if (seq_id == -1) {
2417
- clear();
2418
- } else {
2419
- seq_rm(seq_id, -1, -1);
1525
+ if (!res) {
1526
+ if (seq_id == -1) {
1527
+ clear(true);
1528
+ } else {
1529
+ seq_rm(seq_id, -1, -1);
1530
+ }
1531
+ throw std::runtime_error("failed to restore kv cache");
2420
1532
  }
2421
- throw std::runtime_error("failed to restore kv cache");
2422
1533
  }
2423
1534
  }
2424
1535
 
2425
- void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
2426
- for (const auto & range : cell_ranges) {
1536
+ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
1537
+ const auto & cells = v_cells[cr.strm];
1538
+
1539
+ for (const auto & range : cr.data) {
2427
1540
  for (uint32_t i = range.first; i < range.second; ++i) {
2428
- const auto & cell = cells[i];
2429
- const llama_pos pos = cell.pos;
2430
- const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
1541
+ std::vector<llama_seq_id> seq_ids;
1542
+
1543
+ for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1544
+ if (cur == seq_id || seq_id == -1) {
1545
+ if (cells.seq_has(i, cur)) {
1546
+ seq_ids.push_back(cur);
1547
+ }
1548
+ }
1549
+ }
1550
+
1551
+ const llama_pos pos = cells.pos_get(i);
1552
+ const uint32_t n_seq_id = seq_ids.size();
2431
1553
 
2432
1554
  io.write(&pos, sizeof(pos));
2433
1555
  io.write(&n_seq_id, sizeof(n_seq_id));
2434
1556
 
2435
- if (n_seq_id) {
2436
- for (auto seq_id : cell.seq_id) {
2437
- io.write(&seq_id, sizeof(seq_id));
2438
- }
1557
+ for (const auto & seq_id : seq_ids) {
1558
+ io.write(&seq_id, sizeof(seq_id));
2439
1559
  }
2440
1560
  }
2441
1561
  }
2442
1562
  }
2443
1563
 
2444
- void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
2445
- const uint32_t v_trans = 0;
2446
- const uint32_t n_layer = hparams.n_layer;
1564
+ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
1565
+ const auto & cells = v_cells[cr.strm];
1566
+
1567
+ const uint32_t v_trans = this->v_trans ? 1 : 0;
1568
+ const uint32_t n_layer = layers.size();
2447
1569
 
2448
1570
  io.write(&v_trans, sizeof(v_trans));
2449
1571
  io.write(&n_layer, sizeof(n_layer));
@@ -2452,56 +1574,69 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
2452
1574
 
2453
1575
  // Iterate and write all the keys first, each row is a cell
2454
1576
  // Get whole range at a time
2455
- for (uint32_t il = 0; il < n_layer; ++il) {
2456
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1577
+ for (const auto & layer : layers) {
1578
+ const uint32_t il = layer.il;
1579
+
1580
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1581
+
1582
+ auto * k = layer.k_stream[cr.strm];
2457
1583
 
2458
1584
  // Write key type
2459
- const int32_t k_type_i = (int32_t)k_l[il]->type;
1585
+ const int32_t k_type_i = (int32_t) k->type;
2460
1586
  io.write(&k_type_i, sizeof(k_type_i));
2461
1587
 
2462
1588
  // Write row size of key
2463
- const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1589
+ const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
2464
1590
  io.write(&k_size_row, sizeof(k_size_row));
2465
1591
 
2466
1592
  // Read each range of cells of k_size length each into tmp_buf and write out
2467
- for (const auto & range : cell_ranges) {
1593
+ for (const auto & range : cr.data) {
2468
1594
  const size_t range_size = range.second - range.first;
2469
1595
  const size_t buf_size = range_size * k_size_row;
2470
- io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
1596
+ io.write_tensor(k, range.first * k_size_row, buf_size);
2471
1597
  }
2472
1598
  }
2473
1599
 
2474
1600
  if (!v_trans) {
2475
- for (uint32_t il = 0; il < n_layer; ++il) {
2476
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1601
+ for (const auto & layer : layers) {
1602
+ const uint32_t il = layer.il;
1603
+
1604
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1605
+
1606
+ auto * v = layer.v_stream[cr.strm];
2477
1607
 
2478
1608
  // Write value type
2479
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1609
+ const int32_t v_type_i = (int32_t) v->type;
2480
1610
  io.write(&v_type_i, sizeof(v_type_i));
2481
1611
 
2482
1612
  // Write row size of value
2483
- const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1613
+ const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
2484
1614
  io.write(&v_size_row, sizeof(v_size_row));
2485
1615
 
2486
1616
  // Read each range of cells of v_size length each into tmp_buf and write out
2487
- for (const auto & range : cell_ranges) {
1617
+ for (const auto & range : cr.data) {
2488
1618
  const size_t range_size = range.second - range.first;
2489
1619
  const size_t buf_size = range_size * v_size_row;
2490
- io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
1620
+ io.write_tensor(v, range.first * v_size_row, buf_size);
2491
1621
  }
2492
1622
  }
2493
1623
  } else {
2494
1624
  // When v is transposed, we also need the element size and get the element ranges from each row
2495
- const uint32_t kv_size = size;
2496
- for (uint32_t il = 0; il < n_layer; ++il) {
2497
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1625
+ const uint32_t kv_size = cells.size();
1626
+
1627
+ for (const auto & layer : layers) {
1628
+ const uint32_t il = layer.il;
1629
+
1630
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1631
+
1632
+ auto * v = layer.v_stream[cr.strm];
2498
1633
 
2499
1634
  // Write value type
2500
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1635
+ const int32_t v_type_i = (int32_t) v->type;
2501
1636
  io.write(&v_type_i, sizeof(v_type_i));
2502
1637
 
2503
1638
  // Write element size
2504
- const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
1639
+ const uint32_t v_size_el = ggml_type_size(v->type);
2505
1640
  io.write(&v_size_el, sizeof(v_size_el));
2506
1641
 
2507
1642
  // Write GQA embedding size
@@ -2510,29 +1645,30 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
2510
1645
  // For each row, we get the element values of each cell
2511
1646
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2512
1647
  // Read each range of cells of v_size_el length each into tmp_buf and write out
2513
- for (const auto & range : cell_ranges) {
1648
+ for (const auto & range : cr.data) {
2514
1649
  const size_t range_size = range.second - range.first;
2515
1650
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
2516
1651
  const size_t buf_size = range_size * v_size_el;
2517
- io.write_tensor(v_l[il], src_offset, buf_size);
1652
+ io.write_tensor(v, src_offset, buf_size);
2518
1653
  }
2519
1654
  }
2520
1655
  }
2521
1656
  }
2522
1657
  }
2523
1658
 
2524
- bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1659
+ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
1660
+ auto & cells = v_cells[strm];
1661
+ auto & head = v_heads[strm];
1662
+
2525
1663
  if (dest_seq_id != -1) {
2526
1664
  // single sequence
2527
-
2528
1665
  seq_rm(dest_seq_id, -1, -1);
2529
1666
 
2530
- llama_sbatch sbatch;
2531
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1667
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
1668
+
1669
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
2532
1670
 
2533
- batch.n_tokens = cell_count;
2534
- batch.n_seq_tokens = cell_count;
2535
- batch.n_seqs = 1;
1671
+ ubatch.seq_id_unq[0] = dest_seq_id;
2536
1672
 
2537
1673
  for (uint32_t i = 0; i < cell_count; ++i) {
2538
1674
  llama_pos pos;
@@ -2541,112 +1677,119 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
2541
1677
  io.read_to(&pos, sizeof(pos));
2542
1678
  io.read_to(&n_seq_id, sizeof(n_seq_id));
2543
1679
 
2544
- if (n_seq_id != 0) {
1680
+ if (n_seq_id != 1) {
2545
1681
  LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
2546
1682
  return false;
2547
1683
  }
2548
1684
 
2549
- batch.pos[i] = pos;
1685
+ // read the sequence id, but directly discard it - we will use dest_seq_id instead
1686
+ {
1687
+ llama_seq_id seq_id;
1688
+ io.read_to(&seq_id, sizeof(seq_id));
1689
+ }
1690
+
1691
+ ubatch.pos[i] = pos;
1692
+ ubatch.n_seq_id[i] = n_seq_id;
1693
+ ubatch.seq_id[i] = &dest_seq_id;
2550
1694
  }
2551
- batch.n_seq_id[0] = 1;
2552
- batch.seq_id[0] = &dest_seq_id;
2553
- if (!find_slot(batch)) {
1695
+
1696
+ const auto sinfo = find_slot(ubatch, true);
1697
+ if (sinfo.empty()) {
2554
1698
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2555
1699
  return false;
2556
1700
  }
2557
- commit();
2558
1701
 
2559
- // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1702
+ apply_ubatch(sinfo, ubatch);
1703
+
1704
+ const auto head_cur = sinfo.head();
1705
+
1706
+ // keep the head at the old position because we will read the KV data into it in state_read_data()
1707
+ head = head_cur;
1708
+
1709
+ LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
1710
+
1711
+ // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
2560
1712
  // Assume that this is one contiguous block of cells
2561
- GGML_ASSERT(head + cell_count <= size);
2562
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
2563
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
2564
- GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
2565
- GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
1713
+ GGML_ASSERT(head_cur + cell_count <= cells.size());
1714
+ GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
1715
+ GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
1716
+ GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1717
+ GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
2566
1718
  } else {
2567
1719
  // whole KV cache restore
2568
1720
 
2569
- if (cell_count > size) {
1721
+ if (cell_count > cells.size()) {
2570
1722
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
2571
1723
  return false;
2572
1724
  }
2573
1725
 
2574
- clear();
1726
+ clear(true);
2575
1727
 
2576
1728
  for (uint32_t i = 0; i < cell_count; ++i) {
2577
- kv_cell & cell = cells[i];
2578
-
2579
1729
  llama_pos pos;
2580
1730
  uint32_t n_seq_id;
2581
1731
 
2582
1732
  io.read_to(&pos, sizeof(pos));
2583
1733
  io.read_to(&n_seq_id, sizeof(n_seq_id));
2584
1734
 
2585
- cell.pos = pos;
1735
+ cells.pos_set(i, pos);
2586
1736
 
2587
1737
  for (uint32_t j = 0; j < n_seq_id; ++j) {
2588
1738
  llama_seq_id seq_id;
2589
1739
  io.read_to(&seq_id, sizeof(seq_id));
2590
1740
 
2591
- // TODO: llama_kv_cache_recurrent should have a notion of max sequences
2592
- //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
2593
- if (seq_id < 0) {
2594
- //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
2595
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
1741
+ if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1742
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
2596
1743
  return false;
2597
1744
  }
2598
1745
 
2599
- cell.seq_id.insert(seq_id);
2600
-
2601
- int32_t & tail = cells[seq_id].tail;
2602
- if (tail != -1) {
2603
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
2604
- return false;
2605
- }
2606
- tail = i;
1746
+ cells.seq_add(i, seq_id);
2607
1747
  }
2608
1748
  }
2609
1749
 
2610
1750
  head = 0;
2611
- used = cell_count;
2612
- }
2613
-
2614
- for (uint32_t i = 0; i < cell_count; ++i) {
2615
- uint32_t cell_id = head + i;
2616
- // make sure the recurrent states will keep their restored state
2617
- cells[cell_id].src = cell_id;
2618
1751
  }
2619
1752
 
2620
1753
  return true;
2621
1754
  }
2622
1755
 
2623
- bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1756
+ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
1757
+ auto & cells = v_cells[strm];
1758
+ auto & head = v_heads[strm];
1759
+
2624
1760
  uint32_t v_trans;
2625
1761
  uint32_t n_layer;
1762
+
2626
1763
  io.read_to(&v_trans, sizeof(v_trans));
2627
1764
  io.read_to(&n_layer, sizeof(n_layer));
2628
1765
 
2629
- if (n_layer != hparams.n_layer) {
2630
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1766
+ if (n_layer != layers.size()) {
1767
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
2631
1768
  return false;
2632
1769
  }
2633
- if (cell_count > size) {
2634
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1770
+
1771
+ if (cell_count > cells.size()) {
1772
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
2635
1773
  return false;
2636
1774
  }
2637
- if (false != (bool) v_trans) {
1775
+
1776
+ if (this->v_trans != (bool) v_trans) {
2638
1777
  LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
2639
1778
  return false;
2640
1779
  }
2641
1780
 
2642
1781
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
2643
- for (uint32_t il = 0; il < n_layer; ++il) {
2644
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1782
+ for (const auto & layer : layers) {
1783
+ const uint32_t il = layer.il;
1784
+
1785
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1786
+
1787
+ auto * k = layer.k_stream[strm];
2645
1788
 
2646
1789
  // Read type of key
2647
1790
  int32_t k_type_i_ref;
2648
1791
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
2649
- const int32_t k_type_i = (int32_t) k_l[il]->type;
1792
+ const int32_t k_type_i = (int32_t) k->type;
2650
1793
  if (k_type_i != k_type_i_ref) {
2651
1794
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
2652
1795
  return false;
@@ -2655,7 +1798,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2655
1798
  // Read row size of key
2656
1799
  uint64_t k_size_row_ref;
2657
1800
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
2658
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1801
+ const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
2659
1802
  if (k_size_row != k_size_row_ref) {
2660
1803
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
2661
1804
  return false;
@@ -2663,18 +1806,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2663
1806
 
2664
1807
  if (cell_count) {
2665
1808
  // Read and set the keys for the whole cell range
2666
- ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1809
+ ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
2667
1810
  }
2668
1811
  }
2669
1812
 
2670
- if (!v_trans) {
2671
- for (uint32_t il = 0; il < n_layer; ++il) {
2672
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1813
+ if (!this->v_trans) {
1814
+ for (const auto & layer : layers) {
1815
+ const uint32_t il = layer.il;
1816
+
1817
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1818
+
1819
+ auto * v = layer.v_stream[strm];
2673
1820
 
2674
1821
  // Read type of value
2675
1822
  int32_t v_type_i_ref;
2676
1823
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2677
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1824
+ const int32_t v_type_i = (int32_t) v->type;
2678
1825
  if (v_type_i != v_type_i_ref) {
2679
1826
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2680
1827
  return false;
@@ -2683,7 +1830,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2683
1830
  // Read row size of value
2684
1831
  uint64_t v_size_row_ref;
2685
1832
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
2686
- const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1833
+ const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
2687
1834
  if (v_size_row != v_size_row_ref) {
2688
1835
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
2689
1836
  return false;
@@ -2691,18 +1838,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2691
1838
 
2692
1839
  if (cell_count) {
2693
1840
  // Read and set the values for the whole cell range
2694
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1841
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
2695
1842
  }
2696
1843
  }
2697
1844
  } else {
2698
1845
  // For each layer, read the values for each cell (transposed)
2699
- for (uint32_t il = 0; il < n_layer; ++il) {
2700
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1846
+ for (const auto & layer : layers) {
1847
+ const uint32_t il = layer.il;
1848
+
1849
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1850
+
1851
+ auto * v = layer.v_stream[strm];
2701
1852
 
2702
1853
  // Read type of value
2703
1854
  int32_t v_type_i_ref;
2704
1855
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2705
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1856
+ const int32_t v_type_i = (int32_t) v->type;
2706
1857
  if (v_type_i != v_type_i_ref) {
2707
1858
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2708
1859
  return false;
@@ -2711,7 +1862,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2711
1862
  // Read element size of value
2712
1863
  uint32_t v_size_el_ref;
2713
1864
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
2714
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
1865
+ const size_t v_size_el = ggml_type_size(v->type);
2715
1866
  if (v_size_el != v_size_el_ref) {
2716
1867
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
2717
1868
  return false;
@@ -2728,8 +1879,8 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2728
1879
  if (cell_count) {
2729
1880
  // For each row in the transposed matrix, read the values for the whole cell range
2730
1881
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2731
- const size_t dst_offset = (head + j * size) * v_size_el;
2732
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1882
+ const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1883
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2733
1884
  }
2734
1885
  }
2735
1886
  }
@@ -2737,3 +1888,133 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2737
1888
 
2738
1889
  return true;
2739
1890
  }
1891
+
1892
+ //
1893
+ // llama_kv_cache_context
1894
+ //
1895
+
1896
+ llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
1897
+
1898
+ llama_kv_cache_context::llama_kv_cache_context(
1899
+ llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1900
+ n_kv = kv->get_size();
1901
+
1902
+ const uint32_t n_stream = kv->get_n_stream();
1903
+
1904
+ // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
1905
+ sinfos.resize(1);
1906
+ sinfos[0].s0 = 0;
1907
+ sinfos[0].s1 = n_stream - 1;
1908
+ sinfos[0].idxs.resize(n_stream);
1909
+ for (uint32_t s = 0; s < n_stream; ++s) {
1910
+ sinfos[0].strm.push_back(s);
1911
+ sinfos[0].idxs[s].resize(1, 0);
1912
+ }
1913
+ }
1914
+
1915
+ llama_kv_cache_context::llama_kv_cache_context(
1916
+ llama_kv_cache * kv,
1917
+ llama_context * lctx,
1918
+ bool do_shift,
1919
+ stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
1920
+ if (!do_shift && this->sc_info.empty()) {
1921
+ status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1922
+ }
1923
+ }
1924
+
1925
+ llama_kv_cache_context::llama_kv_cache_context(
1926
+ llama_kv_cache * kv,
1927
+ llama_kv_cache::slot_info_vec_t sinfos,
1928
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
1929
+ }
1930
+
1931
+ llama_kv_cache_context::~llama_kv_cache_context() = default;
1932
+
1933
+ bool llama_kv_cache_context::next() {
1934
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1935
+
1936
+ if (++i_cur >= ubatches.size()) {
1937
+ return false;
1938
+ }
1939
+
1940
+ return true;
1941
+ }
1942
+
1943
+ bool llama_kv_cache_context::apply() {
1944
+ assert(!llama_memory_status_is_fail(status));
1945
+
1946
+ // no ubatches -> this is a KV cache update
1947
+ if (ubatches.empty()) {
1948
+ kv->update(lctx, do_shift, sc_info);
1949
+
1950
+ return true;
1951
+ }
1952
+
1953
+ kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
1954
+ n_kv = kv->get_n_kv(sinfos[i_cur]);
1955
+
1956
+ return true;
1957
+ }
1958
+
1959
+ llama_memory_status llama_kv_cache_context::get_status() const {
1960
+ return status;
1961
+ }
1962
+
1963
+ const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
1964
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1965
+
1966
+ return ubatches[i_cur];
1967
+ }
1968
+
1969
+ uint32_t llama_kv_cache_context::get_n_kv() const {
1970
+ return n_kv;
1971
+ }
1972
+
1973
+ ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
1974
+ return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
1975
+ }
1976
+
1977
+ ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
1978
+ return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
1979
+ }
1980
+
1981
+ ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1982
+ return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1983
+ }
1984
+
1985
+ ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1986
+ return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
1987
+ }
1988
+
1989
+ ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1990
+ return kv->build_input_k_idxs(ctx, ubatch);
1991
+ }
1992
+
1993
+ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1994
+ return kv->build_input_v_idxs(ctx, ubatch);
1995
+ }
1996
+
1997
+ void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
1998
+ kv->set_input_k_shift(dst);
1999
+ }
2000
+
2001
+ void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2002
+ kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
2003
+ }
2004
+
2005
+ void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2006
+ kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
2007
+ }
2008
+
2009
+ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
2010
+ kv->set_input_kq_mask(dst, ubatch, causal_attn);
2011
+ }
2012
+
2013
+ void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2014
+ kv->set_input_pos_bucket(dst, ubatch);
2015
+ }
2016
+
2017
+ uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
2018
+ // the FA kernels require padding to avoid extra runtime boundary checks
2019
+ return cparams.flash_attn ? 256u : 32u;
2020
+ }