@novastera-oss/llamarn 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +140 -38
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +48 -67
  13. package/cpp/LlamaCppModel.h +8 -3
  14. package/cpp/PureCppImpl.cpp +1 -1
  15. package/cpp/PureCppImpl.h +2 -2
  16. package/cpp/build-info.cpp +2 -2
  17. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  18. package/cpp/llama.cpp/Makefile +2 -2
  19. package/cpp/llama.cpp/README.md +33 -13
  20. package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
  21. package/cpp/llama.cpp/common/arg.cpp +38 -12
  22. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  23. package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
  24. package/cpp/llama.cpp/common/chat-parser.h +4 -1
  25. package/cpp/llama.cpp/common/chat.cpp +16 -13
  26. package/cpp/llama.cpp/common/chat.h +1 -1
  27. package/cpp/llama.cpp/common/common.cpp +52 -40
  28. package/cpp/llama.cpp/common/common.h +5 -2
  29. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  30. package/cpp/llama.cpp/common/json-partial.h +2 -1
  31. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  32. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  33. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  34. package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  37. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  38. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
  39. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  41. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  79. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  82. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  112. package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
  113. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  114. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  115. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  116. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  117. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  118. package/cpp/llama.cpp/include/llama.h +140 -38
  119. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  120. package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
  121. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  122. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  123. package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
  124. package/cpp/llama.cpp/src/llama-batch.h +47 -17
  125. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  126. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  127. package/cpp/llama.cpp/src/llama-context.cpp +488 -313
  128. package/cpp/llama.cpp/src/llama-context.h +38 -17
  129. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  130. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  131. package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
  132. package/cpp/llama.cpp/src/llama-graph.h +109 -52
  133. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  134. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
  139. package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  141. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  142. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
  144. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  145. package/cpp/llama.cpp/src/llama-memory.h +89 -4
  146. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  147. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  148. package/cpp/llama.cpp/src/llama-model.cpp +735 -143
  149. package/cpp/llama.cpp/src/llama-model.h +4 -0
  150. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  151. package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
  152. package/cpp/llama.cpp/src/llama.cpp +11 -7
  153. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  154. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  155. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  156. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  157. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  158. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  159. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  160. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  161. package/cpp/rn-completion.cpp +65 -10
  162. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  163. package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
  164. package/ios/include/chat.h +1 -1
  165. package/ios/include/common/minja/chat-template.hpp +1 -1
  166. package/ios/include/common/minja/minja.hpp +1 -1
  167. package/ios/include/common.h +5 -2
  168. package/ios/include/json-schema-to-grammar.h +4 -4
  169. package/ios/include/llama.h +140 -38
  170. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  171. package/ios/libs/llama.xcframework/Info.plist +20 -20
  172. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
  174. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  175. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
  176. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  177. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  178. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  179. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
  180. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  181. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  182. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  184. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  185. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
  186. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  187. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
  188. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  189. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
  190. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  191. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  192. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
  193. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  194. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  195. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  196. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
  197. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  198. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
  199. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
  202. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
  203. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  204. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  205. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  206. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  207. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
  208. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  209. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
  210. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  211. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  212. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
  213. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
  214. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  215. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  216. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  217. package/package.json +1 -2
  218. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  219. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  221. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
  222. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
  223. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  224. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  225. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -3,7 +3,11 @@
3
3
  #include "llama-impl.h"
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
- #include "llama-kv-cache.h"
6
+
7
+ #include "llama-kv-cache-unified.h"
8
+ #include "llama-kv-cache-unified-iswa.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
7
11
 
8
12
  #include <cassert>
9
13
  #include <cmath>
@@ -83,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
83
87
 
84
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
89
  if (pos_bucket) {
86
- kv_self->set_input_pos_bucket(pos_bucket, ubatch);
90
+ kv_state->set_input_pos_bucket(pos_bucket, ubatch);
87
91
  }
88
92
  }
89
93
 
@@ -136,6 +140,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
136
140
 
137
141
  std::vector<uint64_t> sum(n_tokens, 0);
138
142
 
