@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -1,6 +1,7 @@
1
1
  #include "llama-kv-cache-unified.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
4
5
  #include "llama-model.h"
5
6
  #include "llama-context.h"
6
7
 
@@ -32,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
32
33
 
33
34
  GGML_ASSERT(kv_size % n_pad == 0);
34
35
 
36
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
37
+ auto n_layer_cache = hparams.n_layer;
38
+ if (model.arch == LLM_ARCH_GEMMA3N) {
39
+ n_layer_cache = 20;
40
+ }
41
+
35
42
  // create a context for each buffer type
36
43
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
37
44
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
38
45
  auto it = ctx_map.find(buft);
39
46
  if (it == ctx_map.end()) {
40
47
  ggml_init_params params = {
41
- /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
48
+ /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
42
49
  /*.mem_buffer =*/ NULL,
43
50
  /*.no_alloc =*/ true,
44
51
  };
@@ -61,14 +68,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
61
68
 
62
69
  cells.resize(kv_size);
63
70
 
64
- for (uint32_t il = 0; il < hparams.n_layer; il++) {
71
+ for (uint32_t il = 0; il < n_layer_cache; il++) {
65
72
  if (filter && !filter(il)) {
66
73
  LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
67
74
  continue;
68
75
  }
69
76
 
70
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
71
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
77
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
78
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
72
79
 
73
80
  const char * dev_name = "CPU";
74
81
 
@@ -101,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
101
108
  layers.push_back({ il, k, v });
102
109
  }
103
110
 
111
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
112
+ if (model.arch == LLM_ARCH_GEMMA3N) {
113
+ LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
114
+
115
+ for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
116
+ if (filter && !filter(il)) {
117
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
118
+ continue;
119
+ }
120
+
121
+ const bool is_swa = hparams.is_swa(il);
122
+ const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
123
+
124
+ GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
125
+ map_layer_ids[il] = map_layer_ids[il_reuse];
126
+
127
+ LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
128
+ }
129
+ }
130
+
104
131
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
105
132
  for (auto it : ctx_map) {
106
133
  auto * buft = it.first;
@@ -126,15 +153,20 @@ llama_kv_cache_unified::llama_kv_cache_unified(
126
153
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
127
154
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
128
155
  }
156
+
157
+ const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
158
+ debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
129
159
  }
130
160
 
131
- void llama_kv_cache_unified::clear() {
161
+ void llama_kv_cache_unified::clear(bool data) {
132
162
  cells.reset();
133
163
 
134
164
  head = 0;
135
165
 
136
- for (auto & buf : bufs) {
137
- ggml_backend_buffer_clear(buf.get(), 0);
166
+ if (data) {
167
+ for (auto & buf : bufs) {
168
+ ggml_backend_buffer_clear(buf.get(), 0);
169
+ }
138
170
  }
139
171
  }
140
172
 
@@ -149,12 +181,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
149
181
  p1 = std::numeric_limits<llama_pos>::max();
150
182
  }
151
183
 
152
- for (uint32_t i = 0; i < cells.size(); ++i) {
153
- if (!cells.pos_in(i, p0, p1)) {
154
- continue;
184
+ if (seq_id >= 0) {
185
+ for (uint32_t i = 0; i < cells.size(); ++i) {
186
+ if (!cells.pos_in(i, p0, p1)) {
187
+ continue;
188
+ }
189
+
190
+ if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
191
+ if (new_head == cells.size()) {
192
+ new_head = i;
193
+ }
194
+ }
155
195
  }
196
+ } else {
197
+ // match any sequence
198
+ for (uint32_t i = 0; i < cells.size(); ++i) {
199
+ if (!cells.pos_in(i, p0, p1)) {
200
+ continue;
201
+ }
202
+
203
+ cells.rm(i);
156
204
 
157
- if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
158
205
  if (new_head == cells.size()) {
159
206
  new_head = i;
160
207
  }
@@ -286,35 +333,77 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
286
333
  return cells.seq_pos_max(seq_id);
287
334
  }
288
335
 
289
- llama_memory_state_ptr llama_kv_cache_unified::init_batch(
290
- const llama_batch & batch,
336
+ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
337
+ llama_batch_allocr & balloc,
291
338
  uint32_t n_ubatch,
292
- bool embd_pooled,
293
- bool logits_all) {
294
- GGML_UNUSED(embd_pooled);
339
+ bool embd_all) {
340
+ GGML_UNUSED(embd_all);
295
341
 
296
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
342
+ do {
343
+ balloc.split_reset();
297
344
 
298
- std::vector<llama_ubatch> ubatches;
299
- while (sbatch.n_tokens > 0) {
300
- ubatches.push_back(sbatch.split_simple(n_ubatch));
301
- }
345
+ std::vector<llama_ubatch> ubatches;
346
+ while (true) {
347
+ auto ubatch = balloc.split_simple(n_ubatch);
302
348
 
303
- auto heads = prepare(ubatches);
304
- if (heads.empty()) {
305
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
306
- }
349
+ if (ubatch.n_tokens == 0) {
350
+ break;
351
+ }
352
+
353
+ ubatches.push_back(std::move(ubatch)); // NOLINT
354
+ }
355
+
356
+ auto heads = prepare(ubatches);
357
+ if (heads.empty()) {
358
+ break;
359
+ }
360
+
361
+ return std::make_unique<llama_kv_cache_unified_context>(
362
+ this, std::move(heads), std::move(ubatches));
363
+ } while (false);
307
364
 
308
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
309
- this, std::move(sbatch), std::move(heads), std::move(ubatches));
365
+ return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
310
366
  }
311
367
 
312
- llama_memory_state_ptr llama_kv_cache_unified::init_full() {
313
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
368
+ llama_memory_context_ptr llama_kv_cache_unified::init_full() {
369
+ return std::make_unique<llama_kv_cache_unified_context>(this);
314
370
  }
315
371
 
316
- std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
317
- std::vector<uint32_t> res;
372
+ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
373
+ bool do_shift = get_has_shift();
374
+
375
+ defrag_info dinfo;
376
+
377
+ // see if we need to defrag
378
+ {
379
+ bool do_defrag = optimize;
380
+
381
+ const auto thold = lctx->get_cparams().defrag_thold;
382
+
383
+ if (!do_defrag && thold > 0.0f) {
384
+ const auto n_kv = cells.used_max_p1();
385
+
386
+ // - do not defrag small contexts (i.e. < 2048 tokens)
387
+ // - count the padding towards the number of used tokens
388
+ const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
389
+
390
+ if (fragmentation > thold) {
391
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
392
+
393
+ do_defrag = true;
394
+ }
395
+ }
396
+
397
+ if (do_defrag) {
398
+ dinfo = defrag_prepare(lctx->graph_max_nodes());
399
+ }
400
+ }
401
+
402
+ return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
403
+ }
404
+
405
+ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
406
+ llama_kv_cache_unified::ubatch_heads res;
318
407
 
319
408
  struct state {
320
409
  uint32_t head_old; // old position of the head, before placing the ubatch
@@ -359,12 +448,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
359
448
  return res;
360
449
  }
361
450
 
362
- bool llama_kv_cache_unified::update(llama_context & lctx) {
451
+ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
363
452
  bool updated = false;
364
453
 
365
- auto * sched = lctx.get_sched();
454
+ auto * sched = lctx->get_sched();
366
455
 
367
- if (cells.get_has_shift()) {
456
+ if (do_shift) {
368
457
  if (!get_can_shift()) {
369
458
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
370
459
  }
@@ -375,9 +464,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
375
464
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
376
465
  ggml_backend_sched_reset(sched);
377
466
 
378
- auto * gf = lctx.graph_init();
467
+ auto * gf = lctx->graph_init();
379
468
 
380
- auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
469
+ auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
381
470
  if (!res) {
382
471
  LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
383
472
  return updated;
@@ -390,7 +479,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
390
479
 
391
480
  res->set_inputs(nullptr);
392
481
 
393
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
482
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
394
483
  LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
395
484
  return updated;
396
485
  }
@@ -401,54 +490,53 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
401
490
  cells.reset_shift();
402
491
  }
403
492
 
404
- if (do_defrag) {
493
+ if (!dinfo.empty()) {
405
494
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
406
495
 
407
- if (defrag_prepare(lctx.graph_max_nodes())) {
408
- ggml_backend_sched_reset(sched);
496
+ // apply moves:
497
+ {
498
+ const auto n_kv = dinfo.ids.size();
409
499
 
410
- auto * gf = lctx.graph_init();
500
+ for (uint32_t i = 0; i < n_kv; ++i) {
501
+ assert(dinfo.ids[i] <= n_kv);
411
502
 
412
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
413
- if (!res) {
414
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
415
- return updated;
416
- }
417
-
418
- if (!ggml_backend_sched_alloc_graph(sched, gf)) {
419
- LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
420
- return updated;
421
- }
422
-
423
- res->set_inputs(nullptr);
503
+ if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
504
+ continue;
505
+ }
424
506
 
425
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
426
- LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
427
- return updated;
507
+ cells.mv(i, dinfo.ids[i]);
428
508
  }
429
509
 
430
- updated = true;
510
+ // reset the head so we can find the first free slot during the next ubatch
511
+ head = 0;
431
512
  }
432
513
 
433
- do_defrag = false;
434
- }
514
+ ggml_backend_sched_reset(sched);
435
515
 
436
- return updated;
437
- }
516
+ auto * gf = lctx->graph_init();
438
517
 
439
- void llama_kv_cache_unified::defrag_sched(float thold) {
440
- const auto n_kv = cells.used_max_p1();
518
+ auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
519
+ if (!res) {
520
+ LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
521
+ return updated;
522
+ }
441
523
 
442
- // - do not defrag small contexts (i.e. < 2048 tokens)
443
- // - count the padding towards the number of used tokens
444
- const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
524
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
525
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
526
+ return updated;
527
+ }
528
+
529
+ res->set_inputs(nullptr);
445
530
 
446
- // queue defragmentation for next llama_kv_cache_update
447
- if (fragmentation > thold) {
448
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
531
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
532
+ LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
533
+ return updated;
534
+ }
449
535
 
450
- do_defrag = true;
536
+ updated = true;
451
537
  }
538
+
539
+ return updated;
452
540
  }
453
541
 
454
542
  int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
@@ -462,43 +550,68 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
462
550
  head_cur = 0;
463
551
  }
464
552
 
465
- // otherwise, one cell per token.
466
-
467
553
  if (n_tokens > cells.size()) {
468
554
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
469
555
  return -1;
470
556
  }
471
557
 
472
- //#define FIND_SLOT_DEBUG 1
473
- #if FIND_SLOT_DEBUG
474
- LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
558
+ if (debug > 0) {
559
+ LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
475
560
 
476
- // for debugging
477
- {
478
- std::string ss;
479
- if (n_swa > 0) {
561
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
562
+ std::string ss;
480
563
  for (uint32_t i = 0; i < cells.size(); ++i) {
481
564
  if (cells.is_empty(i)) {
482
565
  ss += '.';
483
566
  } else {
484
- ss += std::to_string(cells.seq_get(i));
567
+ assert(cells.seq_count(i) >= 1);
568
+
569
+ if (cells.seq_count(i) == 1) {
570
+ ss += std::to_string(cells.seq_get(i));
571
+ } else {
572
+ ss += 'M';
573
+ }
485
574
  }
486
575
  if (i%256 == 255) {
576
+ ss += " *";
487
577
  ss += '\n';
488
578
  }
489
579
  }
580
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
490
581
  }
491
- LLAMA_LOG_WARN("\n%s\n", ss.c_str());
492
- }
493
582
 
494
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
495
- if (cells.seq_pos_min(s) < 0) {
496
- continue;
583
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
584
+ std::string ss;
585
+ for (uint32_t i = 0; i < cells.size(); ++i) {
586
+ std::string cur;
587
+ if (cells.is_empty(i)) {
588
+ cur = '.';
589
+ } else {
590
+ cur = std::to_string(cells.pos_get(i));
591
+ }
592
+ const int n = cur.size();
593
+ for (int j = 0; j < 5 - n; ++j) {
594
+ cur += ' ';
595
+ }
596
+ ss += cur;
597
+ if (i%256 == 255) {
598
+ ss += " *";
599
+ }
600
+ if (i%64 == 63) {
601
+ ss += '\n';
602
+ }
603
+ }
604
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
497
605
  }
498
606
 
499
- LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
607
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
608
+ if (cells.seq_pos_min(s) < 0) {
609
+ continue;
610
+ }
611
+
612
+ LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
613
+ }
500
614
  }
501
- #endif
502
615
 
503
616
  uint32_t n_tested = 0;
504
617
 
@@ -509,21 +622,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
509
622
  continue;
510
623
  }
