@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -1,6 +1,7 @@
1
- #include "llama-kv-cache-recurrent.h"
1
+ #include "llama-memory-recurrent.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
4
5
  #include "llama-batch.h"
5
6
  #include "llama-model.h"
6
7
 
@@ -11,27 +12,28 @@
11
12
  #include <stdexcept>
12
13
 
13
14
  //
14
- // llama_kv_cache_recurrent
15
+ // llama_memory_recurrent
15
16
  //
16
17
 
17
- llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18
- const llama_model & model,
19
- ggml_type type_k,
20
- ggml_type type_v,
21
- bool offload,
22
- uint32_t kv_size,
23
- uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
18
+ llama_memory_recurrent::llama_memory_recurrent(
19
+ const llama_model & model,
20
+ layer_filter_cb && filter,
21
+ ggml_type type_r,
22
+ ggml_type type_s,
23
+ bool offload,
24
+ uint32_t mem_size,
25
+ uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
24
26
  const int32_t n_layer = hparams.n_layer;
25
27
 
26
- LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
27
- __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
28
+ LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
29
+ __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
28
30
 
29
31
  head = 0;
30
- size = kv_size;
32
+ size = mem_size;
31
33
  used = 0;
32
34
 
33
35
  cells.clear();
34
- cells.resize(kv_size);
36
+ cells.resize(mem_size);
35
37
 
36
38
  // create a context for each buffer type
37
39
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -58,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
58
60
  return it->second;
59
61
  };
60
62
 
61
- k_l.reserve(n_layer);
62
- v_l.reserve(n_layer);
63
+ r_l.resize(n_layer);
64
+ s_l.resize(n_layer);
63
65
 
64
66
  for (int i = 0; i < n_layer; i++) {
65
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
66
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
67
+ if (filter && !filter(i)) {
68
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
69
+ continue;
70
+ }
67
71
 
68
72
  const char * dev_name = "CPU";
69
73
 
@@ -83,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
83
87
  throw std::runtime_error("failed to create ggml context for kv cache");
84
88
  }
85
89
 
86
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
87
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
88
- ggml_format_name(k, "cache_k_l%d", i);
89
- ggml_format_name(v, "cache_v_l%d", i);
90
- k_l.push_back(k);
91
- v_l.push_back(v);
90
+ ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
91
+ ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
92
+ ggml_format_name(r, "cache_r_l%d", i);
93
+ ggml_format_name(s, "cache_s_l%d", i);
94
+ r_l[i] = r;
95
+ s_l[i] = s;
92
96
  }
93
97
 
94
98
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -106,32 +110,35 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
106
110
  }
107
111
 
108
112
  {
109
- const size_t memory_size_k = size_k_bytes();
110
- const size_t memory_size_v = size_v_bytes();
113
+ const size_t memory_size_r = size_r_bytes();
114
+ const size_t memory_size_s = size_s_bytes();
111
115
 
112
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
113
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
114
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
115
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
116
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
117
+ (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
118
+ ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
119
+ ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
116
120
  }
117
121
  }
118
122
 
119
- void llama_kv_cache_recurrent::clear() {
123
+ void llama_memory_recurrent::clear(bool data) {
120
124
  for (int32_t i = 0; i < (int32_t) size; ++i) {
121
125
  cells[i].pos = -1;
122
126
  cells[i].seq_id.clear();
123
127
  cells[i].src = -1;
124
128
  cells[i].tail = -1;
125
129
  }
130
+
126
131
  head = 0;
127
132
  used = 0;
128
133
 
129
- for (auto & buf : bufs) {
130
- ggml_backend_buffer_clear(buf.get(), 0);
134
+ if (data) {
135
+ for (auto & buf : bufs) {
136
+ ggml_backend_buffer_clear(buf.get(), 0);
137
+ }
131
138
  }
132
139
  }
133
140
 
134
- bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
141
+ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
135
142
  uint32_t new_head = size;
136
143
 
137
144
  if (p0 < 0) {
@@ -150,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
150
157
  if (0 <= seq_id) {
151
158
  int32_t & tail_id = cells[seq_id].tail;
152
159
  if (tail_id >= 0) {
153
- const kv_cell & cell = cells[tail_id];
160
+ const auto & cell = cells[tail_id];
154
161
  // partial intersection is invalid
155
162
  if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
156
163
  return false;
@@ -198,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
198
205
  return true;
199
206
  }
200
207
 
201
- void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
208
+ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
202
209
  if (seq_id_src == seq_id_dst) {
203
210
  return;
204
211
  }
@@ -212,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
212
219
  }
213
220
 
214
221
  if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
215
- kv_cell & tail_src = cells[seq_id_src];
216
- kv_cell & tail_dst = cells[seq_id_dst];
222
+ auto & tail_src = cells[seq_id_src];
223
+ auto & tail_dst = cells[seq_id_dst];
217
224
  if (tail_dst.tail >= 0) {
218
225
  // clear destination seq_id if it wasn't empty
219
- kv_cell & cell_dst = cells[tail_dst.tail];
226
+ auto & cell_dst = cells[tail_dst.tail];
220
227
 
221
228
  cell_dst.seq_id.erase(seq_id_dst);
222
229
  tail_dst.tail = -1;
@@ -227,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
227
234
  }
228
235
  }
229
236
  if (tail_src.tail >= 0) {
230
- kv_cell & cell_src = cells[tail_src.tail];
237
+ auto & cell_src = cells[tail_src.tail];
231
238
 
232
239
  cell_src.seq_id.insert(seq_id_dst);
233
240
  tail_dst.tail = tail_src.tail;
@@ -235,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
235
242
  }
236
243
  }
237
244
 
238
- void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
245
+ void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
239
246
  uint32_t new_head = size;
240
247
 
241
248
  for (uint32_t i = 0; i < size; ++i) {
@@ -267,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
267
274
  }
268
275
  }
