@novastera-oss/llamarn 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +140 -38
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +48 -67
  13. package/cpp/LlamaCppModel.h +8 -3
  14. package/cpp/PureCppImpl.cpp +1 -1
  15. package/cpp/PureCppImpl.h +2 -2
  16. package/cpp/build-info.cpp +2 -2
  17. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  18. package/cpp/llama.cpp/Makefile +2 -2
  19. package/cpp/llama.cpp/README.md +33 -13
  20. package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
  21. package/cpp/llama.cpp/common/arg.cpp +38 -12
  22. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  23. package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
  24. package/cpp/llama.cpp/common/chat-parser.h +4 -1
  25. package/cpp/llama.cpp/common/chat.cpp +16 -13
  26. package/cpp/llama.cpp/common/chat.h +1 -1
  27. package/cpp/llama.cpp/common/common.cpp +52 -40
  28. package/cpp/llama.cpp/common/common.h +5 -2
  29. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  30. package/cpp/llama.cpp/common/json-partial.h +2 -1
  31. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  32. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  33. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  34. package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  37. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  38. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
  39. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  41. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  79. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  82. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  112. package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
  113. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  114. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  115. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  116. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  117. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  118. package/cpp/llama.cpp/include/llama.h +140 -38
  119. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  120. package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
  121. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  122. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  123. package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
  124. package/cpp/llama.cpp/src/llama-batch.h +47 -17
  125. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  126. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  127. package/cpp/llama.cpp/src/llama-context.cpp +488 -313
  128. package/cpp/llama.cpp/src/llama-context.h +38 -17
  129. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  130. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  131. package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
  132. package/cpp/llama.cpp/src/llama-graph.h +109 -52
  133. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  134. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
  139. package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  141. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  142. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
  144. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  145. package/cpp/llama.cpp/src/llama-memory.h +89 -4
  146. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  147. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  148. package/cpp/llama.cpp/src/llama-model.cpp +735 -143
  149. package/cpp/llama.cpp/src/llama-model.h +4 -0
  150. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  151. package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
  152. package/cpp/llama.cpp/src/llama.cpp +11 -7
  153. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  154. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  155. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  156. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  157. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  158. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  159. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  160. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  161. package/cpp/rn-completion.cpp +65 -10
  162. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  163. package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
  164. package/ios/include/chat.h +1 -1
  165. package/ios/include/common/minja/chat-template.hpp +1 -1
  166. package/ios/include/common/minja/minja.hpp +1 -1
  167. package/ios/include/common.h +5 -2
  168. package/ios/include/json-schema-to-grammar.h +4 -4
  169. package/ios/include/llama.h +140 -38
  170. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  171. package/ios/libs/llama.xcframework/Info.plist +20 -20
  172. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
  174. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  175. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
  176. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  177. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  178. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  179. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
  180. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  181. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  182. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  184. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  185. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
  186. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  187. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
  188. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  189. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
  190. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  191. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  192. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
  193. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  194. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  195. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  196. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
  197. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  198. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
  199. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
  202. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
  203. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  204. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  205. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  206. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  207. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
  208. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  209. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
  210. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  211. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  212. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
  213. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
  214. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  215. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  216. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  217. package/package.json +1 -2
  218. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  219. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  221. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
  222. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
  223. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  224. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  225. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -17,10 +17,12 @@ struct ggml_tensor;
17
17
  struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- class llama_memory_i;
21
- class llama_kv_cache_unified;
22
- class llama_kv_cache_unified_iswa;
23
- class llama_kv_cache_recurrent;
20
+ struct llama_memory_state_i;
21
+
22
+ class llama_kv_cache_unified_state;
23
+ class llama_kv_cache_unified_iswa_state;
24
+ class llama_memory_recurrent_state;
25
+ class llama_memory_hybrid_state;
24
26
 
25
27
  // certain models (typically multi-modal) can produce different types of graphs