511
624
 
512
- // keep track of what the minimum sequence positions would be if we accept the ubatch
513
- llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
514
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
515
- seq_pos_min[s] = cells.seq_pos_min(s);
516
- }
517
-
518
625
  bool found = true;
519
626
  for (uint32_t i = 0; i < n_tokens; i++) {
520
- const llama_pos pos = ubatch.pos[i];
521
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
627
+ //const llama_pos pos = ubatch.pos[i];
628
+ //const llama_seq_id seq_id = ubatch.seq_id[i][0];
522
629
 
523
630
  // can we use this cell? either:
524
631
  // - the cell is empty
525
632
  // - the cell is occupied only by one sequence:
526
- // - mask causally, if the sequence is the same as the one we are inserting
633
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
527
634
  // - mask SWA, using current max pos for that sequence in the cache
528
635
  // always insert in the cell with minimum pos
529
636
  bool can_use = cells.is_empty(head_cur + i);
@@ -531,21 +638,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
531
638
  if (!can_use && cells.seq_count(head_cur + i) == 1) {
532
639
  const llama_pos pos_cell = cells.pos_get(head_cur + i);
533
640
 
534
- // causal mask
535
- if (cells.seq_has(head_cur + i, seq_id)) {
536
- can_use = pos_cell >= pos;
537
- }
641
+ // (disabled) causal mask
642
+ // note: it's better to purge any "future" tokens beforehand
643
+ //if (cells.seq_has(head_cur + i, seq_id)) {
644
+ // can_use = pos_cell >= pos;
645
+ //}
538
646
 
539
647
  if (!can_use) {
540
648
  const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
541
649
 
542
650
  // SWA mask
543
- // note: we insert only in the cell with minimum pos in order to preserve the invariant that
544
- // all positions between [pos_min, pos_max] for each sequence will be present in the cache
545
- // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
546
- if (pos_cell == seq_pos_min[seq_id_cell] &&
547
- is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
548
- seq_pos_min[seq_id_cell]++;
651
+ if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
549
652
  can_use = true;
550
653
  }
551
654
  }
@@ -573,15 +676,45 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
573
676
  }
574
677
 
575
678
  void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
679
+ // keep track of the max sequence position that we would overwrite with this ubatch
680
+ // for non-SWA cache, this would be always empty
681
+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
682
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
683
+ seq_pos_max_rm[s] = -1;
684
+ }
685
+
576
686
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
577
687
  if (!cells.is_empty(head_cur + i)) {
688
+ assert(cells.seq_count(head_cur + i) == 1);
689
+
690
+ const llama_seq_id seq_id = cells.seq_get(head_cur + i);
691
+ const llama_pos pos = cells.pos_get(head_cur + i);
692
+
693
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
694
+
578
695
  cells.rm(head_cur + i);
579
696
  }