269
276
 
270
- void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
277
+ void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
271
278
  if (shift == 0) {
272
279
  return;
273
280
  }
@@ -289,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
289
296
  if (0 <= seq_id && seq_id < (int64_t) size) {
290
297
  const int32_t tail_id = cells[seq_id].tail;
291
298
  if (tail_id >= 0) {
292
- kv_cell & cell = cells[tail_id];
299
+ auto & cell = cells[tail_id];
293
300
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
294
301
  cell.pos += shift;
295
302
  }
@@ -297,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
297
304
  }
298
305
  }
299
306
 
300
- void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
307
+ void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
301
308
  if (d == 1) {
302
309
  return;
303
310
  }
@@ -319,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
319
326
  if (0 <= seq_id && seq_id < (int64_t) size) {
320
327
  const int32_t tail_id = cells[seq_id].tail;
321
328
  if (tail_id >= 0) {
322
- kv_cell & cell = cells[tail_id];
329
+ auto & cell = cells[tail_id];
323
330
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
324
331
  cell.pos /= d;
325
332
  }
@@ -327,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
327
334
  }
328
335
  }
329
336
 
330
- llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
337
+ llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
331
338
  llama_pos result = std::numeric_limits<llama_pos>::max();
332
339
 
333
340
  for (uint32_t i = 0; i < size; ++i) {
@@ -343,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
343
350
  return result;
344
351
  }
345
352
 
346
- llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
353
+ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
347
354
  llama_pos result = -1;
348
355
 
349
356
  for (uint32_t i = 0; i < size; ++i) {
@@ -355,38 +362,50 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
355
362
  return result;
356
363
  }
357
364
 
358
- llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
359
- GGML_UNUSED(embd_pooled);
365
+ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
366
+ do {
367
+ balloc.split_reset();
360
368
 
361
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
369
+ std::vector<llama_ubatch> ubatches;
370
+ while (true) {
371
+ llama_ubatch ubatch;
362
372
 
363
- std::vector<llama_ubatch> ubatches;
373
+ if (embd_all) {
374
+ // if all tokens are output, split by sequence
375
+ ubatch = balloc.split_seq(n_ubatch);
376
+ } else {
377
+ ubatch = balloc.split_equal(n_ubatch);
378
+ }
364
379
 
365
- while (sbatch.n_tokens > 0) {
366
- llama_ubatch ubatch;
380
+ if (ubatch.n_tokens == 0) {
381
+ break;
382
+ }
367
383
 
368
- if (embd_pooled) {
369
- // Pooled embeddings cannot be split across ubatches (yet)
370
- ubatch = sbatch.split_seq(n_ubatch);
371
- } else {
372
- ubatch = sbatch.split_equal(n_ubatch);
384
+ ubatches.push_back(std::move(ubatch)); // NOLINT
373
385
  }
374
386
 
375
- ubatches.push_back(ubatch);
376
- }
387
+ if (!prepare(ubatches)) {
388
+ break;
389
+ }
377
390
 
378
- if (!prepare(ubatches)) {
379
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
380
- }
391
+ return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
392
+ } while (false);
393
+
394
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
395
+ }
381
396
 
382
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
397
+ llama_memory_context_ptr llama_memory_recurrent::init_full() {
398
+ return std::make_unique<llama_memory_recurrent_context>(this);
383
399
  }
384
400
 
385
- llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
386
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
401
+ llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
402
+ GGML_UNUSED(lctx);
403
+ GGML_UNUSED(optimize);
404
+
405
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
387
406
  }
388
407
 
389
- bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
408
+ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
390
409
  // simply remember the full state because it is very small for this type of cache
391
410
  // TODO: optimize
392
411
  auto org_cells = cells;
@@ -395,21 +414,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
395
414
 
396
415
  bool success = true;
397
416
 
