whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,14 +1,16 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-io.h"
6
+ #include "llama-memory.h"
5
7
  #include "llama-mmap.h"
6
8
  #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
9
 
10
+ #include <cinttypes>
9
11
  #include <cstring>
12
+ #include <limits>
10
13
  #include <stdexcept>
11
- #include <cinttypes>
12
14
 
13
15
  //
14
16
  // llama_context
@@ -17,7 +19,8 @@
17
19
  llama_context::llama_context(
18
20
  const llama_model & model,
19
21
  llama_context_params params) :
20
- model(model) {
22
+ model(model),
23
+ balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
21
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
22
25
 
23
26
  t_start_us = model.t_start_us;
@@ -26,8 +29,8 @@ llama_context::llama_context(
26
29
  const auto & hparams = model.hparams;
27
30
 
28
31
  cparams.n_seq_max = std::max(1u, params.n_seq_max);
29
- if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30
- throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
31
34
  }
32
35
 
33
36
  cparams.n_threads = params.n_threads;
@@ -122,6 +125,11 @@ llama_context::llama_context(
122
125
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
123
126
  }
124
127
 
128
+ if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
129
+ LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
130
+ __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
131
+ }
132
+
125
133
  if (!hparams.vocab_only) {
126
134
  // GPU backends
127
135
  for (auto * dev : model.devices) {
@@ -259,15 +267,9 @@ llama_context::llama_context(
259
267
 
260
268
  // reserve worst-case graph
261
269
  if (!hparams.vocab_only && memory) {
262
- const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
270
+ const uint32_t n_seqs = cparams.n_seq_max;
263
271
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
264
272
 
265
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
266
-
267
- // restore later
268
- // TODO: something cleaner
269
- const auto n_outputs_save = n_outputs;
270
-
271
273
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
272
274
 
273
275
  int n_splits_pp = -1;
@@ -277,25 +279,18 @@ llama_context::llama_context(
277
279
  int n_nodes_tg = -1;
278
280
 
279
281
  // simulate full KV cache
280
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
282
 
282
- kv_self->set_full();
283
+ const auto mctx = memory->init_full();
284
+ if (!mctx) {
285
+ throw std::runtime_error("failed to initialize KV cache");
286
+ }
283
287
 
284
288
  cross.v_embd.clear();
285
289
 
286
290
  // reserve pp graph first so that buffers are only allocated once
287
291
  {
288
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
289
-
290
- // max number of outputs
291
- n_outputs = ubatch_pp.n_tokens;
292
-
293
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
294
-
295
- auto * gf = graph_init();
296
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
297
-
298
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
292
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293
+ if (!gf) {
299
294
  throw std::runtime_error("failed to allocate compute pp buffers");
300
295
  }
301
296
 
@@ -305,16 +300,8 @@ llama_context::llama_context(
305
300
 
306
301
  // reserve with tg graph to get the number of splits and nodes
307
302
  {
308
- llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
309
-
310
- n_outputs = ubatch_tg.n_tokens;
311
-
312
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
313
-
314
- auto * gf = graph_init();
315
- graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
316
-
317
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
303
+ auto * gf = graph_reserve(1, 1, 1, mctx.get());
304
+ if (!gf) {
318
305
  throw std::runtime_error("failed to allocate compute tg buffers");
319
306
  }
320
307
 
@@ -324,22 +311,12 @@ llama_context::llama_context(
324
311
 
325
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
326
313
  {
327
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
328
-
329
- n_outputs = ubatch_pp.n_tokens;
330
-
331
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
332
-
333
- auto * gf = graph_init();
334
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
335
-
336
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
314
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
+ if (!gf) {
337
316
  throw std::runtime_error("failed to allocate compute pp buffers");
338
317
  }
339
318
  }
340
319
 
341
- n_outputs = n_outputs_save;
342
-
343
320
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
344
321
  ggml_backend_t backend = backend_ptrs[i];
345
322
  ggml_backend_buffer_type_t buft = backend_buft[i];
@@ -443,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
443
420
  return cparams.n_threads_batch;
444
421
  }
445
422
 
446
- llama_kv_cache * llama_context::get_kv_self() {
447
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
448
- return kv_self;
423
+ llama_memory_t llama_context::get_memory() const {
424
+ return memory.get();
449
425
  }
450
426
 
451
- const llama_kv_cache * llama_context::get_kv_self() const {
452
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
453
- return kv_self;
454
- }
455
-
456
- void llama_context::kv_self_update() {
457
- bool need_reserve = false;
427
+ // deprecated
428
+ void llama_context::kv_self_defrag_sched() {
429
+ if (!memory) {
430
+ return;
431
+ }
458
432
 
459
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
433
+ memory_force_optimize = true;
434
+ }
460
435
 
461
- need_reserve = kv_self->update(*this);
436
+ // deprecated
437
+ bool llama_context::kv_self_update(bool optimize) {
438
+ if (!memory) {
439
+ return false;
440
+ }
462
441
 
463
- // reserve a worst case graph if needed
464
- if (need_reserve) {
465
- LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
442
+ {
443
+ // TODO: remove in the future
444
+ optimize |= memory_force_optimize;
445
+ memory_force_optimize = false;
466
446
 
467
- // build worst-case graph
468
- uint32_t n_seqs = 1; // TODO: worst-case number of sequences
469
- uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
447
+ const auto mctx = memory->init_update(this, optimize);
448
+ switch (mctx->get_status()) {
449
+ case LLAMA_MEMORY_STATUS_SUCCESS:
450
+ {
451
+ // noop
452
+ } break;
453
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
454
+ {
455
+ // no updates need to be performed
456
+ return false;
457
+ }
458
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
459
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
460
+ {
461
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
462
+ return false;
463
+ }
464
+ }
470
465
 
471
- // simulate full KV cache
472
- kv_self->set_full();
466
+ if (!mctx->apply()) {
467
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
+ }
469
+ }
473
470
 
474
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
475
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
471
+ // if the memory module did any computation, we have to reserve a new worst-case graph
472
+ {
473
+ const auto mctx = memory->init_full();
474
+ if (!mctx) {
475
+ throw std::runtime_error("failed to initialize memory context");
476
+ }
476
477
 
477
- auto * gf = graph_init();
478
- graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
478
+ const uint32_t n_seqs = cparams.n_seq_max;
479
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
479
480
 
480
- // initialize scheduler with the worst-case graph
481
- ggml_backend_sched_reset(sched.get());
482
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
483
- LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
481
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
482
+ if (!gf) {
483
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484
484
  }
485
485
  }
486
+
487
+ return true;
486
488
  }
487
489
 
488
490
  enum llama_pooling_type llama_context::pooling_type() const {
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
494
496
  }
495
497
 
496
498
  float * llama_context::get_logits_ith(int32_t i) {
497
- int32_t j = -1;
499
+ int64_t j = -1;
498
500
 
499
501
  try {
500
502
  if (logits == nullptr) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
517
519
  }
518
520
  if (j >= n_outputs) {
519
521
  // This should not happen
520
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
522
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
521
523
  }
522
524
 
523
525
  return logits + j*model.vocab.n_tokens();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
536
538
  }
537
539
 
538
540
  float * llama_context::get_embeddings_ith(int32_t i) {
539
- int32_t j = -1;
541
+ int64_t j = -1;
540
542
 
541
543
  try {
542
544
  if (embd == nullptr) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
559
561
  }
560
562
  if (j >= n_outputs) {
561
563
  // This should not happen
562
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
564
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
563
565
  }
564
566
 
565
567
  return embd + j*model.hparams.n_embd;
@@ -676,69 +678,95 @@ bool llama_context::apply_adapter_cvec(
676
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
677
679
  }
678
680
 
679
- int llama_context::encode(llama_batch & inp_batch) {
680
- if (inp_batch.n_tokens == 0) {
681
- LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
682
- return -1;
681
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
+ if (mctx && !mctx->apply()) {
683
+ LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
+ ret = GGML_STATUS_FAILED;
685
+ return nullptr;
683
686
  }
684
687
 
685
- // temporary allocate memory for the input batch if needed
686
- // note: during encode, we always pass the full sequence starting from pos = 0
687
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
688
+ auto * gf = graph_init();
689
+ if (!gf) {
690
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
+ ret = GGML_STATUS_FAILED;
692
+ return nullptr;
693
+ }
688
694
 
689
- const llama_batch & batch = batch_allocr.batch;
690
- const int32_t n_tokens = batch.n_tokens;
695
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
+ if (!res) {
697
+ LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
+ ret = GGML_STATUS_FAILED;
699
+ return nullptr;
700
+ }
691
701
 
692
- const auto & hparams = model.hparams;
702
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
693
703
 
694
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
704
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
+ ret = GGML_STATUS_ALLOC_FAILED;
707
+ return nullptr;
708
+ }
695
709
 
696
- // TODO: move the validation to the llama_batch_allocr
697
- if (batch.token) {
698
- for (int32_t i = 0; i < n_tokens; ++i) {
699
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
700
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
701
- return -1;
702
- }
710
+ res->set_inputs(&ubatch);
703
711
 
704
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
706
- throw -1;
707
- }
708
- }
712
+ const auto status = graph_compute(gf, ubatch.n_tokens > 1);
713
+ if (status != GGML_STATUS_SUCCESS) {
714
+ LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
+ ret = status;
716
+ return nullptr;
709
717
  }
710
718
 
719
+ ret = GGML_STATUS_SUCCESS;
720
+
721
+ return res;
722
+ }
723
+
724
+ int llama_context::encode(const llama_batch & batch_inp) {
725
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
726
+
727
+ if (batch_inp.n_tokens == 0) {
728
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
729
+ return -1;
730
+ }
731
+
732
+ const auto & hparams = model.hparams;
733
+
734
+ const int64_t n_embd = hparams.n_embd;
735
+
736
+ // note: during encode, we always pass the full sequence starting from pos = 0
737
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
738
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
+ return -1;
740
+ }
741
+
742
+ const uint32_t n_tokens = balloc->get_n_tokens();
743
+
744
+ const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745
+
711
746
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
712
- GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
747
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
713
748
 
714
749
  if (t_compute_start_us == 0) {
715
750
  t_compute_start_us = ggml_time_us();
716
751
  }
717
752
 
753
+ // TODO: this clear of the buffer can easily be forgotten - need something better
718
754
  embd_seq.clear();
719
755
 
720
756
  n_queued_tokens += n_tokens;
721
757
 
722
- const int64_t n_embd = hparams.n_embd;
723
-
724
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
725
-
726
- const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
727
-
728
758
  // reserve output buffer
729
759
  if (output_reserve(n_tokens) < n_tokens) {
730
760
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
731
761
  return -2;
732
762
  };
733
763
 
734
- for (int32_t i = 0; i < n_tokens; ++i) {
764
+ for (uint32_t i = 0; i < n_tokens; ++i) {
735
765
  output_ids[i] = i;
736
766
  }
737
767
 
738
768
  n_outputs = n_tokens;
739
769
 
740
- //batch_manager->prepare(ubatch);
741
-
742
770
  ggml_backend_sched_reset(sched.get());
743
771
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
744
772
 
@@ -749,26 +777,18 @@ int llama_context::encode(llama_batch & inp_batch) {
749
777
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
750
778
  cparams.causal_attn = false;
751
779
 
752
- auto * gf = graph_init();
753
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
754
-
755
- ggml_backend_sched_alloc_graph(sched.get(), gf);
756
-
757
- res->set_inputs(&ubatch);
780
+ ggml_status status;
781
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
758
782
 
759
783
  cparams.causal_attn = causal_attn_org;
760
784
 
761
- const auto compute_status = graph_compute(gf, n_tokens > 1);
762
- switch (compute_status) {
763
- case GGML_STATUS_SUCCESS:
764
- break;
765
- case GGML_STATUS_ABORTED:
766
- return 2;
767
- case GGML_STATUS_ALLOC_FAILED:
768
- return -2;
769
- case GGML_STATUS_FAILED:
770
- default:
771
- return -3;
785
+ if (!res) {
786
+ switch (status) {
787
+ case GGML_STATUS_ABORTED: return 2;
788
+ case GGML_STATUS_ALLOC_FAILED: return -2;
789
+ case GGML_STATUS_FAILED: return -3;
790
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
791
+ }
772
792
  }
773
793
 
774
794
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -793,31 +813,28 @@ int llama_context::encode(llama_batch & inp_batch) {
793
813
  {
794
814
  // extract sequence embeddings
795
815
  auto & embd_seq_out = embd_seq;
796
- embd_seq_out.clear();
797
816
 
798
- GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
817
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
818
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
819
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
799
820
 
800
- for (int32_t i = 0; i < n_tokens; i++) {
801
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
802
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
803
- continue;
804
- }
805
821
  embd_seq_out[seq_id].resize(n_embd);
806
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
822
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
807
823
  }
808
824
  } break;
809
825
  case LLAMA_POOLING_TYPE_RANK:
810
826
  {
811
- // extract the rerank score - a single float per sequence
827
+ // extract the rerank score - n_cls_out floats per sequence
812
828
  auto & embd_seq_out = embd_seq;
813
829
 
814
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
815
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
816
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
817
- continue;
818
- }
819
- embd_seq_out[seq_id].resize(1);
820
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
830
+ const uint32_t n_cls_out = hparams.n_cls_out;
831
+
832
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
833
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
834
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
835
+
836
+ embd_seq_out[seq_id].resize(n_cls_out);
837
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
821
838
  }
822
839
  } break;
823
840
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -842,12 +859,16 @@ int llama_context::encode(llama_batch & inp_batch) {
842
859
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
843
860
  memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
844
861
 
862
+ const auto & batch = balloc->get_batch();
863
+
845
864
  // remember the sequence ids used during the encoding - needed for cross attention later
846
865
  cross.seq_ids_enc.resize(n_tokens);
847
- for (int32_t i = 0; i < n_tokens; i++) {
866
+ for (uint32_t i = 0; i < n_tokens; i++) {
848
867
  cross.seq_ids_enc[i].clear();
849
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
850
- llama_seq_id seq_id = ubatch.seq_id[i][s];
868
+
869
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
870
+ const llama_seq_id seq_id = batch.seq_id[i][s];
871
+
851
872
  cross.seq_ids_enc[i].insert(seq_id);
852
873
  }
853
874
  }
@@ -856,55 +877,42 @@ int llama_context::encode(llama_batch & inp_batch) {
856
877
  return 0;
857
878
  }
858
879
 
859
- int llama_context::decode(llama_batch & inp_batch) {
880
+ int llama_context::decode(const llama_batch & batch_inp) {
881
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
882
+
860
883
  if (!memory) {
861
884
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
862
- return encode(inp_batch);
885
+ return encode(batch_inp);
863
886
  }
864
887
 
865
- if (inp_batch.n_tokens == 0) {
888
+ if (batch_inp.n_tokens == 0) {
866
889
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
867
890
  return -1;
868
891
  }
869
892
 
870
- if (!inp_batch.pos) {
871
- if (inp_batch.seq_id) {
872
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
873
- return -1;
874
- }
875
- }
876
-
877
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
878
-
879
- // temporary allocate memory for the input batch if needed
880
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
881
-
882
- const llama_batch & batch = batch_allocr.batch;
883
-
884
893
  const auto & vocab = model.vocab;
885
894
  const auto & hparams = model.hparams;
886
895
 
887
896
  const int32_t n_vocab = vocab.n_tokens();
897
+ const int64_t n_embd = hparams.n_embd;
888
898
 
889
- const int64_t n_tokens_all = batch.n_tokens;
890
- const int64_t n_embd = hparams.n_embd;
899
+ // when computing embeddings, all tokens are output
900
+ const bool output_all = cparams.embeddings;
891
901
 
892
- llama_kv_cache_guard kv_guard(kv_self);
902
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
903
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
+ return -1;
905
+ }
893
906
 
894
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
907
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
908
+ const uint32_t n_outputs_all = balloc->get_n_outputs();
895
909
 
896
- // TODO: move the validation to the llama_batch_allocr
897
- if (batch.token) {
898
- for (int64_t i = 0; i < n_tokens_all; ++i) {
899
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
900
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
901
- return -1;
902
- }
903
-
904
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905
- LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
906
- return -1;
907
- }
910
+ if (output_all) {
911
+ // require that all tokens are output
912
+ if (n_outputs_all != n_tokens_all) {
913
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
914
+ __func__, n_outputs_all, n_tokens_all);
915
+ return -1;
908
916
  }
909
917
  }
910
918
 
@@ -917,49 +925,77 @@ int llama_context::decode(llama_batch & inp_batch) {
917
925
  }
918
926
  n_queued_tokens += n_tokens_all;
919
927
 
920
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
921
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
922
-
928
+ // TODO: this clear of the buffer can easily be forgotten - need something better
923
929
  embd_seq.clear();
924
930
 
925
- int64_t n_outputs_all = 0;
931
+ bool did_optimize = false;
932
+
933
+ // handle any pending defrags/shifts
934
+ kv_self_update(false);
926
935
 
927
- // count outputs
928
- if (batch.logits && !embd_pooled) {
929
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
930
- n_outputs_all += batch.logits[i] != 0;
936
+ llama_memory_context_ptr mctx;
937
+
938
+ while (true) {
939
+ mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940
+ if (!mctx) {
941
+ return -2;
931
942
  }
932
- } else if (embd_pooled) {
933
- n_outputs_all = n_tokens_all;
934
- } else {
935
- // keep last output only
936
- n_outputs_all = 1;
937
- }
938
943
 
939
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
944
+ switch (mctx->get_status()) {
945
+ case LLAMA_MEMORY_STATUS_SUCCESS:
946
+ {
947
+ } break;
948
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
949
+ {
950
+ LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
951
+
952
+ return -2;
953
+ }
954
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
955
+ {
956
+ if (!did_optimize) {
957
+ did_optimize = true;
958
+
959
+ if (kv_self_update(true)) {
960
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
961
+
962
+ continue;
963
+ }
964
+ }
965
+
966
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
967
+
968
+ return 1;
969
+ }
970
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
971
+ {
972
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
973
+
974
+ return -2;
975
+ }
976
+ }
977
+
978
+ break;
979
+ }
940
980
 
941
981
  // reserve output buffer
942
982
  if (output_reserve(n_outputs_all) < n_outputs_all) {
943
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
983
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
944
984
  return -2;
945
985
  };
946
986
 
947
- // handle any pending defrags/shifts
948
- kv_self_update();
949
-
950
987
  int64_t n_outputs_prev = 0;
951
988
 
952
- while (sbatch.n_tokens > 0) {
953
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
989
+ do {
990
+ const auto & ubatch = mctx->get_ubatch();
954
991
 
955
- // count the outputs in this u_batch
992
+ // count the outputs in this ubatch
956
993
  {
957
994
  int32_t n_outputs_new = 0;
958
995
 
959
996
  if (n_outputs_all == n_tokens_all) {
960
997
  n_outputs_new = ubatch.n_tokens;
961
998
  } else {
962
- GGML_ASSERT(ubatch.output);
963
999
  for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
964
1000
  n_outputs_new += (int32_t) (ubatch.output[i] != 0);
965
1001
  }
@@ -969,33 +1005,40 @@ int llama_context::decode(llama_batch & inp_batch) {
969
1005
  n_outputs = n_outputs_new;
970
1006
  }
971
1007
 
972
- // find KV slot
973
- if (!kv_self->find_slot(ubatch)) {
974
- return 1;
975
- }
976
-
977
1008
  ggml_backend_sched_reset(sched.get());
978
1009
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
979
1010
 
980
- auto * gf = graph_init();
981
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1011
+ ggml_status status;
1012
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
982
1013
 
983
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1014
+ if (!res) {
1015
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1016
+ llama_pos pos_min[LLAMA_MAX_SEQ];
1017
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1018
+ pos_min[s] = std::numeric_limits<llama_pos>::max();
1019
+ }
984
1020
 
985
- ggml_backend_sched_alloc_graph(sched.get(), gf);
1021
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1022
+ const auto & seq_id = ubatch.seq_id[i][0];
986
1023
 
987
- res->set_inputs(&ubatch);
1024
+ pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1025
+ }
988
1026
 
989
- const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
990
- if (compute_status != GGML_STATUS_SUCCESS) {
991
- switch (compute_status) {
992
- case GGML_STATUS_ABORTED:
993
- return 2;
994
- case GGML_STATUS_ALLOC_FAILED:
995
- return -2;
996
- case GGML_STATUS_FAILED:
997
- default:
998
- return -3;
1027
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1028
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1029
+ continue;
1030
+ }
1031
+
1032
+ LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1033
+
1034
+ memory->seq_rm(s, pos_min[s], -1);
1035
+ }
1036
+
1037
+ switch (status) {
1038
+ case GGML_STATUS_ABORTED: return 2;
1039
+ case GGML_STATUS_ALLOC_FAILED: return -2;
1040
+ case GGML_STATUS_FAILED: return -3;
1041
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
999
1042
  }
1000
1043
  }
1001
1044
 
@@ -1004,7 +1047,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1004
1047
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1005
1048
  //}
1006
1049
 
1007
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1050
+ auto * t_logits = res->get_logits();
1008
1051
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1009
1052
 
1010
1053
  if (t_embd && res->get_embd_pooled()) {
@@ -1051,27 +1094,27 @@ int llama_context::decode(llama_batch & inp_batch) {
1051
1094
  // extract sequence embeddings (cleared before processing each batch)
1052
1095
  auto & embd_seq_out = embd_seq;
1053
1096
 
1054
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1055
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1056
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1057
- continue;
1058
- }
1097
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1098
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1099
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1100
+
1059
1101
  embd_seq_out[seq_id].resize(n_embd);
1060
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1102
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
1061
1103
  }
1062
1104
  } break;
1063
1105
  case LLAMA_POOLING_TYPE_RANK:
1064
1106
  {
1065
- // extract the rerank score - a single float per sequence
1107
+ // extract the rerank score - n_cls_out floats per sequence
1066
1108
  auto & embd_seq_out = embd_seq;
1067
1109
 
1068
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1069
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1070
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1071
- continue;
1072
- }
1073
- embd_seq_out[seq_id].resize(1);
1074
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
1110
+ const uint32_t n_cls_out = hparams.n_cls_out;
1111
+
1112
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1113
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1114
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1115
+
1116
+ embd_seq_out[seq_id].resize(n_cls_out);
1117
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
1075
1118
  }
1076
1119
  } break;
1077
1120
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1082,23 +1125,20 @@ int llama_context::decode(llama_batch & inp_batch) {
1082
1125
  }
1083
1126
 
1084
1127
  n_outputs_prev += n_outputs;
1085
- }
1086
-
1087
- // finalize the batch processing
1088
- kv_guard.commit();
1128
+ } while (mctx->next());
1089
1129
 
1090
1130
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1091
1131
  n_outputs = n_outputs_all;
1092
1132
 
1093
1133
  // set output mappings
1094
- {
1134
+ if (n_outputs > 0) {
1095
1135
  bool sorted_output = true;
1096
1136
 
1097
- auto & out_ids = sbatch.out_ids;
1137
+ auto & out_ids = balloc->get_out_ids();
1098
1138
 
1099
- GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1139
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1100
1140
 
1101
- for (int64_t i = 0; i < n_outputs_all; ++i) {
1141
+ for (int64_t i = 0; i < n_outputs; ++i) {
1102
1142
  int64_t out_id = out_ids[i];
1103
1143
  output_ids[out_id] = i;
1104
1144
  if (out_id != i) {
@@ -1110,20 +1150,22 @@ int llama_context::decode(llama_batch & inp_batch) {
1110
1150
  // note: this is mostly relevant for recurrent models atm
1111
1151
  if (!sorted_output) {
1112
1152
  const uint32_t n_vocab = model.vocab.n_tokens();
1113
- const uint32_t n_embd = model.hparams.n_embd;
1153
+ const uint64_t n_embd = model.hparams.n_embd;
1114
1154
 
1115
1155
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1116
1156
 
1117
1157
  // TODO: is there something more efficient which also minimizes swaps?
1118
1158
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1119
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1120
- int32_t j_min = i;
1121
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1159
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1160
+ uint32_t j_min = i;
1161
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
1122
1162
  if (out_ids[j] < out_ids[j_min]) {
1123
1163
  j_min = j;
1124
1164
  }
1125
1165
  }
1126
- if (j_min == i) { continue; }
1166
+ if (j_min == i) {
1167
+ continue;
1168
+ }
1127
1169
  std::swap(out_ids[i], out_ids[j_min]);
1128
1170
  if (logits_size > 0) {
1129
1171
  for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1136,8 +1178,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1136
1178
  }
1137
1179
  }
1138
1180
  }
1181
+
1139
1182
  std::fill(output_ids.begin(), output_ids.end(), -1);
1140
- for (int32_t i = 0; i < n_outputs; ++i) {
1183
+
1184
+ for (uint32_t i = 0; i < n_outputs; ++i) {
1141
1185
  output_ids[out_ids[i]] = i;
1142
1186
  }
1143
1187
  }
@@ -1146,11 +1190,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1146
1190
  // wait for the computation to finish (automatically done when obtaining the model output)
1147
1191
  //synchronize();
1148
1192
 
1149
- // decide if we need to defrag the kv cache
1150
- if (cparams.defrag_thold > 0.0f) {
1151
- kv_self->defrag_sched(cparams.defrag_thold);
1152
- }
1153
-
1154
1193
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1155
1194
  // overlap with device computation.
1156
1195
  ggml_backend_sched_reset(sched.get());
@@ -1162,7 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1162
1201
  // output
1163
1202
  //
1164
1203
 
1165
- int32_t llama_context::output_reserve(int32_t n_outputs) {
1204
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1166
1205
  const auto & hparams = model.hparams;
1167
1206
  const auto & vocab = model.vocab;
1168
1207
 
@@ -1172,9 +1211,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1172
1211
  const auto n_vocab = vocab.n_tokens();
1173
1212
  const auto n_embd = hparams.n_embd;
1174
1213
 
1175
- // TODO: use a per-batch flag for logits presence instead
1176
- bool has_logits = !cparams.embeddings;
1177
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1214
+ bool has_logits = true;
1215
+ bool has_embd = cparams.embeddings;
1178
1216
 
1179
1217
  // TODO: hacky enc-dec support
1180
1218
  if (model.arch == LLM_ARCH_T5) {
@@ -1228,8 +1266,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1228
1266
  // set all ids as invalid (negative)
1229
1267
  std::fill(output_ids.begin(), output_ids.end(), -1);
1230
1268
 
1231
- this->n_outputs = 0;
1232
- this->n_outputs_max = n_outputs_max;
1269
+ this->n_outputs = 0;
1233
1270
 
1234
1271
  return n_outputs_max;
1235
1272
  }
@@ -1254,11 +1291,52 @@ ggml_cgraph * llama_context::graph_init() {
1254
1291
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1255
1292
  }
1256
1293
 
1294
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1295
+ LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1296
+
1297
+ if (n_tokens % n_seqs != 0) {
1298
+ n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1299
+ n_outputs = std::min(n_outputs, n_tokens);
1300
+
1301
+ LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
+ }
1303
+
1304
+ // store the n_outputs as it is, and restore it afterwards
1305
+ // TODO: not sure if needed, might simplify in the future by removing this
1306
+ const auto save_n_outputs = this->n_outputs;
1307
+
1308
+ this->n_outputs = n_outputs;
1309
+
1310
+ llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311
+ llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1312
+
1313
+ auto * gf = graph_init();
1314
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1315
+
1316
+ this->n_outputs = save_n_outputs;
1317
+
1318
+ if (!res) {
1319
+ LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320
+ return nullptr;
1321
+ }
1322
+
1323
+ ggml_backend_sched_reset(sched.get());
1324
+
1325
+ // initialize scheduler with the specified graph
1326
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1327
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1328
+ return nullptr;
1329
+ }
1330
+
1331
+ return gf;
1332
+ }
1333
+
1257
1334
  llm_graph_result_ptr llama_context::graph_build(
1258
- ggml_context * ctx,
1259
- ggml_cgraph * gf,
1260
- const llama_ubatch & ubatch,
1261
- llm_graph_type gtype) {
1335
+ ggml_context * ctx,
1336
+ ggml_cgraph * gf,
1337
+ const llama_ubatch & ubatch,
1338
+ llm_graph_type gtype,
1339
+ const llama_memory_context_i * mctx) {
1262
1340
  return model.build_graph(
1263
1341
  {
1264
1342
  /*.ctx =*/ ctx,
@@ -1270,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
1270
1348
  /*.backend_cpu =*/ backend_cpu,
1271
1349
  /*.cvec =*/ &cvec,
1272
1350
  /*.loras =*/ &loras,
1273
- /*.memory =*/ memory.get(),
1351
+ /*.mctx =*/ mctx,
1274
1352
  /*.cross =*/ &cross,
1275
1353
  /*.n_outputs =*/ n_outputs,
1276
1354
  /*.cb =*/ graph_get_cb(),
@@ -1679,14 +1757,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1679
1757
 
1680
1758
  std::vector<int32_t> w_output_pos;
1681
1759
 
1682
- GGML_ASSERT(n_outputs <= n_outputs_max);
1683
-
1684
1760
  w_output_pos.resize(n_outputs);
1685
1761
 
1686
1762
  // build a more compact representation of the output ids
1687
1763
  for (size_t i = 0; i < n_batch(); ++i) {
1688
1764
  // map an output id to a position in the batch
1689
- int32_t pos = output_ids[i];
1765
+ int64_t pos = output_ids[i];
1690
1766
  if (pos >= 0) {
1691
1767
  GGML_ASSERT(pos < n_outputs);
1692
1768
  w_output_pos[pos] = i;
@@ -1726,11 +1802,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1726
1802
  }
1727
1803
  }
1728
1804
 
1729
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1730
-
1731
- if (kv_self != nullptr) {
1805
+ if (memory != nullptr) {
1732
1806
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1733
- kv_self->state_write(io);
1807
+ memory->state_write(io);
1734
1808
  }
1735
1809
 
1736
1810
  return io.n_bytes();
@@ -1817,9 +1891,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1817
1891
  if (memory) {
1818
1892
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1819
1893
 
1820
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1821
-
1822
- kv_self->state_read(io);
1894
+ memory->state_read(io);
1823
1895
  }
1824
1896
 
1825
1897
  return io.n_bytes();
@@ -1829,9 +1901,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1829
1901
  GGML_UNUSED(seq_id);
1830
1902
 
1831
1903
  if (memory) {
1832
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1833
-
1834
- kv_self->state_write(io, seq_id);
1904
+ memory->state_write(io, seq_id);
1835
1905
  }
1836
1906
 
1837
1907
  return io.n_bytes();
@@ -1841,9 +1911,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
1841
1911
  GGML_UNUSED(seq_id);
1842
1912
 
1843
1913
  if (memory) {
1844
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1845
-
1846
- kv_self->state_read(io, seq_id);
1914
+ memory->state_read(io, seq_id);
1847
1915
  }
1848
1916
 
1849
1917
  return io.n_bytes();
@@ -1948,10 +2016,7 @@ void llama_context::opt_epoch_iter(
1948
2016
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1949
2017
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1950
2018
 
1951
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1952
-
1953
- kv_self->clear();
1954
- llama_kv_cache_guard kv_guard(kv_self);
2019
+ memory->clear(true);
1955
2020
 
1956
2021
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1957
2022
  batch.n_tokens = n_batch;
@@ -1963,39 +2028,44 @@ void llama_context::opt_epoch_iter(
1963
2028
  batch.logits [pos_batch] = true;
1964
2029
  }
1965
2030
 
1966
- const auto n_tokens_all = batch.n_tokens;
2031
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2032
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2033
+ return;
2034
+ }
2035
+
2036
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
1967
2037
 
1968
2038
  n_queued_tokens += n_tokens_all;
1969
2039
 
1970
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1971
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1972
-
1973
2040
  embd_seq.clear();
1974
2041
 
1975
- int64_t n_outputs_all = n_tokens_all;
2042
+ uint32_t n_outputs_all = n_tokens_all;
1976
2043
 
1977
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
2044
+ auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2045
+ if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2046
+ LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2047
+ break;
2048
+ }
1978
2049
 
1979
2050
  // reserve output buffer
1980
2051
  if (output_reserve(n_outputs_all) < n_outputs_all) {
1981
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2052
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1982
2053
  GGML_ABORT("TODO: handle this error");
1983
2054
  };
1984
2055
 
1985
- for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1986
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
2056
+ uint32_t pos_batch = 0;
2057
+ do {
2058
+ const auto & ubatch = mctx->get_ubatch();
1987
2059
 
1988
2060
  n_outputs = ubatch.n_tokens;
1989
2061
 
1990
- // TODO: not sure if this is needed
1991
- if (!kv_self->find_slot(ubatch)) {
1992
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1993
-
1994
- GGML_ABORT("TODO: handle this error");
2062
+ if (!mctx->apply()) {
2063
+ LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
2064
+ break;
1995
2065
  }
1996
2066
 
1997
2067
  auto * gf = graph_init();
1998
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2068
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
1999
2069
 
2000
2070
  struct ggml_context * ctx_compute_opt;
2001
2071
  {
@@ -2010,6 +2080,7 @@ void llama_context::opt_epoch_iter(
2010
2080
  }
2011
2081
  ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2012
2082
  ggml_opt_alloc(opt_ctx, train);
2083
+
2013
2084
  res->set_inputs(&ubatch);
2014
2085
  {
2015
2086
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@@ -2027,10 +2098,10 @@ void llama_context::opt_epoch_iter(
2027
2098
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2028
2099
  }
2029
2100
  ggml_free(ctx_compute_opt);
2030
- }
2031
- }
2032
2101
 
2033
- kv_guard.commit();
2102
+ pos_batch += ubatch.n_tokens;
2103
+ } while (mctx->next());
2104
+ }
2034
2105
  }
2035
2106
 
2036
2107
  void llama_context::opt_epoch(
@@ -2190,12 +2261,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2190
2261
  return &ctx->get_model();
2191
2262
  }
2192
2263
 
2264
+ // deprecated
2193
2265
  llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2194
- return ctx->get_kv_self();
2266
+ return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2195
2267
  }
2196
2268
 
2269
+ // deprecated
2197
2270
  void llama_kv_self_update(llama_context * ctx) {
2198
- ctx->kv_self_update();
2271
+ ctx->kv_self_update(false);
2199
2272
  }
2200
2273
 
2201
2274
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2310,13 +2383,118 @@ int32_t llama_apply_adapter_cvec(
2310
2383
  return res ? 0 : -1;
2311
2384
  }
2312
2385
 
2386
+ //
2387
+ // memory
2388
+ //
2389
+
2390
+ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2391
+ return ctx->get_memory();
2392
+ }
2393
+
2394
+ void llama_memory_clear(llama_memory_t mem, bool data) {
2395
+ if (!mem) {
2396
+ return;
2397
+ }
2398
+
2399
+ mem->clear(data);
2400
+ }
2401
+
2402
+ bool llama_memory_seq_rm(
2403
+ llama_memory_t mem,
2404
+ llama_seq_id seq_id,
2405
+ llama_pos p0,
2406
+ llama_pos p1) {
2407
+ if (!mem) {
2408
+ return true;
2409
+ }
2410
+
2411
+ return mem->seq_rm(seq_id, p0, p1);
2412
+ }
2413
+
2414
+ void llama_memory_seq_cp(
2415
+ llama_memory_t mem,
2416
+ llama_seq_id seq_id_src,
2417
+ llama_seq_id seq_id_dst,
2418
+ llama_pos p0,
2419
+ llama_pos p1) {
2420
+ if (!mem) {
2421
+ return;
2422
+ }
2423
+
2424
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2425
+ }
2426
+
2427
+ void llama_memory_seq_keep(
2428
+ llama_memory_t mem,
2429
+ llama_seq_id seq_id) {
2430
+ if (!mem) {
2431
+ return;
2432
+ }
2433
+
2434
+ mem->seq_keep(seq_id);
2435
+ }
2436
+
2437
+ void llama_memory_seq_add(
2438
+ llama_memory_t mem,
2439
+ llama_seq_id seq_id,
2440
+ llama_pos p0,
2441
+ llama_pos p1,
2442
+ llama_pos delta) {
2443
+ if (!mem) {
2444
+ return;
2445
+ }
2446
+
2447
+ mem->seq_add(seq_id, p0, p1, delta);
2448
+ }
2449
+
2450
+ void llama_memory_seq_div(
2451
+ llama_memory_t mem,
2452
+ llama_seq_id seq_id,
2453
+ llama_pos p0,
2454
+ llama_pos p1,
2455
+ int d) {
2456
+ if (!mem) {
2457
+ return;
2458
+ }
2459
+
2460
+ mem->seq_div(seq_id, p0, p1, d);
2461
+ }
2462
+
2463
+ llama_pos llama_memory_seq_pos_min(
2464
+ llama_memory_t mem,
2465
+ llama_seq_id seq_id) {
2466
+ if (!mem) {
2467
+ return -1;
2468
+ }
2469
+
2470
+ return mem->seq_pos_min(seq_id);
2471
+ }
2472
+
2473
+ llama_pos llama_memory_seq_pos_max(
2474
+ llama_memory_t mem,
2475
+ llama_seq_id seq_id) {
2476
+ if (!mem) {
2477
+ return -1;
2478
+ }
2479
+
2480
+ return mem->seq_pos_max(seq_id);
2481
+ }
2482
+
2483
+ bool llama_memory_can_shift(llama_memory_t mem) {
2484
+ if (!mem) {
2485
+ return false;
2486
+ }
2487
+
2488
+ return mem->get_can_shift();
2489
+ }
2490
+
2313
2491
  //
2314
2492
  // kv cache
2315
2493
  //
2316
2494
 
2317
2495
  // deprecated
2318
2496
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2319
- const auto * kv = ctx->get_kv_self();
2497
+ const auto * kv = llama_get_memory(ctx);
2320
2498
  if (!kv) {
2321
2499
  return 0;
2322
2500
  }
@@ -2338,7 +2516,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2338
2516
  // deprecated
2339
2517
  // note: this is the same as above - will be removed anyway, so it's ok
2340
2518
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2341
- const auto * kv = ctx->get_kv_self();
2519
+ const auto * kv = llama_get_memory(ctx);
2342
2520
  if (!kv) {
2343
2521
  return 0;
2344
2522
  }
@@ -2357,114 +2535,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2357
2535
  return res;
2358
2536
  }
2359
2537
 
2538
+ // deprecated
2360
2539
  void llama_kv_self_clear(llama_context * ctx) {
2361
- auto * kv = ctx->get_kv_self();
2540
+ auto * kv = llama_get_memory(ctx);
2362
2541
  if (!kv) {
2363
2542
  return;
2364
2543
  }
2365
2544
 
2366
- kv->clear();
2545
+ llama_memory_clear(kv, true);
2367
2546
  }
2368
2547
 
2548
+ // deprecated
2369
2549
  bool llama_kv_self_seq_rm(
2370
2550
  llama_context * ctx,
2371
2551
  llama_seq_id seq_id,
2372
2552
  llama_pos p0,
2373
2553
  llama_pos p1) {
2374
- auto * kv = ctx->get_kv_self();
2554
+ auto * kv = llama_get_memory(ctx);
2375
2555
  if (!kv) {
2376
2556
  return true;
2377
2557
  }
2378
2558
 
2379
- return kv->seq_rm(seq_id, p0, p1);
2559
+ return llama_memory_seq_rm(kv, seq_id, p0, p1);
2380
2560
  }
2381
2561
 
2562
+ // deprecated
2382
2563
  void llama_kv_self_seq_cp(
2383
2564
  llama_context * ctx,
2384
2565
  llama_seq_id seq_id_src,
2385
2566
  llama_seq_id seq_id_dst,
2386
2567
  llama_pos p0,
2387
2568
  llama_pos p1) {
2388
- auto * kv = ctx->get_kv_self();
2569
+ auto * kv = llama_get_memory(ctx);
2389
2570
  if (!kv) {
2390
2571
  return;
2391
2572
  }
2392
2573
 
2393
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2574
+ llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2394
2575
  }
2395
2576
 
2577
+ // deprecated
2396
2578
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2397
- auto * kv = ctx->get_kv_self();
2579
+ auto * kv = llama_get_memory(ctx);
2398
2580
  if (!kv) {
2399
2581
  return;
2400
2582
  }
2401
2583
 
2402
- kv->seq_keep(seq_id);
2584
+ llama_memory_seq_keep(kv, seq_id);
2403
2585
  }
2404
2586
 
2587
+ // deprecated
2405
2588
  void llama_kv_self_seq_add(
2406
2589
  llama_context * ctx,
2407
2590
  llama_seq_id seq_id,
2408
2591
  llama_pos p0,
2409
2592
  llama_pos p1,
2410
2593
  llama_pos delta) {
2411
- auto * kv = ctx->get_kv_self();
2594
+ auto * kv = llama_get_memory(ctx);
2412
2595
  if (!kv) {
2413
2596
  return;
2414
2597
  }
2415
2598
 
2416
- kv->seq_add(seq_id, p0, p1, delta);
2599
+ llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2417
2600
  }
2418
2601
 
2602
+ // deprecated
2419
2603
  void llama_kv_self_seq_div(
2420
2604
  llama_context * ctx,
2421
2605
  llama_seq_id seq_id,
2422
2606
  llama_pos p0,
2423
2607
  llama_pos p1,
2424
2608
  int d) {
2425
- auto * kv = ctx->get_kv_self();
2609
+ auto * kv = llama_get_memory(ctx);
2426
2610
  if (!kv) {
2427
2611
  return;
2428
2612
  }
2429
2613
 
2430
- kv->seq_div(seq_id, p0, p1, d);
2614
+ llama_memory_seq_div(kv, seq_id, p0, p1, d);
2431
2615
  }
2432
2616
 
2617
+ // deprecated
2433
2618
  llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2434
- const auto * kv = ctx->get_kv_self();
2619
+ auto * kv = llama_get_memory(ctx);
2435
2620
  if (!kv) {
2436
2621
  return -1;
2437
2622
  }
2438
2623
 
2439
- return kv->seq_pos_min(seq_id);
2624
+ return llama_memory_seq_pos_min(kv, seq_id);
2440
2625
  }
2441
2626
 
2627
+ // deprecated
2442
2628
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2443
- const auto * kv = ctx->get_kv_self();
2629
+ auto * kv = llama_get_memory(ctx);
2444
2630
  if (!kv) {
2445
2631
  return -1;
2446
2632
  }
2447
2633
 
2448
- return kv->seq_pos_max(seq_id);
2634
+ return llama_memory_seq_pos_max(kv, seq_id);
2449
2635
  }
2450
2636
 
2637
+ // deprecated
2451
2638
  void llama_kv_self_defrag(llama_context * ctx) {
2452
- auto * kv = ctx->get_kv_self();
2453
- if (!kv) {
2454
- return;
2455
- }
2456
-
2457
2639
  // force defrag
2458
- kv->defrag_sched(-1.0f);
2640
+ ctx->kv_self_defrag_sched();
2459
2641
  }
2460
2642
 
2643
+ // deprecated
2461
2644
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2462
- const auto * kv = ctx->get_kv_self();
2645
+ auto * kv = llama_get_memory(ctx);
2463
2646
  if (!kv) {
2464
2647
  return false;
2465
2648
  }
2466
2649
 
2467
- return kv->get_can_shift();
2650
+ return llama_memory_can_shift(kv);
2468
2651
  }
2469
2652
 
2470
2653
  // llama state API
@@ -2589,22 +2772,8 @@ int32_t llama_encode(
2589
2772
  int32_t llama_decode(
2590
2773
  llama_context * ctx,
2591
2774
  llama_batch batch) {
2592
- int ret = ctx->decode(batch);
2593
-
2594
- // defrag and try again
2595
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596
- if (ret == 1) {
2597
- llama_kv_self_defrag(ctx);
2598
- ret = ctx->decode(batch);
2599
-
2600
- if (ret == 1) {
2601
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602
-
2603
- return ret;
2604
- }
2605
- }
2606
-
2607
- if (ret != 0) {
2775
+ const int ret = ctx->decode(batch);
2776
+ if (ret != 0 && ret != 1) {
2608
2777
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2609
2778
  }
2610
2779