580
697
 
581
698
  cells.pos_set(head_cur + i, ubatch.pos[i]);
582
699
 
583
- for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
584
- cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
700
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
701
+ cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
702
+ }
703
+ }
704
+
705
+ // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
706
+ // will be present in the cache. so we have to purge any position which is less than those we would overwrite
707
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
708
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
709
+ if (seq_pos_max_rm[s] == -1) {
710
+ continue;
711
+ }
712
+
713
+ if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
714
+ LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
715
+ __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
716
+
717
+ seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
585
718
  }
586
719
  }
587
720
 
@@ -597,6 +730,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
597
730
  return cells.size();
598
731
  }
599
732
 
733
+ bool llama_kv_cache_unified::get_has_shift() const {
734
+ return cells.get_has_shift();
735
+ }
736
+
600
737
  uint32_t llama_kv_cache_unified::get_n_kv() const {
601
738
  return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
602
739
  }
@@ -677,14 +814,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
677
814
  }
678
815
 
679
816
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
680
- const int64_t n_tokens = ubatch->n_tokens;
681
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
682
- const int64_t n_seqs = ubatch->n_seqs;
817
+ const uint32_t n_tokens = ubatch->n_tokens;
683
818
 
684
819
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
685
820
  float * data = (float *) dst->data;