398
- // TODO: here we have to verify that all ubatches can fit in the cells
399
- // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
400
- // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
401
- //
402
- // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
403
- //
404
- // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
405
- //
406
- GGML_UNUSED(ubatches);
407
- //for (const auto & ubatch : ubatches) {
408
- // if (!find_slot(ubatch)) {
409
- // success = false;
410
- // break;
411
- // }
412
- //}
417
+ for (const auto & ubatch : ubatches) {
418
+ if (!find_slot(ubatch)) {
419
+ success = false;
420
+ break;
421
+ }
422
+ }
413
423
 
414
424
  // restore the original state
415
425
  cells = std::move(org_cells);
@@ -419,26 +429,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
419
429
  return success;
420
430
  }
421
431
 
422
- bool llama_kv_cache_recurrent::update(llama_context & lctx) {
423
- GGML_UNUSED(lctx);
424
- // noop
425
- return false;
426
- }
427
-
428
- void llama_kv_cache_recurrent::defrag_sched(float thold) {
429
- GGML_UNUSED(thold);
430
- // noop
431
- }
432
-
433
- bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434
- const uint32_t n_tokens = ubatch.n_tokens;
435
- const uint32_t n_seqs = ubatch.n_seqs;
436
-
432
+ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
437
433
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
434
+ const uint32_t n_seqs = ubatch.n_seqs;
438
435
 
439
436
  // if we have enough unused cells before the current head ->
440
437
  // better to start searching from the beginning of the cache, hoping to fill it