26
28
  enum llm_graph_type {
@@ -35,6 +37,7 @@ enum llm_ffn_op_type {
35
37
  LLM_FFN_RELU,
36
38
  LLM_FFN_RELU_SQR,
37
39
  LLM_FFN_SWIGLU,
40
+ LLM_FFN_GEGLU,
38
41
  };
39
42
 
40
43
  enum llm_ffn_gate_type {
@@ -133,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
133
136
  public:
134
137
  llm_graph_input_pos_bucket_kv(
135
138
  const llama_hparams & hparams,
136
- const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
139
+ const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
137
140
  virtual ~llm_graph_input_pos_bucket_kv() = default;
138
141
 
139
142
  void set_input(const llama_ubatch * ubatch) override;
@@ -141,7 +144,7 @@ public:
141
144
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
142
145
 
143
146
  const llama_hparams & hparams;
144
- const llama_kv_cache_unified * kv_self;
147
+ const llama_kv_cache_unified_state * kv_state;
145
148
  };
146
149
 
147
150
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -186,28 +189,16 @@ public:
186
189
  const llama_cparams & cparams;
187
190
  };
188
191
 
189
- class llm_graph_input_s_copy : public llm_graph_input_i {
192
+ class llm_graph_input_rs : public llm_graph_input_i {
190
193
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
- virtual ~llm_graph_input_s_copy() = default;
194
+ llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
195
+ virtual ~llm_graph_input_rs() = default;
193
196
 
194
197
  void set_input(const llama_ubatch * ubatch) override;
195
198
 
196
199
  ggml_tensor * s_copy; // I32 [kv_size]
197
200
 
198
- const llama_kv_cache_recurrent * kv_self;
199
- };
200
-
201
- class llm_graph_input_s_mask : public llm_graph_input_i {
202
- public:
203
- llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
- virtual ~llm_graph_input_s_mask() = default;
205
-
206
- void set_input(const llama_ubatch * ubatch) override;
207
-
208
- ggml_tensor * s_mask; // F32 [1, n_kv]
209
-
210
- const llama_kv_cache_recurrent * kv_self;
201
+ const llama_memory_recurrent_state * mem_state;
211
202
  };
212
203
 
213
204
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -247,10 +238,10 @@ public:
247
238
  llm_graph_input_attn_kv_unified(
248
239
  const llama_hparams & hparams,
249
240
  const llama_cparams & cparams,
250
- const llama_kv_cache_unified * kv_self) :
241
+ const llama_kv_cache_unified_state * kv_state) :
251
242
  hparams(hparams),
252
243
  cparams(cparams),
253
- kv_self(kv_self) {
244
+ kv_state(kv_state) {
254
245
  }
255
246
  ~llm_graph_input_attn_kv_unified() = default;
256
247
 
@@ -264,7 +255,7 @@ public:
264
255
  const llama_hparams & hparams;
265
256
  const llama_cparams & cparams;
266
257
 
267
- const llama_kv_cache_unified * kv_self;
258
+ const llama_kv_cache_unified_state * kv_state;
268
259
  };
269
260
 
270
261
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -272,10 +263,10 @@ public:
272
263
  llm_graph_input_attn_kv_unified_iswa(
273
264
  const llama_hparams & hparams,
274
265
  const llama_cparams & cparams,
275
- const llama_kv_cache_unified_iswa * kv_self) :
266
+ const llama_kv_cache_unified_iswa_state * kv_state) :
276
267
  hparams(hparams),
277
268
  cparams(cparams),
278
- kv_self(kv_self) {
269
+ kv_state(kv_state) {
279
270
  }
280
271
  ~llm_graph_input_attn_kv_unified_iswa() = default;
281
272
 
@@ -292,7 +283,7 @@ public:
292
283
  const llama_hparams & hparams;
293
284
  const llama_cparams & cparams;
294
285
 
295
- const llama_kv_cache_unified_iswa * kv_self;
286
+ const llama_kv_cache_unified_iswa_state * kv_state;
296
287
  };
297
288
 