686
821
 
687
- const auto n_kv = dst->ne[0];
822
+ const int64_t n_kv = dst->ne[0];
688
823
 
689
824
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
690
825
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -698,49 +833,47 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
698
833
  // xxxxx-----
699
834
  // xxxxx-----
700
835
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
701
- for (int h = 0; h < 1; ++h) {
702
- for (int s = 0; s < n_seqs; ++s) {
703
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
704
-
705
- for (int j = 0; j < n_seq_tokens; ++j) {
706
- const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
836
+ for (uint32_t h = 0; h < 1; ++h) {
837
+ for (uint32_t i = 0; i < n_tokens; ++i) {
838
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
707
839
 
708
- for (uint32_t i = 0; i < n_kv; ++i) {
709
- float f = 0.0f;
840
+ const llama_pos p1 = ubatch->pos[i];
710
841
 
711
- bool masked = false;
842
+ for (uint32_t j = 0; j < n_kv; ++j) {
843
+ float f = 0.0f;
712
844
 
713
- if (cells.is_empty(i)) {
714
- masked = true;
715
- } else {
716
- const llama_pos p0 = cells.pos_get(i);
845
+ bool masked = false;
717
846
 
718
- // mask the token if not the same sequence
719
- masked = masked || (!cells.seq_has(i, seq_id));
847
+ if (cells.is_empty(j)) {
848
+ masked = true;
849
+ } else {
850
+ const llama_pos p0 = cells.pos_get(j);
720
851
 
721
- // mask future tokens
722
- masked = masked || (causal_attn && p0 > p1);
852
+ // mask the token if not the same sequence
853
+ masked = masked || (!cells.seq_has(j, seq_id));
723
854
 
724
- // apply SWA if any
725
- masked = masked || (is_masked_swa(p0, p1));
855
+ // mask future tokens
856
+ masked = masked || (causal_attn && p0 > p1);
726
857
 
727
- if (!masked && hparams.use_alibi) {
728
- f = -std::abs(p0 - p1);
729
- }
730
- }
858
+ // apply SWA if any
859
+ masked = masked || (is_masked_swa(p0, p1));
731
860
 
732
- if (masked) {
733
- f = -INFINITY;
861
+ if (!masked && hparams.use_alibi) {
862
+ f = -std::abs(p0 - p1);
734
863
  }
864
+ }
735
865
 
736
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
866
+ if (masked) {
867
+ f = -INFINITY;
737
868
  }
869
+
870
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
738
871
  }
739
872
  }
740
873
 
741
874
  // mask padded tokens
742
875
  if (data) {
743
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
876
+ for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
744
877
  for (uint32_t j = 0; j < n_kv; ++j) {
745
878
  data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
746
879
  }
@@ -770,12 +903,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
770
903
  const int32_t n_kv = dst->ne[0];
771
904
 
772
905
  for (int h = 0; h < 1; ++h) {
773
- for (int j = 0; j < n_tokens; ++j) {
774
- for (int i = 0; i < n_kv; ++i) {
906
+ for (int i = 0; i < n_tokens; ++i) {
907
+ for (int j = 0; j < n_kv; ++j) {
775
908
  // the position when the cells is empty is irrelevant - it will be masked out later in the attention
776
- const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
909
+ const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
777
910
 
778
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
911
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
779
912
  }
780
913
  }
781
914
  }
@@ -890,11 +1023,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
890
1023
  const auto & n_embd_head_k = hparams.n_embd_head_k;
891
1024
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
892
1025
 
893
- //GGML_ASSERT(kv_self->size == n_ctx);
894
-
895
1026
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
896
1027
 
897
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
1028
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
898
1029
  ggml_set_input(inp->k_shift);
899
1030
 
900
1031
  for (const auto & layer : layers) {
@@ -926,12 +1057,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
926
1057
  }
927
1058
 
928
1059
  llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
929
- const llama_cparams & cparams,
930
- ggml_context * ctx,
931
- ggml_cgraph * gf) const {
1060
+ const llama_cparams & cparams,
1061
+ ggml_context * ctx,
1062
+ ggml_cgraph * gf,
1063
+ const defrag_info & dinfo) const {
932
1064
  auto res = std::make_unique<llm_graph_result>();
933
1065
 
934
- const auto & ids = defrag_info.ids;
1066
+ const auto & ids = dinfo.ids;
935
1067
 
936
1068
  #if 0
937
1069
  // CPU defrag
@@ -1072,7 +1204,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1072
1204
  return res;
1073
1205
  }
1074
1206
 
1075
- bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1207
+ llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
1076
1208
  const uint32_t n_layer = layers.size();
1077
1209
 
1078
1210
  const uint32_t n_kv = cells.used_max_p1();
@@ -1093,14 +1225,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1093
1225
  const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1094
1226
 
1095
1227
  // determine which KV cells to move where
1096
- //
1097
- // cell i moves to ids[i]
1098
- //
1099
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
1100
- //
1101
- auto & ids = defrag_info.ids;
1228
+ defrag_info res;
1229
+ auto & ids = res.ids;
1102
1230
 
1103
- ids.clear();
1104
1231
  ids.resize(n_kv, n_kv);
1105
1232
 
1106
1233
  for (uint32_t i0 = 0; i0 < n_used; ++i0) {
@@ -1164,11 +1291,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1164
1291
  // this cell goes to (i0 + nf)
1165
1292
  ids[i1] = i0 + nf;
1166
1293
 
1167
- // move the cell meta data
1168
- cells.mv(i1, i0 + nf);
1169
-
1170
- head = n_used;
1171
-
1172
1294
  if (!cont) {
1173
1295
  n_moves++;
1174
1296
  cont = true;
@@ -1191,14 +1313,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1191
1313
  }
1192
1314
 
1193
1315
  if (n_moves == 0) {
1194
- return false;
1316
+ return {};
1195
1317
  }
1196
1318
 
1197
1319
  LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
1198
1320
 
1199
1321
  LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
1200
1322
 
1201
- return true;
1323
+ return res;
1202
1324
  }
1203
1325
 
1204
1326
  bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
@@ -1276,7 +1398,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
1276
1398
 
1277
1399
  if (!res) {
1278
1400
  if (seq_id == -1) {
1279
- clear();
1401
+ clear(true);
1280
1402
  } else {
1281
1403
  seq_rm(seq_id, -1, -1);
1282
1404
  }
@@ -1324,7 +1446,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1324
1446
  for (const auto & layer : layers) {
1325
1447
  const uint32_t il = layer.il;
1326
1448
 
1327
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1449
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1328
1450
 
1329
1451
  // Write key type
1330
1452
  const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1346,7 +1468,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1346
1468
  for (const auto & layer : layers) {
1347
1469
  const uint32_t il = layer.il;
1348
1470
 
1349
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1471
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1350
1472
 
1351
1473
  // Write value type
1352
1474
  const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1370,7 +1492,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1370
1492
  for (const auto & layer : layers) {
1371
1493
  const uint32_t il = layer.il;
1372
1494
 
1373
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1495
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1374
1496
 
1375
1497
  // Write value type
1376
1498
  const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1403,10 +1525,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1403
1525
 
1404
1526
  seq_rm(dest_seq_id, -1, -1);
1405
1527
 
1406
- llama_sbatch sbatch;
1407
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1528
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
1408
1529
 
1409
- batch.n_tokens = cell_count;
1530
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1410
1531
 
1411
1532
  for (uint32_t i = 0; i < cell_count; ++i) {
1412
1533
  llama_pos pos;
@@ -1426,18 +1547,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1426
1547
  io.read_to(&seq_id, sizeof(seq_id));
1427
1548
  }
1428
1549
 
1429
- batch.pos[i] = pos;
1430
- batch.n_seq_id[i] = n_seq_id;
1431
- batch.seq_id[i] = &dest_seq_id;
1550
+ ubatch.pos[i] = pos;
1551
+ ubatch.n_seq_id[i] = n_seq_id;
1552
+ ubatch.seq_id[i] = &dest_seq_id;
1432
1553
  }
1433
1554
 
1434
- const auto head_cur = find_slot(batch);
1555
+ const auto head_cur = find_slot(ubatch);
1435
1556
  if (head_cur < 0) {
1436
1557
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1437
1558
  return false;
1438
1559
  }
1439
1560
 
1440
- apply_ubatch(head_cur, batch);
1561
+ apply_ubatch(head_cur, ubatch);
1441
1562
 
1442
1563
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1443
1564
  head = head_cur;
@@ -1445,8 +1566,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1445
1566
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1446
1567
  // Assume that this is one contiguous block of cells
1447
1568
  GGML_ASSERT(head_cur + cell_count <= cells.size());
1448
- GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
1449
- GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
1569
+ GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
1570
+ GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
1450
1571
  GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1451
1572
  GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
1452
1573
  } else {
@@ -1457,7 +1578,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1457
1578
  return false;
1458
1579
  }
1459
1580
 
1460
- clear();
1581
+ clear(true);
1461
1582
 
1462
1583
  for (uint32_t i = 0; i < cell_count; ++i) {
1463
1584
  llama_pos pos;
@@ -1513,7 +1634,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1513
1634
  for (const auto & layer : layers) {
1514
1635
  const uint32_t il = layer.il;
1515
1636
 
1516
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1637
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1517
1638
 
1518
1639
  // Read type of key
1519
1640
  int32_t k_type_i_ref;
@@ -1543,7 +1664,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1543
1664
  for (const auto & layer : layers) {
1544
1665
  const uint32_t il = layer.il;
1545
1666
 
1546
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1667
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1547
1668
 
1548
1669
  // Read type of value
1549
1670
  int32_t v_type_i_ref;
@@ -1573,7 +1694,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1573
1694
  for (const auto & layer : layers) {
1574
1695
  const uint32_t il = layer.il;
1575
1696
 
1576
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1697
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1577
1698
 
1578
1699
  // Read type of value
1579
1700
  int32_t v_type_i_ref;
@@ -1615,34 +1736,36 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1615
1736
  }
1616
1737
 
1617
1738
  //
1618
- // llama_kv_cache_unified_state
1739
+ // llama_kv_cache_unified_context
1619
1740
  //
1620
1741
 
1621
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1742
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
1622
1743
 
1623
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1624
- llama_memory_status status,
1625
- llama_kv_cache_unified * kv) : status(status), kv(kv) {
1626
- n_kv = kv->get_size();
1627
- head = 0;
1628
- }
1744
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1745
+ llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1746
+ n_kv = kv->get_size();
1747
+ head = 0;
1748
+ }
1629
1749
 
1630
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1631
- llama_memory_status status,
1632
- llama_kv_cache_unified * kv,
1633
- llama_sbatch sbatch,
1634
- std::vector<uint32_t> heads,
1635
- std::vector<llama_ubatch> ubatches)
1636
- : status(status),
1637
- kv(kv),
1638
- sbatch(std::move(sbatch)),
1639
- heads(std::move(heads)),
1640
- ubatches(std::move(ubatches)) {
1750
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1751
+ llama_kv_cache_unified * kv,
1752
+ llama_context * lctx,
1753
+ bool do_shift,
1754
+ defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1755
+ if (!do_shift && this->dinfo.empty()) {
1756
+ status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1641
1757
  }
1758
+ }
1642
1759
 
1643
- llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1760
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1761
+ llama_kv_cache_unified * kv,
1762
+ llama_kv_cache_unified::ubatch_heads heads,
1763
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1764
+ }
1765
+
1766
+ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
1644
1767
 
1645
- bool llama_kv_cache_unified_state::next() {
1768
+ bool llama_kv_cache_unified_context::next() {
1646
1769
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1647
1770
 
1648
1771
  if (++i_next >= ubatches.size()) {
@@ -1652,9 +1775,16 @@ bool llama_kv_cache_unified_state::next() {
1652
1775
  return true;
1653
1776
  }
1654
1777
 
1655
- bool llama_kv_cache_unified_state::apply() {
1778
+ bool llama_kv_cache_unified_context::apply() {
1656
1779
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1657
1780
 
1781
+ // no ubatches -> this is a KV cache update
1782
+ if (ubatches.empty()) {
1783
+ kv->update(lctx, do_shift, dinfo);
1784
+
1785
+ return true;
1786
+ }
1787
+
1658
1788
  kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1659
1789
 
1660
1790
  n_kv = kv->get_n_kv();
@@ -1663,51 +1793,45 @@ bool llama_kv_cache_unified_state::apply() {
1663
1793
  return true;
1664
1794
  }
1665
1795
 
1666
- std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
1667
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1668
-
1669
- return sbatch.out_ids;
1670
- }
1671
-
1672
- llama_memory_status llama_kv_cache_unified_state::get_status() const {
1796
+ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1673
1797
  return status;
1674
1798
  }
1675
1799
 
1676
- const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
1800
+ const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1677
1801
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1678
1802
 
1679
1803
  return ubatches[i_next];
1680
1804
  }
1681
1805
 
1682
- uint32_t llama_kv_cache_unified_state::get_n_kv() const {
1806
+ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
1683
1807
  return n_kv;
1684
1808
  }
1685
1809
 
1686
- ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
1810
+ ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
1687
1811
  return kv->get_k(ctx, il, n_kv);
1688
1812
  }
1689
1813
 
1690
- ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
1814
+ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
1691
1815
  return kv->get_v(ctx, il, n_kv);
1692
1816
  }
1693
1817
 
1694
- ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1818
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1695
1819
  return kv->cpy_k(ctx, k_cur, il, head);
1696
1820
  }
1697
1821
 
1698
- ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1822
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1699
1823
  return kv->cpy_v(ctx, v_cur, il, head);
1700
1824
  }
1701
1825
 
1702
- void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
1826
+ void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1703
1827
  kv->set_input_k_shift(dst);
1704
1828
  }
1705
1829
 
1706
- void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1830
+ void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1707
1831
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1708
1832
  }
1709
1833
 
1710
- void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1834
+ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1711
1835
  kv->set_input_pos_bucket(dst, ubatch);
1712
1836
  }
1713
1837