@novastera-oss/llamarn 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +140 -38
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +48 -67
  13. package/cpp/LlamaCppModel.h +8 -3
  14. package/cpp/PureCppImpl.cpp +1 -1
  15. package/cpp/PureCppImpl.h +2 -2
  16. package/cpp/build-info.cpp +2 -2
  17. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  18. package/cpp/llama.cpp/Makefile +2 -2
  19. package/cpp/llama.cpp/README.md +33 -13
  20. package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
  21. package/cpp/llama.cpp/common/arg.cpp +38 -12
  22. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  23. package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
  24. package/cpp/llama.cpp/common/chat-parser.h +4 -1
  25. package/cpp/llama.cpp/common/chat.cpp +16 -13
  26. package/cpp/llama.cpp/common/chat.h +1 -1
  27. package/cpp/llama.cpp/common/common.cpp +52 -40
  28. package/cpp/llama.cpp/common/common.h +5 -2
  29. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  30. package/cpp/llama.cpp/common/json-partial.h +2 -1
  31. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  32. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  33. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  34. package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  37. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  38. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
  39. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  41. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  79. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  82. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  112. package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
  113. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  114. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  115. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  116. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  117. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  118. package/cpp/llama.cpp/include/llama.h +140 -38
  119. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  120. package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
  121. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  122. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  123. package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
  124. package/cpp/llama.cpp/src/llama-batch.h +47 -17
  125. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  126. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  127. package/cpp/llama.cpp/src/llama-context.cpp +488 -313
  128. package/cpp/llama.cpp/src/llama-context.h +38 -17
  129. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  130. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  131. package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
  132. package/cpp/llama.cpp/src/llama-graph.h +109 -52
  133. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  134. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
  139. package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  141. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  142. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
  144. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  145. package/cpp/llama.cpp/src/llama-memory.h +89 -4
  146. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  147. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  148. package/cpp/llama.cpp/src/llama-model.cpp +735 -143
  149. package/cpp/llama.cpp/src/llama-model.h +4 -0
  150. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  151. package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
  152. package/cpp/llama.cpp/src/llama.cpp +11 -7
  153. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  154. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  155. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  156. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  157. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  158. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  159. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  160. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  161. package/cpp/rn-completion.cpp +65 -10
  162. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  163. package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
  164. package/ios/include/chat.h +1 -1
  165. package/ios/include/common/minja/chat-template.hpp +1 -1
  166. package/ios/include/common/minja/minja.hpp +1 -1
  167. package/ios/include/common.h +5 -2
  168. package/ios/include/json-schema-to-grammar.h +4 -4
  169. package/ios/include/llama.h +140 -38
  170. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  171. package/ios/libs/llama.xcframework/Info.plist +20 -20
  172. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
  174. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  175. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
  176. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  177. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  178. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  179. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
  180. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  181. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  182. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  184. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  185. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
  186. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  187. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
  188. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  189. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
  190. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  191. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  192. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
  193. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  194. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  195. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  196. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
  197. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  198. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
  199. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
  202. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
  203. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  204. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  205. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  206. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  207. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
  208. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  209. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
  210. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  211. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  212. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
  213. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
  214. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  215. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  216. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  217. package/package.json +1 -2
  218. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  219. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  221. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
  222. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
  223. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  224. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  225. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -1,14 +1,16 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-io.h"
6
+ #include "llama-memory.h"
5
7
  #include "llama-mmap.h"
6
8
  #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
9
 
10
+ #include <cinttypes>
9
11
  #include <cstring>
12
+ #include <limits>
10
13
  #include <stdexcept>
11
- #include <cinttypes>
12
14
 
13
15
  //
14
16
  // llama_context
@@ -17,7 +19,8 @@
17
19
  llama_context::llama_context(
18
20
  const llama_model & model,
19
21
  llama_context_params params) :