298
289
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -310,6 +301,33 @@ public:
310
301
  const llama_cross * cross = nullptr;
311
302
  };
312
303
 
304
+ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
305
+ public:
306
+ llm_graph_input_mem_hybrid(
307
+ const llama_hparams & hparams,
308
+ const llama_cparams & cparams,
309
+ const llama_memory_hybrid_state * mem_state) :
310
+ hparams(hparams),
311
+ cparams(cparams),
312
+ mem_state(mem_state) {
313
+ }
314
+ virtual ~llm_graph_input_mem_hybrid() = default;
315
+
316
+ void set_input(const llama_ubatch * ubatch) override;
317
+
318
+ ggml_tensor * s_copy; // I32 [kv_size]
319
+
320
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
321
+
322
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
323
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
324
+
325
+ const llama_hparams & hparams;
326
+ const llama_cparams & cparams;
327
+
328
+ const llama_memory_hybrid_state * mem_state;
329
+ };
330
+
313
331
  //
314
332
  // llm_graph_result
315
333
  //
@@ -383,12 +401,12 @@ struct llm_graph_params {
383
401
  ggml_backend_sched_t sched;
384
402
  ggml_backend_t backend_cpu;
385
403
 
386
- const llama_adapter_cvec * cvec;
387
- const llama_adapter_loras * loras;
388
- const llama_memory_i * memory;
389
- const llama_cross * cross;
404
+ const llama_adapter_cvec * cvec;
405
+ const llama_adapter_loras * loras;
406
+ const llama_memory_state_i * mstate;
407
+ const llama_cross * cross;
390
408
 
391
- int32_t n_outputs;
409
+ uint32_t n_outputs;
392
410
 
393
411
  const llm_graph_cb & cb;
394
412
  };