143
+ // TODO: fix indexing [UBATCH_IDX]
139
144
  for (int s = 0; s < n_seqs; ++s) {
140
145
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
141
146
 
@@ -153,6 +158,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
153
158
  }
154
159
  }
155
160
 
161
+ // TODO: fix indexing [UBATCH_IDX]
156
162
  for (int s = 0; s < n_seqs; ++s) {
157
163
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
158
164
 
@@ -177,6 +183,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
177
183
  uint32_t * data = (uint32_t *) cls->data;
178
184
  memset(cls->data, 0, n_tokens * ggml_element_size(cls));
179
185
 
186
+ // TODO: fix indexing [UBATCH_IDX]
180
187
  for (int s = 0; s < n_seqs; ++s) {
181
188
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
182
189
 
@@ -207,6 +214,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
207
214
  std::vector<int> last_pos(n_tokens, -1);
208
215
  std::vector<int> last_row(n_tokens, -1);
209
216
 
217
+ // TODO: fix indexing [UBATCH_IDX]
210
218
  for (int s = 0; s < n_seqs; ++s) {
211
219
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
212
220
 
@@ -231,34 +239,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
231
239
  }
232
240
  }
233
241
 
234
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
235
243
  GGML_UNUSED(ubatch);
236
244
 
237
- const int64_t n_kv = kv_self->n;
245
+ const int64_t n_rs = mem_state->get_n_rs();
238
246
 
239
247
  if (s_copy) {
240
248
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
241
249
  int32_t * data = (int32_t *) s_copy->data;
242
250
 
243
251
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244
- for (uint32_t i = 0; i < n_kv; ++i) {
245
- data[i] = kv_self->s_copy(i);
246
- }
247
- }
248
- }
249
-
250
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251
- GGML_UNUSED(ubatch);
252
-
253
- const int64_t n_kv = kv_self->n;
254
-
255
- if (s_mask) {
256
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
257
- float * data = (float *) s_mask->data;
258
-
259
- // clear unused states
260
- for (int i = 0; i < n_kv; ++i) {
261
- data[i] = kv_self->s_mask(i);
252
+ for (uint32_t i = 0; i < n_rs; ++i) {
253
+ data[i] = mem_state->s_copy(i);
262
254
  }
263
255
  }
264
256
  }
@@ -296,6 +288,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
296
288
  const int32_t ti = s0*n_seq_tokens + i;
297
289
  float f = -INFINITY;
298
290
 
291
+ // TODO: fix indexing [UBATCH_IDX]
299
292
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
300
293
  if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
301
294
  if (hparams.use_alibi) {
@@ -335,6 +328,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
335
328
  const int32_t ti = s0*n_seq_tokens + i;
336
329
  float f = -INFINITY;
337
330
 
331
+ // TODO: fix indexing [UBATCH_IDX]
338
332
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
339
333
  if (ubatch->seq_id[s0][s] == seq_id) {
340
334
  if (hparams.use_alibi) {
@@ -362,17 +356,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
356
 
363
357
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
358
  if (self_kq_mask) {
365
- kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
359
+ kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
360
  }
367
361
  }
368
362
 
369
363
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
364
  if (self_kq_mask) {
371
- kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
365
+ kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
366
  }
373
367
 
374
368
  if (self_kq_mask_swa) {
375
- kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
369
+ kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376
370
  }
377
371
  }
378
372
 
@@ -390,6 +384,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
390
384
  for (int j = 0; j < n_tokens; ++j) {
391
385
  for (int i = 0; i < n_enc; ++i) {
392
386
  float f = -INFINITY;
387
+ // TODO: fix indexing [UBATCH_IDX]
393
388
  for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
394
389
  const llama_seq_id seq_id = ubatch->seq_id[j][s];
395
390
  if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
@@ -409,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
409
404
  }
410
405
  }
411
406
 
407
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
408
+ if (self_kq_mask) {
409
+ mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
410
+ }
411
+
412
+ const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
413
+
414
+ if (s_copy) {
415
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
416
+ int32_t * data = (int32_t *) s_copy->data;
417
+
418
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
419
+ for (uint32_t i = 0; i < n_rs; ++i) {
420
+ data[i] = mem_state->get_state_recr()->s_copy(i);
421
+ }
422
+ }
423
+ }
424
+
412
425
  //