441
- if (head > used + 2*n_tokens) {
438
+ if (head > used + 2*n_seqs) {
442
439
  head = 0;
443
440
  }
444
441
 
@@ -454,9 +451,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
454
451
 
455
452
  // everything should fit if all seq_ids are smaller than the max
456
453
  for (uint32_t s = 0; s < n_seqs; ++s) {
457
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
454
+ const uint32_t i = s*n_seq_tokens; // first token of sequence set s
455
+ const uint32_t n_seq_id = ubatch.n_seq_id[i];
456
+
458
457
  for (uint32_t j = 0; j < n_seq_id; ++j) {
459
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
458
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
460
459
 
461
460
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
462
461
  // too big seq_id
@@ -465,9 +464,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
465
464
  return false;
466
465
  }
467
466
  if (j > 0) {
468
- kv_cell & seq = cells[seq_id];
467
+ auto & seq = cells[seq_id];
469
468
  if (seq.tail >= 0) {
470
- kv_cell & cell = cells[seq.tail];
469
+ auto & cell = cells[seq.tail];
471
470
  // clear cells from seq_ids that become shared
472
471
  // (should not normally happen, but let's handle it anyway)
473
472
  cell.seq_id.erase(seq_id);
@@ -487,7 +486,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
487
486
  std::vector<int32_t> tails_verif;
488
487
  tails_verif.assign(size, -1);
489
488
  for (uint32_t i = 0; i < size; ++i) {
490
- kv_cell & cell = cells[i];
489
+ auto & cell = cells[i];
491
490
  for (llama_seq_id seq_id : cell.seq_id) {
492
491
  if (tails_verif[seq_id] != -1) {
493
492
  LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -508,42 +507,43 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
508
507
 
509
508
  for (uint32_t i = 0; i < size; ++i) {
510
509
  if (next_empty_cell >= size) { next_empty_cell -= size; }
511
- kv_cell & cell = cells[next_empty_cell];
510
+ auto & cell = cells[next_empty_cell];
512
511
  if (cell.is_empty()) { break; }
513
512
  next_empty_cell += 1;
514
513
  }
515
514
 
516
515
  // find usable cell range
517
516
  for (uint32_t s = 0; s < n_seqs; ++s) {
518
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
519
- kv_cell & seq_meta = cells[seq_id];
517
+ const uint32_t i = s*n_seq_tokens;
518
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
519
+ auto & seq_meta = cells[seq_id];
520
520
  bool has_cell = false;
521
521
  if (seq_meta.tail >= 0) {
522
- kv_cell & cell = cells[seq_meta.tail];
522
+ auto & cell = cells[seq_meta.tail];
523
523
  GGML_ASSERT(cell.has_seq_id(seq_id));
524
524
  // does this seq_id "own" the cell?
525
525
  if (cell.seq_id.size() == 1) { has_cell = true; }
526
526
  }
527
527
  if (!has_cell) {
528
- kv_cell & empty_cell = cells[next_empty_cell];
528
+ auto & empty_cell = cells[next_empty_cell];
529
529
  GGML_ASSERT(empty_cell.is_empty());
530
530
  // copy old tail into the empty cell
531
531
  if (seq_meta.tail >= 0) {
532
- kv_cell & orig_cell = cells[seq_meta.tail];
532
+ auto & orig_cell = cells[seq_meta.tail];
533
533
  empty_cell.pos = orig_cell.pos;
534
534
  empty_cell.src = orig_cell.src;
535
535
  orig_cell.seq_id.erase(seq_id);
536
536
  empty_cell.seq_id.insert(seq_id); // will be overwritten
537
+ GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
537
538
  }
538
539
  seq_meta.tail = next_empty_cell;
539
540
  // find next empty cell
540
541
  if (s + 1 < n_seqs) {
541
- next_empty_cell += 1;
542
- for (uint32_t i = 0; i < size; ++i) {
542
+ for (uint32_t j = 0; j < size; ++j) {
543
+ next_empty_cell += 1;
543
544
  if (next_empty_cell >= size) { next_empty_cell -= size; }
544
- kv_cell & cell = cells[next_empty_cell];
545
+ auto & cell = cells[next_empty_cell];
545
546
  if (cell.is_empty()) { break; }
546
- next_empty_cell += 1;
547
547
  }
548
548
  }
549
549
  }
@@ -553,102 +553,99 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
553
553
 
554
554
  // gather and re-order
555
555
  for (uint32_t s = 0; s < n_seqs; ++s) {
556
- int32_t dst_id = s + min;
557
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
556
+ const uint32_t i = s*n_seq_tokens;
557
+ const int32_t dst_id = s + min;
558
+ const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
558
559
  if (dst_id != src_id) {
559
- kv_cell & dst_cell = cells[dst_id];
560
- kv_cell & src_cell = cells[src_id];
560
+ auto & dst_cell = cells[dst_id];
561
+ auto & src_cell = cells[src_id];
561
562
 
562
563
  std::swap(dst_cell.pos, src_cell.pos);
563
564
  std::swap(dst_cell.src, src_cell.src);
564
565
  std::swap(dst_cell.seq_id, src_cell.seq_id);
565
566
 
566
- // swap tails (assuming they NEVER overlap)
567
- for (const llama_seq_id seq_id : src_cell.seq_id) {
568
- cells[seq_id].tail = src_id;
569
- }
570
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
571
- cells[seq_id].tail = dst_id;
567
+ // swap tails
568
+ for (uint32_t j = 0; j < size; ++j) {
569
+ int32_t & tail = cells[j].tail;
570
+ if (tail == src_id) {
571
+ tail = dst_id;
572
+ } else if (tail == dst_id) {
573
+ tail = src_id;
574
+ }
572
575
  }
573
576
  }
574
577
  }
575
578
 
576
579
  // update the pos of the used seqs
577
580
  for (uint32_t s = 0; s < n_seqs; ++s) {
578
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
579
- int32_t cell_id = s + min;
580
- kv_cell & cell = cells[cell_id];
581
+ const uint32_t i = s*n_seq_tokens;
582
+ const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
583
+ const int32_t cell_id = s + min;
584
+ auto & cell = cells[cell_id];
581
585
 
582
586
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
583
587
  // What should happen when the pos backtracks or skips a value?
584
588
  // Clearing the state mid-batch would require special-casing which isn't done.
585
589
  LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
586
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
590
+ __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
587
591
  }
588
592
  cell.pos = last_pos;
589
593
  cell.seq_id.clear();
590
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
591
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
594
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
595
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
592
596
  cell.seq_id.insert(seq_id);
593
597
  cells[seq_id].tail = cell_id;
594
598
  }
595
599
  }
596
600
 
601
+ // Find first cell without src refs, to use as the zero-ed state
602
+ {
603
+ // TODO: bake-in src refcounts in the cell metadata
604
+ std::vector<int32_t> refcounts(size, 0);
605
+ for (size_t i = 0; i < size; ++i) {
606
+ const int32_t src = cells[i].src;
607
+ if (src >= 0) {
608
+ refcounts[src] += 1;
609
+ }
610
+ }
611
+
612
+ rs_z = -1;
613
+ for (int i = min; i <= max; ++i) {
614
+ if (refcounts[i] == 0) {
615
+ rs_z = i;
616
+ break;
617
+ }
618
+ }
619
+
620
+ for (int i = min; i <= max; ++i) {
621
+ if (cells[i].src < 0) {
622
+ GGML_ASSERT(rs_z >= 0);
623
+ cells[i].src0 = rs_z;
624
+ } else {
625
+ // Stage the source ids for all used cells to allow correct seq_* behavior
626
+ // and still make these values available when setting the inputs
627
+ cells[i].src0 = cells[i].src;
628
+ }
629
+ cells[i].src = i; // avoid moving or clearing twice
630
+ }
631
+ }
632
+
597
633
  // allow getting the range of used cells, from head to head + n
598
634
  head = min;
599
635
  n = max - min + 1;
600
636
  used = std::count_if(cells.begin(), cells.end(),
601
- [](const kv_cell & cell){ return !cell.is_empty(); });
637
+ [](const mem_cell & cell){ return !cell.is_empty(); });
602
638
 
603
639
  // sanity check
604
640
  return n >= n_seqs;
605
641
  }
606
642
 
607
- bool llama_kv_cache_recurrent::get_can_shift() const {
608
- return false;
609
- }
610
-
611
- int32_t llama_kv_cache_recurrent::s_copy(int i) const {
612
- const uint32_t cell_id = i + head;
613
-
614
- //////////////////////////////////////////////
615
- // TODO: this should not mutate the KV cache !
616
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
617
-
618
- // prevent out-of-bound sources
619
- if (cell.src < 0 || (uint32_t) cell.src >= size) {
620
- cell.src = cell_id;
621
- }
622
-
623
- int32_t res = cell.src;
624
-
625
- // TODO: do not mutate the KV cache
626
- // ensure copy only happens once
627
- if (cell.src != (int32_t) cell_id) {
628
- cell.src = cell_id;
629
- }
630
-
631
- return res;
632
- }
633
-
634
- float llama_kv_cache_recurrent::s_mask(int i) const {
635
- const uint32_t cell_id = i + head;
636
-
637
- //////////////////////////////////////////////
638
- // TODO: this should not mutate the KV cache !
639
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
640
-
641
- float res = (float) (cell.src >= 0);
642
-
643
- // only clear once
644
- if (cell.src < 0) {
645
- cell.src = cell_id;
646
- }
647
-
648
- return res;
643
+ bool llama_memory_recurrent::get_can_shift() const {
644
+ // shifting the pos is trivial for recurrent models
645
+ return true;
649
646
  }
650
647
 
651
- size_t llama_kv_cache_recurrent::total_size() const {
648
+ size_t llama_memory_recurrent::total_size() const {
652
649
  size_t size = 0;
653
650
  for (const auto & buf : bufs) {
654
651
  size += ggml_backend_buffer_get_size(buf.get());
@@ -657,27 +654,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
657
654
  return size;
658
655
  }
659
656
 
660
- size_t llama_kv_cache_recurrent::size_k_bytes() const {
661
- size_t size_k_bytes = 0;
657
+ size_t llama_memory_recurrent::size_r_bytes() const {
658
+ size_t size_r_bytes = 0;
662
659
 
663
- for (const auto & k : k_l) {
664
- size_k_bytes += ggml_nbytes(k);
660
+ for (const auto & r : r_l) {
661
+ if (r != nullptr) {
662
+ size_r_bytes += ggml_nbytes(r);
663
+ }
665
664
  }
666
665
 
667
- return size_k_bytes;
666
+ return size_r_bytes;
668
667
  }
669
668
 
670
- size_t llama_kv_cache_recurrent::size_v_bytes() const {
671
- size_t size_v_bytes = 0;
669
+ size_t llama_memory_recurrent::size_s_bytes() const {
670
+ size_t size_s_bytes = 0;
672
671
 
673
- for (const auto & v : v_l) {
674
- size_v_bytes += ggml_nbytes(v);
672
+ for (const auto & s : s_l) {
673
+ if (s != nullptr) {
674
+ size_s_bytes += ggml_nbytes(s);
675
+ }
675
676
  }
676
677
 
677
- return size_v_bytes;
678
+ return size_s_bytes;
678
679
  }
679
680
 
680
- void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
681
+ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
681
682
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
682
683
  uint32_t cell_count = 0;
683
684
 
@@ -715,7 +716,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
715
716
  state_write_data(io, cell_ranges);
716
717
  }
717
718
 
718
- void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
719
+ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
719
720
  uint32_t cell_count;
720
721
  io.read_to(&cell_count, sizeof(cell_count));
721
722
 
@@ -726,7 +727,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
726
727
 
727
728
  if (!res) {
728
729
  if (seq_id == -1) {
729
- clear();
730
+ clear(true);
730
731
  } else {
731
732
  seq_rm(seq_id, -1, -1);
732
733
  }
@@ -734,7 +735,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
734
735
  }
735
736
  }
736
737
 
737
- void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
738
+ void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
738
739
  for (const auto & range : cell_ranges) {
739
740
  for (uint32_t i = range.first; i < range.second; ++i) {
740
741
  const auto & cell = cells[i];
@@ -753,98 +754,93 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
753
754
  }
754
755
  }
755
756
 
756
- void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
757
- const uint32_t v_trans = 0;
757
+ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
758
+ const uint32_t s_trans = 0;
758
759
  const uint32_t n_layer = hparams.n_layer;
759
760
 
760
- io.write(&v_trans, sizeof(v_trans));
761
- io.write(&n_layer, sizeof(n_layer));
761
+ io.write(&s_trans, sizeof(s_trans));
762
+ io.write(&n_layer, sizeof(n_layer));
762
763
 
763
764
  std::vector<uint8_t> tmp_buf;
764
765
 
765
766
  // Iterate and write all the keys first, each row is a cell
766
767
  // Get whole range at a time
767
768
  for (uint32_t il = 0; il < n_layer; ++il) {
768
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
769
769
 
770
770
  // Write key type
771
- const int32_t k_type_i = (int32_t)k_l[il]->type;
772
- io.write(&k_type_i, sizeof(k_type_i));
771
+ const int32_t r_type_i = (int32_t)r_l[il]->type;
772
+ io.write(&r_type_i, sizeof(r_type_i));
773
773
 
774
774
  // Write row size of key
775
- const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
776
- io.write(&k_size_row, sizeof(k_size_row));
775
+ const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
776
+ io.write(&r_size_row, sizeof(r_size_row));
777
777
 
778
778
  // Read each range of cells of k_size length each into tmp_buf and write out
779
779
  for (const auto & range : cell_ranges) {
780
780
  const size_t range_size = range.second - range.first;
781
- const size_t buf_size = range_size * k_size_row;
782
- io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
781
+ const size_t buf_size = range_size * r_size_row;
782
+ io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
783
783
  }
784
784
  }
785
785
 
786
- if (!v_trans) {
786
+ if (!s_trans) {
787
787
  for (uint32_t il = 0; il < n_layer; ++il) {
788
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
789
788
 
790
789
  // Write value type
791
- const int32_t v_type_i = (int32_t)v_l[il]->type;
792
- io.write(&v_type_i, sizeof(v_type_i));
790
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
791
+ io.write(&s_type_i, sizeof(s_type_i));
793
792
 
794
793
  // Write row size of value
795
- const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
796
- io.write(&v_size_row, sizeof(v_size_row));
794
+ const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
795
+ io.write(&s_size_row, sizeof(s_size_row));
797
796
 
798
- // Read each range of cells of v_size length each into tmp_buf and write out
797
+ // Read each range of cells of s_size length each into tmp_buf and write out
799
798
  for (const auto & range : cell_ranges) {
800
799
  const size_t range_size = range.second - range.first;
801
- const size_t buf_size = range_size * v_size_row;
802
- io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
800
+ const size_t buf_size = range_size * s_size_row;
801
+ io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
803
802
  }
804
803
  }
805
804
  } else {
806
805
  // When v is transposed, we also need the element size and get the element ranges from each row
807
- const uint32_t kv_size = size;
806
+ const uint32_t mem_size = size;
808
807
  for (uint32_t il = 0; il < n_layer; ++il) {
809
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
808
+ const uint32_t n_embd_s = hparams.n_embd_s();
810
809
 
811
810
  // Write value type
812
- const int32_t v_type_i = (int32_t)v_l[il]->type;
813
- io.write(&v_type_i, sizeof(v_type_i));
811
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
812
+ io.write(&s_type_i, sizeof(s_type_i));
814
813
 
815
814
  // Write element size
816
- const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
817
- io.write(&v_size_el, sizeof(v_size_el));
815
+ const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
816
+ io.write(&s_size_el, sizeof(s_size_el));
818
817
 
819
818
  // Write GQA embedding size
820
- io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
819
+ io.write(&n_embd_s, sizeof(n_embd_s));
821
820
 
822
821
  // For each row, we get the element values of each cell
823
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
822
+ for (uint32_t j = 0; j < n_embd_s; ++j) {
824
823
  // Read each range of cells of v_size_el length each into tmp_buf and write out
825
824
  for (const auto & range : cell_ranges) {
826
825
  const size_t range_size = range.second - range.first;
827
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
828
- const size_t buf_size = range_size * v_size_el;
829
- io.write_tensor(v_l[il], src_offset, buf_size);
826
+ const size_t src_offset = (range.first + j * mem_size) * s_size_el;
827
+ const size_t buf_size = range_size * s_size_el;
828
+ io.write_tensor(s_l[il], src_offset, buf_size);
830
829
  }
831
830
  }
832
831
  }
833
832
  }
834
833
  }
835
834
 
836
- bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
835
+ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
837
836
  if (dest_seq_id != -1) {
838
837
  // single sequence
839
838
 
840
839
  seq_rm(dest_seq_id, -1, -1);
841
840
 
842
- llama_sbatch sbatch;
843
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
841
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
844
842
 
845
- batch.n_tokens = cell_count;
846
- batch.n_seq_tokens = cell_count;
847
- batch.n_seqs = 1;
843
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
848
844
 
849
845
  for (uint32_t i = 0; i < cell_count; ++i) {
850
846
  llama_pos pos;
@@ -858,12 +854,12 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
858
854
  return false;
859
855
  }
860
856
 
861
- batch.pos[i] = pos;
857
+ ubatch.pos[i] = pos;
862
858
  }
863
- batch.n_seq_id[0] = 1;
864
- batch.seq_id[0] = &dest_seq_id;
859
+ ubatch.n_seq_id[0] = 1;
860
+ ubatch.seq_id[0] = &dest_seq_id;
865
861
 
866
- if (!find_slot(batch)) {
862
+ if (!find_slot(ubatch)) {
867
863
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
868
864
  return false;
869
865
  }
@@ -871,8 +867,8 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
871
867
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
872
868
  // Assume that this is one contiguous block of cells
873
869
  GGML_ASSERT(head + cell_count <= size);
874
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
875
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
870
+ GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
871
+ GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
876
872
  GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
877
873
  GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
878
874
  } else {
@@ -883,10 +879,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
883
879
  return false;
884
880
  }
885
881
 
886
- clear();
882
+ clear(true);
887
883
 
888
884
  for (uint32_t i = 0; i < cell_count; ++i) {
889
- kv_cell & cell = cells[i];
885
+ auto & cell = cells[i];
890
886
 
891
887
  llama_pos pos;
892
888
  uint32_t n_seq_id;
@@ -900,7 +896,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
900
896
  llama_seq_id seq_id;
901
897
  io.read_to(&seq_id, sizeof(seq_id));
902
898
 
903
- // TODO: llama_kv_cache_recurrent should have a notion of max sequences
899
+ // TODO: llama_memory_recurrent should have a notion of max sequences
904
900
  //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
905
901
  if (seq_id < 0) {
906
902
  //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
@@ -932,10 +928,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
932
928
  return true;
933
929
  }
934
930
 
935
- bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
936
- uint32_t v_trans;
931
+ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
932
+ uint32_t s_trans;
937
933
  uint32_t n_layer;
938
- io.read_to(&v_trans, sizeof(v_trans));
934
+ io.read_to(&s_trans, sizeof(s_trans));
939
935
  io.read_to(&n_layer, sizeof(n_layer));
940
936
 
941
937
  if (n_layer != hparams.n_layer) {
@@ -946,102 +942,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
946
942
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
947
943
  return false;
948
944
  }
949
- if (false != (bool) v_trans) {
950
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
945
+ if (false != (bool) s_trans) {
946
+ LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
951
947
  return false;
952
948
  }
953
949
 
954
950
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
955
951
  for (uint32_t il = 0; il < n_layer; ++il) {
956
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
957
952
 
958
953
  // Read type of key
959
- int32_t k_type_i_ref;
960
- io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
961
- const int32_t k_type_i = (int32_t) k_l[il]->type;
962
- if (k_type_i != k_type_i_ref) {
963
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
954
+ int32_t r_type_i_ref;
955
+ io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
956
+ const int32_t r_type_i = (int32_t) r_l[il]->type;
957
+ if (r_type_i != r_type_i_ref) {
958
+ LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
964
959
  return false;
965
960
  }
966
961
 
967
962
  // Read row size of key
968
- uint64_t k_size_row_ref;
969
- io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
970
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
971
- if (k_size_row != k_size_row_ref) {
972
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
963
+ uint64_t r_size_row_ref;
964
+ io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
965
+ const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
966
+ if (r_size_row != r_size_row_ref) {
967
+ LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
973
968
  return false;
974
969
  }
975
970
 
976
971
  if (cell_count) {
977
972
  // Read and set the keys for the whole cell range
978
- ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
973
+ ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
979
974
  }
980
975
  }
981
976
 
982
- if (!v_trans) {
977
+ if (!s_trans) {
983
978
  for (uint32_t il = 0; il < n_layer; ++il) {
984
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
985
979
 
986
980
  // Read type of value
987
- int32_t v_type_i_ref;
988
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
989
- const int32_t v_type_i = (int32_t)v_l[il]->type;
990
- if (v_type_i != v_type_i_ref) {
991
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
981
+ int32_t s_type_i_ref;
982
+ io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
983
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
984
+ if (s_type_i != s_type_i_ref) {
985
+ LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
992
986
  return false;
993
987
  }
994
988
 
995
989
  // Read row size of value
996
- uint64_t v_size_row_ref;
997
- io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
998
- const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
999
- if (v_size_row != v_size_row_ref) {
1000
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
990
+ uint64_t s_size_row_ref;
991
+ io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
992
+ const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
993
+ if (s_size_row != s_size_row_ref) {
994
+ LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
1001
995
  return false;
1002
996
  }
1003
997
 
1004
998
  if (cell_count) {
1005
999
  // Read and set the values for the whole cell range
1006
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1000
+ ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
1007
1001
  }
1008
1002
  }
1009
1003
  } else {
1010
1004
  // For each layer, read the values for each cell (transposed)
1011
1005
  for (uint32_t il = 0; il < n_layer; ++il) {
1012
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1006
+ const uint32_t n_embd_s = hparams.n_embd_s();
1013
1007
 
1014
1008
  // Read type of value
1015
- int32_t v_type_i_ref;
1016
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1017
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1018
- if (v_type_i != v_type_i_ref) {
1019
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1009
+ int32_t s_type_i_ref;
1010
+ io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
1011
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
1012
+ if (s_type_i != s_type_i_ref) {
1013
+ LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
1020
1014
  return false;
1021
1015
  }
1022
1016
 
1023
1017
  // Read element size of value
1024
- uint32_t v_size_el_ref;
1025
- io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1026
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
1027
- if (v_size_el != v_size_el_ref) {
1028
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1018
+ uint32_t s_size_el_ref;
1019
+ io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
1020
+ const size_t s_size_el = ggml_type_size(s_l[il]->type);
1021
+ if (s_size_el != s_size_el_ref) {
1022
+ LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
1029
1023
  return false;
1030
1024
  }
1031
1025
 
1032
- // Read GQA embedding size
1033
- uint32_t n_embd_v_gqa_ref;
1034
- io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1035
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1036
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1026
+ // Read state embedding size
1027
+ uint32_t n_embd_s_ref;
1028
+ io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
1029
+ if (n_embd_s != n_embd_s_ref) {
1030
+ LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
1037
1031
  return false;
1038
1032
  }
1039
1033
 
1040
1034
  if (cell_count) {
1041
1035
  // For each row in the transposed matrix, read the values for the whole cell range
1042
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1043
- const size_t dst_offset = (head + j * size) * v_size_el;
1044
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1036
+ for (uint32_t j = 0; j < n_embd_s; ++j) {
1037
+ const size_t dst_offset = (head + j * size) * s_size_el;
1038
+ ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
1045
1039
  }
1046
1040
  }
1047
1041
  }
@@ -1051,25 +1045,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
1051
1045
  }
1052
1046
 
1053
1047
  //
1054
- // llama_kv_cache_recurrent_state
1048
+ // llama_memory_recurrent_context
1055
1049
  //
1056
1050
 
1057
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
1051
+ llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
1058
1052
 
1059
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1060
- llama_memory_status status,
1061
- llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
1053
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1054
+ llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1062
1055
  }
1063
1056
 
1064
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1065
- llama_memory_status status,
1066
- llama_kv_cache_recurrent * kv,
1067
- llama_sbatch sbatch,
1068
- std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
1057
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1058
+ llama_memory_recurrent * mem,
1059
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
1069
1060
 
1070
- llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
1061
+ llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
1071
1062
 
1072
- bool llama_kv_cache_recurrent_state::next() {
1063
+ bool llama_memory_recurrent_context::next() {
1073
1064
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1074
1065
 
1075
1066
  if (++i_next >= ubatches.size()) {
@@ -1079,54 +1070,48 @@ bool llama_kv_cache_recurrent_state::next() {
1079
1070
  return true;
1080
1071
  }
1081
1072
 
1082
- bool llama_kv_cache_recurrent_state::apply() {
1073
+ bool llama_memory_recurrent_context::apply() {
1083
1074
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1084
1075
 
1085
- kv->find_slot(ubatches[i_next]);
1076
+ mem->find_slot(ubatches[i_next]);
1086
1077
 
1087
1078
  return true;
1088
1079
  }
1089
1080
 
1090
- std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
1091
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1092
-
1093
- return sbatch.out_ids;
1094
- }
1095
-
1096
- llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
1081
+ llama_memory_status llama_memory_recurrent_context::get_status() const {
1097
1082
  return status;
1098
1083
  }
1099
1084
 
1100
- const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
1085
+ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
1101
1086
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1102
1087
 
1103
1088
  return ubatches[i_next];
1104
1089
  }
1105
1090
 
1106
- uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
1107
- return is_full ? kv->size : kv->n;
1091
+ uint32_t llama_memory_recurrent_context::get_n_rs() const {
1092
+ return is_full ? mem->size : mem->n;
1108
1093
  }
1109
1094
 
1110
- uint32_t llama_kv_cache_recurrent_state::get_head() const {
1111
- return is_full ? 0 : kv->head;
1095
+ uint32_t llama_memory_recurrent_context::get_head() const {
1096
+ return is_full ? 0 : mem->head;
1112
1097
  }
1113
1098
 
1114
- uint32_t llama_kv_cache_recurrent_state::get_size() const {
1115
- return kv->size;
1099
+ int32_t llama_memory_recurrent_context::get_rs_z() const {
1100
+ return is_full ? 0 : mem->rs_z;
1116
1101
  }
1117
1102
 
1118
- ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
1119
- return kv->k_l[il];
1103
+ uint32_t llama_memory_recurrent_context::get_size() const {
1104
+ return mem->size;
1120
1105
  }
1121
1106
 
1122
- ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
1123
- return kv->v_l[il];
1107
+ ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
1108
+ return mem->r_l[il];
1124
1109
  }
1125
1110
 
1126
- int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1127
- return kv->s_copy(i);
1111
+ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
1112
+ return mem->s_l[il];
1128
1113
  }
1129
1114
 
1130
- float llama_kv_cache_recurrent_state::s_mask(int i) const {
1131
- return kv->s_mask(i);
1115
+ int32_t llama_memory_recurrent_context::s_copy(int i) const {
1116
+ return mem->cells[i + mem->head].src0;
1132
1117
  }