@@ -422,8 +440,8 @@ struct llm_graph_context {
422
440
  const float norm_eps;
423
441
  const float norm_rms_eps;
424
442
 
425
- const int32_t n_tokens;
426
- const int32_t n_outputs;
443
+ const int64_t n_tokens;
444
+ const int64_t n_outputs;
427
445
  const int32_t n_ctx_orig; // yarn
428
446
 
429
447
  const enum llama_pooling_type pooling_type;
@@ -435,10 +453,10 @@ struct llm_graph_context {
435
453
 
436
454
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
437
455
 
438
- const llama_adapter_cvec * cvec;
439
- const llama_adapter_loras * loras;
440
- const llama_memory_i * memory;
441
- const llama_cross * cross;
456
+ const llama_adapter_cvec * cvec;
457
+ const llama_adapter_loras * loras;
458
+ const llama_memory_state_i * mstate;
459
+ const llama_cross * cross;
442
460
 
443
461
  const llm_graph_cb & cb_func;
444
462
 
@@ -518,14 +536,14 @@ struct llm_graph_context {
518
536
  ggml_tensor * build_inp_out_ids() const;
519
537
  ggml_tensor * build_inp_mean() const;
520
538
  ggml_tensor * build_inp_cls() const;
521
- ggml_tensor * build_inp_s_copy() const;
522
- ggml_tensor * build_inp_s_mask() const;
523
539
 
524
540
  ggml_tensor * build_inp_cross_embd() const;
525
541
  ggml_tensor * build_inp_pos_bucket_enc() const;
526
542
  ggml_tensor * build_inp_pos_bucket_dec() const;
527
543
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
528
544
 
545
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
546
+
529
547
  //
530
548
  // attention
531
549
  //
@@ -600,23 +618,62 @@ struct llm_graph_context {
600
618
  float kq_scale,
601
619
  int il) const;
602
620
 
621
+ ggml_tensor * build_attn(
622
+ llm_graph_input_mem_hybrid * inp,
623
+ ggml_cgraph * gf,
624
+ ggml_tensor * wo,
625
+ ggml_tensor * wo_b,
626
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
627
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
628
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
629
+ ggml_tensor * kq_b,
630
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
631
+ float kq_scale,
632
+ int il) const;
603
633
  //
604
634
  // recurrent
605
635
  //
606
636
 
607
- ggml_tensor * build_copy_mask_state(
608
- ggml_cgraph * gf,
609
- ggml_tensor * s,
610
- ggml_tensor * state_copy,
611
- ggml_tensor * state_mask,
612
- int32_t n_state,
613
- int32_t n_seqs) const;
637
+ // TODO: avoid notion of "kv"
638
+ // TODO: move this implementation to llama_memory_recurrent.
639
+ // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
640
+ // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
641
+ // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
642
+ // `llama_memory_recurrent`
643
+ ggml_tensor * build_rs(
644
+ ggml_cgraph * gf,
645
+ ggml_tensor * s,
646
+ ggml_tensor * state_copy,
647
+ int32_t state_size,
648
+ int32_t n_seqs,
649
+ uint32_t n_kv,
650
+ uint32_t kv_head,
651
+ uint32_t kv_size,
652
+ int32_t rs_zero,
653
+ bool avoid_copies = false) const;
654
+
655
+ llm_graph_input_rs * build_rs_inp() const;
656
+
657
+ ggml_tensor * build_rs(
658
+ llm_graph_input_rs * inp,
659
+ ggml_cgraph * gf,
660
+ ggml_tensor * s,
661
+ int32_t state_size,
662
+ int32_t n_seqs,
663
+ bool avoid_copies = false) const;
664
+
665
+ ggml_tensor * build_rs(
666
+ llm_graph_input_mem_hybrid * inp,
667
+ ggml_cgraph * gf,
668
+ ggml_tensor * s,
669
+ int32_t state_size,
670
+ int32_t n_seqs,
671
+ bool avoid_copies = false) const;
614
672
 
615
673
  ggml_tensor * build_rwkv_token_shift_load(
616
- ggml_cgraph * gf,
617
- ggml_tensor * state_copy,
618
- ggml_tensor * state_mask,
619
- const llama_ubatch & ubatch,
674
+ llm_graph_input_rs * inp,
675
+ ggml_cgraph * gf,
676
+ const llama_ubatch & ubatch,
620
677
  int il) const;
621
678
 
622
679
  ggml_tensor * build_rwkv_token_shift_store(
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
65
65
  return n_embd_head_v * n_head_kv;
66
66
  }
67
67
 
68
- uint32_t llama_hparams::n_embd_k_s() const {
68
+ uint32_t llama_hparams::n_embd_r() const {
69
69
  if (wkv_head_size != 0) {
70
70
  // for RWKV models
71
71
  return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
76
76
  return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
77
77
  }
78
78
 
79
- uint32_t llama_hparams::n_embd_v_s() const {
79
+ uint32_t llama_hparams::n_embd_s() const {
80
80
  if (wkv_head_size != 0) {
81
81
  // corresponds to RWKV's wkv_states size
82
82
  return n_embd * wkv_head_size;
@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
86
86
  return ssm_d_state * ssm_d_inner;
87
87
  }
88
88
 
89
+ bool llama_hparams::is_recurrent(uint32_t il) const {
90
+ return recurrent_layer_arr[il];
91
+ }
92
+
89
93
  bool llama_hparams::is_swa(uint32_t il) const {
90
94
  if (il < n_layer) {
91
95
  return swa_layers[il];
@@ -115,6 +115,9 @@ struct llama_hparams {
115
115
  uint32_t ssm_d_state = 0;
116
116
  uint32_t ssm_dt_rank = 0;
117
117
 
118
+ // for hybrid state space models
119
+ std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120
+
118
121
  bool ssm_dt_b_c_rms = false;
119
122
 
120
123
  float f_clamp_kqv = 0.0f;
@@ -181,10 +184,13 @@ struct llama_hparams {
181
184
 
182
185
  // dimension of the rolling state embeddings
183
186
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
184
- uint32_t n_embd_k_s() const;
187
+ uint32_t n_embd_r() const;
185
188
 
186
189
  // dimension of the recurrent state embeddings
187
- uint32_t n_embd_v_s() const;
190
+ uint32_t n_embd_s() const;
191
+
192
+ // whether or not the given layer is recurrent (for hybrid models)
193
+ bool is_recurrent(uint32_t il) const;
188
194
 
189
195
  bool is_swa(uint32_t il) const;
190
196
  };
@@ -0,0 +1,281 @@
1
+ #include "llama-kv-cache-unified-iswa.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-model.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+
10
+ //
11
+ // llama_kv_cache_unified_iswa
12
+ //
13
+
14
+ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15
+ const llama_model & model,
16
+ ggml_type type_k,
17
+ ggml_type type_v,
18
+ bool v_trans,
19
+ bool offload,
20
+ bool swa_full,
21
+ uint32_t kv_size,
22
+ uint32_t n_seq_max,
23
+ uint32_t n_ubatch,
24
+ uint32_t n_pad) : hparams(model.hparams) {
25
+ llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
26
+ llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
27
+
28
+ const uint32_t size_base = kv_size;
29
+
30
+ uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31
+
32
+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
33
+ if (swa_full) {
34
+ LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
35
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
36
+
37
+ size_swa = size_base;
38
+ }
39
+
40
+ LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
41
+
42
+ kv_base = std::make_unique<llama_kv_cache_unified>(
43
+ model, std::move(filter_base), type_k, type_v,
44
+ v_trans, offload, size_base, n_seq_max, n_pad,
45
+ 0, LLAMA_SWA_TYPE_NONE);
46
+
47
+ LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
48
+
49
+ kv_swa = std::make_unique<llama_kv_cache_unified>(
50
+ model, std::move(filter_swa), type_k, type_v,
51
+ v_trans, offload, size_swa, n_seq_max, n_pad,
52
+ hparams.n_swa, hparams.swa_type);
53
+ }
54
+
55
+ void llama_kv_cache_unified_iswa::clear(bool data) {
56
+ kv_base->clear(data);
57
+ kv_swa ->clear(data);
58
+ }
59
+
60
+ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
61
+ bool res = true;
62
+
63
+ res = res & kv_base->seq_rm(seq_id, p0, p1);
64
+ res = res & kv_swa ->seq_rm(seq_id, p0, p1);
65
+
66
+ return res;
67
+ }
68
+
69
+ void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
70
+ kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
71
+ kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
72
+ }
73
+
74
+ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
75
+ kv_base->seq_keep(seq_id);
76
+ kv_swa ->seq_keep(seq_id);
77
+ }
78
+
79
+ void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
80
+ kv_base->seq_add(seq_id, p0, p1, shift);
81
+ kv_swa ->seq_add(seq_id, p0, p1, shift);
82
+ }
83
+
84
+ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
85
+ kv_base->seq_div(seq_id, p0, p1, d);
86
+ kv_swa ->seq_div(seq_id, p0, p1, d);
87
+ }
88
+
89
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
90
+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
91
+ return kv_swa->seq_pos_min(seq_id);
92
+ }
93
+
94
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
+ return kv_swa->seq_pos_max(seq_id);
96
+ }
97
+
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99
+ GGML_UNUSED(embd_all);
100
+
101
+ // first try simple split
102
+ do {
103
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
104
+
105
+ std::vector<llama_ubatch> ubatches;
106
+
107
+ while (sbatch.n_tokens > 0) {
108
+ auto ubatch = sbatch.split_simple(n_ubatch);
109
+
110
+ ubatches.push_back(ubatch);
111
+ }
112
+
113
+ auto heads_base = kv_base->prepare(ubatches);
114
+ if (heads_base.empty()) {
115
+ break;
116
+ }
117
+
118
+ auto heads_swa = kv_swa->prepare(ubatches);
119
+ if (heads_swa.empty()) {
120
+ break;
121
+ }
122
+
123
+ assert(heads_base.size() == heads_swa.size());
124
+
125
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
126
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
127
+ } while (false);
128
+
129
+ // if it fails, try equal split
130
+ do {
131
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
132
+
133
+ std::vector<llama_ubatch> ubatches;
134
+
135
+ while (sbatch.n_tokens > 0) {
136
+ auto ubatch = sbatch.split_equal(n_ubatch);
137
+
138
+ ubatches.push_back(ubatch);
139
+ }
140
+
141
+ auto heads_base = kv_base->prepare(ubatches);
142
+ if (heads_base.empty()) {
143
+ break;
144
+ }
145
+
146
+ auto heads_swa = kv_swa->prepare(ubatches);
147
+ if (heads_swa.empty()) {
148
+ break;
149
+ }
150
+
151
+ assert(heads_base.size() == heads_swa.size());
152
+
153
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
154
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
155
+ } while (false);
156
+
157
+ // TODO: if we fail again, we should attempt different splitting strategies
158
+ // but to do that properly, we first have to refactor the batches to be more flexible
159
+
160
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
161
+ }
162
+
163
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
164
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
165
+ }
166
+
167
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
168
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
169
+ }
170
+
171
+ bool llama_kv_cache_unified_iswa::get_can_shift() const {
172
+ return kv_base->get_size() == kv_swa->get_size();
173
+ }
174
+
175
+ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
176
+ kv_base->state_write(io, seq_id);
177
+ kv_swa ->state_write(io, seq_id);
178
+ }
179
+
180
+ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
181
+ kv_base->state_read(io, seq_id);
182
+ kv_swa ->state_read(io, seq_id);
183
+ }
184
+
185
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
186
+ return kv_base.get();
187
+ }
188
+
189
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
190
+ return kv_swa.get();
191
+ }
192
+
193
+ //
194
+ // llama_kv_cache_unified_iswa_state
195
+ //
196
+
197
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
198
+
199
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200
+ llama_kv_cache_unified_iswa * kv) :
201
+ state_base(kv->get_base()->init_full()),
202
+ state_swa (kv->get_swa ()->init_full()),
203
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
204
+ }
205
+
206
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
207
+ llama_kv_cache_unified_iswa * kv,
208
+ llama_context * lctx,
209
+ bool optimize) :
210
+ state_base(kv->get_base()->init_update(lctx, optimize)),
211
+ state_swa (kv->get_swa ()->init_update(lctx, optimize)),
212
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
213
+ }
214
+
215
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
216
+ llama_kv_cache_unified_iswa * kv,
217
+ llama_sbatch sbatch,
218
+ std::vector<uint32_t> heads_base,
219
+ std::vector<uint32_t> heads_swa,
220
+ std::vector<llama_ubatch> ubatches) :
221
+ sbatch(std::move(sbatch)),
222
+ ubatches(std::move(ubatches)),
223
+ // note: here we copy the ubatches. not sure if this is ideal
224
+ state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
225
+ state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
226
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
227
+ }
228
+
229
+ llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
230
+
231
+ bool llama_kv_cache_unified_iswa_state::next() {
232
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
233
+
234
+ state_base->next();
235
+ state_swa ->next();
236
+
237
+ if (++i_next >= ubatches.size()) {
238
+ return false;
239
+ }
240
+
241
+ return true;
242
+ }
243
+
244
+ bool llama_kv_cache_unified_iswa_state::apply() {
245
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
246
+
247
+ bool res = true;
248
+
249
+ res = res & state_base->apply();
250
+ res = res & state_swa ->apply();
251
+
252
+ return res;
253
+ }
254
+
255
+ std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
256
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
257
+
258
+ return sbatch.out_ids;
259
+ }
260
+
261
+ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
262
+ return status;
263
+ }
264
+
265
+ const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
266
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
267
+
268
+ return ubatches[i_next];
269
+ }
270
+
271
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
272
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
273
+
274
+ return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
275
+ }
276
+
277
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
278
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
279
+
280
+ return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
281
+ }