413
426
  // llm_graph_context
414
427
  //
@@ -448,7 +461,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448
461
  backend_cpu (params.backend_cpu),
449
462
  cvec (params.cvec),
450
463
  loras (params.loras),
451
- memory (params.memory),
464
+ mstate (params.mstate),
452
465
  cross (params.cross),
453
466
  cb_func (params.cb),
454
467
  res (std::make_unique<llm_graph_result>()) {
@@ -647,6 +660,7 @@ ggml_tensor * llm_graph_context::build_ffn(
647
660
  {
648
661
  // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
649
662
  int64_t split_point = cur->ne[0] / 2;
663
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
650
664
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
651
665
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
652
666
 
@@ -656,6 +670,20 @@ ggml_tensor * llm_graph_context::build_ffn(
656
670
  cur = ggml_mul(ctx0, x0, x1);
657
671
  cb(cur, "ffn_mul", il);
658
672
  } break;
673
+ case LLM_FFN_GEGLU:
674
+ {
675
+ // Split into two equal parts
676
+ int64_t split_point = cur->ne[0] / 2;
677
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
678
+ ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
679
+ ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
680
+
681
+ x0 = ggml_gelu(ctx0, x0);
682
+ cb(x0, "ffn_gelu", il);
683
+
684
+ cur = ggml_mul(ctx0, x0, x1);
685
+ cb(cur, "ffn_geglu", il);
686
+ } break;
659
687
  }
660
688
 
661
689
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -766,9 +794,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
766
794
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
767
795
 
768
796
  if (weight_before_ffn) {
769
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
770
- ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
771
- repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
797
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
798
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
772
799
  cur = ggml_mul(ctx0, repeated, weights);
773
800
  cb(cur, "ffn_moe_weighted", il);
774
801
  }
@@ -953,40 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
953
980
  return cur;
954
981
  }
955
982
 
956
- ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
958
-
959
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960
-
961
- const auto n_kv = kv_self->n;
962
-
963
- auto & cur = inp->s_copy;
964
-
965
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
966
- ggml_set_input(cur);
967
-
968
- res->add_input(std::move(inp));
969
-
970
- return cur;
971
- }
972
-
973
- ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
975
-
976
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977
-
978
- const auto n_kv = kv_self->n;
979
-
980
- auto & cur = inp->s_mask;
981
-
982
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
983
- ggml_set_input(cur);
984
-
985
- res->add_input(std::move(inp));
986
-
987
- return cur;
988
- }
989
-
990
983
  ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
991
984
  auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
992
985
 
@@ -1025,11 +1018,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
1018
  }
1026
1019
 
1027
1020
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1021
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1029
1022
 
1030
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1023
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
1031
1024
 
1032
- const auto n_kv = kv_self->get_n();
1025
+ const auto n_kv = kv_state->get_n_kv();
1033
1026
 
1034
1027
  auto & cur = inp->pos_bucket;
1035
1028
 
@@ -1056,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1056
1049
  return pos_bias;
1057
1050
  }
1058
1051
 