20
- model(model) {
22
+ model(model),
23
+ batch_allocr(std::make_unique<llama_batch_allocr>()) {
21
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
22
25
 
23
26
  t_start_us = model.t_start_us;
@@ -26,8 +29,8 @@ llama_context::llama_context(
26
29
  const auto & hparams = model.hparams;
27
30
 
28
31
  cparams.n_seq_max = std::max(1u, params.n_seq_max);
29
- if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30
- throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
31
34
  }
32
35
 
33
36
  cparams.n_threads = params.n_threads;
@@ -122,6 +125,11 @@ llama_context::llama_context(
122
125
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
123
126
  }
124
127
 
128
+ if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
129
+ LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
130
+ __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
131
+ }
132
+
125
133
  if (!hparams.vocab_only) {
126
134
  // GPU backends
127
135
  for (auto * dev : model.devices) {
@@ -259,15 +267,9 @@ llama_context::llama_context(
259
267
 
260
268
  // reserve worst-case graph
261
269
  if (!hparams.vocab_only && memory) {
262
- const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
270
+ const uint32_t n_seqs = cparams.n_seq_max;
263
271
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
264
272
 
265
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
266
-
267
- // restore later
268
- // TODO: something cleaner
269
- const auto n_outputs_save = n_outputs;
270
-
271
273
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
272
274
 
273
275
  int n_splits_pp = -1;
@@ -277,25 +279,18 @@ llama_context::llama_context(
277
279
  int n_nodes_tg = -1;
278
280
 
279
281
  // simulate full KV cache
280
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
282
 
282
- kv_self->set_full();
283
+ const auto mstate = memory->init_full();
284
+ if (!mstate) {
285
+ throw std::runtime_error("failed to initialize KV cache");
286
+ }
283
287
 
284
288
  cross.v_embd.clear();
285
289
 
286
290
  // reserve pp graph first so that buffers are only allocated once
287
291
  {
288
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
289
-
290
- // max number of outputs
291
- n_outputs = ubatch_pp.n_tokens;
292
-
293
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
294
-
295
- auto * gf = graph_init();
296
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
297
-
298
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
292
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
293
+ if (!gf) {
299
294
  throw std::runtime_error("failed to allocate compute pp buffers");
300
295
  }
301
296
 
@@ -305,16 +300,8 @@ llama_context::llama_context(
305
300
 
306
301
  // reserve with tg graph to get the number of splits and nodes
307
302
  {
308
- llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
309
-
310
- n_outputs = ubatch_tg.n_tokens;
311
-
312
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
313
-
314
- auto * gf = graph_init();
315
- graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
316
-
317
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
303
+ auto * gf = graph_reserve(1, 1, 1, mstate.get());
304
+ if (!gf) {
318
305
  throw std::runtime_error("failed to allocate compute tg buffers");
319
306
  }
320
307
 
@@ -324,22 +311,12 @@ llama_context::llama_context(
324
311
 
325
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
326
313
  {
327
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
328
-
329
- n_outputs = ubatch_pp.n_tokens;
330
-
331
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
332
-
333
- auto * gf = graph_init();
334
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
335
-
336
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
314
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
315
+ if (!gf) {
337
316
  throw std::runtime_error("failed to allocate compute pp buffers");
338
317
  }
339
318
  }
340
319
 
341
- n_outputs = n_outputs_save;
342
-
343
320
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
344
321
  ggml_backend_t backend = backend_ptrs[i];
345
322
  ggml_backend_buffer_type_t buft = backend_buft[i];
@@ -443,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
443
420
  return cparams.n_threads_batch;
444
421
  }
445
422
 
446
- llama_kv_cache * llama_context::get_kv_self() {
447
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
448
- return kv_self;
449
- }
450
-
451
- const llama_kv_cache * llama_context::get_kv_self() const {
452
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
453
- return kv_self;
423
+ llama_memory_t llama_context::get_memory() const {
424
+ return memory.get();
454
425
  }
455
426
 
456
- void llama_context::kv_self_update() {
457
- bool need_reserve = false;
427
+ // deprecated
428
+ void llama_context::kv_self_defrag_sched() {
429
+ if (!memory) {
430
+ return;
431
+ }
458
432
 
459
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
433
+ memory_force_optimize = true;
434
+ }
460
435
 
461
- need_reserve = kv_self->update(*this);
436
+ // deprecated
437
+ bool llama_context::kv_self_update(bool optimize) {
438
+ if (!memory) {
439
+ return false;
440
+ }
462
441
 
463
- // reserve a worst case graph if needed
464
- if (need_reserve) {
465
- LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
442
+ {
443
+ // TODO: remove in the future
444
+ optimize |= memory_force_optimize;
445
+ memory_force_optimize = false;
466
446
 
467
- // build worst-case graph
468
- uint32_t n_seqs = 1; // TODO: worst-case number of sequences
469
- uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
447
+ const auto mstate = memory->init_update(this, optimize);
448
+ switch (mstate->get_status()) {
449
+ case LLAMA_MEMORY_STATUS_SUCCESS:
450
+ {
451
+ // noop
452
+ } break;
453
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
454
+ {
455
+ // no updates need to be performed
456
+ return false;
457
+ }
458
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
459
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
460
+ {
461
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
462
+ return false;
463
+ }
464
+ }
470
465
 
471
- // simulate full KV cache
472
- kv_self->set_full();
466
+ if (!mstate->apply()) {
467
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
+ }
469
+ }
473
470
 
474
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
475
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
471
+ // if the memory module did any computation, we have to reserve a new worst-case graph
472
+ {
473
+ const auto mstate = memory->init_full();
474
+ if (!mstate) {
475
+ throw std::runtime_error("failed to initialize memory state");
476
+ }
476
477
 
477
- auto * gf = graph_init();
478
- graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
478
+ const uint32_t n_seqs = cparams.n_seq_max;
479
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
479
480
 
480
- // initialize scheduler with the worst-case graph
481
- ggml_backend_sched_reset(sched.get());
482
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
483
- LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
481
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
482
+ if (!gf) {
483
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484
484
  }
485
485
  }
486
+
487
+ return true;
486
488
  }
487
489
 
488
490
  enum llama_pooling_type llama_context::pooling_type() const {
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
494
496
  }
495
497
 
496
498
  float * llama_context::get_logits_ith(int32_t i) {
497
- int32_t j = -1;
499
+ int64_t j = -1;
498
500
 
499
501
  try {
500
502
  if (logits == nullptr) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
517
519
  }
518
520
  if (j >= n_outputs) {
519
521
  // This should not happen
520
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
522
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
521
523
  }
522
524
 
523
525
  return logits + j*model.vocab.n_tokens();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
536
538
  }
537
539
 
538
540
  float * llama_context::get_embeddings_ith(int32_t i) {
539
- int32_t j = -1;
541
+ int64_t j = -1;
540
542
 
541
543
  try {
542
544
  if (embd == nullptr) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
559
561
  }
560
562
  if (j >= n_outputs) {
561
563
  // This should not happen
562
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
564
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
563
565
  }
564
566
 
565
567
  return embd + j*model.hparams.n_embd;
@@ -676,52 +678,84 @@ bool llama_context::apply_adapter_cvec(
676
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
677
679
  }
678
680
 
679
- int llama_context::encode(llama_batch & inp_batch) {
680
- if (inp_batch.n_tokens == 0) {
681
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
682
+ if (mstate && !mstate->apply()) {
683
+ LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
684
+ ret = GGML_STATUS_FAILED;
685
+ return nullptr;
686
+ }
687
+
688
+ auto * gf = graph_init();
689
+ if (!gf) {
690
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
+ ret = GGML_STATUS_FAILED;
692
+ return nullptr;
693
+ }
694
+
695
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
696
+ if (!res) {
697
+ LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
+ ret = GGML_STATUS_FAILED;
699
+ return nullptr;
700
+ }
701
+
702
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
703
+
704
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
+ ret = GGML_STATUS_ALLOC_FAILED;
707
+ return nullptr;
708
+ }
709
+
710
+ res->set_inputs(&ubatch);
711
+
712
+ const auto status = graph_compute(gf, ubatch.n_tokens > 1);
713
+ if (status != GGML_STATUS_SUCCESS) {
714
+ LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
+ ret = status;
716
+ return nullptr;
717
+ }
718
+
719
+ ret = GGML_STATUS_SUCCESS;
720
+
721
+ return res;
722
+ }
723
+
724
+ int llama_context::encode(const llama_batch & batch_inp) {
725
+ if (batch_inp.n_tokens == 0) {
681
726
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
682
727
  return -1;
683
728
  }
684
729
 
685
- // temporary allocate memory for the input batch if needed
686
730
  // note: during encode, we always pass the full sequence starting from pos = 0
687
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
731
+ if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733
+ return -1;
734
+ }
688
735
 
689
- const llama_batch & batch = batch_allocr.batch;
690
- const int32_t n_tokens = batch.n_tokens;
736
+ const llama_batch & batch = batch_allocr->get_batch();
691
737
 
692
- const auto & hparams = model.hparams;
738
+ const uint32_t n_tokens = batch.n_tokens;
693
739
 
694
740
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
695
741
 
696
- // TODO: move the validation to the llama_batch_allocr
697
- if (batch.token) {
698
- for (int32_t i = 0; i < n_tokens; ++i) {
699
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
700
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
701
- return -1;
702
- }
703
-
704
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
706
- throw -1;
707
- }
708
- }
709
- }
710
-
711
742
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
712
- GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
743
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
713
744
 
714
745
  if (t_compute_start_us == 0) {
715
746
  t_compute_start_us = ggml_time_us();
716
747
  }
717
748
 
749
+ // TODO: this clear of the buffer can easily be forgotten - need something better
718
750
  embd_seq.clear();
719
751
 
720
752
  n_queued_tokens += n_tokens;
721
753
 
754
+ const auto & hparams = model.hparams;
755
+
722
756
  const int64_t n_embd = hparams.n_embd;
723
757
 
724
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
758
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
725
759
 
726
760
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
727
761
 
@@ -731,14 +765,12 @@ int llama_context::encode(llama_batch & inp_batch) {
731
765
  return -2;
732
766
  };
733
767
 
734
- for (int32_t i = 0; i < n_tokens; ++i) {
768
+ for (uint32_t i = 0; i < n_tokens; ++i) {
735
769
  output_ids[i] = i;
736
770
  }
737
771
 
738
772
  n_outputs = n_tokens;
739
773
 
740
- //batch_manager->prepare(ubatch);
741
-
742
774
  ggml_backend_sched_reset(sched.get());
743
775
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
744
776
 
@@ -749,26 +781,18 @@ int llama_context::encode(llama_batch & inp_batch) {
749
781
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
750
782
  cparams.causal_attn = false;
751
783
 
752
- auto * gf = graph_init();
753
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
754
-
755
- ggml_backend_sched_alloc_graph(sched.get(), gf);
756
-
757
- res->set_inputs(&ubatch);
784
+ ggml_status status;
785
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
758
786
 
759
787
  cparams.causal_attn = causal_attn_org;
760
788
 
761
- const auto compute_status = graph_compute(gf, n_tokens > 1);
762
- switch (compute_status) {
763
- case GGML_STATUS_SUCCESS:
764
- break;
765
- case GGML_STATUS_ABORTED:
766
- return 2;
767
- case GGML_STATUS_ALLOC_FAILED:
768
- return -2;
769
- case GGML_STATUS_FAILED:
770
- default:
771
- return -3;
789
+ if (!res) {
790
+ switch (status) {
791
+ case GGML_STATUS_ABORTED: return 2;
792
+ case GGML_STATUS_ALLOC_FAILED: return -2;
793
+ case GGML_STATUS_FAILED: return -3;
794
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
795
+ }
772
796
  }
773
797
 
774
798
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -797,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) {
797
821
 
798
822
  GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
799
823
 
800
- for (int32_t i = 0; i < n_tokens; i++) {
824
+ // TODO: fix indexing [UBATCH_IDX]
825
+ for (uint32_t i = 0; i < n_tokens; i++) {
801
826
  const llama_seq_id seq_id = ubatch.seq_id[i][0];
802
827
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
803
828
  continue;
@@ -808,16 +833,18 @@ int llama_context::encode(llama_batch & inp_batch) {
808
833
  } break;
809
834
  case LLAMA_POOLING_TYPE_RANK:
810
835
  {
811
- // extract the rerank score - a single float per sequence
836
+ // extract the rerank score - n_cls_out floats per sequence
812
837
  auto & embd_seq_out = embd_seq;
838
+ const uint32_t n_cls_out = hparams.n_cls_out;
813
839
 
840
+ // TODO: fix indexing [UBATCH_IDX]
814
841
  for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
815
842
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
816
843
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
817
844
  continue;
818
845
  }
819
- embd_seq_out[seq_id].resize(1);
820
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
846
+ embd_seq_out[seq_id].resize(n_cls_out);
847
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
821
848
  }
822
849
  } break;
823
850
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -844,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) {
844
871
 
845
872
  // remember the sequence ids used during the encoding - needed for cross attention later
846
873
  cross.seq_ids_enc.resize(n_tokens);
847
- for (int32_t i = 0; i < n_tokens; i++) {
874
+ for (uint32_t i = 0; i < n_tokens; i++) {
848
875
  cross.seq_ids_enc[i].clear();
849
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
850
- llama_seq_id seq_id = ubatch.seq_id[i][s];
876
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
877
+ llama_seq_id seq_id = batch.seq_id[i][s];
851
878
  cross.seq_ids_enc[i].insert(seq_id);
852
879
  }
853
880
  }
@@ -856,55 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) {
856
883
  return 0;
857
884
  }
858
885
 
859
- int llama_context::decode(llama_batch & inp_batch) {
886
+ int llama_context::decode(const llama_batch & batch_inp) {
860
887
  if (!memory) {
861
888
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
862
- return encode(inp_batch);
889
+ return encode(batch_inp);
863
890
  }
864
891
 
865
- if (inp_batch.n_tokens == 0) {
892
+ if (batch_inp.n_tokens == 0) {
866
893
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
867
894
  return -1;
868
895
  }
869
896
 
870
- if (!inp_batch.pos) {
871
- if (inp_batch.seq_id) {
872
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
873
- return -1;
874
- }
875
- }
876
-
877
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
897
+ // when computing embeddings, all tokens are output
898
+ const bool embd_all = cparams.embeddings;
878
899
 
879
- // temporary allocate memory for the input batch if needed
880
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
900
+ if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
901
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
902
+ return -1;
903
+ }
881
904
 
882
- const llama_batch & batch = batch_allocr.batch;
905
+ const llama_batch & batch = batch_allocr->get_batch();
883
906
 
884
907
  const auto & vocab = model.vocab;
885
908
  const auto & hparams = model.hparams;
886
909
 
887
910
  const int32_t n_vocab = vocab.n_tokens();
911
+ const int64_t n_embd = hparams.n_embd;
888
912
 
889
- const int64_t n_tokens_all = batch.n_tokens;
890
- const int64_t n_embd = hparams.n_embd;
891
-
892
- llama_kv_cache_guard kv_guard(kv_self);
913
+ const uint32_t n_tokens_all = batch.n_tokens;
893
914
 
894
915
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
895
916
 
896
- // TODO: move the validation to the llama_batch_allocr
897
- if (batch.token) {
898
- for (int64_t i = 0; i < n_tokens_all; ++i) {
899
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
900
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
901
- return -1;
902
- }
917
+ const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
903
918
 
904
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905
- LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
906
- return -1;
907
- }
919
+ if (embd_all) {
920
+ // require that all tokens are output
921
+ if (n_outputs_all != n_tokens_all) {
922
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
923
+ __func__, n_outputs_all, n_tokens_all);
924
+ return -1;
908
925
  }
909
926
  }
910
927
 
@@ -917,42 +934,71 @@ int llama_context::decode(llama_batch & inp_batch) {
917
934
  }
918
935
  n_queued_tokens += n_tokens_all;
919
936
 
920
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
921
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
922
-
937
+ // TODO: this clear of the buffer can easily be forgotten - need something better
923
938
  embd_seq.clear();
924
939
 
925
- int64_t n_outputs_all = 0;
940
+ bool did_optimize = false;
941
+
942
+ // handle any pending defrags/shifts
943
+ kv_self_update(false);
944
+
945
+ llama_memory_state_ptr mstate;
926
946
 
927
- // count outputs
928
- if (batch.logits && !embd_pooled) {
929
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
930
- n_outputs_all += batch.logits[i] != 0;
947
+ while (true) {
948
+ mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949
+ if (!mstate) {
950
+ return -2;
951
+ }
952
+
953
+ switch (mstate->get_status()) {
954
+ case LLAMA_MEMORY_STATUS_SUCCESS:
955
+ {
956
+ } break;
957
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
958
+ {
959
+ LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
960
+
961
+ return -2;
962
+ }
963
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
964
+ {
965
+ if (!did_optimize) {
966
+ did_optimize = true;
967
+
968
+ if (kv_self_update(true)) {
969
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
970
+
971
+ continue;
972
+ }
973
+ }
974
+
975
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
976
+
977
+ return 1;
978
+ }
979
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
980
+ {
981
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
982
+
983
+ return -2;
984
+ }
931
985
  }
932
- } else if (embd_pooled) {
933
- n_outputs_all = n_tokens_all;
934
- } else {
935
- // keep last output only
936
- n_outputs_all = 1;
937
- }
938
986
 
939
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
987
+ break;
988
+ }
940
989
 
941
990
  // reserve output buffer
942
991
  if (output_reserve(n_outputs_all) < n_outputs_all) {
943
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
992
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
944
993
  return -2;
945
994
  };
946
995
 
947
- // handle any pending defrags/shifts
948
- kv_self_update();
949
-
950
996
  int64_t n_outputs_prev = 0;
951
997
 
952
- while (sbatch.n_tokens > 0) {
953
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
998
+ do {
999
+ const auto & ubatch = mstate->get_ubatch();
954
1000
 
955
- // count the outputs in this u_batch
1001
+ // count the outputs in this ubatch
956
1002
  {
957
1003
  int32_t n_outputs_new = 0;
958
1004
 
@@ -969,33 +1015,41 @@ int llama_context::decode(llama_batch & inp_batch) {
969
1015
  n_outputs = n_outputs_new;
970
1016
  }
971
1017
 
972
- // find KV slot
973
- if (!kv_self->find_slot(ubatch)) {
974
- return 1;
975
- }
976
-
977
1018
  ggml_backend_sched_reset(sched.get());
978
1019
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
979
1020
 
980
- auto * gf = graph_init();
981
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1021
+ ggml_status status;
1022
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
982
1023
 
983
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1024
+ if (!res) {
1025
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1026
+ llama_pos pos_min[LLAMA_MAX_SEQ];
1027
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1028
+ pos_min[s] = std::numeric_limits<llama_pos>::max();
1029
+ }
984
1030
 
985
- ggml_backend_sched_alloc_graph(sched.get(), gf);
1031
+ // TODO: fix sequence indexing
1032
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1033
+ const auto & seq_id = ubatch.seq_id[i][0];
986
1034
 
987
- res->set_inputs(&ubatch);
1035
+ pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1036
+ }
988
1037
 
989
- const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
990
- if (compute_status != GGML_STATUS_SUCCESS) {
991
- switch (compute_status) {
992
- case GGML_STATUS_ABORTED:
993
- return 2;
994
- case GGML_STATUS_ALLOC_FAILED:
995
- return -2;
996
- case GGML_STATUS_FAILED:
997
- default:
998
- return -3;
1038
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1039
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1040
+ continue;
1041
+ }
1042
+
1043
+ LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1044
+
1045
+ memory->seq_rm(s, pos_min[s], -1);
1046
+ }
1047
+
1048
+ switch (status) {
1049
+ case GGML_STATUS_ABORTED: return 2;
1050
+ case GGML_STATUS_ALLOC_FAILED: return -2;
1051
+ case GGML_STATUS_FAILED: return -3;
1052
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
999
1053
  }
1000
1054
  }
1001
1055
 
@@ -1004,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1004
1058
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1005
1059
  //}
1006
1060
 
1007
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1061
+ auto * t_logits = res->get_logits();
1008
1062
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1009
1063
 
1010
1064
  if (t_embd && res->get_embd_pooled()) {
@@ -1082,23 +1136,20 @@ int llama_context::decode(llama_batch & inp_batch) {
1082
1136
  }
1083
1137
 
1084
1138
  n_outputs_prev += n_outputs;
1085
- }
1086
-
1087
- // finalize the batch processing
1088
- kv_guard.commit();
1139
+ } while (mstate->next());
1089
1140
 
1090
1141
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1091
1142
  n_outputs = n_outputs_all;
1092
1143
 
1093
1144
  // set output mappings
1094
- {
1145
+ if (n_outputs > 0) {
1095
1146
  bool sorted_output = true;
1096
1147
 
1097
- auto & out_ids = sbatch.out_ids;
1148
+ auto & out_ids = mstate->out_ids();
1098
1149
 
1099
- GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1150
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1100
1151
 
1101
- for (int64_t i = 0; i < n_outputs_all; ++i) {
1152
+ for (int64_t i = 0; i < n_outputs; ++i) {
1102
1153
  int64_t out_id = out_ids[i];
1103
1154
  output_ids[out_id] = i;
1104
1155
  if (out_id != i) {
@@ -1110,20 +1161,22 @@ int llama_context::decode(llama_batch & inp_batch) {
1110
1161
  // note: this is mostly relevant for recurrent models atm
1111
1162
  if (!sorted_output) {
1112
1163
  const uint32_t n_vocab = model.vocab.n_tokens();
1113
- const uint32_t n_embd = model.hparams.n_embd;
1164
+ const uint64_t n_embd = model.hparams.n_embd;
1114
1165
 
1115
1166
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1116
1167
 
1117
1168
  // TODO: is there something more efficient which also minimizes swaps?
1118
1169
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1119
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1120
- int32_t j_min = i;
1121
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1170
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1171
+ uint32_t j_min = i;
1172
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
1122
1173
  if (out_ids[j] < out_ids[j_min]) {
1123
1174
  j_min = j;
1124
1175
  }
1125
1176
  }
1126
- if (j_min == i) { continue; }
1177
+ if (j_min == i) {
1178
+ continue;
1179
+ }
1127
1180
  std::swap(out_ids[i], out_ids[j_min]);
1128
1181
  if (logits_size > 0) {
1129
1182
  for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1136,8 +1189,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1136
1189
  }
1137
1190
  }
1138
1191
  }
1192
+
1139
1193
  std::fill(output_ids.begin(), output_ids.end(), -1);
1140
- for (int32_t i = 0; i < n_outputs; ++i) {
1194
+
1195
+ for (uint32_t i = 0; i < n_outputs; ++i) {
1141
1196
  output_ids[out_ids[i]] = i;
1142
1197
  }
1143
1198
  }
@@ -1146,11 +1201,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1146
1201
  // wait for the computation to finish (automatically done when obtaining the model output)
1147
1202
  //synchronize();
1148
1203
 
1149
- // decide if we need to defrag the kv cache
1150
- if (cparams.defrag_thold > 0.0f) {
1151
- kv_self->defrag_sched(cparams.defrag_thold);
1152
- }
1153
-
1154
1204
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1155
1205
  // overlap with device computation.
1156
1206
  ggml_backend_sched_reset(sched.get());
@@ -1162,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1162
1212
  // output
1163
1213
  //
1164
1214
 
1165
- int32_t llama_context::output_reserve(int32_t n_outputs) {
1215
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1166
1216
  const auto & hparams = model.hparams;
1167
1217
  const auto & vocab = model.vocab;
1168
1218
 
@@ -1172,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1172
1222
  const auto n_vocab = vocab.n_tokens();
1173
1223
  const auto n_embd = hparams.n_embd;
1174
1224
 
1175
- // TODO: use a per-batch flag for logits presence instead
1176
- bool has_logits = !cparams.embeddings;
1177
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1225
+ bool has_logits = true;
1226
+ bool has_embd = cparams.embeddings;
1178
1227
 
1179
1228
  // TODO: hacky enc-dec support
1180
1229
  if (model.arch == LLM_ARCH_T5) {
@@ -1228,8 +1277,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1228
1277
  // set all ids as invalid (negative)
1229
1278
  std::fill(output_ids.begin(), output_ids.end(), -1);
1230
1279
 
1231
- this->n_outputs = 0;
1232
- this->n_outputs_max = n_outputs_max;
1280
+ this->n_outputs = 0;
1233
1281
 
1234
1282
  return n_outputs_max;
1235
1283
  }
@@ -1254,11 +1302,52 @@ ggml_cgraph * llama_context::graph_init() {
1254
1302
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1255
1303
  }
1256
1304
 
1305
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
1306
+ LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1307
+
1308
+ if (n_tokens % n_seqs != 0) {
1309
+ n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1310
+ n_outputs = std::min(n_outputs, n_tokens);
1311
+
1312
+ LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1313
+ }
1314
+
1315
+ // store the n_outputs as it is, and restore it afterwards
1316
+ // TODO: not sure if needed, might simplify in the future by removing this
1317
+ const auto save_n_outputs = this->n_outputs;
1318
+
1319
+ this->n_outputs = n_outputs;
1320
+
1321
+ llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1322
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
1323
+
1324
+ auto * gf = graph_init();
1325
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
1326
+
1327
+ this->n_outputs = save_n_outputs;
1328
+
1329
+ if (!res) {
1330
+ LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1331
+ return nullptr;
1332
+ }
1333
+
1334
+ ggml_backend_sched_reset(sched.get());
1335
+
1336
+ // initialize scheduler with the specified graph
1337
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1338
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1339
+ return nullptr;
1340
+ }
1341
+
1342
+ return gf;
1343
+ }
1344
+
1257
1345
  llm_graph_result_ptr llama_context::graph_build(
1258
- ggml_context * ctx,
1259
- ggml_cgraph * gf,
1260
- const llama_ubatch & ubatch,
1261
- llm_graph_type gtype) {
1346
+ ggml_context * ctx,
1347
+ ggml_cgraph * gf,
1348
+ const llama_ubatch & ubatch,
1349
+ llm_graph_type gtype,
1350
+ const llama_memory_state_i * mstate) {
1262
1351
  return model.build_graph(
1263
1352
  {
1264
1353
  /*.ctx =*/ ctx,
@@ -1270,7 +1359,7 @@ llm_graph_result_ptr llama_context::graph_build(
1270
1359
  /*.backend_cpu =*/ backend_cpu,
1271
1360
  /*.cvec =*/ &cvec,
1272
1361
  /*.loras =*/ &loras,
1273
- /*.memory =*/ memory.get(),
1362
+ /*.mstate =*/ mstate,
1274
1363
  /*.cross =*/ &cross,
1275
1364
  /*.n_outputs =*/ n_outputs,
1276
1365
  /*.cb =*/ graph_get_cb(),
@@ -1679,14 +1768,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1679
1768
 
1680
1769
  std::vector<int32_t> w_output_pos;
1681
1770
 
1682
- GGML_ASSERT(n_outputs <= n_outputs_max);
1683
-
1684
1771
  w_output_pos.resize(n_outputs);
1685
1772
 
1686
1773
  // build a more compact representation of the output ids
1687
1774
  for (size_t i = 0; i < n_batch(); ++i) {
1688
1775
  // map an output id to a position in the batch
1689
- int32_t pos = output_ids[i];
1776
+ int64_t pos = output_ids[i];
1690
1777
  if (pos >= 0) {
1691
1778
  GGML_ASSERT(pos < n_outputs);
1692
1779
  w_output_pos[pos] = i;
@@ -1726,11 +1813,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1726
1813
  }
1727
1814
  }
1728
1815
 
1729
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1730
-
1731
- if (kv_self != nullptr) {
1816
+ if (memory != nullptr) {
1732
1817
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1733
- kv_self->state_write(io);
1818
+ memory->state_write(io);
1734
1819
  }
1735
1820
 
1736
1821
  return io.n_bytes();
@@ -1817,9 +1902,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1817
1902
  if (memory) {
1818
1903
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1819
1904
 
1820
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1821
-
1822
- kv_self->state_read(io);
1905
+ memory->state_read(io);
1823
1906
  }
1824
1907
 
1825
1908
  return io.n_bytes();
@@ -1829,9 +1912,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1829
1912
  GGML_UNUSED(seq_id);
1830
1913
 
1831
1914
  if (memory) {
1832
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1833
-
1834
- kv_self->state_write(io, seq_id);
1915
+ memory->state_write(io, seq_id);
1835
1916
  }
1836
1917
 
1837
1918
  return io.n_bytes();
@@ -1841,9 +1922,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
1841
1922
  GGML_UNUSED(seq_id);
1842
1923
 
1843
1924
  if (memory) {
1844
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1845
-
1846
- kv_self->state_read(io, seq_id);
1925
+ memory->state_read(io, seq_id);
1847
1926
  }
1848
1927
 
1849
1928
  return io.n_bytes();
@@ -1948,10 +2027,7 @@ void llama_context::opt_epoch_iter(
1948
2027
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1949
2028
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1950
2029
 
1951
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1952
-
1953
- kv_self->clear();
1954
- llama_kv_cache_guard kv_guard(kv_self);
2030
+ memory->clear(true);
1955
2031
 
1956
2032
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1957
2033
  batch.n_tokens = n_batch;
@@ -1967,35 +2043,35 @@ void llama_context::opt_epoch_iter(
1967
2043
 
1968
2044
  n_queued_tokens += n_tokens_all;
1969
2045
 
1970
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1971
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1972
-
1973
2046
  embd_seq.clear();
1974
2047
 
1975
- int64_t n_outputs_all = n_tokens_all;
2048
+ uint32_t n_outputs_all = n_tokens_all;
1976
2049
 
1977
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
2050
+ auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
2051
+ if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2052
+ LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2053
+ break;
2054
+ }
1978
2055
 
1979
2056
  // reserve output buffer
1980
2057
  if (output_reserve(n_outputs_all) < n_outputs_all) {
1981
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2058
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1982
2059
  GGML_ABORT("TODO: handle this error");
1983
2060
  };
1984
2061
 
1985
- for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1986
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
2062
+ uint32_t pos_batch = 0;
2063
+ do {
2064
+ const auto & ubatch = mstate->get_ubatch();
1987
2065
 
1988
2066
  n_outputs = ubatch.n_tokens;
1989
2067
 
1990
- // TODO: not sure if this is needed
1991
- if (!kv_self->find_slot(ubatch)) {
1992
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1993
-
1994
- GGML_ABORT("TODO: handle this error");
2068
+ if (!mstate->apply()) {
2069
+ LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2070
+ break;
1995
2071
  }
1996
2072
 
1997
2073
  auto * gf = graph_init();
1998
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2074
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
1999
2075
 
2000
2076
  struct ggml_context * ctx_compute_opt;
2001
2077
  {
@@ -2010,6 +2086,7 @@ void llama_context::opt_epoch_iter(
2010
2086
  }
2011
2087
  ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2012
2088
  ggml_opt_alloc(opt_ctx, train);
2089
+
2013
2090
  res->set_inputs(&ubatch);
2014
2091
  {
2015
2092
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@@ -2027,10 +2104,10 @@ void llama_context::opt_epoch_iter(
2027
2104
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2028
2105
  }
2029
2106
  ggml_free(ctx_compute_opt);
2030
- }
2031
- }
2032
2107
 
2033
- kv_guard.commit();
2108
+ pos_batch += ubatch.n_tokens;
2109
+ } while (mstate->next());
2110
+ }
2034
2111
  }
2035
2112
 
2036
2113
  void llama_context::opt_epoch(
@@ -2190,12 +2267,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2190
2267
  return &ctx->get_model();
2191
2268
  }
2192
2269
 
2270
+ // deprecated
2193
2271
  llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2194
- return ctx->get_kv_self();
2272
+ return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2195
2273
  }
2196
2274
 
2275
+ // deprecated
2197
2276
  void llama_kv_self_update(llama_context * ctx) {
2198
- ctx->kv_self_update();
2277
+ ctx->kv_self_update(false);
2199
2278
  }
2200
2279
 
2201
2280
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2310,13 +2389,118 @@ int32_t llama_apply_adapter_cvec(
2310
2389
  return res ? 0 : -1;
2311
2390
  }
2312
2391
 
2392
+ //
2393
+ // memory
2394
+ //
2395
+
2396
+ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2397
+ return ctx->get_memory();
2398
+ }
2399
+
2400
+ void llama_memory_clear(llama_memory_t mem, bool data) {
2401
+ if (!mem) {
2402
+ return;
2403
+ }
2404
+
2405
+ mem->clear(data);
2406
+ }
2407
+
2408
+ bool llama_memory_seq_rm(
2409
+ llama_memory_t mem,
2410
+ llama_seq_id seq_id,
2411
+ llama_pos p0,
2412
+ llama_pos p1) {
2413
+ if (!mem) {
2414
+ return true;
2415
+ }
2416
+
2417
+ return mem->seq_rm(seq_id, p0, p1);
2418
+ }
2419
+
2420
+ void llama_memory_seq_cp(
2421
+ llama_memory_t mem,
2422
+ llama_seq_id seq_id_src,
2423
+ llama_seq_id seq_id_dst,
2424
+ llama_pos p0,
2425
+ llama_pos p1) {
2426
+ if (!mem) {
2427
+ return;
2428
+ }
2429
+
2430
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2431
+ }
2432
+
2433
+ void llama_memory_seq_keep(
2434
+ llama_memory_t mem,
2435
+ llama_seq_id seq_id) {
2436
+ if (!mem) {
2437
+ return;
2438
+ }
2439
+
2440
+ mem->seq_keep(seq_id);
2441
+ }
2442
+
2443
+ void llama_memory_seq_add(
2444
+ llama_memory_t mem,
2445
+ llama_seq_id seq_id,
2446
+ llama_pos p0,
2447
+ llama_pos p1,
2448
+ llama_pos delta) {
2449
+ if (!mem) {
2450
+ return;
2451
+ }
2452
+
2453
+ mem->seq_add(seq_id, p0, p1, delta);
2454
+ }
2455
+
2456
+ void llama_memory_seq_div(
2457
+ llama_memory_t mem,
2458
+ llama_seq_id seq_id,
2459
+ llama_pos p0,
2460
+ llama_pos p1,
2461
+ int d) {
2462
+ if (!mem) {
2463
+ return;
2464
+ }
2465
+
2466
+ mem->seq_div(seq_id, p0, p1, d);
2467
+ }
2468
+
2469
+ llama_pos llama_memory_seq_pos_min(
2470
+ llama_memory_t mem,
2471
+ llama_seq_id seq_id) {
2472
+ if (!mem) {
2473
+ return -1;
2474
+ }
2475
+
2476
+ return mem->seq_pos_min(seq_id);
2477
+ }
2478
+
2479
+ llama_pos llama_memory_seq_pos_max(
2480
+ llama_memory_t mem,
2481
+ llama_seq_id seq_id) {
2482
+ if (!mem) {
2483
+ return -1;
2484
+ }
2485
+
2486
+ return mem->seq_pos_max(seq_id);
2487
+ }
2488
+
2489
+ bool llama_memory_can_shift(llama_memory_t mem) {
2490
+ if (!mem) {
2491
+ return false;
2492
+ }
2493
+
2494
+ return mem->get_can_shift();
2495
+ }
2496
+
2313
2497
  //
2314
2498
  // kv cache
2315
2499
  //
2316
2500
 
2317
2501
  // deprecated
2318
2502
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2319
- const auto * kv = ctx->get_kv_self();
2503
+ const auto * kv = llama_get_memory(ctx);
2320
2504
  if (!kv) {
2321
2505
  return 0;
2322
2506
  }
@@ -2338,7 +2522,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2338
2522
  // deprecated
2339
2523
  // note: this is the same as above - will be removed anyway, so it's ok
2340
2524
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2341
- const auto * kv = ctx->get_kv_self();
2525
+ const auto * kv = llama_get_memory(ctx);
2342
2526
  if (!kv) {
2343
2527
  return 0;
2344
2528
  }
@@ -2357,114 +2541,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2357
2541
  return res;
2358
2542
  }
2359
2543
 
2544
+ // deprecated
2360
2545
  void llama_kv_self_clear(llama_context * ctx) {
2361
- auto * kv = ctx->get_kv_self();
2546
+ auto * kv = llama_get_memory(ctx);
2362
2547
  if (!kv) {
2363
2548
  return;
2364
2549
  }
2365
2550
 
2366
- kv->clear();
2551
+ llama_memory_clear(kv, true);
2367
2552
  }
2368
2553
 
2554
+ // deprecated
2369
2555
  bool llama_kv_self_seq_rm(
2370
2556
  llama_context * ctx,
2371
2557
  llama_seq_id seq_id,
2372
2558
  llama_pos p0,
2373
2559
  llama_pos p1) {
2374
- auto * kv = ctx->get_kv_self();
2560
+ auto * kv = llama_get_memory(ctx);
2375
2561
  if (!kv) {
2376
2562
  return true;
2377
2563
  }
2378
2564
 
2379
- return kv->seq_rm(seq_id, p0, p1);
2565
+ return llama_memory_seq_rm(kv, seq_id, p0, p1);
2380
2566
  }
2381
2567
 
2568
+ // deprecated
2382
2569
  void llama_kv_self_seq_cp(
2383
2570
  llama_context * ctx,
2384
2571
  llama_seq_id seq_id_src,
2385
2572
  llama_seq_id seq_id_dst,
2386
2573
  llama_pos p0,
2387
2574
  llama_pos p1) {
2388
- auto * kv = ctx->get_kv_self();
2575
+ auto * kv = llama_get_memory(ctx);
2389
2576
  if (!kv) {
2390
2577
  return;
2391
2578
  }
2392
2579
 
2393
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2580
+ llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2394
2581
  }
2395
2582
 
2583
+ // deprecated
2396
2584
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2397
- auto * kv = ctx->get_kv_self();
2585
+ auto * kv = llama_get_memory(ctx);
2398
2586
  if (!kv) {
2399
2587
  return;
2400
2588
  }
2401
2589
 
2402
- kv->seq_keep(seq_id);
2590
+ llama_memory_seq_keep(kv, seq_id);
2403
2591
  }
2404
2592
 
2593
+ // deprecated
2405
2594
  void llama_kv_self_seq_add(
2406
2595
  llama_context * ctx,
2407
2596
  llama_seq_id seq_id,
2408
2597
  llama_pos p0,
2409
2598
  llama_pos p1,
2410
2599
  llama_pos delta) {
2411
- auto * kv = ctx->get_kv_self();
2600
+ auto * kv = llama_get_memory(ctx);
2412
2601
  if (!kv) {
2413
2602
  return;
2414
2603
  }
2415
2604
 
2416
- kv->seq_add(seq_id, p0, p1, delta);
2605
+ llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2417
2606
  }
2418
2607
 
2608
+ // deprecated
2419
2609
  void llama_kv_self_seq_div(
2420
2610
  llama_context * ctx,
2421
2611
  llama_seq_id seq_id,
2422
2612
  llama_pos p0,
2423
2613
  llama_pos p1,
2424
2614
  int d) {
2425
- auto * kv = ctx->get_kv_self();
2615
+ auto * kv = llama_get_memory(ctx);
2426
2616
  if (!kv) {
2427
2617
  return;
2428
2618
  }
2429
2619
 
2430
- kv->seq_div(seq_id, p0, p1, d);
2620
+ llama_memory_seq_div(kv, seq_id, p0, p1, d);
2431
2621
  }
2432
2622
 
2623
+ // deprecated
2433
2624
  llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2434
- const auto * kv = ctx->get_kv_self();
2625
+ auto * kv = llama_get_memory(ctx);
2435
2626
  if (!kv) {
2436
2627
  return -1;
2437
2628
  }
2438
2629
 
2439
- return kv->seq_pos_min(seq_id);
2630
+ return llama_memory_seq_pos_min(kv, seq_id);
2440
2631
  }
2441
2632
 
2633
+ // deprecated
2442
2634
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2443
- const auto * kv = ctx->get_kv_self();
2635
+ auto * kv = llama_get_memory(ctx);
2444
2636
  if (!kv) {
2445
2637
  return -1;
2446
2638
  }
2447
2639
 
2448
- return kv->seq_pos_max(seq_id);
2640
+ return llama_memory_seq_pos_max(kv, seq_id);
2449
2641
  }
2450
2642
 
2643
+ // deprecated
2451
2644
  void llama_kv_self_defrag(llama_context * ctx) {
2452
- auto * kv = ctx->get_kv_self();
2453
- if (!kv) {
2454
- return;
2455
- }
2456
-
2457
2645
  // force defrag
2458
- kv->defrag_sched(-1.0f);
2646
+ ctx->kv_self_defrag_sched();
2459
2647
  }
2460
2648
 
2649
+ // deprecated
2461
2650
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2462
- const auto * kv = ctx->get_kv_self();
2651
+ auto * kv = llama_get_memory(ctx);
2463
2652
  if (!kv) {
2464
2653
  return false;
2465
2654
  }
2466
2655
 
2467
- return kv->get_can_shift();
2656
+ return llama_memory_can_shift(kv);
2468
2657
  }
2469
2658
 
2470
2659
  // llama state API
@@ -2589,22 +2778,8 @@ int32_t llama_encode(
2589
2778
  int32_t llama_decode(
2590
2779
  llama_context * ctx,
2591
2780
  llama_batch batch) {
2592
- int ret = ctx->decode(batch);
2593
-
2594
- // defrag and try again
2595
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596
- if (ret == 1) {
2597
- llama_kv_self_defrag(ctx);
2598
- ret = ctx->decode(batch);
2599
-
2600
- if (ret == 1) {
2601
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602
-
2603
- return ret;
2604
- }
2605
- }
2606
-
2607
- if (ret != 0) {
2781
+ const int ret = ctx->decode(batch);
2782
+ if (ret != 0 && ret != 1) {
2608
2783
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2609
2784
  }
2610
2785