@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -1,10 +1,11 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-io.h"
6
+ #include "llama-memory.h"
5
7
  #include "llama-mmap.h"
6
8
  #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
9
 
9
10
  #include <cinttypes>
10
11
  #include <cstring>
@@ -18,7 +19,8 @@
18
19
  llama_context::llama_context(
19
20
  const llama_model & model,
20
21
  llama_context_params params) :
21
- model(model) {
22
+ model(model),
23
+ batch_allocr(std::make_unique<llama_batch_allocr>()) {
22
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
23
25
 
24
26
  t_start_us = model.t_start_us;
@@ -27,8 +29,8 @@ llama_context::llama_context(
27
29
  const auto & hparams = model.hparams;
28
30
 
29
31
  cparams.n_seq_max = std::max(1u, params.n_seq_max);
30
- if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
31
- throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
32
34
  }
33
35
 
34
36
  cparams.n_threads = params.n_threads;
@@ -123,7 +125,7 @@ llama_context::llama_context(
123
125
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
124
126
  }
125
127
 
126
- if (!params.swa_full && cparams.n_seq_max > 1) {
128
+ if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
127
129
  LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
128
130
  __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
129
131
  }
@@ -277,10 +279,9 @@ llama_context::llama_context(
277
279
  int n_nodes_tg = -1;
278
280
 
279
281
  // simulate full KV cache
280
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
282
 
282
- const auto kv_state = kv_self->init_full();
283
- if (!kv_state) {
283
+ const auto mstate = memory->init_full();
284
+ if (!mstate) {
284
285
  throw std::runtime_error("failed to initialize KV cache");
285
286
  }
286
287
 
@@ -288,7 +289,7 @@ llama_context::llama_context(
288
289
 
289
290
  // reserve pp graph first so that buffers are only allocated once
290
291
  {
291
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
292
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
292
293
  if (!gf) {
293
294
  throw std::runtime_error("failed to allocate compute pp buffers");
294
295
  }
@@ -299,7 +300,7 @@ llama_context::llama_context(
299
300
 
300
301
  // reserve with tg graph to get the number of splits and nodes
301
302
  {
302
- auto * gf = graph_reserve(1, 1, 1, kv_state.get());
303
+ auto * gf = graph_reserve(1, 1, 1, mstate.get());
303
304
  if (!gf) {
304
305
  throw std::runtime_error("failed to allocate compute tg buffers");
305
306
  }
@@ -310,7 +311,7 @@ llama_context::llama_context(
310
311
 
311
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
312
313
  {
313
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
314
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
314
315
  if (!gf) {
315
316
  throw std::runtime_error("failed to allocate compute pp buffers");
316
317
  }
@@ -419,40 +420,68 @@ uint32_t llama_context::n_threads_batch() const {
419
420
  return cparams.n_threads_batch;
420
421
  }
421
422
 
422
- llama_kv_cache * llama_context::get_kv_self() {
423
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
424
- return kv_self;
423
+ llama_memory_t llama_context::get_memory() const {
424
+ return memory.get();
425
425
  }
426
426
 
427
- const llama_kv_cache * llama_context::get_kv_self() const {
428
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
429
- return kv_self;
427
+ // deprecated
428
+ void llama_context::kv_self_defrag_sched() {
429
+ if (!memory) {
430
+ return;
431
+ }
432
+
433
+ memory_force_optimize = true;
430
434
  }
431
435
 
432
- bool llama_context::kv_self_update() {
436
+ // deprecated
437
+ bool llama_context::kv_self_update(bool optimize) {
433
438
  if (!memory) {
434
439
  return false;
435
440
  }
436
441
 
437
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
442
+ {
443
+ // TODO: remove in the future
444
+ optimize |= memory_force_optimize;
445
+ memory_force_optimize = false;
438
446
 
439
- if (!kv_self->update(*this)) {
440
- // no updates have been performed
441
- return false;
442
- }
447
+ const auto mstate = memory->init_update(this, optimize);
448
+ switch (mstate->get_status()) {
449
+ case LLAMA_MEMORY_STATUS_SUCCESS:
450
+ {
451
+ // noop
452
+ } break;
453
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
454
+ {
455
+ // no updates need to be performed
456
+ return false;
457
+ }
458
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
459
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
460
+ {
461
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
462
+ return false;
463
+ }
464
+ }
443
465
 
444
- // if the KV cache did any computation, we have to reserve a new worst-case graph
445
- const auto kv_state = kv_self->init_full();
446
- if (!kv_state) {
447
- throw std::runtime_error("failed to initialize KV cache");
466
+ if (!mstate->apply()) {
467
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
+ }
448
469
  }
449
470
 
450
- const uint32_t n_seqs = cparams.n_seq_max;
451
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
471
+ // if the memory module did any computation, we have to reserve a new worst-case graph
472
+ {
473
+ const auto mstate = memory->init_full();
474
+ if (!mstate) {
475
+ throw std::runtime_error("failed to initialize memory state");
476
+ }
452
477
 
453
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
454
- if (!gf) {
455
- LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
478
+ const uint32_t n_seqs = cparams.n_seq_max;
479
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
+
481
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
482
+ if (!gf) {
483
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484
+ }
456
485
  }
457
486
 
458
487
  return true;
@@ -467,7 +496,7 @@ float * llama_context::get_logits() {
467
496
  }
468
497
 
469
498
  float * llama_context::get_logits_ith(int32_t i) {
470
- int32_t j = -1;
499
+ int64_t j = -1;
471
500
 
472
501
  try {
473
502
  if (logits == nullptr) {
@@ -490,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
490
519
  }
491
520
  if (j >= n_outputs) {
492
521
  // This should not happen
493
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
522
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
494
523
  }
495
524
 
496
525
  return logits + j*model.vocab.n_tokens();
@@ -509,7 +538,7 @@ float * llama_context::get_embeddings() {
509
538
  }
510
539
 
511
540
  float * llama_context::get_embeddings_ith(int32_t i) {
512
- int32_t j = -1;
541
+ int64_t j = -1;
513
542
 
514
543
  try {
515
544
  if (embd == nullptr) {
@@ -532,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
532
561
  }
533
562
  if (j >= n_outputs) {
534
563
  // This should not happen
535
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
564
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
536
565
  }
537
566
 
538
567
  return embd + j*model.hparams.n_embd;
@@ -692,52 +721,41 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
692
721
  return res;
693
722
  }
694
723
 
695
- int llama_context::encode(llama_batch & inp_batch) {
696
- if (inp_batch.n_tokens == 0) {
724
+ int llama_context::encode(const llama_batch & batch_inp) {
725
+ if (batch_inp.n_tokens == 0) {
697
726
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
698
727
  return -1;
699
728
  }
700
729
 
701
- // temporary allocate memory for the input batch if needed
702
730
  // note: during encode, we always pass the full sequence starting from pos = 0
703
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
731
+ if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733
+ return -1;
734
+ }
704
735
 
705
- const llama_batch & batch = batch_allocr.batch;
706
- const int32_t n_tokens = batch.n_tokens;
736
+ const llama_batch & batch = batch_allocr->get_batch();
707
737
 
708
- const auto & hparams = model.hparams;
738
+ const uint32_t n_tokens = batch.n_tokens;
709
739
 
710
740
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
711
741
 
712
- // TODO: move the validation to the llama_batch_allocr
713
- if (batch.token) {
714
- for (int32_t i = 0; i < n_tokens; ++i) {
715
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
716
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
717
- return -1;
718
- }
719
-
720
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
721
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
722
- throw -1;
723
- }
724
- }
725
- }
726
-
727
742
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
728
- GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
743
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
729
744
 
730
745
  if (t_compute_start_us == 0) {
731
746
  t_compute_start_us = ggml_time_us();
732
747
  }
733
748
 
749
+ // TODO: this clear of the buffer can easily be forgotten - need something better
734
750
  embd_seq.clear();
735
751
 
736
752
  n_queued_tokens += n_tokens;
737
753
 
754
+ const auto & hparams = model.hparams;
755
+
738
756
  const int64_t n_embd = hparams.n_embd;
739
757
 
740
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
758
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
741
759
 
742
760
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
743
761
 
@@ -747,7 +765,7 @@ int llama_context::encode(llama_batch & inp_batch) {
747
765
  return -2;
748
766
  };
749
767
 
750
- for (int32_t i = 0; i < n_tokens; ++i) {
768
+ for (uint32_t i = 0; i < n_tokens; ++i) {
751
769
  output_ids[i] = i;
752
770
  }
753
771
 
@@ -803,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) {
803
821
 
804
822
  GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
805
823
 
806
- for (int32_t i = 0; i < n_tokens; i++) {
824
+ // TODO: fix indexing [UBATCH_IDX]
825
+ for (uint32_t i = 0; i < n_tokens; i++) {
807
826
  const llama_seq_id seq_id = ubatch.seq_id[i][0];
808
827
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
809
828
  continue;
@@ -814,16 +833,18 @@ int llama_context::encode(llama_batch & inp_batch) {
814
833
  } break;
815
834
  case LLAMA_POOLING_TYPE_RANK:
816
835
  {
817
- // extract the rerank score - a single float per sequence
836
+ // extract the rerank score - n_cls_out floats per sequence
818
837
  auto & embd_seq_out = embd_seq;
838
+ const uint32_t n_cls_out = hparams.n_cls_out;
819
839
 
840
+ // TODO: fix indexing [UBATCH_IDX]
820
841
  for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
821
842
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
822
843
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
823
844
  continue;
824
845
  }
825
- embd_seq_out[seq_id].resize(1);
826
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
846
+ embd_seq_out[seq_id].resize(n_cls_out);
847
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
827
848
  }
828
849
  } break;
829
850
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -850,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) {
850
871
 
851
872
  // remember the sequence ids used during the encoding - needed for cross attention later
852
873
  cross.seq_ids_enc.resize(n_tokens);
853
- for (int32_t i = 0; i < n_tokens; i++) {
874
+ for (uint32_t i = 0; i < n_tokens; i++) {
854
875
  cross.seq_ids_enc[i].clear();
855
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
856
- llama_seq_id seq_id = ubatch.seq_id[i][s];
876
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
877
+ llama_seq_id seq_id = batch.seq_id[i][s];
857
878
  cross.seq_ids_enc[i].insert(seq_id);
858
879
  }
859
880
  }
@@ -862,53 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) {
862
883
  return 0;
863
884
  }
864
885
 
865
- int llama_context::decode(llama_batch & inp_batch) {
886
+ int llama_context::decode(const llama_batch & batch_inp) {
866
887
  if (!memory) {
867
888
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
868
- return encode(inp_batch);
889
+ return encode(batch_inp);
869
890
  }
870
891
 
871
- if (inp_batch.n_tokens == 0) {
892
+ if (batch_inp.n_tokens == 0) {
872
893
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
873
894
  return -1;
874
895
  }
875
896
 
876
- if (!inp_batch.pos) {
877
- if (inp_batch.seq_id) {
878
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
879
- return -1;
880
- }
881
- }
882
-
883
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
897
+ // when computing embeddings, all tokens are output
898
+ const bool embd_all = cparams.embeddings;
884
899
 
885
- // temporary allocate memory for the input batch if needed
886
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
900
+ if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
901
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
902
+ return -1;
903
+ }
887
904
 
888
- const llama_batch & batch = batch_allocr.batch;
905
+ const llama_batch & batch = batch_allocr->get_batch();
889
906
 
890
907
  const auto & vocab = model.vocab;
891
908
  const auto & hparams = model.hparams;
892
909
 
893
910
  const int32_t n_vocab = vocab.n_tokens();
911
+ const int64_t n_embd = hparams.n_embd;
894
912
 
895
- const int64_t n_tokens_all = batch.n_tokens;
896
- const int64_t n_embd = hparams.n_embd;
913
+ const uint32_t n_tokens_all = batch.n_tokens;
897
914
 
898
915
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
899
916
 
900
- // TODO: move the validation to the llama_batch_allocr
901
- if (batch.token) {
902
- for (int64_t i = 0; i < n_tokens_all; ++i) {
903
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
904
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
905
- return -1;
906
- }
917
+ const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
907
918
 
908
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
909
- LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
910
- return -1;
911
- }
919
+ if (embd_all) {
920
+ // require that all tokens are output
921
+ if (n_outputs_all != n_tokens_all) {
922
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
923
+ __func__, n_outputs_all, n_tokens_all);
924
+ return -1;
912
925
  }
913
926
  }
914
927
 
@@ -921,61 +934,52 @@ int llama_context::decode(llama_batch & inp_batch) {
921
934
  }
922
935
  n_queued_tokens += n_tokens_all;
923
936
 
924
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
925
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
926
-
937
+ // TODO: this clear of the buffer can easily be forgotten - need something better
927
938
  embd_seq.clear();
928
939
 
929
- int64_t n_outputs_all = 0;
930
-
931
- // count outputs
932
- if (batch.logits && !embd_pooled) {
933
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
934
- n_outputs_all += batch.logits[i] != 0;
935
- }
936
- } else if (embd_pooled) {
937
- n_outputs_all = n_tokens_all;
938
- } else {
939
- // keep last output only
940
- n_outputs_all = 1;
941
- }
940
+ bool did_optimize = false;
942
941
 
943
942
  // handle any pending defrags/shifts
944
- kv_self_update();
943
+ kv_self_update(false);
945
944
 
946
- llama_memory_state_ptr kv_state;
947
-
948
- bool did_defrag = false;
945
+ llama_memory_state_ptr mstate;
949
946
 
950
947
  while (true) {
951
- kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
952
- if (!kv_state) {
948
+ mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949
+ if (!mstate) {
953
950
  return -2;
954
951
  }
955
952
 
956
- switch (kv_state->get_status()) {
953
+ switch (mstate->get_status()) {
957
954
  case LLAMA_MEMORY_STATUS_SUCCESS:
958
955
  {
959
956
  } break;
957
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
958
+ {
959
+ LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
960
+
961
+ return -2;
962
+ }
960
963
  case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
961
964
  {
962
- if (!did_defrag) {
963
- did_defrag = true;
965
+ if (!did_optimize) {
966
+ did_optimize = true;
964
967
 
965
- kv_self->defrag_sched(-1.0f);
966
- if (kv_self_update()) {
967
- LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
968
+ if (kv_self_update(true)) {
969
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
968
970
 
969
971
  continue;
970
972
  }
971
973
  }
972
974
 
973
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
975
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
974
976
 
975
977
  return 1;
976
978
  }
977
979
  case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
978
980
  {
981
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
982
+
979
983
  return -2;
980
984
  }
981
985
  }
@@ -985,16 +989,16 @@ int llama_context::decode(llama_batch & inp_batch) {
985
989
 
986
990
  // reserve output buffer
987
991
  if (output_reserve(n_outputs_all) < n_outputs_all) {
988
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
992
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
989
993
  return -2;
990
994
  };
991
995
 
992
996
  int64_t n_outputs_prev = 0;
993
997
 
994
998
  do {
995
- const auto & ubatch = kv_state->get_ubatch();
999
+ const auto & ubatch = mstate->get_ubatch();
996
1000
 
997
- // count the outputs in this u_batch
1001
+ // count the outputs in this ubatch
998
1002
  {
999
1003
  int32_t n_outputs_new = 0;
1000
1004
 
@@ -1015,26 +1019,30 @@ int llama_context::decode(llama_batch & inp_batch) {
1015
1019
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1016
1020
 
1017
1021
  ggml_status status;
1018
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
1022
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
1019
1023
 
1020
1024
  if (!res) {
1021
1025
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1022
- llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
1026
+ llama_pos pos_min[LLAMA_MAX_SEQ];
1027
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1028
+ pos_min[s] = std::numeric_limits<llama_pos>::max();
1029
+ }
1023
1030
 
1031
+ // TODO: fix sequence indexing
1024
1032
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1025
1033
  const auto & seq_id = ubatch.seq_id[i][0];
1026
1034
 
1027
1035
  pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1028
1036
  }
1029
1037
 
1030
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1038
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1031
1039
  if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1032
1040
  continue;
1033
1041
  }
1034
1042
 
1035
1043
  LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1036
1044
 
1037
- llama_kv_self_seq_rm(this, s, pos_min[s], -1);
1045
+ memory->seq_rm(s, pos_min[s], -1);
1038
1046
  }
1039
1047
 
1040
1048
  switch (status) {
@@ -1050,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1050
1058
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1051
1059
  //}
1052
1060
 
1053
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1061
+ auto * t_logits = res->get_logits();
1054
1062
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1055
1063
 
1056
1064
  if (t_embd && res->get_embd_pooled()) {
@@ -1128,20 +1136,20 @@ int llama_context::decode(llama_batch & inp_batch) {
1128
1136
  }
1129
1137
 
1130
1138
  n_outputs_prev += n_outputs;
1131
- } while (kv_state->next());
1139
+ } while (mstate->next());
1132
1140
 
1133
1141
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1134
1142
  n_outputs = n_outputs_all;
1135
1143
 
1136
1144
  // set output mappings
1137
- {
1145
+ if (n_outputs > 0) {
1138
1146
  bool sorted_output = true;
1139
1147
 
1140
- auto & out_ids = kv_state->out_ids();
1148
+ auto & out_ids = mstate->out_ids();
1141
1149
 
1142
- GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1150
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1143
1151
 
1144
- for (int64_t i = 0; i < n_outputs_all; ++i) {
1152
+ for (int64_t i = 0; i < n_outputs; ++i) {
1145
1153
  int64_t out_id = out_ids[i];
1146
1154
  output_ids[out_id] = i;
1147
1155
  if (out_id != i) {
@@ -1153,20 +1161,22 @@ int llama_context::decode(llama_batch & inp_batch) {
1153
1161
  // note: this is mostly relevant for recurrent models atm
1154
1162
  if (!sorted_output) {
1155
1163
  const uint32_t n_vocab = model.vocab.n_tokens();
1156
- const uint32_t n_embd = model.hparams.n_embd;
1164
+ const uint64_t n_embd = model.hparams.n_embd;
1157
1165
 
1158
1166
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1159
1167
 
1160
1168
  // TODO: is there something more efficient which also minimizes swaps?
1161
1169
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1162
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1163
- int32_t j_min = i;
1164
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1170
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1171
+ uint32_t j_min = i;
1172
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
1165
1173
  if (out_ids[j] < out_ids[j_min]) {
1166
1174
  j_min = j;
1167
1175
  }
1168
1176
  }
1169
- if (j_min == i) { continue; }
1177
+ if (j_min == i) {
1178
+ continue;
1179
+ }
1170
1180
  std::swap(out_ids[i], out_ids[j_min]);
1171
1181
  if (logits_size > 0) {
1172
1182
  for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1179,8 +1189,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1179
1189
  }
1180
1190
  }
1181
1191
  }
1192
+
1182
1193
  std::fill(output_ids.begin(), output_ids.end(), -1);
1183
- for (int32_t i = 0; i < n_outputs; ++i) {
1194
+
1195
+ for (uint32_t i = 0; i < n_outputs; ++i) {
1184
1196
  output_ids[out_ids[i]] = i;
1185
1197
  }
1186
1198
  }
@@ -1189,11 +1201,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1189
1201
  // wait for the computation to finish (automatically done when obtaining the model output)
1190
1202
  //synchronize();
1191
1203
 
1192
- // decide if we need to defrag the kv cache
1193
- if (cparams.defrag_thold > 0.0f) {
1194
- kv_self->defrag_sched(cparams.defrag_thold);
1195
- }
1196
-
1197
1204
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1198
1205
  // overlap with device computation.
1199
1206
  ggml_backend_sched_reset(sched.get());
@@ -1205,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1205
1212
  // output
1206
1213
  //
1207
1214
 
1208
- int32_t llama_context::output_reserve(int32_t n_outputs) {
1215
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1209
1216
  const auto & hparams = model.hparams;
1210
1217
  const auto & vocab = model.vocab;
1211
1218
 
@@ -1215,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1215
1222
  const auto n_vocab = vocab.n_tokens();
1216
1223
  const auto n_embd = hparams.n_embd;
1217
1224
 
1218
- // TODO: use a per-batch flag for logits presence instead
1219
- bool has_logits = !cparams.embeddings;
1220
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1225
+ bool has_logits = true;
1226
+ bool has_embd = cparams.embeddings;
1221
1227
 
1222
1228
  // TODO: hacky enc-dec support
1223
1229
  if (model.arch == LLM_ARCH_T5) {
@@ -1271,8 +1277,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1271
1277
  // set all ids as invalid (negative)
1272
1278
  std::fill(output_ids.begin(), output_ids.end(), -1);
1273
1279
 
1274
- this->n_outputs = 0;
1275
- this->n_outputs_max = n_outputs_max;
1280
+ this->n_outputs = 0;
1276
1281
 
1277
1282
  return n_outputs_max;
1278
1283
  }
@@ -1301,7 +1306,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1301
1306
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
1307
 
1303
1308
  if (n_tokens % n_seqs != 0) {
1304
- n_tokens = (n_tokens / n_seqs) * n_seqs;
1309
+ n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1305
1310
  n_outputs = std::min(n_outputs, n_tokens);
1306
1311
 
1307
1312
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -1763,14 +1768,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1763
1768
 
1764
1769
  std::vector<int32_t> w_output_pos;
1765
1770
 
1766
- GGML_ASSERT(n_outputs <= n_outputs_max);
1767
-
1768
1771
  w_output_pos.resize(n_outputs);
1769
1772
 
1770
1773
  // build a more compact representation of the output ids
1771
1774
  for (size_t i = 0; i < n_batch(); ++i) {
1772
1775
  // map an output id to a position in the batch
1773
- int32_t pos = output_ids[i];
1776
+ int64_t pos = output_ids[i];
1774
1777
  if (pos >= 0) {
1775
1778
  GGML_ASSERT(pos < n_outputs);
1776
1779
  w_output_pos[pos] = i;
@@ -1810,11 +1813,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1810
1813
  }
1811
1814
  }
1812
1815
 
1813
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1814
-
1815
- if (kv_self != nullptr) {
1816
+ if (memory != nullptr) {
1816
1817
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1817
- kv_self->state_write(io);
1818
+ memory->state_write(io);
1818
1819
  }
1819
1820
 
1820
1821
  return io.n_bytes();
@@ -1901,9 +1902,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1901
1902
  if (memory) {
1902
1903
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1903
1904
 
1904
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1905
-
1906
- kv_self->state_read(io);
1905
+ memory->state_read(io);
1907
1906
  }
1908
1907
 
1909
1908
  return io.n_bytes();
@@ -1913,9 +1912,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1913
1912
  GGML_UNUSED(seq_id);
1914
1913
 
1915
1914
  if (memory) {
1916
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1917
-
1918
- kv_self->state_write(io, seq_id);
1915
+ memory->state_write(io, seq_id);
1919
1916
  }
1920
1917
 
1921
1918
  return io.n_bytes();
@@ -1925,9 +1922,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
1925
1922
  GGML_UNUSED(seq_id);
1926
1923
 
1927
1924
  if (memory) {
1928
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1929
-
1930
- kv_self->state_read(io, seq_id);
1925
+ memory->state_read(io, seq_id);
1931
1926
  }
1932
1927
 
1933
1928
  return io.n_bytes();
@@ -2032,9 +2027,7 @@ void llama_context::opt_epoch_iter(
2032
2027
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
2033
2028
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
2034
2029
 
2035
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2036
-
2037
- kv_self->clear();
2030
+ memory->clear(true);
2038
2031
 
2039
2032
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
2040
2033
  batch.n_tokens = n_batch;
@@ -2050,38 +2043,35 @@ void llama_context::opt_epoch_iter(
2050
2043
 
2051
2044
  n_queued_tokens += n_tokens_all;
2052
2045
 
2053
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
2054
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2055
-
2056
2046
  embd_seq.clear();
2057
2047
 
2058
- int64_t n_outputs_all = n_tokens_all;
2048
+ uint32_t n_outputs_all = n_tokens_all;
2059
2049
 
2060
- auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2061
- if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2050
+ auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
2051
+ if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2062
2052
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2063
2053
  break;
2064
2054
  }
2065
2055
 
2066
2056
  // reserve output buffer
2067
2057
  if (output_reserve(n_outputs_all) < n_outputs_all) {
2068
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2058
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
2069
2059
  GGML_ABORT("TODO: handle this error");
2070
2060
  };
2071
2061
 
2072
2062
  uint32_t pos_batch = 0;
2073
2063
  do {
2074
- const auto & ubatch = kv_state->get_ubatch();
2064
+ const auto & ubatch = mstate->get_ubatch();
2075
2065
 
2076
2066
  n_outputs = ubatch.n_tokens;
2077
2067
 
2078
- if (!kv_state->apply()) {
2068
+ if (!mstate->apply()) {
2079
2069
  LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2080
2070
  break;
2081
2071
  }
2082
2072
 
2083
2073
  auto * gf = graph_init();
2084
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
2074
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
2085
2075
 
2086
2076
  struct ggml_context * ctx_compute_opt;
2087
2077
  {
@@ -2116,7 +2106,7 @@ void llama_context::opt_epoch_iter(
2116
2106
  ggml_free(ctx_compute_opt);
2117
2107
 
2118
2108
  pos_batch += ubatch.n_tokens;
2119
- } while (kv_state->next());
2109
+ } while (mstate->next());
2120
2110
  }
2121
2111
  }
2122
2112
 
@@ -2277,13 +2267,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2277
2267
  return &ctx->get_model();
2278
2268
  }
2279
2269
 
2270
+ // deprecated
2280
2271
  llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2281
- return ctx->get_kv_self();
2272
+ return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2282
2273
  }
2283
2274
 
2284
2275
  // deprecated
2285
2276
  void llama_kv_self_update(llama_context * ctx) {
2286
- ctx->kv_self_update();
2277
+ ctx->kv_self_update(false);
2287
2278
  }
2288
2279
 
2289
2280
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2398,13 +2389,118 @@ int32_t llama_apply_adapter_cvec(
2398
2389
  return res ? 0 : -1;
2399
2390
  }
2400
2391
 
2392
+ //
2393
+ // memory
2394
+ //
2395
+
2396
+ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2397
+ return ctx->get_memory();
2398
+ }
2399
+
2400
+ void llama_memory_clear(llama_memory_t mem, bool data) {
2401
+ if (!mem) {
2402
+ return;
2403
+ }
2404
+
2405
+ mem->clear(data);
2406
+ }
2407
+
2408
+ bool llama_memory_seq_rm(
2409
+ llama_memory_t mem,
2410
+ llama_seq_id seq_id,
2411
+ llama_pos p0,
2412
+ llama_pos p1) {
2413
+ if (!mem) {
2414
+ return true;
2415
+ }
2416
+
2417
+ return mem->seq_rm(seq_id, p0, p1);
2418
+ }
2419
+
2420
+ void llama_memory_seq_cp(
2421
+ llama_memory_t mem,
2422
+ llama_seq_id seq_id_src,
2423
+ llama_seq_id seq_id_dst,
2424
+ llama_pos p0,
2425
+ llama_pos p1) {
2426
+ if (!mem) {
2427
+ return;
2428
+ }
2429
+
2430
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2431
+ }
2432
+
2433
+ void llama_memory_seq_keep(
2434
+ llama_memory_t mem,
2435
+ llama_seq_id seq_id) {
2436
+ if (!mem) {
2437
+ return;
2438
+ }
2439
+
2440
+ mem->seq_keep(seq_id);
2441
+ }
2442
+
2443
+ void llama_memory_seq_add(
2444
+ llama_memory_t mem,
2445
+ llama_seq_id seq_id,
2446
+ llama_pos p0,
2447
+ llama_pos p1,
2448
+ llama_pos delta) {
2449
+ if (!mem) {
2450
+ return;
2451
+ }
2452
+
2453
+ mem->seq_add(seq_id, p0, p1, delta);
2454
+ }
2455
+
2456
+ void llama_memory_seq_div(
2457
+ llama_memory_t mem,
2458
+ llama_seq_id seq_id,
2459
+ llama_pos p0,
2460
+ llama_pos p1,
2461
+ int d) {
2462
+ if (!mem) {
2463
+ return;
2464
+ }
2465
+
2466
+ mem->seq_div(seq_id, p0, p1, d);
2467
+ }
2468
+
2469
+ llama_pos llama_memory_seq_pos_min(
2470
+ llama_memory_t mem,
2471
+ llama_seq_id seq_id) {
2472
+ if (!mem) {
2473
+ return -1;
2474
+ }
2475
+
2476
+ return mem->seq_pos_min(seq_id);
2477
+ }
2478
+
2479
+ llama_pos llama_memory_seq_pos_max(
2480
+ llama_memory_t mem,
2481
+ llama_seq_id seq_id) {
2482
+ if (!mem) {
2483
+ return -1;
2484
+ }
2485
+
2486
+ return mem->seq_pos_max(seq_id);
2487
+ }
2488
+
2489
+ bool llama_memory_can_shift(llama_memory_t mem) {
2490
+ if (!mem) {
2491
+ return false;
2492
+ }
2493
+
2494
+ return mem->get_can_shift();
2495
+ }
2496
+
2401
2497
  //
2402
2498
  // kv cache
2403
2499
  //
2404
2500
 
2405
2501
  // deprecated
2406
2502
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2407
- const auto * kv = ctx->get_kv_self();
2503
+ const auto * kv = llama_get_memory(ctx);
2408
2504
  if (!kv) {
2409
2505
  return 0;
2410
2506
  }
@@ -2426,7 +2522,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2426
2522
  // deprecated
2427
2523
  // note: this is the same as above - will be removed anyway, so it's ok
2428
2524
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2429
- const auto * kv = ctx->get_kv_self();
2525
+ const auto * kv = llama_get_memory(ctx);
2430
2526
  if (!kv) {
2431
2527
  return 0;
2432
2528
  }
@@ -2445,115 +2541,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2445
2541
  return res;
2446
2542
  }
2447
2543
 
2544
+ // deprecated
2448
2545
  void llama_kv_self_clear(llama_context * ctx) {
2449
- auto * kv = ctx->get_kv_self();
2546
+ auto * kv = llama_get_memory(ctx);
2450
2547
  if (!kv) {
2451
2548
  return;
2452
2549
  }
2453
2550
 
2454
- kv->clear();
2551
+ llama_memory_clear(kv, true);
2455
2552
  }
2456
2553
 
2554
+ // deprecated
2457
2555
  bool llama_kv_self_seq_rm(
2458
2556
  llama_context * ctx,
2459
2557
  llama_seq_id seq_id,
2460
2558
  llama_pos p0,
2461
2559
  llama_pos p1) {
2462
- auto * kv = ctx->get_kv_self();
2560
+ auto * kv = llama_get_memory(ctx);
2463
2561
  if (!kv) {
2464
2562
  return true;
2465
2563
  }
2466
2564
 
2467
- return kv->seq_rm(seq_id, p0, p1);
2565
+ return llama_memory_seq_rm(kv, seq_id, p0, p1);
2468
2566
  }
2469
2567
 
2568
+ // deprecated
2470
2569
  void llama_kv_self_seq_cp(
2471
2570
  llama_context * ctx,
2472
2571
  llama_seq_id seq_id_src,
2473
2572
  llama_seq_id seq_id_dst,
2474
2573
  llama_pos p0,
2475
2574
  llama_pos p1) {
2476
- auto * kv = ctx->get_kv_self();
2575
+ auto * kv = llama_get_memory(ctx);
2477
2576
  if (!kv) {
2478
2577
  return;
2479
2578
  }
2480
2579
 
2481
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2580
+ llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2482
2581
  }
2483
2582
 
2583
+ // deprecated
2484
2584
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2485
- auto * kv = ctx->get_kv_self();
2585
+ auto * kv = llama_get_memory(ctx);
2486
2586
  if (!kv) {
2487
2587
  return;
2488
2588
  }
2489
2589
 
2490
- kv->seq_keep(seq_id);
2590
+ llama_memory_seq_keep(kv, seq_id);
2491
2591
  }
2492
2592
 
2593
+ // deprecated
2493
2594
  void llama_kv_self_seq_add(
2494
2595
  llama_context * ctx,
2495
2596
  llama_seq_id seq_id,
2496
2597
  llama_pos p0,
2497
2598
  llama_pos p1,
2498
2599
  llama_pos delta) {
2499
- auto * kv = ctx->get_kv_self();
2600
+ auto * kv = llama_get_memory(ctx);
2500
2601
  if (!kv) {
2501
2602
  return;
2502
2603
  }
2503
2604
 
2504
- kv->seq_add(seq_id, p0, p1, delta);
2605
+ llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2505
2606
  }
2506
2607
 
2608
+ // deprecated
2507
2609
  void llama_kv_self_seq_div(
2508
2610
  llama_context * ctx,
2509
2611
  llama_seq_id seq_id,
2510
2612
  llama_pos p0,
2511
2613
  llama_pos p1,
2512
2614
  int d) {
2513
- auto * kv = ctx->get_kv_self();
2615
+ auto * kv = llama_get_memory(ctx);
2514
2616
  if (!kv) {
2515
2617
  return;
2516
2618
  }
2517
2619
 
2518
- kv->seq_div(seq_id, p0, p1, d);
2620
+ llama_memory_seq_div(kv, seq_id, p0, p1, d);
2519
2621
  }
2520
2622
 
2623
+ // deprecated
2521
2624
  llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2522
- const auto * kv = ctx->get_kv_self();
2625
+ auto * kv = llama_get_memory(ctx);
2523
2626
  if (!kv) {
2524
2627
  return -1;
2525
2628
  }
2526
2629
 
2527
- return kv->seq_pos_min(seq_id);
2630
+ return llama_memory_seq_pos_min(kv, seq_id);
2528
2631
  }
2529
2632
 
2633
+ // deprecated
2530
2634
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2531
- const auto * kv = ctx->get_kv_self();
2635
+ auto * kv = llama_get_memory(ctx);
2532
2636
  if (!kv) {
2533
2637
  return -1;
2534
2638
  }
2535
2639
 
2536
- return kv->seq_pos_max(seq_id);
2640
+ return llama_memory_seq_pos_max(kv, seq_id);
2537
2641
  }
2538
2642
 
2539
2643
  // deprecated
2540
2644
  void llama_kv_self_defrag(llama_context * ctx) {
2541
- auto * kv = ctx->get_kv_self();
2542
- if (!kv) {
2543
- return;
2544
- }
2545
-
2546
2645
  // force defrag
2547
- kv->defrag_sched(-1.0f);
2646
+ ctx->kv_self_defrag_sched();
2548
2647
  }
2549
2648
 
2649
+ // deprecated
2550
2650
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2551
- const auto * kv = ctx->get_kv_self();
2651
+ auto * kv = llama_get_memory(ctx);
2552
2652
  if (!kv) {
2553
2653
  return false;
2554
2654
  }
2555
2655
 
2556
- return kv->get_can_shift();
2656
+ return llama_memory_can_shift(kv);
2557
2657
  }
2558
2658
 
2559
2659
  // llama state API