1052
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1053
+ const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
1054
+
1055
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
1056
+
1057
+ {
1058
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
1059
+
1060
+ const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
1061
+
1062
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1063
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1064
+ ggml_set_input(inp->self_kq_mask);
1065
+
1066
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1067
+ }
1068
+
1069
+ {
1070
+ const auto n_rs = mem_state->get_state_recr()->get_n_rs();
1071
+
1072
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1073
+ ggml_set_input(inp->s_copy);
1074
+ }
1075
+
1076
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1077
+ }
1078
+
1059
1079
  ggml_tensor * llm_graph_context::build_attn_mha(
1060
1080
  ggml_cgraph * gf,
1061
1081
  ggml_tensor * q,
@@ -1231,14 +1251,14 @@ ggml_tensor * llm_graph_context::build_attn(
1231
1251
  }
1232
1252
 
1233
1253
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1254
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1235
1255
 
1236
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1256
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1237
1257
 
1238
1258
  {
1239
1259
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1240
1260
 
1241
- const auto n_kv = kv_self->get_n();
1261
+ const auto n_kv = kv_state->get_n_kv();
1242
1262
 
1243
1263
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
1264
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1268,19 +1288,19 @@ ggml_tensor * llm_graph_context::build_attn(
1268
1288
  ggml_build_forward_expand(gf, k_cur);
1269
1289
  ggml_build_forward_expand(gf, v_cur);
1270
1290
 
1271
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1291
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1272
1292
 
1273
1293
  // store to KV cache
1274
1294
  {
1275
- ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
- ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1295
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1296
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1277
1297
  }
1278
1298
 
1279
1299
  const auto & kq_mask = inp->get_kq_mask();
1280
1300
 
1281
1301
  ggml_tensor * q = q_cur;
1282
- ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
- ggml_tensor * v = kv_self->get_v(ctx0, il);
1302
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1303
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1284
1304
 
1285
1305
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
1306
  cb(cur, "kqv_out", il);
@@ -1300,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
1300
1320
  return cur;
1301
1321
  }
1302
1322
 
1303
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1304
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1305
-
1306
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1307
-
1308
- {
1309
- const auto n_kv = kv_self->get_kv_base()->get_n();
1310
-
1311
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1312
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1313
- ggml_set_input(inp->self_kq_mask);
1314
-
1315
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1316
- }
1317
-
1318
- {
1319
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1320
-
1321
- const auto n_kv = kv_self->get_kv_swa()->get_n();
1322
-
1323
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1324
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1325
- ggml_set_input(inp->self_kq_mask_swa);
1326
-
1327
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1328
- }
1329
-
1330
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1331
- }
1332
-
1333
1323
  ggml_tensor * llm_graph_context::build_attn(
1334
1324
  llm_graph_input_attn_kv_unified_iswa * inp,
1335
1325
  ggml_cgraph * gf,
@@ -1348,23 +1338,23 @@ ggml_tensor * llm_graph_context::build_attn(
1348
1338
  ggml_build_forward_expand(gf, k_cur);
1349
1339
  ggml_build_forward_expand(gf, v_cur);
1350
1340
 
1351
- const bool is_swa = hparams.is_swa(il);
1341
+ const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1352
1342
 
1353
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1343
+ const bool is_swa = hparams.is_swa(il);
1354
1344
 
1355
- const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1345
+ const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1356
1346
 
1357
1347
  // store to KV cache
1358
1348
  {
1359
- ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360
- ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1349
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1350
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1361
1351
  }
1362
1352
 
1363
1353
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1364
1354
 
1365
1355
  ggml_tensor * q = q_cur;
1366
- ggml_tensor * k = kv->get_k(ctx0, il);
1367
- ggml_tensor * v = kv->get_v(ctx0, il);
1356
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1357
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1368
1358
 
1369
1359
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1370
1360
  cb(cur, "kqv_out", il);
@@ -1439,56 +1429,182 @@ ggml_tensor * llm_graph_context::build_attn(
1439
1429
  return cur;
1440
1430
  }
1441
1431
 
1442
- ggml_tensor * llm_graph_context::build_copy_mask_state(
1443
- ggml_cgraph * gf,
1444
- ggml_tensor * s,
1445
- ggml_tensor * state_copy,
1446
- ggml_tensor * state_mask,
1447
- int32_t n_state,
1448
- int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1432
+ ggml_tensor * llm_graph_context::build_attn(
1433
+ llm_graph_input_mem_hybrid * inp,
1434
+ ggml_cgraph * gf,
1435
+ ggml_tensor * wo,
1436
+ ggml_tensor * wo_b,
1437
+ ggml_tensor * q_cur,
1438
+ ggml_tensor * k_cur,
1439
+ ggml_tensor * v_cur,
1440
+ ggml_tensor * kq_b,
1441
+ ggml_tensor * v_mla,
1442
+ float kq_scale,
1443
+ int il) const {
1444
+ // these nodes are added to the graph together so that they are not reordered
1445
+ // by doing so, the number of splits in the graph is reduced
1446
+ ggml_build_forward_expand(gf, q_cur);
1447
+ ggml_build_forward_expand(gf, k_cur);
1448
+ ggml_build_forward_expand(gf, v_cur);
1450
1449
 
1451
- const auto n_kv = kv_self->n;
1452
- const auto kv_head = kv_self->head;
1450
+ const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
1453
1451
 
1454
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1452
+ // store to KV cache
1453
+ {
1454
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1455
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1456
+ }
1455
1457
 
1456
- // copy states
1457
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458
- // this shrinks the tensors's ne[1] to n_kv
1459
- states = ggml_get_rows(ctx0, states, state_copy);
1458
+ const auto & kq_mask = inp->get_kq_mask();
1459
+
1460
+ ggml_tensor * q = q_cur;
1461
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1462
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1463
+
1464
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1465
+ cb(cur, "kqv_out", il);
1466
+
1467
+ if (wo) {
1468
+ cur = build_lora_mm(wo, cur);
1469
+ if (arch == LLM_ARCH_GLM4) {
1470
+ // GLM4 seems to have numerical issues with half-precision accumulators
1471
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1472
+ }
1473
+ }
1474
+
1475
+ if (wo_b) {
1476
+ cur = ggml_add(ctx0, cur, wo_b);
1477
+ }
1478
+
1479
+ return cur;
1480
+ }
1481
+
1482
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1483
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1460
1484
 
1461
- // clear states of sequences which are starting at the beginning of this batch
1462
- // FIXME: zero-out NANs?
1463
- states = ggml_mul(ctx0, states, state_mask);
1485
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1464
1486
 
1465
- // copy states which won't be changed further (between n_seqs and n_kv)
1487
+ {
1488
+ const auto n_kv = kv_state->get_base()->get_n_kv();
1489
+
1490
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1491
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1492
+ ggml_set_input(inp->self_kq_mask);
1493
+
1494
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1495
+ }
1496
+
1497
+ {
1498
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1499
+
1500
+ const auto n_kv = kv_state->get_swa()->get_n_kv();
1501
+
1502
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1503
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1504
+ ggml_set_input(inp->self_kq_mask_swa);
1505
+
1506
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1507
+ }
1508
+
1509
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1510
+ }
1511
+
1512
+ ggml_tensor * llm_graph_context::build_rs(
1513
+ ggml_cgraph * gf,
1514
+ ggml_tensor * s,
1515
+ ggml_tensor * state_copy,
1516
+ int32_t state_size,
1517
+ int32_t n_seqs,
1518
+ uint32_t n_kv,
1519
+ uint32_t kv_head,
1520
+ uint32_t kv_size,
1521
+ int32_t rs_zero,
1522
+ bool avoid_copies) const {
1523
+
1524
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1525
+
1526
+ // Clear a single state which will then be copied to the other cleared states.
1527
+ // Note that this is a no-op when the view is zero-sized.
1528
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1529
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1530
+
1531
+ ggml_tensor * output_states;
1532
+
1533
+ if (!avoid_copies) {
1534
+ // copy states
1535
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1536
+ // {state_size, kv_size} -> {state_size, n_seqs}
1537
+ output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1538
+ ggml_build_forward_expand(gf, output_states);
1539
+ } else {
1540
+ // FIXME: make the gathering operation happen before the copy below
1541
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1542
+ output_states = states;
1543
+ }
1544
+
1545
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1546
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1466
1547
  ggml_build_forward_expand(gf,
1467
1548
  ggml_cpy(ctx0,
1468
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1469
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1549
+ states_extra,
1550
+ ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1551
+
1552
+ return output_states;
1553
+ }
1470
1554
 
1471
- // the part of the states that will be used and modified
1472
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1555
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1556
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1557
+
1558
+ auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1559
+
1560
+ const auto n_rs = kv_state->get_n_rs();
1561
+
1562
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1563
+ ggml_set_input(inp->s_copy);
1564
+
1565
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1566
+ }
1567
+
1568
+ ggml_tensor * llm_graph_context::build_rs(
1569
+ llm_graph_input_rs * inp,
1570
+ ggml_cgraph * gf,
1571
+ ggml_tensor * s,
1572
+ int32_t state_size,
1573
+ int32_t n_seqs,
1574
+ bool avoid_copies) const {
1575
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1576
+
1577
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1578
+ }
1579
+
1580
+ ggml_tensor * llm_graph_context::build_rs(
1581
+ llm_graph_input_mem_hybrid * inp,
1582
+ ggml_cgraph * gf,
1583
+ ggml_tensor * s,
1584
+ int32_t state_size,
1585
+ int32_t n_seqs,
1586
+ bool avoid_copies) const {
1587
+ const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1588
+
1589
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1473
1590
  }
1474
1591
 
1475
1592
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1476
- ggml_cgraph * gf,
1477
- ggml_tensor * state_copy,
1478
- ggml_tensor * state_mask,
1479
- const llama_ubatch & ubatch,
1593
+ llm_graph_input_rs * inp,
1594
+ ggml_cgraph * gf,
1595
+ const llama_ubatch & ubatch,
1480
1596
  int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1597
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1482
1598
 
1483
1599
  const auto token_shift_count = hparams.token_shift_count;
1484
1600
 
1485
1601
  const int64_t n_seqs = ubatch.n_seqs;
1486
1602
 
1487
- ggml_tensor * token_shift_all = kv_self->k_l[il];
1603
+ ggml_tensor * token_shift_all = kv_state->get_r_l(il);
1488
1604
 
1489
- ggml_tensor * token_shift = build_copy_mask_state(
1490
- gf, token_shift_all, state_copy, state_mask,
1491
- hparams.n_embd_k_s(), n_seqs);
1605
+ ggml_tensor * token_shift = build_rs(
1606
+ inp, gf, token_shift_all,
1607
+ hparams.n_embd_r(), n_seqs);
1492
1608
 
1493
1609
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1494
1610
 
@@ -1499,19 +1615,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1615
  ggml_tensor * token_shift,
1500
1616
  const llama_ubatch & ubatch,
1501
1617
  int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1618
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1503
1619
 
1504
1620
  const auto token_shift_count = hparams.token_shift_count;
1505
1621
  const auto n_embd = hparams.n_embd;
1506
1622
 
1507
1623
  const int64_t n_seqs = ubatch.n_seqs;
1508
1624
 
1509
- const auto kv_head = kv_self->head;
1625
+ const auto kv_head = kv_state->get_head();
1510
1626
 
1511
1627
  return ggml_cpy(
1512
1628
  ctx0,
1513
1629
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1514
- ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
1630
+ ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
1515
1631
  );
1516
1632
  }
1517
1633
 
@@ -1562,23 +1678,30 @@ void llm_graph_context::build_pooling(
1562
1678
  ggml_tensor * inp_cls = build_inp_cls();
1563
1679
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1564
1680
 
1565
- if (cls != nullptr && cls_b != nullptr) {
1681
+ if (cls) {
1566
1682
  // classification head
1567
1683
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1568
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1684
+ cur = ggml_mul_mat(ctx0, cls, inp);
1685
+ if (cls_b) {
1686
+ cur = ggml_add(ctx0, cur, cls_b);
1687
+ }
1569
1688
  cur = ggml_tanh(ctx0, cur);
1570
1689
 
1571
1690
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1572
1691
  // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1573
1692
  if (cls_out) {
1574
- GGML_ASSERT(cls_out_b != nullptr);
1575
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1693
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1694
+ if (cls_out_b) {
1695
+ cur = ggml_add(ctx0, cur, cls_out_b);
1696
+ }
1576
1697
  }
1577
1698
  } else if (cls_out) {
1578
1699
  // Single layer classification head (direct projection)
1579
1700
  // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1580
- GGML_ASSERT(cls_out_b != nullptr);
1581
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1701
+ cur = ggml_mul_mat(ctx0, cls_out, inp);
1702
+ if (cls_out_b) {
1703
+ cur = ggml_add(ctx0, cur, cls_out_b);
1704
+ }
1582
1705
  } else {
1583
1706
  GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1584
1707
  }